Files
yolov26_3d/tools/temporal_analysis/visualize_tracking_boxes.py
2026-06-24 09:35:46 +08:00

686 lines
24 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Visualize tracking results by drawing 2D boxes on images.
Usage:
# 从图像文件夹可视化
python visualize_tracking_boxes.py \
--tracking runs/val_viz/exp/tracking.json \
--images path/to/images \
--output runs/val_viz/exp/tracked_images
# 生成视频
python visualize_tracking_boxes.py \
--tracking runs/val_viz/exp/tracking.json \
--images path/to/images \
--output runs/val_viz/exp/tracked.mp4 \
--save-video
# 只可视化前100帧
python visualize_tracking_boxes.py \
--tracking tracking.json \
--images images/ \
--output output/ \
--max-frames 100
# 按 frameId 对齐原始 camera4.bin并导出单帧图片
python visualize_tracking_boxes.py \
--tracking merge.json \
--images /path/to/case/sigmastar.1/camera4.bin \
--output tracking_vis_raw \
--align-by-frame-id
"""
import argparse
import json
from pathlib import Path
import re
import sys
import cv2
import numpy as np
from tqdm import tqdm
import colorsys
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from tools.model_inference.adapters.video_dir_inference_utils import (
iter_video_case_frames,
read_video_frame_index,
)
def generate_colors(n):
"""生成n个区分度高的颜色"""
colors = []
for i in range(n):
hue = i / n
saturation = 0.9
value = 0.9
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
colors.append(tuple(int(x * 255) for x in rgb))
return colors
def get_color_for_track(track_id, color_map):
"""为track_id分配一个持久的颜色"""
if track_id not in color_map:
# 使用track_id作为种子生成颜色
np.random.seed(track_id)
color = (
np.random.randint(50, 255),
np.random.randint(50, 255),
np.random.randint(50, 255)
)
color_map[track_id] = color
return color_map[track_id]
def should_visualize_detection(det, class_id_filter=None, track_ids_filter=None):
"""判断检测是否应该被可视化
Args:
det: 检测结果字典
class_id_filter: 指定的类别IDNone表示所有类别
track_ids_filter: 指定的track ID列表None表示所有ID
Returns:
bool: 是否应该可视化该检测
"""
# 检查class_id过滤
if class_id_filter is not None and det.get('class_id') != class_id_filter:
return False
# 检查track_id过滤
if track_ids_filter is not None:
track_id = det.get('track_id')
if track_id not in track_ids_filter:
return False
return True
def safe_int(value):
"""Best-effort integer conversion."""
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def extract_frame_id_from_image_name(image_name):
"""Best-effort frame-id extraction from tracking image_name."""
stem = Path(str(image_name or "")).stem
match = re.search(r"_(\d+)_(\d+)(?:_merged)?$", stem)
if match is not None:
return int(match.group(1))
match = re.search(r"_(\d+)(?:_merged)?$", stem)
if match is not None:
return int(match.group(1))
match = re.search(r"(\d+)", stem)
if match is not None:
return int(match.group(1))
return None
def extract_tracking_frame_id(frame_data, fallback_idx=None):
"""Resolve one frame's frameId from detections, frame_info, or image_name."""
for det in frame_data.get("detections", []):
for key in ("frameId", "frame_id"):
frame_id = safe_int(det.get(key))
if frame_id is not None:
return frame_id
frame_info = frame_data.get("frame_info")
if isinstance(frame_info, dict):
for key in ("original_frame_id", "frame_id", "frameId"):
frame_id = safe_int(frame_info.get(key))
if frame_id is not None:
return frame_id
frame_id = extract_frame_id_from_image_name(frame_data.get("image_name"))
if frame_id is not None:
return frame_id
return fallback_idx
def resolve_detection_class_name(det, class_id, class_names):
"""Prefer detection-provided type_name before falling back to static names."""
type_name = str(det.get("type_name") or "").strip()
if type_name:
return type_name
return class_names.get(class_id, f"Class{class_id}")
def build_output_frame_name(frame_data, fallback_idx):
"""Build a stable output filename for one visualized frame."""
frame_name = frame_data.get("image_name", f"frame_{fallback_idx:06d}")
stem = Path(str(frame_name)).stem
return f"{stem}.jpg"
def draw_box_with_label(img, bbox, track_id, class_id, confidence, color, class_names, class_name_override=None):
"""在图像上绘制边界框和标签"""
x1, y1, x2, y2 = map(int, bbox)
# 绘制边界框
thickness = 2
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
# 准备标签文本
class_name = class_name_override or class_names.get(class_id, f'Class{class_id}')
label = f'ID:{track_id} {class_name} {confidence:.2f}'
# 计算标签背景大小
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
font_thickness = 2
(text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness)
# 绘制标签背景
label_y1 = max(y1 - text_height - 10, 0)
label_y2 = y1
cv2.rectangle(img, (x1, label_y1), (x1 + text_width + 10, label_y2), color, -1)
# 绘制标签文本
text_x = x1 + 5
text_y = label_y1 + text_height + 5
cv2.putText(img, label, (text_x, text_y), font, font_scale, (255, 255, 255), font_thickness)
return img
def render_tracking_frame(
img,
frame_data,
frame_idx,
total_frames,
color_map,
class_names,
trajectory_history=None,
class_id_filter=None,
track_ids_filter=None,
):
"""Render one tracking frame on top of an image."""
detections = frame_data.get('detections', [])
visualized_count = 0
for det in detections:
track_id = det.get('track_id')
if track_id is None:
continue
if not should_visualize_detection(det, class_id_filter, track_ids_filter):
continue
visualized_count += 1
bbox = det['bbox']
class_id = det['class_id']
confidence = det['confidence']
color = get_color_for_track(track_id, color_map)
class_name = resolve_detection_class_name(det, class_id, class_names)
img = draw_box_with_label(
img,
bbox,
track_id,
class_id,
confidence,
color,
class_names,
class_name_override=class_name,
)
if trajectory_history is not None:
center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
if track_id not in trajectory_history:
trajectory_history[track_id] = []
trajectory_history[track_id].append(center)
if len(trajectory_history[track_id]) > 1:
points = np.array(trajectory_history[track_id], dtype=np.int32)
for i in range(len(points) - 1):
cv2.line(img, tuple(points[i]), tuple(points[i + 1]), color, 2)
if class_id_filter is not None or track_ids_filter is not None:
info_text = f"Frame: {frame_idx + 1}/{total_frames} | Showing: {visualized_count}/{len(detections)}"
else:
info_text = f"Frame: {frame_idx + 1}/{total_frames} | Tracks: {len(detections)}"
cv2.putText(img, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 255, 0), 2)
return img, visualized_count
def visualize_tracking_on_raw_video_by_frame_id(
tracking_data,
video_path,
output_path,
max_frames=None,
show_trajectory=False,
class_id_filter=None,
track_ids_filter=None,
):
"""Visualize tracking results on raw camera4.bin frames using frameId alignment."""
video_path = Path(video_path)
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
class_names = {
0: 'Car',
1: 'Pedestrian',
2: 'Tricycle',
3: 'Cyclist'
}
color_map = {}
trajectory_history = {} if show_trajectory else None
selected_tracking_data = tracking_data[:max_frames] if max_frames else tracking_data
frame_id_to_tracking = {}
ordered_frame_ids = []
skipped_duplicates = []
unresolved_frames = []
for idx, frame_data in enumerate(selected_tracking_data):
frame_id = extract_tracking_frame_id(frame_data, fallback_idx=idx)
if frame_id is None:
unresolved_frames.append(frame_data.get("image_name", f"frame_{idx:06d}"))
continue
if frame_id in frame_id_to_tracking:
skipped_duplicates.append(frame_id)
continue
frame_id_to_tracking[frame_id] = (idx, frame_data)
ordered_frame_ids.append(frame_id)
if not frame_id_to_tracking:
print("Error: no frameIds could be resolved from tracking data")
return
frame_index_payload = read_video_frame_index(video_path)
total_targets = len(ordered_frame_ids)
matched_frame_ids = set()
print(f"Reading raw video with frameId alignment: {video_path}")
print(f"Saving aligned images to: {output_path}")
if unresolved_frames:
print(f"Warning: skipped {len(unresolved_frames)} frame(s) without resolvable frameId")
if skipped_duplicates:
print(f"Warning: skipped duplicate frameId(s): {sorted(set(skipped_duplicates))[:10]}")
progress = tqdm(total=total_targets, desc="Visualizing")
for read_frame_index, frame_bgr, _, frame_info in iter_video_case_frames(
video_path,
frame_index_payload=frame_index_payload,
frame_stride=1,
max_frames=0,
):
aligned_frame_id = safe_int(None if frame_info is None else frame_info.get("frame_id"))
if aligned_frame_id is None:
aligned_frame_id = read_frame_index
if aligned_frame_id not in frame_id_to_tracking:
continue
tracking_idx, frame_data = frame_id_to_tracking[aligned_frame_id]
rendered, _ = render_tracking_frame(
frame_bgr.copy(),
frame_data,
tracking_idx,
total_targets,
color_map,
class_names,
trajectory_history=trajectory_history,
class_id_filter=class_id_filter,
track_ids_filter=track_ids_filter,
)
output_file = output_path / build_output_frame_name(frame_data, tracking_idx)
cv2.imwrite(str(output_file), rendered)
matched_frame_ids.add(aligned_frame_id)
progress.update(1)
if len(matched_frame_ids) >= total_targets:
break
progress.close()
missing_frame_ids = [frame_id for frame_id in ordered_frame_ids if frame_id not in matched_frame_ids]
print(f"\n✓ Visualization complete!")
print(f" Images saved to: {output_path}")
print(f" Matched frames: {len(matched_frame_ids)}/{total_targets}")
print(f" Total tracks visualized: {len(color_map)}")
if missing_frame_ids:
preview = ", ".join(str(frame_id) for frame_id in missing_frame_ids[:10])
suffix = "..." if len(missing_frame_ids) > 10 else ""
print(f" Missing frameIds: {preview}{suffix}")
def visualize_tracking_on_images(tracking_data, image_source, output_path,
save_video=False, fps=25, max_frames=None,
show_trajectory=False, class_id_filter=None,
track_ids_filter=None, align_by_frame_id=False):
"""将跟踪结果可视化到图像上
Args:
tracking_data: 跟踪结果数据从tracking.json加载
image_source: 图像来源(文件夹路径或视频文件路径)
output_path: 输出路径(文件夹或视频文件)
save_video: 是否保存为视频
fps: 视频帧率
max_frames: 最大处理帧数
show_trajectory: 是否显示轨迹线
class_id_filter: 指定类别IDNone表示所有类别
track_ids_filter: 指定track ID集合None表示所有ID
align_by_frame_id: 是否按 frameId 对齐原始视频帧
"""
image_source = Path(image_source)
output_path = Path(output_path)
if align_by_frame_id:
if save_video:
print("Error: --align-by-frame-id only supports image-directory output")
return
if not image_source.is_file():
print("Error: --align-by-frame-id requires a video file such as camera4.bin")
return
visualize_tracking_on_raw_video_by_frame_id(
tracking_data,
image_source,
output_path,
max_frames=max_frames,
show_trajectory=show_trajectory,
class_id_filter=class_id_filter,
track_ids_filter=track_ids_filter,
)
return
# 类别名称
class_names = {
0: 'Car',
1: 'Pedestrian',
2: 'Tricycle',
3: 'Cyclist'
}
# 颜色映射
color_map = {}
# 轨迹历史(用于绘制轨迹线)
trajectory_history = {} if show_trajectory else None
# 检查图像源类型
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.bin'}
if image_source.is_file() and image_source.suffix.lower() in VIDEO_EXTENSIONS:
# 从视频读取
cap = cv2.VideoCapture(str(image_source))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Reading from video: {image_source}")
print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames")
elif image_source.is_dir():
# 从图像文件夹读取
cap = None
image_files = sorted(list(image_source.glob('*.jpg')) +
list(image_source.glob('*.png')) +
list(image_source.glob('*.jpeg')))
if not image_files:
print(f"Error: No images found in {image_source}")
return
# 读取第一张图像获取尺寸
first_img = cv2.imread(str(image_files[0]))
if first_img is None:
print(f"Error: Cannot read image {image_files[0]}")
return
frame_height, frame_width = first_img.shape[:2]
total_frames = len(image_files)
print(f"Reading from image folder: {image_source}")
print(f"Found {total_frames} images, size: {frame_width}x{frame_height}")
else:
print(f"Error: Invalid image source: {image_source}")
return
# 限制处理帧数
if max_frames:
total_frames = min(total_frames, max_frames)
tracking_data = tracking_data[:total_frames]
# 准备输出
if save_video:
if not str(output_path).endswith(('.mp4', '.avi')):
output_path = output_path.with_suffix('.mp4')
output_path.parent.mkdir(parents=True, exist_ok=True)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
print(f"Saving video to: {output_path}")
else:
output_path.mkdir(parents=True, exist_ok=True)
video_writer = None
print(f"Saving images to: {output_path}")
# 处理每一帧
print(f"\nProcessing {len(tracking_data)} frames...")
for frame_idx, frame_data in enumerate(tqdm(tracking_data, desc="Visualizing")):
# 读取图像
if cap is not None:
ret, img = cap.read()
if not ret:
print(f"Warning: Cannot read frame {frame_idx}")
break
else:
if frame_idx >= len(image_files):
break
img = cv2.imread(str(image_files[frame_idx]))
if img is None:
print(f"Warning: Cannot read image {image_files[frame_idx]}")
continue
img, _ = render_tracking_frame(
img,
frame_data,
frame_idx,
len(tracking_data),
color_map,
class_names,
trajectory_history=trajectory_history,
class_id_filter=class_id_filter,
track_ids_filter=track_ids_filter,
)
# 保存或写入视频
if video_writer:
video_writer.write(img)
else:
frame_name = frame_data.get('image_name', f'frame_{frame_idx:06d}.jpg')
if not frame_name.endswith(('.jpg', '.png', '.jpeg')):
frame_name = f'{Path(frame_name).stem}.jpg'
output_file = output_path / frame_name.replace('.png', '.jpg')
cv2.imwrite(str(output_file), img)
# 清理
if cap is not None:
cap.release()
if video_writer is not None:
video_writer.release()
print(f"\n✓ Visualization complete!")
if save_video:
print(f" Video saved to: {output_path}")
else:
print(f" Images saved to: {output_path}")
print(f" Total tracks visualized: {len(color_map)}")
def create_side_by_side_comparison(tracking_file1, tracking_file2, image_source,
output_path, labels=None, max_frames=100):
"""创建并排对比可视化
Args:
tracking_file1, tracking_file2: 两个跟踪结果文件
image_source: 图像源
output_path: 输出视频路径
labels: 两个方法的标签
max_frames: 最大处理帧数
"""
if labels is None:
labels = ['Method 1', 'Method 2']
# 加载跟踪数据
with open(tracking_file1, 'r') as f:
data1 = json.load(f)
with open(tracking_file2, 'r') as f:
data2 = json.load(f)
image_source = Path(image_source)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
class_names = {0: 'Car', 1: 'Pedestrian', 2: 'Tricycle', 3: 'Cyclist'}
color_map1, color_map2 = {}, {}
# 读取图像
if image_source.is_dir():
image_files = sorted(list(image_source.glob('*.jpg')) +
list(image_source.glob('*.png')))
first_img = cv2.imread(str(image_files[0]))
h, w = first_img.shape[:2]
else:
cap = cv2.VideoCapture(str(image_source))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# 创建视频写入器(并排宽度加倍)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(str(output_path), fourcc, 25, (w * 2, h))
print(f"Creating side-by-side comparison video...")
frames_to_process = min(len(data1), len(data2), max_frames)
for frame_idx in tqdm(range(frames_to_process), desc="Processing"):
# 读取原始图像
if image_source.is_dir():
img = cv2.imread(str(image_files[frame_idx]))
else:
cap = cv2.VideoCapture(str(image_source))
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, img = cap.read()
cap.release()
if not ret:
break
img1 = img.copy()
img2 = img.copy()
# 绘制第一个跟踪结果
for det in data1[frame_idx].get('detections', []):
if 'track_id' in det:
color = get_color_for_track(det['track_id'], color_map1)
img1 = draw_box_with_label(img1, det['bbox'], det['track_id'],
det['class_id'], det['confidence'], color, class_names)
# 绘制第二个跟踪结果
for det in data2[frame_idx].get('detections', []):
if 'track_id' in det:
color = get_color_for_track(det['track_id'], color_map2)
img2 = draw_box_with_label(img2, det['bbox'], det['track_id'],
det['class_id'], det['confidence'], color, class_names)
# 添加标签
cv2.putText(img1, labels[0], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.putText(img2, labels[1], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# 并排拼接
combined = np.hstack([img1, img2])
video_writer.write(combined)
video_writer.release()
print(f"✓ Comparison video saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(description='Visualize tracking results on images')
parser.add_argument('--tracking', type=str, required=True, help='Path to tracking.json file')
parser.add_argument('--images', type=str, required=True, help='Path to images folder or video file')
parser.add_argument('--output', type=str, required=True, help='Output path (folder or video file)')
parser.add_argument('--save-video', action='store_true', help='Save as video instead of images')
parser.add_argument('--fps', type=int, default=25, help='Video FPS (default: 25)')
parser.add_argument('--max-frames', type=int, help='Maximum number of frames to process')
parser.add_argument('--show-trajectory', action='store_true', help='Show trajectory lines')
parser.add_argument('--compare', type=str, help='Second tracking.json for side-by-side comparison')
parser.add_argument('--labels', type=str, nargs=2, default=['Method 1', 'Method 2'],
help='Labels for comparison mode')
parser.add_argument('--class-id', type=int, default=None,
help='Filter by class ID (0=Car, 1=Pedestrian, 2=Cyclist, etc.). If not specified, all classes are shown.')
parser.add_argument('--track-ids', type=int, nargs='+', default=None,
help='Specific track IDs to visualize (e.g., --track-ids 1 2 3). If not specified, all tracks are shown.')
parser.add_argument('--align-by-frame-id', action='store_true',
help='Align tracking frames to the raw input video by frameId and export images only.')
args = parser.parse_args()
# 检查输入文件
tracking_path = Path(args.tracking)
if not tracking_path.exists():
print(f"Error: Tracking file not found: {tracking_path}")
return
image_source = Path(args.images)
if not image_source.exists():
print(f"Error: Image source not found: {image_source}")
return
# 加载跟踪数据
print(f"Loading tracking data from {tracking_path}")
with open(tracking_path, 'r') as f:
tracking_data = json.load(f)
print(f"Loaded {len(tracking_data)} frames")
# 对比模式
if args.compare:
compare_path = Path(args.compare)
if not compare_path.exists():
print(f"Error: Comparison file not found: {compare_path}")
return
output_path = Path(args.output)
if not str(output_path).endswith('.mp4'):
output_path = output_path.with_suffix('.mp4')
create_side_by_side_comparison(
tracking_path, compare_path, image_source,
output_path, args.labels, args.max_frames or 100
)
else:
# 单文件可视化
# 准备过滤参数
track_ids_set = set(args.track_ids) if args.track_ids else None
# 打印过滤信息
if args.class_id is not None or args.track_ids:
print(f"\nFiltering settings:")
if args.class_id is not None:
print(f" Class ID: {args.class_id}")
if args.track_ids:
print(f" Track IDs: {args.track_ids}")
visualize_tracking_on_images(
tracking_data,
image_source,
args.output,
save_video=args.save_video,
fps=args.fps,
max_frames=args.max_frames,
show_trajectory=args.show_trajectory,
class_id_filter=args.class_id,
track_ids_filter=track_ids_set,
align_by_frame_id=args.align_by_frame_id,
)
if __name__ == '__main__':
main()