"""Analyze tracked heading_debug data and summarize temporal heading instability. Usage examples: python analyze_heading_debug.py --input /path/to/tracking.json python analyze_heading_debug.py --input /path/to/combined_tracking.json --sources merge python analyze_heading_debug.py --input /path/to/case_dir python analyze_heading_debug.py --input /path/to/case_dir --json-name combined_tracking.json """ import argparse import csv import json from collections import Counter, defaultdict from pathlib import Path import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np LANE_ASSIGNMENT_TO_SOURCE = { 0: 'roi0', 1: 'roi1', 2: 'merge', } DEFAULT_JSON_CANDIDATES = [ 'merge.json', 'combined_tracking.json', 'roi0.json', 'roi1.json', 'tracking.json', ] DEFAULT_DEPTH_BUCKETS = [10.0, 20.0, 30.0, 50.0] def parse_numeric_value(value): """Convert JSON scalar strings to int/float when possible.""" if value is None: return None if isinstance(value, (int, float)): return value value_str = str(value).strip() if not value_str: return None try: numeric = float(value_str) except ValueError: return value if numeric.is_integer(): return int(numeric) return numeric def safe_int(value): """Best-effort integer conversion.""" numeric = parse_numeric_value(value) if numeric is None: return None try: return int(numeric) except (TypeError, ValueError): return None def safe_float(value): """Best-effort float conversion.""" numeric = parse_numeric_value(value) if isinstance(numeric, (int, float)): return float(numeric) return None def safe_mean(values): """Return mean for a non-empty numeric list.""" if not values: return None return float(np.mean(np.asarray(values, dtype=np.float64))) def safe_median(values): """Return median for a non-empty numeric list.""" if not values: return None return float(np.median(np.asarray(values, dtype=np.float64))) def safe_percentile(values, q): """Return percentile for a non-empty numeric list.""" if not values: return None return float(np.percentile(np.asarray(values, dtype=np.float64), q)) def normalize_heading_debug(heading_debug): """Normalize heading_debug values to numeric Python types.""" if not isinstance(heading_debug, dict): return None normalized = {} yaw_bin = safe_int(heading_debug.get('yaw_bin')) if yaw_bin is not None: normalized['yaw_bin'] = yaw_bin yaw_delta = safe_float(heading_debug.get('yaw_delta')) if yaw_delta is not None: normalized['yaw_delta'] = yaw_delta rot_y_decoded = safe_float(heading_debug.get('rot_y_decoded')) if rot_y_decoded is not None: normalized['rot_y_decoded'] = rot_y_decoded yaw_probs_raw = heading_debug.get('yaw_probs') if isinstance(yaw_probs_raw, (list, tuple)): yaw_probs = [] valid = True for value in yaw_probs_raw: numeric = safe_float(value) if numeric is None: valid = False break yaw_probs.append(numeric) if valid: normalized['yaw_probs'] = yaw_probs return normalized or None def wrap_angle_diff(angle_curr, angle_prev): """Return wrapped angle difference in radians within [-pi, pi].""" return float(np.arctan2(np.sin(angle_curr - angle_prev), np.cos(angle_curr - angle_prev))) def resolve_input_path(input_path, json_name=None): """Resolve an input path that may be a JSON file or a case directory.""" path = Path(input_path) if path.is_file(): return path if not path.is_dir(): raise FileNotFoundError(f'Input path does not exist: {path}') candidates = [json_name] if json_name else DEFAULT_JSON_CANDIDATES for name in candidates: candidate = path / name if candidate.is_file(): return candidate raise FileNotFoundError( f'No tracking JSON found in {path}. Tried: {", ".join(candidates)}' ) def load_tracking_data(input_json_path): """Load tracking data from JSON.""" with open(input_json_path, 'r', encoding='utf-8') as f: data = json.load(f) if not isinstance(data, list): raise ValueError(f'Expected a list of frames in {input_json_path}, got {type(data).__name__}') return data def infer_default_source_label(input_json_path): """Infer a default source label from the input filename.""" stem = input_json_path.stem.lower() if stem in {'roi0', 'roi1', 'merge', 'combined_tracking', 'tracking'}: return stem return 'input' def extract_bbox(det): """Extract [x1, y1, x2, y2] bbox from a detection.""" bbox = det.get('bbox') if bbox is None: bbox = det.get('box2d') if not isinstance(bbox, (list, tuple)) or len(bbox) < 4: return None values = [safe_float(v) for v in bbox[:4]] if any(v is None for v in values): return None return values def extract_object_3d(det): """Return the object_3d payload when available.""" return det.get('object_3d') def extract_yaw_final(det): """Extract final yaw from object_3d when available.""" obj = extract_object_3d(det) if obj is None: return None if isinstance(obj, dict): return safe_float(obj.get('rotation_y')) if isinstance(obj, (list, tuple)) and len(obj) >= 7: return safe_float(obj[6]) return None def extract_depth_z(det): """Extract camera-space depth z from object_3d when available.""" obj = extract_object_3d(det) if obj is None: return None if isinstance(obj, dict): location = obj.get('location') if isinstance(location, (list, tuple)) and len(location) >= 3: return safe_float(location[2]) return None if isinstance(obj, (list, tuple)) and len(obj) >= 3: return safe_float(obj[2]) return None def extract_class_id(det): """Extract class id from detection.""" class_id = safe_int(det.get('class_id')) if class_id is not None: return class_id return safe_int(det.get('type')) def extract_confidence(det): """Extract detection confidence score.""" confidence = safe_float(det.get('confidence')) if confidence is not None: return confidence return safe_float(det.get('score')) def extract_source_label(det, default_source): """Extract source label, supporting combined_tracking lane assignments.""" source = det.get('source') if source is not None: return str(source) lane_assignment = safe_int(det.get('lane_assignment')) if lane_assignment is not None: return LANE_ASSIGNMENT_TO_SOURCE.get(lane_assignment, f'lane_{lane_assignment}') return default_source def extract_frame_id(det, image_name): """Extract frameId as a string when present.""" frame_id = det.get('frameId') if frame_id is not None: return str(frame_id) image_name = str(image_name or '') digits = ''.join(ch if ch.isdigit() else ' ' for ch in image_name).split() if digits: return digits[-1] return None def compute_prob_stats(yaw_probs): """Compute top-k probability statistics from yaw_probs.""" if not yaw_probs: return None, None, None, None probs = np.asarray(yaw_probs, dtype=np.float64) sorted_probs = np.sort(probs)[::-1] p1 = float(sorted_probs[0]) p2 = float(sorted_probs[1]) if len(sorted_probs) > 1 else None margin = float(p1 - p2) if p2 is not None else None num_pos_bins = int(np.sum(probs > 0.5)) return p1, p2, margin, num_pos_bins def build_flat_rows(tracking_data, default_source): """Flatten tracked detections that contain heading_debug.""" rows = [] stats = { 'total_frames': len(tracking_data), 'total_detections': 0, 'detections_with_heading_debug': 0, 'detections_missing_heading_debug': 0, 'detections_missing_track_id': 0, } for frame_index, frame in enumerate(tracking_data): image_name = frame.get('image_name') or f'frame_{frame_index:06d}' detections = frame.get('detections', []) for det_index, det in enumerate(detections): stats['total_detections'] += 1 heading_debug = normalize_heading_debug(det.get('heading_debug')) if heading_debug is None: stats['detections_missing_heading_debug'] += 1 continue stats['detections_with_heading_debug'] += 1 bbox = extract_bbox(det) x1 = bbox[0] if bbox else None y1 = bbox[1] if bbox else None x2 = bbox[2] if bbox else None y2 = bbox[3] if bbox else None box_w = (x2 - x1) if bbox else None box_h = (y2 - y1) if bbox else None box_area = (box_w * box_h) if box_w is not None and box_h is not None else None yaw_probs = heading_debug.get('yaw_probs') p1, p2, margin, num_pos_bins = compute_prob_stats(yaw_probs) track_id = safe_int(det.get('track_id')) if track_id is None: stats['detections_missing_track_id'] += 1 yaw_final = extract_yaw_final(det) rot_y_decoded = heading_debug.get('rot_y_decoded') rows.append({ 'frame_index': frame_index, 'image_name': str(image_name), 'det_index': det_index, 'track_id': track_id, 'frameId': extract_frame_id(det, image_name), 'timestamp': parse_numeric_value(det.get('timestamp')), 'class_id': extract_class_id(det), 'type_name': str(det.get('type_name', '')), 'sub_cls': safe_int(det.get('sub_cls')), 'confidence': extract_confidence(det), 'source': extract_source_label(det, default_source), 'lane_assignment': safe_int(det.get('lane_assignment')), 'roi_id': safe_int(det.get('roi_id')), 'face_cls': str(det.get('face_cls', '')), 'cut_cls': safe_int(det.get('cut_cls')), 'cut_cls_name': str(det.get('cut_cls_name', '')), 'anchor': str(det.get('anchor', '')), 'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2, 'box_w': box_w, 'box_h': box_h, 'box_area': box_area, 'depth_z': extract_depth_z(det), 'yaw_final': yaw_final, 'yaw_analysis': yaw_final if yaw_final is not None else rot_y_decoded, 'yaw_bin': heading_debug.get('yaw_bin'), 'yaw_delta': heading_debug.get('yaw_delta'), 'rot_y_decoded': rot_y_decoded, 'yaw_prob_0': yaw_probs[0] if yaw_probs and len(yaw_probs) > 0 else None, 'yaw_prob_1': yaw_probs[1] if yaw_probs and len(yaw_probs) > 1 else None, 'yaw_prob_2': yaw_probs[2] if yaw_probs and len(yaw_probs) > 2 else None, 'yaw_prob_3': yaw_probs[3] if yaw_probs and len(yaw_probs) > 3 else None, 'yaw_prob_top1': p1, 'yaw_prob_top2': p2, 'margin': margin, 'num_pos_bins': num_pos_bins, }) return rows, stats def filter_rows(rows, class_ids=None, sources=None): """Apply class/source filters to rows.""" class_ids = set(class_ids or []) sources = set(sources or []) filtered = [] for row in rows: if class_ids and row['class_id'] not in class_ids: continue if sources and row['source'] not in sources: continue filtered.append(row) return filtered def group_rows_by_track(rows, min_track_len): """Group rows by track id and keep sufficiently long tracks.""" grouped = defaultdict(list) for row in rows: track_id = row.get('track_id') if track_id is None: continue grouped[track_id].append(row) kept = {} for track_id, items in grouped.items(): items.sort(key=lambda row: (row['frame_index'], row['det_index'])) if len(items) >= min_track_len: kept[track_id] = items return kept def classify_transition(metrics, args): """Assign a root-cause category to one track transition.""" abs_dyaw = metrics.get('abs_dyaw') raw_abs_dyaw = metrics.get('raw_abs_dyaw') abs_dyaw_delta = metrics.get('abs_dyaw_delta') yaw_bin_switched = metrics.get('yaw_bin_switched') low_margin = metrics.get('low_margin') multi_bin = metrics.get('multi_bin') source_switched = metrics.get('source_switched') roi_switched = metrics.get('roi_switched') face_switched = metrics.get('face_switched') cut_switched = metrics.get('cut_switched') low_score = metrics.get('low_score') small_box = metrics.get('small_box') far_depth = metrics.get('far_depth') if abs_dyaw is None: return 'insufficient_yaw' wrap_around = ( raw_abs_dyaw is not None and raw_abs_dyaw > np.pi and abs_dyaw < args.wrap_small_dyaw ) flip_180 = abs_dyaw >= args.flip_180_dyaw if flip_180 and (yaw_bin_switched or face_switched or cut_switched): return 'flip_180' if wrap_around: return 'wrap_around' if (source_switched or roi_switched) and abs_dyaw >= args.large_dyaw: return 'roi_sensitive' if yaw_bin_switched and (low_margin or multi_bin): return 'yaw_bin_jitter' if ( not yaw_bin_switched and abs_dyaw_delta is not None and abs_dyaw_delta >= args.large_delta and abs_dyaw >= max(0.1, args.large_dyaw * 0.5) ): return 'yaw_delta_jitter' if abs_dyaw >= args.large_dyaw and (low_score or small_box or far_depth or face_switched or cut_switched): return 'visibility_ambiguity' if yaw_bin_switched and abs_dyaw >= args.large_dyaw: return 'yaw_bin_switch' if abs_dyaw >= args.large_dyaw: return 'other_instability' return 'stable' def build_transition_records(track_rows_by_id, args): """Build per-transition records for all analyzed tracks.""" transitions = [] for track_id, rows in track_rows_by_id.items(): for idx in range(1, len(rows)): prev_row = rows[idx - 1] curr_row = rows[idx] prev_yaw = prev_row.get('yaw_analysis') curr_yaw = curr_row.get('yaw_analysis') wrapped_dyaw = None abs_dyaw = None raw_abs_dyaw = None if prev_yaw is not None and curr_yaw is not None: wrapped_dyaw = wrap_angle_diff(curr_yaw, prev_yaw) abs_dyaw = abs(wrapped_dyaw) raw_abs_dyaw = abs(curr_yaw - prev_yaw) same_yaw_bin = ( prev_row.get('yaw_bin') is not None and curr_row.get('yaw_bin') is not None and prev_row['yaw_bin'] == curr_row['yaw_bin'] ) abs_dyaw_delta = None if same_yaw_bin and prev_row.get('yaw_delta') is not None and curr_row.get('yaw_delta') is not None: abs_dyaw_delta = abs(curr_row['yaw_delta'] - prev_row['yaw_delta']) margin_values = [value for value in [prev_row.get('margin'), curr_row.get('margin')] if value is not None] min_margin = min(margin_values) if margin_values else None num_pos_candidates = [value for value in [prev_row.get('num_pos_bins'), curr_row.get('num_pos_bins')] if value is not None] max_num_pos_bins = max(num_pos_candidates) if num_pos_candidates else 0 score_candidates = [value for value in [prev_row.get('confidence'), curr_row.get('confidence')] if value is not None] min_score = min(score_candidates) if score_candidates else None box_area_candidates = [value for value in [prev_row.get('box_area'), curr_row.get('box_area')] if value is not None] min_box_area = min(box_area_candidates) if box_area_candidates else None depth_candidates = [value for value in [prev_row.get('depth_z'), curr_row.get('depth_z')] if value is not None] max_depth_z = max(depth_candidates) if depth_candidates else None metrics = { 'abs_dyaw': abs_dyaw, 'raw_abs_dyaw': raw_abs_dyaw, 'abs_dyaw_delta': abs_dyaw_delta, 'yaw_bin_switched': prev_row.get('yaw_bin') != curr_row.get('yaw_bin'), 'low_margin': min_margin is not None and min_margin < args.margin_threshold, 'multi_bin': max_num_pos_bins >= 2, 'source_switched': prev_row.get('source') != curr_row.get('source'), 'roi_switched': prev_row.get('roi_id') != curr_row.get('roi_id'), 'face_switched': prev_row.get('face_cls') != curr_row.get('face_cls'), 'cut_switched': prev_row.get('cut_cls') != curr_row.get('cut_cls'), 'low_score': min_score is not None and min_score < args.low_score_threshold, 'small_box': min_box_area is not None and min_box_area < args.small_box_area, 'far_depth': max_depth_z is not None and max_depth_z > args.far_depth_threshold, } issue = classify_transition(metrics, args) transitions.append({ 'track_id': track_id, 'class_id': curr_row.get('class_id'), 'type_name': curr_row.get('type_name'), 'source': curr_row.get('source'), 'prev_frame_index': prev_row.get('frame_index'), 'curr_frame_index': curr_row.get('frame_index'), 'frame_gap': curr_row.get('frame_index') - prev_row.get('frame_index'), 'prev_image_name': prev_row.get('image_name'), 'curr_image_name': curr_row.get('image_name'), 'prev_timestamp': prev_row.get('timestamp'), 'curr_timestamp': curr_row.get('timestamp'), 'prev_yaw': prev_yaw, 'curr_yaw': curr_yaw, 'wrapped_dyaw': wrapped_dyaw, 'abs_dyaw': abs_dyaw, 'raw_abs_dyaw': raw_abs_dyaw, 'prev_yaw_bin': prev_row.get('yaw_bin'), 'curr_yaw_bin': curr_row.get('yaw_bin'), 'yaw_bin_switched': metrics['yaw_bin_switched'], 'same_yaw_bin': same_yaw_bin, 'prev_yaw_delta': prev_row.get('yaw_delta'), 'curr_yaw_delta': curr_row.get('yaw_delta'), 'abs_dyaw_delta': abs_dyaw_delta, 'prev_margin': prev_row.get('margin'), 'curr_margin': curr_row.get('margin'), 'min_margin': min_margin, 'low_margin': metrics['low_margin'], 'multi_bin': metrics['multi_bin'], 'prev_confidence': prev_row.get('confidence'), 'curr_confidence': curr_row.get('confidence'), 'min_confidence': min_score, 'prev_box_area': prev_row.get('box_area'), 'curr_box_area': curr_row.get('box_area'), 'min_box_area': min_box_area, 'prev_depth_z': prev_row.get('depth_z'), 'curr_depth_z': curr_row.get('depth_z'), 'max_depth_z': max_depth_z, 'prev_source': prev_row.get('source'), 'curr_source': curr_row.get('source'), 'source_switched': metrics['source_switched'], 'prev_roi_id': prev_row.get('roi_id'), 'curr_roi_id': curr_row.get('roi_id'), 'roi_switched': metrics['roi_switched'], 'face_switched': metrics['face_switched'], 'cut_switched': metrics['cut_switched'], 'issue': issue, }) return transitions def summarize_track(track_id, rows, transitions, args): """Summarize one track into a report-friendly dictionary.""" first_row = rows[0] abs_dyaws = [item['abs_dyaw'] for item in transitions if item.get('abs_dyaw') is not None] abs_dyaw_deltas = [item['abs_dyaw_delta'] for item in transitions if item.get('abs_dyaw_delta') is not None] margins = [row['margin'] for row in rows if row.get('margin') is not None] confidences = [row['confidence'] for row in rows if row.get('confidence') is not None] depths = [row['depth_z'] for row in rows if row.get('depth_z') is not None] box_areas = [row['box_area'] for row in rows if row.get('box_area') is not None] issue_counts = Counter(item['issue'] for item in transitions) non_stable_issue_counts = Counter({ issue: count for issue, count in issue_counts.items() if issue not in {'stable', 'insufficient_yaw'} }) dominant_issue = non_stable_issue_counts.most_common(1)[0][0] if non_stable_issue_counts else 'stable' transition_count = len(transitions) yaw_bin_switch_count = sum(1 for item in transitions if item.get('yaw_bin_switched')) roi_switch_count = sum(1 for item in transitions if item.get('roi_switched')) source_switch_count = sum(1 for item in transitions if item.get('source_switched')) face_switch_count = sum(1 for item in transitions if item.get('face_switched')) cut_switch_count = sum(1 for item in transitions if item.get('cut_switched')) low_margin_ratio = ( float(sum(1 for row in rows if row.get('margin') is not None and row['margin'] < args.margin_threshold)) / len(rows) if rows else None ) multi_bin_ratio = ( float(sum(1 for row in rows if (row.get('num_pos_bins') or 0) >= 2) / len(rows)) if rows else None ) p95_abs_dyaw = safe_percentile(abs_dyaws, 95) max_abs_dyaw = max(abs_dyaws) if abs_dyaws else None instability_score = ( (p95_abs_dyaw or 0.0) + 0.5 * (max_abs_dyaw or 0.0) + 0.75 * (yaw_bin_switch_count / max(transition_count, 1)) + 0.25 * (low_margin_ratio or 0.0) ) return { 'track_id': track_id, 'class_id': first_row.get('class_id'), 'type_name': first_row.get('type_name'), 'source': first_row.get('source'), 'roi_id': first_row.get('roi_id'), 'trajectory_length': len(rows), 'transition_count': transition_count, 'frame_start': rows[0].get('frame_index'), 'frame_end': rows[-1].get('frame_index'), 'frame_span': rows[-1].get('frame_index') - rows[0].get('frame_index') + 1, 'image_start': rows[0].get('image_name'), 'image_end': rows[-1].get('image_name'), 'confidence_mean': safe_mean(confidences), 'confidence_min': min(confidences) if confidences else None, 'depth_mean': safe_mean(depths), 'depth_min': min(depths) if depths else None, 'depth_max': max(depths) if depths else None, 'box_area_mean': safe_mean(box_areas), 'margin_mean': safe_mean(margins), 'margin_min': min(margins) if margins else None, 'low_margin_ratio': low_margin_ratio, 'multi_bin_ratio': multi_bin_ratio, 'median_abs_dyaw': safe_median(abs_dyaws), 'p95_abs_dyaw': p95_abs_dyaw, 'max_abs_dyaw': max_abs_dyaw, 'median_abs_dyaw_delta_same_bin': safe_median(abs_dyaw_deltas), 'max_abs_dyaw_delta_same_bin': max(abs_dyaw_deltas) if abs_dyaw_deltas else None, 'yaw_bin_switch_count': yaw_bin_switch_count, 'yaw_bin_switch_rate': float(yaw_bin_switch_count / max(transition_count, 1)), 'roi_switch_count': roi_switch_count, 'source_switch_count': source_switch_count, 'face_cls_switch_count': face_switch_count, 'cut_cls_switch_count': cut_switch_count, 'wrap_around_count': issue_counts.get('wrap_around', 0), 'flip_180_count': issue_counts.get('flip_180', 0), 'yaw_bin_jitter_count': issue_counts.get('yaw_bin_jitter', 0), 'yaw_delta_jitter_count': issue_counts.get('yaw_delta_jitter', 0), 'roi_sensitive_count': issue_counts.get('roi_sensitive', 0), 'visibility_ambiguity_count': issue_counts.get('visibility_ambiguity', 0), 'other_instability_count': issue_counts.get('other_instability', 0) + issue_counts.get('yaw_bin_switch', 0), 'dominant_issue': dominant_issue, 'instability_score': instability_score, } def build_track_summaries(track_rows_by_id, transitions, args): """Build per-track summaries for all analyzed tracks.""" transitions_by_track = defaultdict(list) for transition in transitions: transitions_by_track[transition['track_id']].append(transition) summaries = [] for track_id, rows in track_rows_by_id.items(): summaries.append(summarize_track(track_id, rows, transitions_by_track.get(track_id, []), args)) summaries.sort(key=lambda item: item['instability_score'], reverse=True) return summaries def depth_bucket_name(depth_z, depth_edges): """Map a depth value to a readable bucket label.""" if depth_z is None: return 'unknown' lower = 0.0 for edge in depth_edges: if depth_z < edge: return f'[{int(lower)}, {int(edge)})' lower = edge return f'[{int(depth_edges[-1])}, +inf)' def compute_group_summary(transitions, group_field): """Aggregate transition metrics by a grouping field.""" grouped = defaultdict(list) for transition in transitions: key = transition.get(group_field) if key is None: key = 'unknown' grouped[key].append(transition) summaries = [] for key, items in sorted(grouped.items(), key=lambda kv: str(kv[0])): abs_dyaws = [item['abs_dyaw'] for item in items if item.get('abs_dyaw') is not None] summaries.append({ group_field: key, 'transition_count': len(items), 'median_abs_dyaw': safe_median(abs_dyaws), 'p95_abs_dyaw': safe_percentile(abs_dyaws, 95), 'max_abs_dyaw': max(abs_dyaws) if abs_dyaws else None, 'yaw_bin_switch_rate': float(sum(1 for item in items if item.get('yaw_bin_switched')) / max(len(items), 1)), 'low_margin_ratio': float(sum(1 for item in items if item.get('low_margin')) / max(len(items), 1)), }) return summaries def compute_depth_summary(transitions, depth_edges): """Aggregate transition metrics by depth bucket.""" transitions_with_bucket = [] for transition in transitions: item = dict(transition) item['depth_bucket'] = depth_bucket_name(item.get('max_depth_z'), depth_edges) transitions_with_bucket.append(item) return compute_group_summary(transitions_with_bucket, 'depth_bucket') def compute_issue_counts(transitions): """Count root-cause categories across transitions.""" return dict(Counter(item.get('issue', 'unknown') for item in transitions)) def ensure_dir(path): """Create a directory if needed and return it as Path.""" path = Path(path) path.mkdir(parents=True, exist_ok=True) return path def write_csv(rows, output_path, field_order=None): """Write a list of dictionaries to CSV.""" output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) if rows: fieldnames = list(field_order or rows[0].keys()) if field_order is None: extras = sorted({key for row in rows for key in row.keys()} - set(fieldnames)) fieldnames.extend(extras) else: fieldnames = list(field_order or []) with open(output_path, 'w', encoding='utf-8', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for row in rows: writer.writerow({key: row.get(key) for key in fieldnames}) def write_json(data, output_path): """Write UTF-8 JSON.""" output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) def maybe_sample_xy(x_values, y_values, max_points): """Subsample scatter points when the dataset is very large.""" if len(x_values) <= max_points: return np.asarray(x_values), np.asarray(y_values) indices = np.linspace(0, len(x_values) - 1, max_points).astype(np.int64) return np.asarray(x_values)[indices], np.asarray(y_values)[indices] def plot_margin_vs_dyaw(transitions, output_path, args): """Scatter plot of min margin vs absolute wrapped yaw change.""" x_values = [item['min_margin'] for item in transitions if item.get('min_margin') is not None and item.get('abs_dyaw') is not None] y_values = [item['abs_dyaw'] for item in transitions if item.get('min_margin') is not None and item.get('abs_dyaw') is not None] if not x_values: return False x_values, y_values = maybe_sample_xy(x_values, y_values, args.max_scatter_points) fig, ax = plt.subplots(figsize=(8, 6)) ax.scatter(x_values, y_values, s=10, alpha=0.2, color='#1f77b4', edgecolors='none') ax.axvline(args.margin_threshold, color='tab:red', linestyle='--', label=f'margin={args.margin_threshold:.2f}') ax.axhline(args.large_dyaw, color='tab:orange', linestyle='--', label=f'|Δyaw|={args.large_dyaw:.2f}') ax.set_xlabel('Min Margin') ax.set_ylabel('|Δyaw| (rad)') ax.set_title('Heading Change vs Min Margin') ax.grid(True, alpha=0.3) ax.legend() plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close(fig) return True def plot_bin_transition_matrix(transitions, output_path): """Plot the yaw-bin transition matrix across all track transitions.""" matrix = np.zeros((4, 4), dtype=np.int32) has_any = False for item in transitions: prev_bin = item.get('prev_yaw_bin') curr_bin = item.get('curr_yaw_bin') if prev_bin is None or curr_bin is None: continue if 0 <= prev_bin < 4 and 0 <= curr_bin < 4: matrix[prev_bin, curr_bin] += 1 has_any = True if not has_any: return False fig, ax = plt.subplots(figsize=(6, 5)) im = ax.imshow(matrix, cmap='Blues') ax.set_xticks(range(4)) ax.set_yticks(range(4)) ax.set_xlabel('Current yaw_bin') ax.set_ylabel('Previous yaw_bin') ax.set_title('Yaw Bin Transition Matrix') for row in range(4): for col in range(4): ax.text(col, row, str(matrix[row, col]), ha='center', va='center', color='black') fig.colorbar(im, ax=ax, shrink=0.85) plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close(fig) return True def plot_bucket_summary(summary_rows, label_field, output_path, title): """Plot bucketed median/p95 yaw change plus switch and low-margin ratios.""" if not summary_rows: return False labels = [str(row[label_field]) for row in summary_rows] median_values = [row.get('median_abs_dyaw') or 0.0 for row in summary_rows] p95_values = [row.get('p95_abs_dyaw') or 0.0 for row in summary_rows] switch_rates = [row.get('yaw_bin_switch_rate') or 0.0 for row in summary_rows] low_margin_rates = [row.get('low_margin_ratio') or 0.0 for row in summary_rows] x = np.arange(len(labels)) width = 0.35 fig, axes = plt.subplots(2, 1, figsize=(max(8, len(labels) * 1.1), 8), sharex=True) axes[0].bar(x - width / 2, median_values, width=width, label='median |Δyaw|') axes[0].bar(x + width / 2, p95_values, width=width, label='p95 |Δyaw|') axes[0].set_ylabel('Radians') axes[0].set_title(title) axes[0].grid(True, axis='y', alpha=0.3) axes[0].legend() axes[1].bar(x - width / 2, switch_rates, width=width, label='yaw_bin switch rate') axes[1].bar(x + width / 2, low_margin_rates, width=width, label='low margin ratio') axes[1].set_ylabel('Ratio') axes[1].set_xticks(x) axes[1].set_xticklabels(labels, rotation=20, ha='right') axes[1].grid(True, axis='y', alpha=0.3) axes[1].legend() plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close(fig) return True def plot_issue_distribution(issue_counts, output_path): """Plot transition root-cause distribution.""" if not issue_counts: return False labels = list(issue_counts.keys()) values = [issue_counts[label] for label in labels] fig, ax = plt.subplots(figsize=(10, 5)) ax.bar(labels, values, color='#4c72b0') ax.set_ylabel('Transition Count') ax.set_title('Heading Transition Root-Cause Distribution') ax.grid(True, axis='y', alpha=0.3) ax.tick_params(axis='x', rotation=25) plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close(fig) return True def plot_bad_track_series(track_summary, track_rows, output_path, args): """Generate a multi-panel plot for one unstable track.""" frames = [row['frame_index'] for row in track_rows] yaw_final = [row.get('yaw_final', np.nan) for row in track_rows] rot_y_decoded = [row.get('rot_y_decoded', np.nan) for row in track_rows] yaw_bin = [np.nan if row.get('yaw_bin') is None else row.get('yaw_bin') for row in track_rows] yaw_delta = [row.get('yaw_delta', np.nan) for row in track_rows] margin = [row.get('margin', np.nan) for row in track_rows] confidence = [row.get('confidence', np.nan) for row in track_rows] fig, axes = plt.subplots(4, 1, figsize=(12, 12), sharex=True) axes[0].plot(frames, yaw_final, marker='o', label='yaw_final', color='tab:blue', linewidth=1.8) axes[0].plot(frames, rot_y_decoded, marker='s', label='rot_y_decoded', color='tab:orange', linewidth=1.2, linestyle='--') axes[0].axhline(0.0, color='gray', linestyle='--', alpha=0.5) axes[0].axhline(np.pi, color='gray', linestyle=':', alpha=0.5) axes[0].axhline(-np.pi, color='gray', linestyle=':', alpha=0.5) axes[0].set_ylabel('Yaw (rad)') axes[0].legend() axes[0].grid(True, alpha=0.3) axes[0].set_title( f"Track {track_summary['track_id']} | class={track_summary['class_id']} | " f"source={track_summary['source']} | dominant={track_summary['dominant_issue']}" ) axes[1].step(frames, yaw_bin, where='post', color='tab:green') axes[1].scatter(frames, yaw_bin, c=np.nan_to_num(np.asarray(yaw_bin), nan=-1), cmap='tab10', s=35) axes[1].set_ylabel('yaw_bin') axes[1].set_yticks([0, 1, 2, 3]) axes[1].grid(True, alpha=0.3) axes[2].plot(frames, yaw_delta, marker='o', color='tab:purple', label='yaw_delta') axes[2].set_ylabel('yaw_delta') axes[2].grid(True, alpha=0.3) axes[2].legend() axes[3].plot(frames, margin, marker='o', color='tab:red', label='margin') axes[3].plot(frames, confidence, marker='s', color='tab:cyan', label='confidence') axes[3].axhline(args.margin_threshold, color='tab:red', linestyle='--', alpha=0.6) axes[3].set_ylabel('Margin / Score') axes[3].set_xlabel('Frame Index') axes[3].grid(True, alpha=0.3) axes[3].legend() plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close(fig) def format_float(value, digits=4): """Format a scalar for markdown tables.""" if value is None: return '-' return f'{float(value):.{digits}f}' def markdown_table(rows, columns, headers=None, max_rows=None): """Render a small markdown table.""" if not rows: return '_None_' rows = rows[:max_rows] if max_rows is not None else rows headers = headers or columns lines = [ '| ' + ' | '.join(headers) + ' |', '| ' + ' | '.join(['---'] * len(headers)) + ' |', ] for row in rows: cells = [] for column in columns: value = row.get(column) if isinstance(value, float): cells.append(format_float(value)) else: cells.append('-' if value is None else str(value)) lines.append('| ' + ' | '.join(cells) + ' |') return '\n'.join(lines) def build_report_markdown(summary, output_dir): """Build a markdown report for the analysis outputs.""" agg = summary['aggregate_metrics'] issue_rows = [ {'issue': issue, 'count': count} for issue, count in sorted(summary['issue_counts'].items(), key=lambda kv: (-kv[1], kv[0])) ] lines = [ '# Heading Debug Analysis Report', '', '## Overview', '', f"- Input: `{summary['input_json']}`", f"- Output directory: `{output_dir}`", f"- Total frames: {summary['input_stats']['total_frames']}", f"- Total detections: {summary['input_stats']['total_detections']}", f"- Detections with heading_debug: {summary['input_stats']['detections_with_heading_debug']}", f"- Filtered detections analyzed: {summary['filtered_detection_count']}", f"- Tracks analyzed: {summary['track_count']}", f"- Transitions analyzed: {summary['transition_count']}", '', '## Aggregate Metrics', '', f"- median |Δyaw|: {format_float(agg.get('median_abs_dyaw'))} rad", f"- p95 |Δyaw|: {format_float(agg.get('p95_abs_dyaw'))} rad", f"- max |Δyaw|: {format_float(agg.get('max_abs_dyaw'))} rad", f"- yaw_bin switch rate: {format_float(agg.get('yaw_bin_switch_rate'))}", f"- low margin ratio: {format_float(agg.get('low_margin_ratio'))}", '', '## Root Cause Distribution', '', markdown_table(issue_rows, ['issue', 'count']), '', '## By Source', '', markdown_table( summary['source_summary'], ['source', 'transition_count', 'median_abs_dyaw', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio'], headers=['source', 'transitions', 'median |Δyaw|', 'p95 |Δyaw|', 'bin switch', 'low margin'], ), '', '## By Class', '', markdown_table( summary['class_summary'], ['class_id', 'transition_count', 'median_abs_dyaw', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio'], headers=['class', 'transitions', 'median |Δyaw|', 'p95 |Δyaw|', 'bin switch', 'low margin'], ), '', '## By Depth Bucket', '', markdown_table( summary['depth_summary'], ['depth_bucket', 'transition_count', 'median_abs_dyaw', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio'], headers=['depth', 'transitions', 'median |Δyaw|', 'p95 |Δyaw|', 'bin switch', 'low margin'], ), '', '## Top Unstable Tracks', '', markdown_table( summary['bad_tracks'][:10], ['track_id', 'class_id', 'source', 'trajectory_length', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio', 'dominant_issue'], headers=['track', 'class', 'source', 'len', 'p95 |Δyaw|', 'bin switch', 'low margin', 'dominant issue'], ), '', '## Output Files', '', ] for name in summary['output_files']: lines.append(f'- `{name}`') lines.append('') return '\n'.join(lines) def build_summary_dict(input_json_path, input_stats, rows, transitions, track_summaries, class_summary, source_summary, depth_summary, issue_counts, output_files): """Collect machine-readable summary data.""" abs_dyaws = [item['abs_dyaw'] for item in transitions if item.get('abs_dyaw') is not None] aggregate_metrics = { 'median_abs_dyaw': safe_median(abs_dyaws), 'p95_abs_dyaw': safe_percentile(abs_dyaws, 95), 'max_abs_dyaw': max(abs_dyaws) if abs_dyaws else None, 'yaw_bin_switch_rate': float(sum(1 for item in transitions if item.get('yaw_bin_switched')) / max(len(transitions), 1)) if transitions else None, 'low_margin_ratio': float(sum(1 for item in transitions if item.get('low_margin')) / max(len(transitions), 1)) if transitions else None, } return { 'input_json': str(input_json_path), 'input_stats': input_stats, 'filtered_detection_count': len(rows), 'transition_count': len(transitions), 'track_count': len(track_summaries), 'aggregate_metrics': aggregate_metrics, 'issue_counts': issue_counts, 'class_summary': class_summary, 'source_summary': source_summary, 'depth_summary': depth_summary, 'bad_tracks': track_summaries[: min(50, len(track_summaries))], 'output_files': output_files, } def parse_args(): """Parse CLI arguments.""" parser = argparse.ArgumentParser(description='Analyze heading_debug stability from tracked JSON outputs.') parser.add_argument('--input', required=True, help='Tracking JSON file or a case directory. When a directory is given, merge.json is used by default.') parser.add_argument('--json-name', default=None, help='When --input is a directory, analyze this JSON filename inside it. Defaults to merge.json.') parser.add_argument('--output-dir', default=None, help='Output directory. Defaults to _heading_debug_analysis next to the input JSON.') parser.add_argument('--class-ids', nargs='*', type=int, default=None, help='Optional class filter, e.g. --class-ids 0 13') parser.add_argument('--sources', nargs='*', default=None, help='Optional source filter, e.g. --sources roi0 roi1 merge') parser.add_argument('--min-track-len', type=int, default=3, help='Minimum track length required for track-level analysis.') parser.add_argument('--top-k-bad-tracks', type=int, default=20, help='Number of most unstable tracks to export and plot.') parser.add_argument('--margin-threshold', type=float, default=0.10, help='Threshold used to mark low-margin bins.') parser.add_argument('--large-dyaw', type=float, default=0.35, help='Threshold in radians used to flag large heading changes.') parser.add_argument('--large-delta', type=float, default=0.20, help='Threshold in radians used to flag large yaw_delta changes.') parser.add_argument('--wrap-small-dyaw', type=float, default=0.35, help='Wrapped |Δyaw| threshold used to detect +pi/-pi wrap-around cases.') parser.add_argument('--flip-180-dyaw', type=float, default=2.50, help='Threshold in radians used to flag 180-degree flips.') parser.add_argument('--low-score-threshold', type=float, default=0.45, help='Threshold used in visibility/ambiguity heuristics.') parser.add_argument('--small-box-area', type=float, default=1600.0, help='Box-area threshold used in visibility/ambiguity heuristics.') parser.add_argument('--far-depth-threshold', type=float, default=30.0, help='Depth threshold used in visibility/ambiguity heuristics.') parser.add_argument('--max-scatter-points', type=int, default=50000, help='Maximum number of points kept in scatter plots.') parser.add_argument('--no-plots', action='store_true', help='Skip plot generation and only write tables/report.') return parser.parse_args() def main(): args = parse_args() input_json_path = resolve_input_path(args.input, args.json_name) tracking_data = load_tracking_data(input_json_path) default_source = infer_default_source_label(input_json_path) if args.output_dir is None: output_dir = input_json_path.parent / f'{input_json_path.stem}_heading_debug_analysis' else: output_dir = Path(args.output_dir) output_dir = ensure_dir(output_dir) plots_dir = ensure_dir(output_dir / 'figures') bad_track_dir = ensure_dir(output_dir / 'fig_track_yaw_series') rows, input_stats = build_flat_rows(tracking_data, default_source) rows = filter_rows(rows, class_ids=args.class_ids, sources=args.sources) if not rows: raise ValueError('No detections with heading_debug remain after filtering.') track_rows_by_id = group_rows_by_track(rows, args.min_track_len) transitions = build_transition_records(track_rows_by_id, args) track_summaries = build_track_summaries(track_rows_by_id, transitions, args) class_summary = compute_group_summary(transitions, 'class_id') source_summary = compute_group_summary(transitions, 'source') depth_summary = compute_depth_summary(transitions, DEFAULT_DEPTH_BUCKETS) issue_counts = compute_issue_counts(transitions) flat_csv_path = output_dir / 'heading_debug_flat.csv' transition_csv_path = output_dir / 'heading_debug_transition_flat.csv' track_summary_csv_path = output_dir / 'heading_debug_track_summary.csv' bad_tracks_csv_path = output_dir / 'heading_debug_bad_tracks.csv' class_summary_csv_path = output_dir / 'heading_debug_class_summary.csv' source_summary_csv_path = output_dir / 'heading_debug_source_summary.csv' depth_summary_csv_path = output_dir / 'heading_debug_depth_summary.csv' summary_json_path = output_dir / 'heading_debug_summary.json' report_md_path = output_dir / 'heading_debug_analysis_report.md' write_csv(rows, flat_csv_path) write_csv(transitions, transition_csv_path) write_csv(track_summaries, track_summary_csv_path) bad_tracks = track_summaries[: max(0, args.top_k_bad_tracks)] write_csv(bad_tracks, bad_tracks_csv_path) write_csv(class_summary, class_summary_csv_path) write_csv(source_summary, source_summary_csv_path) write_csv(depth_summary, depth_summary_csv_path) output_files = [ flat_csv_path.name, transition_csv_path.name, track_summary_csv_path.name, bad_tracks_csv_path.name, class_summary_csv_path.name, source_summary_csv_path.name, depth_summary_csv_path.name, summary_json_path.name, report_md_path.name, ] if not args.no_plots: if plot_margin_vs_dyaw(transitions, plots_dir / 'fig_margin_vs_dyaw.png', args): output_files.append('figures/fig_margin_vs_dyaw.png') if plot_bin_transition_matrix(transitions, plots_dir / 'fig_bin_transition_matrix.png'): output_files.append('figures/fig_bin_transition_matrix.png') if plot_bucket_summary(source_summary, 'source', plots_dir / 'fig_stability_by_source.png', 'Heading Stability by Source'): output_files.append('figures/fig_stability_by_source.png') if plot_bucket_summary(class_summary, 'class_id', plots_dir / 'fig_stability_by_class.png', 'Heading Stability by Class'): output_files.append('figures/fig_stability_by_class.png') if plot_bucket_summary(depth_summary, 'depth_bucket', plots_dir / 'fig_stability_by_depth.png', 'Heading Stability by Depth'): output_files.append('figures/fig_stability_by_depth.png') if plot_issue_distribution(issue_counts, plots_dir / 'fig_issue_distribution.png'): output_files.append('figures/fig_issue_distribution.png') for summary_row in bad_tracks: track_id = summary_row['track_id'] track_rows = track_rows_by_id.get(track_id) if not track_rows: continue plot_path = bad_track_dir / f'track_{track_id}.png' plot_bad_track_series(summary_row, track_rows, plot_path, args) output_files.append(f'fig_track_yaw_series/{plot_path.name}') summary = build_summary_dict( input_json_path=input_json_path, input_stats=input_stats, rows=rows, transitions=transitions, track_summaries=track_summaries, class_summary=class_summary, source_summary=source_summary, depth_summary=depth_summary, issue_counts=issue_counts, output_files=output_files, ) write_json(summary, summary_json_path) report_md = build_report_markdown(summary, output_dir) report_md_path.write_text(report_md, encoding='utf-8') print(f'Input JSON: {input_json_path}') print(f'Output directory: {output_dir}') print(f'Filtered detections analyzed: {len(rows)}') print(f'Tracks analyzed: {len(track_summaries)}') print(f'Transitions analyzed: {len(transitions)}') print('Generated files:') for name in output_files: print(f' - {name}') if __name__ == '__main__': main()