Files
yolov26_3d/tools/temporal_analysis/analyze_heading_debug.py
2026-06-24 09:35:46 +08:00

1157 lines
47 KiB
Python
Executable File

"""Analyze tracked heading_debug data and summarize temporal heading instability.
Usage examples:
python analyze_heading_debug.py --input /path/to/tracking.json
python analyze_heading_debug.py --input /path/to/combined_tracking.json --sources merge
python analyze_heading_debug.py --input /path/to/case_dir
python analyze_heading_debug.py --input /path/to/case_dir --json-name combined_tracking.json
"""
import argparse
import csv
import json
from collections import Counter, defaultdict
from pathlib import Path
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
LANE_ASSIGNMENT_TO_SOURCE = {
0: 'roi0',
1: 'roi1',
2: 'merge',
}
DEFAULT_JSON_CANDIDATES = [
'merge.json',
'combined_tracking.json',
'roi0.json',
'roi1.json',
'tracking.json',
]
DEFAULT_DEPTH_BUCKETS = [10.0, 20.0, 30.0, 50.0]
def parse_numeric_value(value):
"""Convert JSON scalar strings to int/float when possible."""
if value is None:
return None
if isinstance(value, (int, float)):
return value
value_str = str(value).strip()
if not value_str:
return None
try:
numeric = float(value_str)
except ValueError:
return value
if numeric.is_integer():
return int(numeric)
return numeric
def safe_int(value):
"""Best-effort integer conversion."""
numeric = parse_numeric_value(value)
if numeric is None:
return None
try:
return int(numeric)
except (TypeError, ValueError):
return None
def safe_float(value):
"""Best-effort float conversion."""
numeric = parse_numeric_value(value)
if isinstance(numeric, (int, float)):
return float(numeric)
return None
def safe_mean(values):
"""Return mean for a non-empty numeric list."""
if not values:
return None
return float(np.mean(np.asarray(values, dtype=np.float64)))
def safe_median(values):
"""Return median for a non-empty numeric list."""
if not values:
return None
return float(np.median(np.asarray(values, dtype=np.float64)))
def safe_percentile(values, q):
"""Return percentile for a non-empty numeric list."""
if not values:
return None
return float(np.percentile(np.asarray(values, dtype=np.float64), q))
def normalize_heading_debug(heading_debug):
"""Normalize heading_debug values to numeric Python types."""
if not isinstance(heading_debug, dict):
return None
normalized = {}
yaw_bin = safe_int(heading_debug.get('yaw_bin'))
if yaw_bin is not None:
normalized['yaw_bin'] = yaw_bin
yaw_delta = safe_float(heading_debug.get('yaw_delta'))
if yaw_delta is not None:
normalized['yaw_delta'] = yaw_delta
rot_y_decoded = safe_float(heading_debug.get('rot_y_decoded'))
if rot_y_decoded is not None:
normalized['rot_y_decoded'] = rot_y_decoded
yaw_probs_raw = heading_debug.get('yaw_probs')
if isinstance(yaw_probs_raw, (list, tuple)):
yaw_probs = []
valid = True
for value in yaw_probs_raw:
numeric = safe_float(value)
if numeric is None:
valid = False
break
yaw_probs.append(numeric)
if valid:
normalized['yaw_probs'] = yaw_probs
return normalized or None
def wrap_angle_diff(angle_curr, angle_prev):
"""Return wrapped angle difference in radians within [-pi, pi]."""
return float(np.arctan2(np.sin(angle_curr - angle_prev), np.cos(angle_curr - angle_prev)))
def resolve_input_path(input_path, json_name=None):
"""Resolve an input path that may be a JSON file or a case directory."""
path = Path(input_path)
if path.is_file():
return path
if not path.is_dir():
raise FileNotFoundError(f'Input path does not exist: {path}')
candidates = [json_name] if json_name else DEFAULT_JSON_CANDIDATES
for name in candidates:
candidate = path / name
if candidate.is_file():
return candidate
raise FileNotFoundError(
f'No tracking JSON found in {path}. Tried: {", ".join(candidates)}'
)
def load_tracking_data(input_json_path):
"""Load tracking data from JSON."""
with open(input_json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f'Expected a list of frames in {input_json_path}, got {type(data).__name__}')
return data
def infer_default_source_label(input_json_path):
"""Infer a default source label from the input filename."""
stem = input_json_path.stem.lower()
if stem in {'roi0', 'roi1', 'merge', 'combined_tracking', 'tracking'}:
return stem
return 'input'
def extract_bbox(det):
"""Extract [x1, y1, x2, y2] bbox from a detection."""
bbox = det.get('bbox')
if bbox is None:
bbox = det.get('box2d')
if not isinstance(bbox, (list, tuple)) or len(bbox) < 4:
return None
values = [safe_float(v) for v in bbox[:4]]
if any(v is None for v in values):
return None
return values
def extract_object_3d(det):
"""Return the object_3d payload when available."""
return det.get('object_3d')
def extract_yaw_final(det):
"""Extract final yaw from object_3d when available."""
obj = extract_object_3d(det)
if obj is None:
return None
if isinstance(obj, dict):
return safe_float(obj.get('rotation_y'))
if isinstance(obj, (list, tuple)) and len(obj) >= 7:
return safe_float(obj[6])
return None
def extract_depth_z(det):
"""Extract camera-space depth z from object_3d when available."""
obj = extract_object_3d(det)
if obj is None:
return None
if isinstance(obj, dict):
location = obj.get('location')
if isinstance(location, (list, tuple)) and len(location) >= 3:
return safe_float(location[2])
return None
if isinstance(obj, (list, tuple)) and len(obj) >= 3:
return safe_float(obj[2])
return None
def extract_class_id(det):
"""Extract class id from detection."""
class_id = safe_int(det.get('class_id'))
if class_id is not None:
return class_id
return safe_int(det.get('type'))
def extract_confidence(det):
"""Extract detection confidence score."""
confidence = safe_float(det.get('confidence'))
if confidence is not None:
return confidence
return safe_float(det.get('score'))
def extract_source_label(det, default_source):
"""Extract source label, supporting combined_tracking lane assignments."""
source = det.get('source')
if source is not None:
return str(source)
lane_assignment = safe_int(det.get('lane_assignment'))
if lane_assignment is not None:
return LANE_ASSIGNMENT_TO_SOURCE.get(lane_assignment, f'lane_{lane_assignment}')
return default_source
def extract_frame_id(det, image_name):
"""Extract frameId as a string when present."""
frame_id = det.get('frameId')
if frame_id is not None:
return str(frame_id)
image_name = str(image_name or '')
digits = ''.join(ch if ch.isdigit() else ' ' for ch in image_name).split()
if digits:
return digits[-1]
return None
def compute_prob_stats(yaw_probs):
"""Compute top-k probability statistics from yaw_probs."""
if not yaw_probs:
return None, None, None, None
probs = np.asarray(yaw_probs, dtype=np.float64)
sorted_probs = np.sort(probs)[::-1]
p1 = float(sorted_probs[0])
p2 = float(sorted_probs[1]) if len(sorted_probs) > 1 else None
margin = float(p1 - p2) if p2 is not None else None
num_pos_bins = int(np.sum(probs > 0.5))
return p1, p2, margin, num_pos_bins
def build_flat_rows(tracking_data, default_source):
"""Flatten tracked detections that contain heading_debug."""
rows = []
stats = {
'total_frames': len(tracking_data),
'total_detections': 0,
'detections_with_heading_debug': 0,
'detections_missing_heading_debug': 0,
'detections_missing_track_id': 0,
}
for frame_index, frame in enumerate(tracking_data):
image_name = frame.get('image_name') or f'frame_{frame_index:06d}'
detections = frame.get('detections', [])
for det_index, det in enumerate(detections):
stats['total_detections'] += 1
heading_debug = normalize_heading_debug(det.get('heading_debug'))
if heading_debug is None:
stats['detections_missing_heading_debug'] += 1
continue
stats['detections_with_heading_debug'] += 1
bbox = extract_bbox(det)
x1 = bbox[0] if bbox else None
y1 = bbox[1] if bbox else None
x2 = bbox[2] if bbox else None
y2 = bbox[3] if bbox else None
box_w = (x2 - x1) if bbox else None
box_h = (y2 - y1) if bbox else None
box_area = (box_w * box_h) if box_w is not None and box_h is not None else None
yaw_probs = heading_debug.get('yaw_probs')
p1, p2, margin, num_pos_bins = compute_prob_stats(yaw_probs)
track_id = safe_int(det.get('track_id'))
if track_id is None:
stats['detections_missing_track_id'] += 1
yaw_final = extract_yaw_final(det)
rot_y_decoded = heading_debug.get('rot_y_decoded')
rows.append({
'frame_index': frame_index,
'image_name': str(image_name),
'det_index': det_index,
'track_id': track_id,
'frameId': extract_frame_id(det, image_name),
'timestamp': parse_numeric_value(det.get('timestamp')),
'class_id': extract_class_id(det),
'type_name': str(det.get('type_name', '')),
'sub_cls': safe_int(det.get('sub_cls')),
'confidence': extract_confidence(det),
'source': extract_source_label(det, default_source),
'lane_assignment': safe_int(det.get('lane_assignment')),
'roi_id': safe_int(det.get('roi_id')),
'face_cls': str(det.get('face_cls', '')),
'cut_cls': safe_int(det.get('cut_cls')),
'cut_cls_name': str(det.get('cut_cls_name', '')),
'anchor': str(det.get('anchor', '')),
'x1': x1,
'y1': y1,
'x2': x2,
'y2': y2,
'box_w': box_w,
'box_h': box_h,
'box_area': box_area,
'depth_z': extract_depth_z(det),
'yaw_final': yaw_final,
'yaw_analysis': yaw_final if yaw_final is not None else rot_y_decoded,
'yaw_bin': heading_debug.get('yaw_bin'),
'yaw_delta': heading_debug.get('yaw_delta'),
'rot_y_decoded': rot_y_decoded,
'yaw_prob_0': yaw_probs[0] if yaw_probs and len(yaw_probs) > 0 else None,
'yaw_prob_1': yaw_probs[1] if yaw_probs and len(yaw_probs) > 1 else None,
'yaw_prob_2': yaw_probs[2] if yaw_probs and len(yaw_probs) > 2 else None,
'yaw_prob_3': yaw_probs[3] if yaw_probs and len(yaw_probs) > 3 else None,
'yaw_prob_top1': p1,
'yaw_prob_top2': p2,
'margin': margin,
'num_pos_bins': num_pos_bins,
})
return rows, stats
def filter_rows(rows, class_ids=None, sources=None):
"""Apply class/source filters to rows."""
class_ids = set(class_ids or [])
sources = set(sources or [])
filtered = []
for row in rows:
if class_ids and row['class_id'] not in class_ids:
continue
if sources and row['source'] not in sources:
continue
filtered.append(row)
return filtered
def group_rows_by_track(rows, min_track_len):
"""Group rows by track id and keep sufficiently long tracks."""
grouped = defaultdict(list)
for row in rows:
track_id = row.get('track_id')
if track_id is None:
continue
grouped[track_id].append(row)
kept = {}
for track_id, items in grouped.items():
items.sort(key=lambda row: (row['frame_index'], row['det_index']))
if len(items) >= min_track_len:
kept[track_id] = items
return kept
def classify_transition(metrics, args):
"""Assign a root-cause category to one track transition."""
abs_dyaw = metrics.get('abs_dyaw')
raw_abs_dyaw = metrics.get('raw_abs_dyaw')
abs_dyaw_delta = metrics.get('abs_dyaw_delta')
yaw_bin_switched = metrics.get('yaw_bin_switched')
low_margin = metrics.get('low_margin')
multi_bin = metrics.get('multi_bin')
source_switched = metrics.get('source_switched')
roi_switched = metrics.get('roi_switched')
face_switched = metrics.get('face_switched')
cut_switched = metrics.get('cut_switched')
low_score = metrics.get('low_score')
small_box = metrics.get('small_box')
far_depth = metrics.get('far_depth')
if abs_dyaw is None:
return 'insufficient_yaw'
wrap_around = (
raw_abs_dyaw is not None
and raw_abs_dyaw > np.pi
and abs_dyaw < args.wrap_small_dyaw
)
flip_180 = abs_dyaw >= args.flip_180_dyaw
if flip_180 and (yaw_bin_switched or face_switched or cut_switched):
return 'flip_180'
if wrap_around:
return 'wrap_around'
if (source_switched or roi_switched) and abs_dyaw >= args.large_dyaw:
return 'roi_sensitive'
if yaw_bin_switched and (low_margin or multi_bin):
return 'yaw_bin_jitter'
if (
not yaw_bin_switched
and abs_dyaw_delta is not None
and abs_dyaw_delta >= args.large_delta
and abs_dyaw >= max(0.1, args.large_dyaw * 0.5)
):
return 'yaw_delta_jitter'
if abs_dyaw >= args.large_dyaw and (low_score or small_box or far_depth or face_switched or cut_switched):
return 'visibility_ambiguity'
if yaw_bin_switched and abs_dyaw >= args.large_dyaw:
return 'yaw_bin_switch'
if abs_dyaw >= args.large_dyaw:
return 'other_instability'
return 'stable'
def build_transition_records(track_rows_by_id, args):
"""Build per-transition records for all analyzed tracks."""
transitions = []
for track_id, rows in track_rows_by_id.items():
for idx in range(1, len(rows)):
prev_row = rows[idx - 1]
curr_row = rows[idx]
prev_yaw = prev_row.get('yaw_analysis')
curr_yaw = curr_row.get('yaw_analysis')
wrapped_dyaw = None
abs_dyaw = None
raw_abs_dyaw = None
if prev_yaw is not None and curr_yaw is not None:
wrapped_dyaw = wrap_angle_diff(curr_yaw, prev_yaw)
abs_dyaw = abs(wrapped_dyaw)
raw_abs_dyaw = abs(curr_yaw - prev_yaw)
same_yaw_bin = (
prev_row.get('yaw_bin') is not None
and curr_row.get('yaw_bin') is not None
and prev_row['yaw_bin'] == curr_row['yaw_bin']
)
abs_dyaw_delta = None
if same_yaw_bin and prev_row.get('yaw_delta') is not None and curr_row.get('yaw_delta') is not None:
abs_dyaw_delta = abs(curr_row['yaw_delta'] - prev_row['yaw_delta'])
margin_values = [value for value in [prev_row.get('margin'), curr_row.get('margin')] if value is not None]
min_margin = min(margin_values) if margin_values else None
num_pos_candidates = [value for value in [prev_row.get('num_pos_bins'), curr_row.get('num_pos_bins')] if value is not None]
max_num_pos_bins = max(num_pos_candidates) if num_pos_candidates else 0
score_candidates = [value for value in [prev_row.get('confidence'), curr_row.get('confidence')] if value is not None]
min_score = min(score_candidates) if score_candidates else None
box_area_candidates = [value for value in [prev_row.get('box_area'), curr_row.get('box_area')] if value is not None]
min_box_area = min(box_area_candidates) if box_area_candidates else None
depth_candidates = [value for value in [prev_row.get('depth_z'), curr_row.get('depth_z')] if value is not None]
max_depth_z = max(depth_candidates) if depth_candidates else None
metrics = {
'abs_dyaw': abs_dyaw,
'raw_abs_dyaw': raw_abs_dyaw,
'abs_dyaw_delta': abs_dyaw_delta,
'yaw_bin_switched': prev_row.get('yaw_bin') != curr_row.get('yaw_bin'),
'low_margin': min_margin is not None and min_margin < args.margin_threshold,
'multi_bin': max_num_pos_bins >= 2,
'source_switched': prev_row.get('source') != curr_row.get('source'),
'roi_switched': prev_row.get('roi_id') != curr_row.get('roi_id'),
'face_switched': prev_row.get('face_cls') != curr_row.get('face_cls'),
'cut_switched': prev_row.get('cut_cls') != curr_row.get('cut_cls'),
'low_score': min_score is not None and min_score < args.low_score_threshold,
'small_box': min_box_area is not None and min_box_area < args.small_box_area,
'far_depth': max_depth_z is not None and max_depth_z > args.far_depth_threshold,
}
issue = classify_transition(metrics, args)
transitions.append({
'track_id': track_id,
'class_id': curr_row.get('class_id'),
'type_name': curr_row.get('type_name'),
'source': curr_row.get('source'),
'prev_frame_index': prev_row.get('frame_index'),
'curr_frame_index': curr_row.get('frame_index'),
'frame_gap': curr_row.get('frame_index') - prev_row.get('frame_index'),
'prev_image_name': prev_row.get('image_name'),
'curr_image_name': curr_row.get('image_name'),
'prev_timestamp': prev_row.get('timestamp'),
'curr_timestamp': curr_row.get('timestamp'),
'prev_yaw': prev_yaw,
'curr_yaw': curr_yaw,
'wrapped_dyaw': wrapped_dyaw,
'abs_dyaw': abs_dyaw,
'raw_abs_dyaw': raw_abs_dyaw,
'prev_yaw_bin': prev_row.get('yaw_bin'),
'curr_yaw_bin': curr_row.get('yaw_bin'),
'yaw_bin_switched': metrics['yaw_bin_switched'],
'same_yaw_bin': same_yaw_bin,
'prev_yaw_delta': prev_row.get('yaw_delta'),
'curr_yaw_delta': curr_row.get('yaw_delta'),
'abs_dyaw_delta': abs_dyaw_delta,
'prev_margin': prev_row.get('margin'),
'curr_margin': curr_row.get('margin'),
'min_margin': min_margin,
'low_margin': metrics['low_margin'],
'multi_bin': metrics['multi_bin'],
'prev_confidence': prev_row.get('confidence'),
'curr_confidence': curr_row.get('confidence'),
'min_confidence': min_score,
'prev_box_area': prev_row.get('box_area'),
'curr_box_area': curr_row.get('box_area'),
'min_box_area': min_box_area,
'prev_depth_z': prev_row.get('depth_z'),
'curr_depth_z': curr_row.get('depth_z'),
'max_depth_z': max_depth_z,
'prev_source': prev_row.get('source'),
'curr_source': curr_row.get('source'),
'source_switched': metrics['source_switched'],
'prev_roi_id': prev_row.get('roi_id'),
'curr_roi_id': curr_row.get('roi_id'),
'roi_switched': metrics['roi_switched'],
'face_switched': metrics['face_switched'],
'cut_switched': metrics['cut_switched'],
'issue': issue,
})
return transitions
def summarize_track(track_id, rows, transitions, args):
"""Summarize one track into a report-friendly dictionary."""
first_row = rows[0]
abs_dyaws = [item['abs_dyaw'] for item in transitions if item.get('abs_dyaw') is not None]
abs_dyaw_deltas = [item['abs_dyaw_delta'] for item in transitions if item.get('abs_dyaw_delta') is not None]
margins = [row['margin'] for row in rows if row.get('margin') is not None]
confidences = [row['confidence'] for row in rows if row.get('confidence') is not None]
depths = [row['depth_z'] for row in rows if row.get('depth_z') is not None]
box_areas = [row['box_area'] for row in rows if row.get('box_area') is not None]
issue_counts = Counter(item['issue'] for item in transitions)
non_stable_issue_counts = Counter({
issue: count
for issue, count in issue_counts.items()
if issue not in {'stable', 'insufficient_yaw'}
})
dominant_issue = non_stable_issue_counts.most_common(1)[0][0] if non_stable_issue_counts else 'stable'
transition_count = len(transitions)
yaw_bin_switch_count = sum(1 for item in transitions if item.get('yaw_bin_switched'))
roi_switch_count = sum(1 for item in transitions if item.get('roi_switched'))
source_switch_count = sum(1 for item in transitions if item.get('source_switched'))
face_switch_count = sum(1 for item in transitions if item.get('face_switched'))
cut_switch_count = sum(1 for item in transitions if item.get('cut_switched'))
low_margin_ratio = (
float(sum(1 for row in rows if row.get('margin') is not None and row['margin'] < args.margin_threshold)) / len(rows)
if rows else None
)
multi_bin_ratio = (
float(sum(1 for row in rows if (row.get('num_pos_bins') or 0) >= 2) / len(rows))
if rows else None
)
p95_abs_dyaw = safe_percentile(abs_dyaws, 95)
max_abs_dyaw = max(abs_dyaws) if abs_dyaws else None
instability_score = (
(p95_abs_dyaw or 0.0)
+ 0.5 * (max_abs_dyaw or 0.0)
+ 0.75 * (yaw_bin_switch_count / max(transition_count, 1))
+ 0.25 * (low_margin_ratio or 0.0)
)
return {
'track_id': track_id,
'class_id': first_row.get('class_id'),
'type_name': first_row.get('type_name'),
'source': first_row.get('source'),
'roi_id': first_row.get('roi_id'),
'trajectory_length': len(rows),
'transition_count': transition_count,
'frame_start': rows[0].get('frame_index'),
'frame_end': rows[-1].get('frame_index'),
'frame_span': rows[-1].get('frame_index') - rows[0].get('frame_index') + 1,
'image_start': rows[0].get('image_name'),
'image_end': rows[-1].get('image_name'),
'confidence_mean': safe_mean(confidences),
'confidence_min': min(confidences) if confidences else None,
'depth_mean': safe_mean(depths),
'depth_min': min(depths) if depths else None,
'depth_max': max(depths) if depths else None,
'box_area_mean': safe_mean(box_areas),
'margin_mean': safe_mean(margins),
'margin_min': min(margins) if margins else None,
'low_margin_ratio': low_margin_ratio,
'multi_bin_ratio': multi_bin_ratio,
'median_abs_dyaw': safe_median(abs_dyaws),
'p95_abs_dyaw': p95_abs_dyaw,
'max_abs_dyaw': max_abs_dyaw,
'median_abs_dyaw_delta_same_bin': safe_median(abs_dyaw_deltas),
'max_abs_dyaw_delta_same_bin': max(abs_dyaw_deltas) if abs_dyaw_deltas else None,
'yaw_bin_switch_count': yaw_bin_switch_count,
'yaw_bin_switch_rate': float(yaw_bin_switch_count / max(transition_count, 1)),
'roi_switch_count': roi_switch_count,
'source_switch_count': source_switch_count,
'face_cls_switch_count': face_switch_count,
'cut_cls_switch_count': cut_switch_count,
'wrap_around_count': issue_counts.get('wrap_around', 0),
'flip_180_count': issue_counts.get('flip_180', 0),
'yaw_bin_jitter_count': issue_counts.get('yaw_bin_jitter', 0),
'yaw_delta_jitter_count': issue_counts.get('yaw_delta_jitter', 0),
'roi_sensitive_count': issue_counts.get('roi_sensitive', 0),
'visibility_ambiguity_count': issue_counts.get('visibility_ambiguity', 0),
'other_instability_count': issue_counts.get('other_instability', 0) + issue_counts.get('yaw_bin_switch', 0),
'dominant_issue': dominant_issue,
'instability_score': instability_score,
}
def build_track_summaries(track_rows_by_id, transitions, args):
"""Build per-track summaries for all analyzed tracks."""
transitions_by_track = defaultdict(list)
for transition in transitions:
transitions_by_track[transition['track_id']].append(transition)
summaries = []
for track_id, rows in track_rows_by_id.items():
summaries.append(summarize_track(track_id, rows, transitions_by_track.get(track_id, []), args))
summaries.sort(key=lambda item: item['instability_score'], reverse=True)
return summaries
def depth_bucket_name(depth_z, depth_edges):
"""Map a depth value to a readable bucket label."""
if depth_z is None:
return 'unknown'
lower = 0.0
for edge in depth_edges:
if depth_z < edge:
return f'[{int(lower)}, {int(edge)})'
lower = edge
return f'[{int(depth_edges[-1])}, +inf)'
def compute_group_summary(transitions, group_field):
"""Aggregate transition metrics by a grouping field."""
grouped = defaultdict(list)
for transition in transitions:
key = transition.get(group_field)
if key is None:
key = 'unknown'
grouped[key].append(transition)
summaries = []
for key, items in sorted(grouped.items(), key=lambda kv: str(kv[0])):
abs_dyaws = [item['abs_dyaw'] for item in items if item.get('abs_dyaw') is not None]
summaries.append({
group_field: key,
'transition_count': len(items),
'median_abs_dyaw': safe_median(abs_dyaws),
'p95_abs_dyaw': safe_percentile(abs_dyaws, 95),
'max_abs_dyaw': max(abs_dyaws) if abs_dyaws else None,
'yaw_bin_switch_rate': float(sum(1 for item in items if item.get('yaw_bin_switched')) / max(len(items), 1)),
'low_margin_ratio': float(sum(1 for item in items if item.get('low_margin')) / max(len(items), 1)),
})
return summaries
def compute_depth_summary(transitions, depth_edges):
"""Aggregate transition metrics by depth bucket."""
transitions_with_bucket = []
for transition in transitions:
item = dict(transition)
item['depth_bucket'] = depth_bucket_name(item.get('max_depth_z'), depth_edges)
transitions_with_bucket.append(item)
return compute_group_summary(transitions_with_bucket, 'depth_bucket')
def compute_issue_counts(transitions):
"""Count root-cause categories across transitions."""
return dict(Counter(item.get('issue', 'unknown') for item in transitions))
def ensure_dir(path):
"""Create a directory if needed and return it as Path."""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
return path
def write_csv(rows, output_path, field_order=None):
"""Write a list of dictionaries to CSV."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
if rows:
fieldnames = list(field_order or rows[0].keys())
if field_order is None:
extras = sorted({key for row in rows for key in row.keys()} - set(fieldnames))
fieldnames.extend(extras)
else:
fieldnames = list(field_order or [])
with open(output_path, 'w', encoding='utf-8', newline='') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow({key: row.get(key) for key in fieldnames})
def write_json(data, output_path):
"""Write UTF-8 JSON."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def maybe_sample_xy(x_values, y_values, max_points):
"""Subsample scatter points when the dataset is very large."""
if len(x_values) <= max_points:
return np.asarray(x_values), np.asarray(y_values)
indices = np.linspace(0, len(x_values) - 1, max_points).astype(np.int64)
return np.asarray(x_values)[indices], np.asarray(y_values)[indices]
def plot_margin_vs_dyaw(transitions, output_path, args):
"""Scatter plot of min margin vs absolute wrapped yaw change."""
x_values = [item['min_margin'] for item in transitions if item.get('min_margin') is not None and item.get('abs_dyaw') is not None]
y_values = [item['abs_dyaw'] for item in transitions if item.get('min_margin') is not None and item.get('abs_dyaw') is not None]
if not x_values:
return False
x_values, y_values = maybe_sample_xy(x_values, y_values, args.max_scatter_points)
fig, ax = plt.subplots(figsize=(8, 6))
ax.scatter(x_values, y_values, s=10, alpha=0.2, color='#1f77b4', edgecolors='none')
ax.axvline(args.margin_threshold, color='tab:red', linestyle='--', label=f'margin={args.margin_threshold:.2f}')
ax.axhline(args.large_dyaw, color='tab:orange', linestyle='--', label=f'|Δyaw|={args.large_dyaw:.2f}')
ax.set_xlabel('Min Margin')
ax.set_ylabel('|Δyaw| (rad)')
ax.set_title('Heading Change vs Min Margin')
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close(fig)
return True
def plot_bin_transition_matrix(transitions, output_path):
"""Plot the yaw-bin transition matrix across all track transitions."""
matrix = np.zeros((4, 4), dtype=np.int32)
has_any = False
for item in transitions:
prev_bin = item.get('prev_yaw_bin')
curr_bin = item.get('curr_yaw_bin')
if prev_bin is None or curr_bin is None:
continue
if 0 <= prev_bin < 4 and 0 <= curr_bin < 4:
matrix[prev_bin, curr_bin] += 1
has_any = True
if not has_any:
return False
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(matrix, cmap='Blues')
ax.set_xticks(range(4))
ax.set_yticks(range(4))
ax.set_xlabel('Current yaw_bin')
ax.set_ylabel('Previous yaw_bin')
ax.set_title('Yaw Bin Transition Matrix')
for row in range(4):
for col in range(4):
ax.text(col, row, str(matrix[row, col]), ha='center', va='center', color='black')
fig.colorbar(im, ax=ax, shrink=0.85)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close(fig)
return True
def plot_bucket_summary(summary_rows, label_field, output_path, title):
"""Plot bucketed median/p95 yaw change plus switch and low-margin ratios."""
if not summary_rows:
return False
labels = [str(row[label_field]) for row in summary_rows]
median_values = [row.get('median_abs_dyaw') or 0.0 for row in summary_rows]
p95_values = [row.get('p95_abs_dyaw') or 0.0 for row in summary_rows]
switch_rates = [row.get('yaw_bin_switch_rate') or 0.0 for row in summary_rows]
low_margin_rates = [row.get('low_margin_ratio') or 0.0 for row in summary_rows]
x = np.arange(len(labels))
width = 0.35
fig, axes = plt.subplots(2, 1, figsize=(max(8, len(labels) * 1.1), 8), sharex=True)
axes[0].bar(x - width / 2, median_values, width=width, label='median |Δyaw|')
axes[0].bar(x + width / 2, p95_values, width=width, label='p95 |Δyaw|')
axes[0].set_ylabel('Radians')
axes[0].set_title(title)
axes[0].grid(True, axis='y', alpha=0.3)
axes[0].legend()
axes[1].bar(x - width / 2, switch_rates, width=width, label='yaw_bin switch rate')
axes[1].bar(x + width / 2, low_margin_rates, width=width, label='low margin ratio')
axes[1].set_ylabel('Ratio')
axes[1].set_xticks(x)
axes[1].set_xticklabels(labels, rotation=20, ha='right')
axes[1].grid(True, axis='y', alpha=0.3)
axes[1].legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close(fig)
return True
def plot_issue_distribution(issue_counts, output_path):
"""Plot transition root-cause distribution."""
if not issue_counts:
return False
labels = list(issue_counts.keys())
values = [issue_counts[label] for label in labels]
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(labels, values, color='#4c72b0')
ax.set_ylabel('Transition Count')
ax.set_title('Heading Transition Root-Cause Distribution')
ax.grid(True, axis='y', alpha=0.3)
ax.tick_params(axis='x', rotation=25)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close(fig)
return True
def plot_bad_track_series(track_summary, track_rows, output_path, args):
"""Generate a multi-panel plot for one unstable track."""
frames = [row['frame_index'] for row in track_rows]
yaw_final = [row.get('yaw_final', np.nan) for row in track_rows]
rot_y_decoded = [row.get('rot_y_decoded', np.nan) for row in track_rows]
yaw_bin = [np.nan if row.get('yaw_bin') is None else row.get('yaw_bin') for row in track_rows]
yaw_delta = [row.get('yaw_delta', np.nan) for row in track_rows]
margin = [row.get('margin', np.nan) for row in track_rows]
confidence = [row.get('confidence', np.nan) for row in track_rows]
fig, axes = plt.subplots(4, 1, figsize=(12, 12), sharex=True)
axes[0].plot(frames, yaw_final, marker='o', label='yaw_final', color='tab:blue', linewidth=1.8)
axes[0].plot(frames, rot_y_decoded, marker='s', label='rot_y_decoded', color='tab:orange', linewidth=1.2, linestyle='--')
axes[0].axhline(0.0, color='gray', linestyle='--', alpha=0.5)
axes[0].axhline(np.pi, color='gray', linestyle=':', alpha=0.5)
axes[0].axhline(-np.pi, color='gray', linestyle=':', alpha=0.5)
axes[0].set_ylabel('Yaw (rad)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_title(
f"Track {track_summary['track_id']} | class={track_summary['class_id']} | "
f"source={track_summary['source']} | dominant={track_summary['dominant_issue']}"
)
axes[1].step(frames, yaw_bin, where='post', color='tab:green')
axes[1].scatter(frames, yaw_bin, c=np.nan_to_num(np.asarray(yaw_bin), nan=-1), cmap='tab10', s=35)
axes[1].set_ylabel('yaw_bin')
axes[1].set_yticks([0, 1, 2, 3])
axes[1].grid(True, alpha=0.3)
axes[2].plot(frames, yaw_delta, marker='o', color='tab:purple', label='yaw_delta')
axes[2].set_ylabel('yaw_delta')
axes[2].grid(True, alpha=0.3)
axes[2].legend()
axes[3].plot(frames, margin, marker='o', color='tab:red', label='margin')
axes[3].plot(frames, confidence, marker='s', color='tab:cyan', label='confidence')
axes[3].axhline(args.margin_threshold, color='tab:red', linestyle='--', alpha=0.6)
axes[3].set_ylabel('Margin / Score')
axes[3].set_xlabel('Frame Index')
axes[3].grid(True, alpha=0.3)
axes[3].legend()
plt.tight_layout()
plt.savefig(output_path, dpi=150)
plt.close(fig)
def format_float(value, digits=4):
"""Format a scalar for markdown tables."""
if value is None:
return '-'
return f'{float(value):.{digits}f}'
def markdown_table(rows, columns, headers=None, max_rows=None):
"""Render a small markdown table."""
if not rows:
return '_None_'
rows = rows[:max_rows] if max_rows is not None else rows
headers = headers or columns
lines = [
'| ' + ' | '.join(headers) + ' |',
'| ' + ' | '.join(['---'] * len(headers)) + ' |',
]
for row in rows:
cells = []
for column in columns:
value = row.get(column)
if isinstance(value, float):
cells.append(format_float(value))
else:
cells.append('-' if value is None else str(value))
lines.append('| ' + ' | '.join(cells) + ' |')
return '\n'.join(lines)
def build_report_markdown(summary, output_dir):
"""Build a markdown report for the analysis outputs."""
agg = summary['aggregate_metrics']
issue_rows = [
{'issue': issue, 'count': count}
for issue, count in sorted(summary['issue_counts'].items(), key=lambda kv: (-kv[1], kv[0]))
]
lines = [
'# Heading Debug Analysis Report',
'',
'## Overview',
'',
f"- Input: `{summary['input_json']}`",
f"- Output directory: `{output_dir}`",
f"- Total frames: {summary['input_stats']['total_frames']}",
f"- Total detections: {summary['input_stats']['total_detections']}",
f"- Detections with heading_debug: {summary['input_stats']['detections_with_heading_debug']}",
f"- Filtered detections analyzed: {summary['filtered_detection_count']}",
f"- Tracks analyzed: {summary['track_count']}",
f"- Transitions analyzed: {summary['transition_count']}",
'',
'## Aggregate Metrics',
'',
f"- median |Δyaw|: {format_float(agg.get('median_abs_dyaw'))} rad",
f"- p95 |Δyaw|: {format_float(agg.get('p95_abs_dyaw'))} rad",
f"- max |Δyaw|: {format_float(agg.get('max_abs_dyaw'))} rad",
f"- yaw_bin switch rate: {format_float(agg.get('yaw_bin_switch_rate'))}",
f"- low margin ratio: {format_float(agg.get('low_margin_ratio'))}",
'',
'## Root Cause Distribution',
'',
markdown_table(issue_rows, ['issue', 'count']),
'',
'## By Source',
'',
markdown_table(
summary['source_summary'],
['source', 'transition_count', 'median_abs_dyaw', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio'],
headers=['source', 'transitions', 'median |Δyaw|', 'p95 |Δyaw|', 'bin switch', 'low margin'],
),
'',
'## By Class',
'',
markdown_table(
summary['class_summary'],
['class_id', 'transition_count', 'median_abs_dyaw', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio'],
headers=['class', 'transitions', 'median |Δyaw|', 'p95 |Δyaw|', 'bin switch', 'low margin'],
),
'',
'## By Depth Bucket',
'',
markdown_table(
summary['depth_summary'],
['depth_bucket', 'transition_count', 'median_abs_dyaw', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio'],
headers=['depth', 'transitions', 'median |Δyaw|', 'p95 |Δyaw|', 'bin switch', 'low margin'],
),
'',
'## Top Unstable Tracks',
'',
markdown_table(
summary['bad_tracks'][:10],
['track_id', 'class_id', 'source', 'trajectory_length', 'p95_abs_dyaw', 'yaw_bin_switch_rate', 'low_margin_ratio', 'dominant_issue'],
headers=['track', 'class', 'source', 'len', 'p95 |Δyaw|', 'bin switch', 'low margin', 'dominant issue'],
),
'',
'## Output Files',
'',
]
for name in summary['output_files']:
lines.append(f'- `{name}`')
lines.append('')
return '\n'.join(lines)
def build_summary_dict(input_json_path, input_stats, rows, transitions, track_summaries, class_summary, source_summary, depth_summary, issue_counts, output_files):
"""Collect machine-readable summary data."""
abs_dyaws = [item['abs_dyaw'] for item in transitions if item.get('abs_dyaw') is not None]
aggregate_metrics = {
'median_abs_dyaw': safe_median(abs_dyaws),
'p95_abs_dyaw': safe_percentile(abs_dyaws, 95),
'max_abs_dyaw': max(abs_dyaws) if abs_dyaws else None,
'yaw_bin_switch_rate': float(sum(1 for item in transitions if item.get('yaw_bin_switched')) / max(len(transitions), 1)) if transitions else None,
'low_margin_ratio': float(sum(1 for item in transitions if item.get('low_margin')) / max(len(transitions), 1)) if transitions else None,
}
return {
'input_json': str(input_json_path),
'input_stats': input_stats,
'filtered_detection_count': len(rows),
'transition_count': len(transitions),
'track_count': len(track_summaries),
'aggregate_metrics': aggregate_metrics,
'issue_counts': issue_counts,
'class_summary': class_summary,
'source_summary': source_summary,
'depth_summary': depth_summary,
'bad_tracks': track_summaries[: min(50, len(track_summaries))],
'output_files': output_files,
}
def parse_args():
"""Parse CLI arguments."""
parser = argparse.ArgumentParser(description='Analyze heading_debug stability from tracked JSON outputs.')
parser.add_argument('--input', required=True, help='Tracking JSON file or a case directory. When a directory is given, merge.json is used by default.')
parser.add_argument('--json-name', default=None, help='When --input is a directory, analyze this JSON filename inside it. Defaults to merge.json.')
parser.add_argument('--output-dir', default=None, help='Output directory. Defaults to <input_stem>_heading_debug_analysis next to the input JSON.')
parser.add_argument('--class-ids', nargs='*', type=int, default=None, help='Optional class filter, e.g. --class-ids 0 13')
parser.add_argument('--sources', nargs='*', default=None, help='Optional source filter, e.g. --sources roi0 roi1 merge')
parser.add_argument('--min-track-len', type=int, default=3, help='Minimum track length required for track-level analysis.')
parser.add_argument('--top-k-bad-tracks', type=int, default=20, help='Number of most unstable tracks to export and plot.')
parser.add_argument('--margin-threshold', type=float, default=0.10, help='Threshold used to mark low-margin bins.')
parser.add_argument('--large-dyaw', type=float, default=0.35, help='Threshold in radians used to flag large heading changes.')
parser.add_argument('--large-delta', type=float, default=0.20, help='Threshold in radians used to flag large yaw_delta changes.')
parser.add_argument('--wrap-small-dyaw', type=float, default=0.35, help='Wrapped |Δyaw| threshold used to detect +pi/-pi wrap-around cases.')
parser.add_argument('--flip-180-dyaw', type=float, default=2.50, help='Threshold in radians used to flag 180-degree flips.')
parser.add_argument('--low-score-threshold', type=float, default=0.45, help='Threshold used in visibility/ambiguity heuristics.')
parser.add_argument('--small-box-area', type=float, default=1600.0, help='Box-area threshold used in visibility/ambiguity heuristics.')
parser.add_argument('--far-depth-threshold', type=float, default=30.0, help='Depth threshold used in visibility/ambiguity heuristics.')
parser.add_argument('--max-scatter-points', type=int, default=50000, help='Maximum number of points kept in scatter plots.')
parser.add_argument('--no-plots', action='store_true', help='Skip plot generation and only write tables/report.')
return parser.parse_args()
def main():
args = parse_args()
input_json_path = resolve_input_path(args.input, args.json_name)
tracking_data = load_tracking_data(input_json_path)
default_source = infer_default_source_label(input_json_path)
if args.output_dir is None:
output_dir = input_json_path.parent / f'{input_json_path.stem}_heading_debug_analysis'
else:
output_dir = Path(args.output_dir)
output_dir = ensure_dir(output_dir)
plots_dir = ensure_dir(output_dir / 'figures')
bad_track_dir = ensure_dir(output_dir / 'fig_track_yaw_series')
rows, input_stats = build_flat_rows(tracking_data, default_source)
rows = filter_rows(rows, class_ids=args.class_ids, sources=args.sources)
if not rows:
raise ValueError('No detections with heading_debug remain after filtering.')
track_rows_by_id = group_rows_by_track(rows, args.min_track_len)
transitions = build_transition_records(track_rows_by_id, args)
track_summaries = build_track_summaries(track_rows_by_id, transitions, args)
class_summary = compute_group_summary(transitions, 'class_id')
source_summary = compute_group_summary(transitions, 'source')
depth_summary = compute_depth_summary(transitions, DEFAULT_DEPTH_BUCKETS)
issue_counts = compute_issue_counts(transitions)
flat_csv_path = output_dir / 'heading_debug_flat.csv'
transition_csv_path = output_dir / 'heading_debug_transition_flat.csv'
track_summary_csv_path = output_dir / 'heading_debug_track_summary.csv'
bad_tracks_csv_path = output_dir / 'heading_debug_bad_tracks.csv'
class_summary_csv_path = output_dir / 'heading_debug_class_summary.csv'
source_summary_csv_path = output_dir / 'heading_debug_source_summary.csv'
depth_summary_csv_path = output_dir / 'heading_debug_depth_summary.csv'
summary_json_path = output_dir / 'heading_debug_summary.json'
report_md_path = output_dir / 'heading_debug_analysis_report.md'
write_csv(rows, flat_csv_path)
write_csv(transitions, transition_csv_path)
write_csv(track_summaries, track_summary_csv_path)
bad_tracks = track_summaries[: max(0, args.top_k_bad_tracks)]
write_csv(bad_tracks, bad_tracks_csv_path)
write_csv(class_summary, class_summary_csv_path)
write_csv(source_summary, source_summary_csv_path)
write_csv(depth_summary, depth_summary_csv_path)
output_files = [
flat_csv_path.name,
transition_csv_path.name,
track_summary_csv_path.name,
bad_tracks_csv_path.name,
class_summary_csv_path.name,
source_summary_csv_path.name,
depth_summary_csv_path.name,
summary_json_path.name,
report_md_path.name,
]
if not args.no_plots:
if plot_margin_vs_dyaw(transitions, plots_dir / 'fig_margin_vs_dyaw.png', args):
output_files.append('figures/fig_margin_vs_dyaw.png')
if plot_bin_transition_matrix(transitions, plots_dir / 'fig_bin_transition_matrix.png'):
output_files.append('figures/fig_bin_transition_matrix.png')
if plot_bucket_summary(source_summary, 'source', plots_dir / 'fig_stability_by_source.png', 'Heading Stability by Source'):
output_files.append('figures/fig_stability_by_source.png')
if plot_bucket_summary(class_summary, 'class_id', plots_dir / 'fig_stability_by_class.png', 'Heading Stability by Class'):
output_files.append('figures/fig_stability_by_class.png')
if plot_bucket_summary(depth_summary, 'depth_bucket', plots_dir / 'fig_stability_by_depth.png', 'Heading Stability by Depth'):
output_files.append('figures/fig_stability_by_depth.png')
if plot_issue_distribution(issue_counts, plots_dir / 'fig_issue_distribution.png'):
output_files.append('figures/fig_issue_distribution.png')
for summary_row in bad_tracks:
track_id = summary_row['track_id']
track_rows = track_rows_by_id.get(track_id)
if not track_rows:
continue
plot_path = bad_track_dir / f'track_{track_id}.png'
plot_bad_track_series(summary_row, track_rows, plot_path, args)
output_files.append(f'fig_track_yaw_series/{plot_path.name}')
summary = build_summary_dict(
input_json_path=input_json_path,
input_stats=input_stats,
rows=rows,
transitions=transitions,
track_summaries=track_summaries,
class_summary=class_summary,
source_summary=source_summary,
depth_summary=depth_summary,
issue_counts=issue_counts,
output_files=output_files,
)
write_json(summary, summary_json_path)
report_md = build_report_markdown(summary, output_dir)
report_md_path.write_text(report_md, encoding='utf-8')
print(f'Input JSON: {input_json_path}')
print(f'Output directory: {output_dir}')
print(f'Filtered detections analyzed: {len(rows)}')
print(f'Tracks analyzed: {len(track_summaries)}')
print(f'Transitions analyzed: {len(transitions)}')
print('Generated files:')
for name in output_files:
print(f' - {name}')
if __name__ == '__main__':
main()