""" Visualize tracking results by drawing 2D boxes on images. Usage: # 从图像文件夹可视化 python visualize_tracking_boxes.py \ --tracking runs/val_viz/exp/tracking.json \ --images path/to/images \ --output runs/val_viz/exp/tracked_images # 生成视频 python visualize_tracking_boxes.py \ --tracking runs/val_viz/exp/tracking.json \ --images path/to/images \ --output runs/val_viz/exp/tracked.mp4 \ --save-video # 只可视化前100帧 python visualize_tracking_boxes.py \ --tracking tracking.json \ --images images/ \ --output output/ \ --max-frames 100 # 按 frameId 对齐原始 camera4.bin,并导出单帧图片 python visualize_tracking_boxes.py \ --tracking merge.json \ --images /path/to/case/sigmastar.1/camera4.bin \ --output tracking_vis_raw \ --align-by-frame-id """ import argparse import json from pathlib import Path import re import sys import cv2 import numpy as np from tqdm import tqdm import colorsys PROJECT_ROOT = Path(__file__).resolve().parents[2] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from tools.model_inference.adapters.video_dir_inference_utils import ( iter_video_case_frames, read_video_frame_index, ) def generate_colors(n): """生成n个区分度高的颜色""" colors = [] for i in range(n): hue = i / n saturation = 0.9 value = 0.9 rgb = colorsys.hsv_to_rgb(hue, saturation, value) colors.append(tuple(int(x * 255) for x in rgb)) return colors def get_color_for_track(track_id, color_map): """为track_id分配一个持久的颜色""" if track_id not in color_map: # 使用track_id作为种子生成颜色 np.random.seed(track_id) color = ( np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255) ) color_map[track_id] = color return color_map[track_id] def should_visualize_detection(det, class_id_filter=None, track_ids_filter=None): """判断检测是否应该被可视化 Args: det: 检测结果字典 class_id_filter: 指定的类别ID,None表示所有类别 track_ids_filter: 指定的track ID列表,None表示所有ID Returns: bool: 是否应该可视化该检测 """ # 检查class_id过滤 if class_id_filter is not None and det.get('class_id') != class_id_filter: return False # 检查track_id过滤 if track_ids_filter is not None: track_id = det.get('track_id') if track_id not in track_ids_filter: return False return True def safe_int(value): """Best-effort integer conversion.""" if value is None: return None try: return int(value) except (TypeError, ValueError): return None def extract_frame_id_from_image_name(image_name): """Best-effort frame-id extraction from tracking image_name.""" stem = Path(str(image_name or "")).stem match = re.search(r"_(\d+)_(\d+)(?:_merged)?$", stem) if match is not None: return int(match.group(1)) match = re.search(r"_(\d+)(?:_merged)?$", stem) if match is not None: return int(match.group(1)) match = re.search(r"(\d+)", stem) if match is not None: return int(match.group(1)) return None def extract_tracking_frame_id(frame_data, fallback_idx=None): """Resolve one frame's frameId from detections, frame_info, or image_name.""" for det in frame_data.get("detections", []): for key in ("frameId", "frame_id"): frame_id = safe_int(det.get(key)) if frame_id is not None: return frame_id frame_info = frame_data.get("frame_info") if isinstance(frame_info, dict): for key in ("original_frame_id", "frame_id", "frameId"): frame_id = safe_int(frame_info.get(key)) if frame_id is not None: return frame_id frame_id = extract_frame_id_from_image_name(frame_data.get("image_name")) if frame_id is not None: return frame_id return fallback_idx def resolve_detection_class_name(det, class_id, class_names): """Prefer detection-provided type_name before falling back to static names.""" type_name = str(det.get("type_name") or "").strip() if type_name: return type_name return class_names.get(class_id, f"Class{class_id}") def build_output_frame_name(frame_data, fallback_idx): """Build a stable output filename for one visualized frame.""" frame_name = frame_data.get("image_name", f"frame_{fallback_idx:06d}") stem = Path(str(frame_name)).stem return f"{stem}.jpg" def draw_box_with_label(img, bbox, track_id, class_id, confidence, color, class_names, class_name_override=None): """在图像上绘制边界框和标签""" x1, y1, x2, y2 = map(int, bbox) # 绘制边界框 thickness = 2 cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) # 准备标签文本 class_name = class_name_override or class_names.get(class_id, f'Class{class_id}') label = f'ID:{track_id} {class_name} {confidence:.2f}' # 计算标签背景大小 font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 font_thickness = 2 (text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness) # 绘制标签背景 label_y1 = max(y1 - text_height - 10, 0) label_y2 = y1 cv2.rectangle(img, (x1, label_y1), (x1 + text_width + 10, label_y2), color, -1) # 绘制标签文本 text_x = x1 + 5 text_y = label_y1 + text_height + 5 cv2.putText(img, label, (text_x, text_y), font, font_scale, (255, 255, 255), font_thickness) return img def render_tracking_frame( img, frame_data, frame_idx, total_frames, color_map, class_names, trajectory_history=None, class_id_filter=None, track_ids_filter=None, ): """Render one tracking frame on top of an image.""" detections = frame_data.get('detections', []) visualized_count = 0 for det in detections: track_id = det.get('track_id') if track_id is None: continue if not should_visualize_detection(det, class_id_filter, track_ids_filter): continue visualized_count += 1 bbox = det['bbox'] class_id = det['class_id'] confidence = det['confidence'] color = get_color_for_track(track_id, color_map) class_name = resolve_detection_class_name(det, class_id, class_names) img = draw_box_with_label( img, bbox, track_id, class_id, confidence, color, class_names, class_name_override=class_name, ) if trajectory_history is not None: center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) if track_id not in trajectory_history: trajectory_history[track_id] = [] trajectory_history[track_id].append(center) if len(trajectory_history[track_id]) > 1: points = np.array(trajectory_history[track_id], dtype=np.int32) for i in range(len(points) - 1): cv2.line(img, tuple(points[i]), tuple(points[i + 1]), color, 2) if class_id_filter is not None or track_ids_filter is not None: info_text = f"Frame: {frame_idx + 1}/{total_frames} | Showing: {visualized_count}/{len(detections)}" else: info_text = f"Frame: {frame_idx + 1}/{total_frames} | Tracks: {len(detections)}" cv2.putText(img, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) return img, visualized_count def visualize_tracking_on_raw_video_by_frame_id( tracking_data, video_path, output_path, max_frames=None, show_trajectory=False, class_id_filter=None, track_ids_filter=None, ): """Visualize tracking results on raw camera4.bin frames using frameId alignment.""" video_path = Path(video_path) output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) class_names = { 0: 'Car', 1: 'Pedestrian', 2: 'Tricycle', 3: 'Cyclist' } color_map = {} trajectory_history = {} if show_trajectory else None selected_tracking_data = tracking_data[:max_frames] if max_frames else tracking_data frame_id_to_tracking = {} ordered_frame_ids = [] skipped_duplicates = [] unresolved_frames = [] for idx, frame_data in enumerate(selected_tracking_data): frame_id = extract_tracking_frame_id(frame_data, fallback_idx=idx) if frame_id is None: unresolved_frames.append(frame_data.get("image_name", f"frame_{idx:06d}")) continue if frame_id in frame_id_to_tracking: skipped_duplicates.append(frame_id) continue frame_id_to_tracking[frame_id] = (idx, frame_data) ordered_frame_ids.append(frame_id) if not frame_id_to_tracking: print("Error: no frameIds could be resolved from tracking data") return frame_index_payload = read_video_frame_index(video_path) total_targets = len(ordered_frame_ids) matched_frame_ids = set() print(f"Reading raw video with frameId alignment: {video_path}") print(f"Saving aligned images to: {output_path}") if unresolved_frames: print(f"Warning: skipped {len(unresolved_frames)} frame(s) without resolvable frameId") if skipped_duplicates: print(f"Warning: skipped duplicate frameId(s): {sorted(set(skipped_duplicates))[:10]}") progress = tqdm(total=total_targets, desc="Visualizing") for read_frame_index, frame_bgr, _, frame_info in iter_video_case_frames( video_path, frame_index_payload=frame_index_payload, frame_stride=1, max_frames=0, ): aligned_frame_id = safe_int(None if frame_info is None else frame_info.get("frame_id")) if aligned_frame_id is None: aligned_frame_id = read_frame_index if aligned_frame_id not in frame_id_to_tracking: continue tracking_idx, frame_data = frame_id_to_tracking[aligned_frame_id] rendered, _ = render_tracking_frame( frame_bgr.copy(), frame_data, tracking_idx, total_targets, color_map, class_names, trajectory_history=trajectory_history, class_id_filter=class_id_filter, track_ids_filter=track_ids_filter, ) output_file = output_path / build_output_frame_name(frame_data, tracking_idx) cv2.imwrite(str(output_file), rendered) matched_frame_ids.add(aligned_frame_id) progress.update(1) if len(matched_frame_ids) >= total_targets: break progress.close() missing_frame_ids = [frame_id for frame_id in ordered_frame_ids if frame_id not in matched_frame_ids] print(f"\n✓ Visualization complete!") print(f" Images saved to: {output_path}") print(f" Matched frames: {len(matched_frame_ids)}/{total_targets}") print(f" Total tracks visualized: {len(color_map)}") if missing_frame_ids: preview = ", ".join(str(frame_id) for frame_id in missing_frame_ids[:10]) suffix = "..." if len(missing_frame_ids) > 10 else "" print(f" Missing frameIds: {preview}{suffix}") def visualize_tracking_on_images(tracking_data, image_source, output_path, save_video=False, fps=25, max_frames=None, show_trajectory=False, class_id_filter=None, track_ids_filter=None, align_by_frame_id=False): """将跟踪结果可视化到图像上 Args: tracking_data: 跟踪结果数据(从tracking.json加载) image_source: 图像来源(文件夹路径或视频文件路径) output_path: 输出路径(文件夹或视频文件) save_video: 是否保存为视频 fps: 视频帧率 max_frames: 最大处理帧数 show_trajectory: 是否显示轨迹线 class_id_filter: 指定类别ID,None表示所有类别 track_ids_filter: 指定track ID集合,None表示所有ID align_by_frame_id: 是否按 frameId 对齐原始视频帧 """ image_source = Path(image_source) output_path = Path(output_path) if align_by_frame_id: if save_video: print("Error: --align-by-frame-id only supports image-directory output") return if not image_source.is_file(): print("Error: --align-by-frame-id requires a video file such as camera4.bin") return visualize_tracking_on_raw_video_by_frame_id( tracking_data, image_source, output_path, max_frames=max_frames, show_trajectory=show_trajectory, class_id_filter=class_id_filter, track_ids_filter=track_ids_filter, ) return # 类别名称 class_names = { 0: 'Car', 1: 'Pedestrian', 2: 'Tricycle', 3: 'Cyclist' } # 颜色映射 color_map = {} # 轨迹历史(用于绘制轨迹线) trajectory_history = {} if show_trajectory else None # 检查图像源类型 VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.bin'} if image_source.is_file() and image_source.suffix.lower() in VIDEO_EXTENSIONS: # 从视频读取 cap = cv2.VideoCapture(str(image_source)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) print(f"Reading from video: {image_source}") print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames") elif image_source.is_dir(): # 从图像文件夹读取 cap = None image_files = sorted(list(image_source.glob('*.jpg')) + list(image_source.glob('*.png')) + list(image_source.glob('*.jpeg'))) if not image_files: print(f"Error: No images found in {image_source}") return # 读取第一张图像获取尺寸 first_img = cv2.imread(str(image_files[0])) if first_img is None: print(f"Error: Cannot read image {image_files[0]}") return frame_height, frame_width = first_img.shape[:2] total_frames = len(image_files) print(f"Reading from image folder: {image_source}") print(f"Found {total_frames} images, size: {frame_width}x{frame_height}") else: print(f"Error: Invalid image source: {image_source}") return # 限制处理帧数 if max_frames: total_frames = min(total_frames, max_frames) tracking_data = tracking_data[:total_frames] # 准备输出 if save_video: if not str(output_path).endswith(('.mp4', '.avi')): output_path = output_path.with_suffix('.mp4') output_path.parent.mkdir(parents=True, exist_ok=True) fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height)) print(f"Saving video to: {output_path}") else: output_path.mkdir(parents=True, exist_ok=True) video_writer = None print(f"Saving images to: {output_path}") # 处理每一帧 print(f"\nProcessing {len(tracking_data)} frames...") for frame_idx, frame_data in enumerate(tqdm(tracking_data, desc="Visualizing")): # 读取图像 if cap is not None: ret, img = cap.read() if not ret: print(f"Warning: Cannot read frame {frame_idx}") break else: if frame_idx >= len(image_files): break img = cv2.imread(str(image_files[frame_idx])) if img is None: print(f"Warning: Cannot read image {image_files[frame_idx]}") continue img, _ = render_tracking_frame( img, frame_data, frame_idx, len(tracking_data), color_map, class_names, trajectory_history=trajectory_history, class_id_filter=class_id_filter, track_ids_filter=track_ids_filter, ) # 保存或写入视频 if video_writer: video_writer.write(img) else: frame_name = frame_data.get('image_name', f'frame_{frame_idx:06d}.jpg') if not frame_name.endswith(('.jpg', '.png', '.jpeg')): frame_name = f'{Path(frame_name).stem}.jpg' output_file = output_path / frame_name.replace('.png', '.jpg') cv2.imwrite(str(output_file), img) # 清理 if cap is not None: cap.release() if video_writer is not None: video_writer.release() print(f"\n✓ Visualization complete!") if save_video: print(f" Video saved to: {output_path}") else: print(f" Images saved to: {output_path}") print(f" Total tracks visualized: {len(color_map)}") def create_side_by_side_comparison(tracking_file1, tracking_file2, image_source, output_path, labels=None, max_frames=100): """创建并排对比可视化 Args: tracking_file1, tracking_file2: 两个跟踪结果文件 image_source: 图像源 output_path: 输出视频路径 labels: 两个方法的标签 max_frames: 最大处理帧数 """ if labels is None: labels = ['Method 1', 'Method 2'] # 加载跟踪数据 with open(tracking_file1, 'r') as f: data1 = json.load(f) with open(tracking_file2, 'r') as f: data2 = json.load(f) image_source = Path(image_source) output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) class_names = {0: 'Car', 1: 'Pedestrian', 2: 'Tricycle', 3: 'Cyclist'} color_map1, color_map2 = {}, {} # 读取图像 if image_source.is_dir(): image_files = sorted(list(image_source.glob('*.jpg')) + list(image_source.glob('*.png'))) first_img = cv2.imread(str(image_files[0])) h, w = first_img.shape[:2] else: cap = cv2.VideoCapture(str(image_source)) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() # 创建视频写入器(并排宽度加倍) fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(str(output_path), fourcc, 25, (w * 2, h)) print(f"Creating side-by-side comparison video...") frames_to_process = min(len(data1), len(data2), max_frames) for frame_idx in tqdm(range(frames_to_process), desc="Processing"): # 读取原始图像 if image_source.is_dir(): img = cv2.imread(str(image_files[frame_idx])) else: cap = cv2.VideoCapture(str(image_source)) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, img = cap.read() cap.release() if not ret: break img1 = img.copy() img2 = img.copy() # 绘制第一个跟踪结果 for det in data1[frame_idx].get('detections', []): if 'track_id' in det: color = get_color_for_track(det['track_id'], color_map1) img1 = draw_box_with_label(img1, det['bbox'], det['track_id'], det['class_id'], det['confidence'], color, class_names) # 绘制第二个跟踪结果 for det in data2[frame_idx].get('detections', []): if 'track_id' in det: color = get_color_for_track(det['track_id'], color_map2) img2 = draw_box_with_label(img2, det['bbox'], det['track_id'], det['class_id'], det['confidence'], color, class_names) # 添加标签 cv2.putText(img1, labels[0], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.putText(img2, labels[1], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) # 并排拼接 combined = np.hstack([img1, img2]) video_writer.write(combined) video_writer.release() print(f"✓ Comparison video saved to: {output_path}") def main(): parser = argparse.ArgumentParser(description='Visualize tracking results on images') parser.add_argument('--tracking', type=str, required=True, help='Path to tracking.json file') parser.add_argument('--images', type=str, required=True, help='Path to images folder or video file') parser.add_argument('--output', type=str, required=True, help='Output path (folder or video file)') parser.add_argument('--save-video', action='store_true', help='Save as video instead of images') parser.add_argument('--fps', type=int, default=25, help='Video FPS (default: 25)') parser.add_argument('--max-frames', type=int, help='Maximum number of frames to process') parser.add_argument('--show-trajectory', action='store_true', help='Show trajectory lines') parser.add_argument('--compare', type=str, help='Second tracking.json for side-by-side comparison') parser.add_argument('--labels', type=str, nargs=2, default=['Method 1', 'Method 2'], help='Labels for comparison mode') parser.add_argument('--class-id', type=int, default=None, help='Filter by class ID (0=Car, 1=Pedestrian, 2=Cyclist, etc.). If not specified, all classes are shown.') parser.add_argument('--track-ids', type=int, nargs='+', default=None, help='Specific track IDs to visualize (e.g., --track-ids 1 2 3). If not specified, all tracks are shown.') parser.add_argument('--align-by-frame-id', action='store_true', help='Align tracking frames to the raw input video by frameId and export images only.') args = parser.parse_args() # 检查输入文件 tracking_path = Path(args.tracking) if not tracking_path.exists(): print(f"Error: Tracking file not found: {tracking_path}") return image_source = Path(args.images) if not image_source.exists(): print(f"Error: Image source not found: {image_source}") return # 加载跟踪数据 print(f"Loading tracking data from {tracking_path}") with open(tracking_path, 'r') as f: tracking_data = json.load(f) print(f"Loaded {len(tracking_data)} frames") # 对比模式 if args.compare: compare_path = Path(args.compare) if not compare_path.exists(): print(f"Error: Comparison file not found: {compare_path}") return output_path = Path(args.output) if not str(output_path).endswith('.mp4'): output_path = output_path.with_suffix('.mp4') create_side_by_side_comparison( tracking_path, compare_path, image_source, output_path, args.labels, args.max_frames or 100 ) else: # 单文件可视化 # 准备过滤参数 track_ids_set = set(args.track_ids) if args.track_ids else None # 打印过滤信息 if args.class_id is not None or args.track_ids: print(f"\nFiltering settings:") if args.class_id is not None: print(f" Class ID: {args.class_id}") if args.track_ids: print(f" Track IDs: {args.track_ids}") visualize_tracking_on_images( tracking_data, image_source, args.output, save_video=args.save_video, fps=args.fps, max_frames=args.max_frames, show_trajectory=args.show_trajectory, class_id_filter=args.class_id, track_ids_filter=track_ids_set, align_by_frame_id=args.align_by_frame_id, ) if __name__ == '__main__': main()