Files
yolov26_3d/tools/temporal_analysis/visualize_tracking_boxes.py

686 lines
24 KiB
Python
Raw Normal View History

2026-06-24 09:35:46 +08:00
"""
Visualize tracking results by drawing 2D boxes on images.
Usage:
# 从图像文件夹可视化
python visualize_tracking_boxes.py \
--tracking runs/val_viz/exp/tracking.json \
--images path/to/images \
--output runs/val_viz/exp/tracked_images
# 生成视频
python visualize_tracking_boxes.py \
--tracking runs/val_viz/exp/tracking.json \
--images path/to/images \
--output runs/val_viz/exp/tracked.mp4 \
--save-video
# 只可视化前100帧
python visualize_tracking_boxes.py \
--tracking tracking.json \
--images images/ \
--output output/ \
--max-frames 100
# 按 frameId 对齐原始 camera4.bin并导出单帧图片
python visualize_tracking_boxes.py \
--tracking merge.json \
--images /path/to/case/sigmastar.1/camera4.bin \
--output tracking_vis_raw \
--align-by-frame-id
"""
import argparse
import json
from pathlib import Path
import re
import sys
import cv2
import numpy as np
from tqdm import tqdm
import colorsys
PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from tools.model_inference.adapters.video_dir_inference_utils import (
iter_video_case_frames,
read_video_frame_index,
)
def generate_colors(n):
"""生成n个区分度高的颜色"""
colors = []
for i in range(n):
hue = i / n
saturation = 0.9
value = 0.9
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
colors.append(tuple(int(x * 255) for x in rgb))
return colors
def get_color_for_track(track_id, color_map):
"""为track_id分配一个持久的颜色"""
if track_id not in color_map:
# 使用track_id作为种子生成颜色
np.random.seed(track_id)
color = (
np.random.randint(50, 255),
np.random.randint(50, 255),
np.random.randint(50, 255)
)
color_map[track_id] = color
return color_map[track_id]
def should_visualize_detection(det, class_id_filter=None, track_ids_filter=None):
"""判断检测是否应该被可视化
Args:
det: 检测结果字典
class_id_filter: 指定的类别IDNone表示所有类别
track_ids_filter: 指定的track ID列表None表示所有ID
Returns:
bool: 是否应该可视化该检测
"""
# 检查class_id过滤
if class_id_filter is not None and det.get('class_id') != class_id_filter:
return False
# 检查track_id过滤
if track_ids_filter is not None:
track_id = det.get('track_id')
if track_id not in track_ids_filter:
return False
return True
def safe_int(value):
"""Best-effort integer conversion."""
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def extract_frame_id_from_image_name(image_name):
"""Best-effort frame-id extraction from tracking image_name."""
stem = Path(str(image_name or "")).stem
match = re.search(r"_(\d+)_(\d+)(?:_merged)?$", stem)
if match is not None:
return int(match.group(1))
match = re.search(r"_(\d+)(?:_merged)?$", stem)
if match is not None:
return int(match.group(1))
match = re.search(r"(\d+)", stem)
if match is not None:
return int(match.group(1))
return None
def extract_tracking_frame_id(frame_data, fallback_idx=None):
"""Resolve one frame's frameId from detections, frame_info, or image_name."""
for det in frame_data.get("detections", []):
for key in ("frameId", "frame_id"):
frame_id = safe_int(det.get(key))
if frame_id is not None:
return frame_id
frame_info = frame_data.get("frame_info")
if isinstance(frame_info, dict):
for key in ("original_frame_id", "frame_id", "frameId"):
frame_id = safe_int(frame_info.get(key))
if frame_id is not None:
return frame_id
frame_id = extract_frame_id_from_image_name(frame_data.get("image_name"))
if frame_id is not None:
return frame_id
return fallback_idx
def resolve_detection_class_name(det, class_id, class_names):
"""Prefer detection-provided type_name before falling back to static names."""
type_name = str(det.get("type_name") or "").strip()
if type_name:
return type_name
return class_names.get(class_id, f"Class{class_id}")
def build_output_frame_name(frame_data, fallback_idx):
"""Build a stable output filename for one visualized frame."""
frame_name = frame_data.get("image_name", f"frame_{fallback_idx:06d}")
stem = Path(str(frame_name)).stem
return f"{stem}.jpg"
def draw_box_with_label(img, bbox, track_id, class_id, confidence, color, class_names, class_name_override=None):
"""在图像上绘制边界框和标签"""
x1, y1, x2, y2 = map(int, bbox)
# 绘制边界框
thickness = 2
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
# 准备标签文本
class_name = class_name_override or class_names.get(class_id, f'Class{class_id}')
label = f'ID:{track_id} {class_name} {confidence:.2f}'
# 计算标签背景大小
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
font_thickness = 2
(text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness)
# 绘制标签背景
label_y1 = max(y1 - text_height - 10, 0)
label_y2 = y1
cv2.rectangle(img, (x1, label_y1), (x1 + text_width + 10, label_y2), color, -1)
# 绘制标签文本
text_x = x1 + 5
text_y = label_y1 + text_height + 5
cv2.putText(img, label, (text_x, text_y), font, font_scale, (255, 255, 255), font_thickness)
return img
def render_tracking_frame(
img,
frame_data,
frame_idx,
total_frames,
color_map,
class_names,
trajectory_history=None,
class_id_filter=None,
track_ids_filter=None,
):
"""Render one tracking frame on top of an image."""
detections = frame_data.get('detections', [])
visualized_count = 0
for det in detections:
track_id = det.get('track_id')
if track_id is None:
continue
if not should_visualize_detection(det, class_id_filter, track_ids_filter):
continue
visualized_count += 1
bbox = det['bbox']
class_id = det['class_id']
confidence = det['confidence']
color = get_color_for_track(track_id, color_map)
class_name = resolve_detection_class_name(det, class_id, class_names)
img = draw_box_with_label(
img,
bbox,
track_id,
class_id,
confidence,
color,
class_names,
class_name_override=class_name,
)
if trajectory_history is not None:
center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
if track_id not in trajectory_history:
trajectory_history[track_id] = []
trajectory_history[track_id].append(center)
if len(trajectory_history[track_id]) > 1:
points = np.array(trajectory_history[track_id], dtype=np.int32)
for i in range(len(points) - 1):
cv2.line(img, tuple(points[i]), tuple(points[i + 1]), color, 2)
if class_id_filter is not None or track_ids_filter is not None:
info_text = f"Frame: {frame_idx + 1}/{total_frames} | Showing: {visualized_count}/{len(detections)}"
else:
info_text = f"Frame: {frame_idx + 1}/{total_frames} | Tracks: {len(detections)}"
cv2.putText(img, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 255, 0), 2)
return img, visualized_count
def visualize_tracking_on_raw_video_by_frame_id(
tracking_data,
video_path,
output_path,
max_frames=None,
show_trajectory=False,
class_id_filter=None,
track_ids_filter=None,
):
"""Visualize tracking results on raw camera4.bin frames using frameId alignment."""
video_path = Path(video_path)
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
class_names = {
0: 'Car',
1: 'Pedestrian',
2: 'Tricycle',
3: 'Cyclist'
}
color_map = {}
trajectory_history = {} if show_trajectory else None
selected_tracking_data = tracking_data[:max_frames] if max_frames else tracking_data
frame_id_to_tracking = {}
ordered_frame_ids = []
skipped_duplicates = []
unresolved_frames = []
for idx, frame_data in enumerate(selected_tracking_data):
frame_id = extract_tracking_frame_id(frame_data, fallback_idx=idx)
if frame_id is None:
unresolved_frames.append(frame_data.get("image_name", f"frame_{idx:06d}"))
continue
if frame_id in frame_id_to_tracking:
skipped_duplicates.append(frame_id)
continue
frame_id_to_tracking[frame_id] = (idx, frame_data)
ordered_frame_ids.append(frame_id)
if not frame_id_to_tracking:
print("Error: no frameIds could be resolved from tracking data")
return
frame_index_payload = read_video_frame_index(video_path)
total_targets = len(ordered_frame_ids)
matched_frame_ids = set()
print(f"Reading raw video with frameId alignment: {video_path}")
print(f"Saving aligned images to: {output_path}")
if unresolved_frames:
print(f"Warning: skipped {len(unresolved_frames)} frame(s) without resolvable frameId")
if skipped_duplicates:
print(f"Warning: skipped duplicate frameId(s): {sorted(set(skipped_duplicates))[:10]}")
progress = tqdm(total=total_targets, desc="Visualizing")
for read_frame_index, frame_bgr, _, frame_info in iter_video_case_frames(
video_path,
frame_index_payload=frame_index_payload,
frame_stride=1,
max_frames=0,
):
aligned_frame_id = safe_int(None if frame_info is None else frame_info.get("frame_id"))
if aligned_frame_id is None:
aligned_frame_id = read_frame_index
if aligned_frame_id not in frame_id_to_tracking:
continue
tracking_idx, frame_data = frame_id_to_tracking[aligned_frame_id]
rendered, _ = render_tracking_frame(
frame_bgr.copy(),
frame_data,
tracking_idx,
total_targets,
color_map,
class_names,
trajectory_history=trajectory_history,
class_id_filter=class_id_filter,
track_ids_filter=track_ids_filter,
)
output_file = output_path / build_output_frame_name(frame_data, tracking_idx)
cv2.imwrite(str(output_file), rendered)
matched_frame_ids.add(aligned_frame_id)
progress.update(1)
if len(matched_frame_ids) >= total_targets:
break
progress.close()
missing_frame_ids = [frame_id for frame_id in ordered_frame_ids if frame_id not in matched_frame_ids]
print(f"\n✓ Visualization complete!")
print(f" Images saved to: {output_path}")
print(f" Matched frames: {len(matched_frame_ids)}/{total_targets}")
print(f" Total tracks visualized: {len(color_map)}")
if missing_frame_ids:
preview = ", ".join(str(frame_id) for frame_id in missing_frame_ids[:10])
suffix = "..." if len(missing_frame_ids) > 10 else ""
print(f" Missing frameIds: {preview}{suffix}")
def visualize_tracking_on_images(tracking_data, image_source, output_path,
save_video=False, fps=25, max_frames=None,
show_trajectory=False, class_id_filter=None,
track_ids_filter=None, align_by_frame_id=False):
"""将跟踪结果可视化到图像上
Args:
tracking_data: 跟踪结果数据从tracking.json加载
image_source: 图像来源文件夹路径或视频文件路径
output_path: 输出路径文件夹或视频文件
save_video: 是否保存为视频
fps: 视频帧率
max_frames: 最大处理帧数
show_trajectory: 是否显示轨迹线
class_id_filter: 指定类别IDNone表示所有类别
track_ids_filter: 指定track ID集合None表示所有ID
align_by_frame_id: 是否按 frameId 对齐原始视频帧
"""
image_source = Path(image_source)
output_path = Path(output_path)
if align_by_frame_id:
if save_video:
print("Error: --align-by-frame-id only supports image-directory output")
return
if not image_source.is_file():
print("Error: --align-by-frame-id requires a video file such as camera4.bin")
return
visualize_tracking_on_raw_video_by_frame_id(
tracking_data,
image_source,
output_path,
max_frames=max_frames,
show_trajectory=show_trajectory,
class_id_filter=class_id_filter,
track_ids_filter=track_ids_filter,
)
return
# 类别名称
class_names = {
0: 'Car',
1: 'Pedestrian',
2: 'Tricycle',
3: 'Cyclist'
}
# 颜色映射
color_map = {}
# 轨迹历史(用于绘制轨迹线)
trajectory_history = {} if show_trajectory else None
# 检查图像源类型
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.bin'}
if image_source.is_file() and image_source.suffix.lower() in VIDEO_EXTENSIONS:
# 从视频读取
cap = cv2.VideoCapture(str(image_source))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Reading from video: {image_source}")
print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames")
elif image_source.is_dir():
# 从图像文件夹读取
cap = None
image_files = sorted(list(image_source.glob('*.jpg')) +
list(image_source.glob('*.png')) +
list(image_source.glob('*.jpeg')))
if not image_files:
print(f"Error: No images found in {image_source}")
return
# 读取第一张图像获取尺寸
first_img = cv2.imread(str(image_files[0]))
if first_img is None:
print(f"Error: Cannot read image {image_files[0]}")
return
frame_height, frame_width = first_img.shape[:2]
total_frames = len(image_files)
print(f"Reading from image folder: {image_source}")
print(f"Found {total_frames} images, size: {frame_width}x{frame_height}")
else:
print(f"Error: Invalid image source: {image_source}")
return
# 限制处理帧数
if max_frames:
total_frames = min(total_frames, max_frames)
tracking_data = tracking_data[:total_frames]
# 准备输出
if save_video:
if not str(output_path).endswith(('.mp4', '.avi')):
output_path = output_path.with_suffix('.mp4')
output_path.parent.mkdir(parents=True, exist_ok=True)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
print(f"Saving video to: {output_path}")
else:
output_path.mkdir(parents=True, exist_ok=True)
video_writer = None
print(f"Saving images to: {output_path}")
# 处理每一帧
print(f"\nProcessing {len(tracking_data)} frames...")
for frame_idx, frame_data in enumerate(tqdm(tracking_data, desc="Visualizing")):
# 读取图像
if cap is not None:
ret, img = cap.read()
if not ret:
print(f"Warning: Cannot read frame {frame_idx}")
break
else:
if frame_idx >= len(image_files):
break
img = cv2.imread(str(image_files[frame_idx]))
if img is None:
print(f"Warning: Cannot read image {image_files[frame_idx]}")
continue
img, _ = render_tracking_frame(
img,
frame_data,
frame_idx,
len(tracking_data),
color_map,
class_names,
trajectory_history=trajectory_history,
class_id_filter=class_id_filter,
track_ids_filter=track_ids_filter,
)
# 保存或写入视频
if video_writer:
video_writer.write(img)
else:
frame_name = frame_data.get('image_name', f'frame_{frame_idx:06d}.jpg')
if not frame_name.endswith(('.jpg', '.png', '.jpeg')):
frame_name = f'{Path(frame_name).stem}.jpg'
output_file = output_path / frame_name.replace('.png', '.jpg')
cv2.imwrite(str(output_file), img)
# 清理
if cap is not None:
cap.release()
if video_writer is not None:
video_writer.release()
print(f"\n✓ Visualization complete!")
if save_video:
print(f" Video saved to: {output_path}")
else:
print(f" Images saved to: {output_path}")
print(f" Total tracks visualized: {len(color_map)}")
def create_side_by_side_comparison(tracking_file1, tracking_file2, image_source,
output_path, labels=None, max_frames=100):
"""创建并排对比可视化
Args:
tracking_file1, tracking_file2: 两个跟踪结果文件
image_source: 图像源
output_path: 输出视频路径
labels: 两个方法的标签
max_frames: 最大处理帧数
"""
if labels is None:
labels = ['Method 1', 'Method 2']
# 加载跟踪数据
with open(tracking_file1, 'r') as f:
data1 = json.load(f)
with open(tracking_file2, 'r') as f:
data2 = json.load(f)
image_source = Path(image_source)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
class_names = {0: 'Car', 1: 'Pedestrian', 2: 'Tricycle', 3: 'Cyclist'}
color_map1, color_map2 = {}, {}
# 读取图像
if image_source.is_dir():
image_files = sorted(list(image_source.glob('*.jpg')) +
list(image_source.glob('*.png')))
first_img = cv2.imread(str(image_files[0]))
h, w = first_img.shape[:2]
else:
cap = cv2.VideoCapture(str(image_source))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# 创建视频写入器(并排宽度加倍)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(str(output_path), fourcc, 25, (w * 2, h))
print(f"Creating side-by-side comparison video...")
frames_to_process = min(len(data1), len(data2), max_frames)
for frame_idx in tqdm(range(frames_to_process), desc="Processing"):
# 读取原始图像
if image_source.is_dir():
img = cv2.imread(str(image_files[frame_idx]))
else:
cap = cv2.VideoCapture(str(image_source))
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, img = cap.read()
cap.release()
if not ret:
break
img1 = img.copy()
img2 = img.copy()
# 绘制第一个跟踪结果
for det in data1[frame_idx].get('detections', []):
if 'track_id' in det:
color = get_color_for_track(det['track_id'], color_map1)
img1 = draw_box_with_label(img1, det['bbox'], det['track_id'],
det['class_id'], det['confidence'], color, class_names)
# 绘制第二个跟踪结果
for det in data2[frame_idx].get('detections', []):
if 'track_id' in det:
color = get_color_for_track(det['track_id'], color_map2)
img2 = draw_box_with_label(img2, det['bbox'], det['track_id'],
det['class_id'], det['confidence'], color, class_names)
# 添加标签
cv2.putText(img1, labels[0], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.putText(img2, labels[1], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# 并排拼接
combined = np.hstack([img1, img2])
video_writer.write(combined)
video_writer.release()
print(f"✓ Comparison video saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(description='Visualize tracking results on images')
parser.add_argument('--tracking', type=str, required=True, help='Path to tracking.json file')
parser.add_argument('--images', type=str, required=True, help='Path to images folder or video file')
parser.add_argument('--output', type=str, required=True, help='Output path (folder or video file)')
parser.add_argument('--save-video', action='store_true', help='Save as video instead of images')
parser.add_argument('--fps', type=int, default=25, help='Video FPS (default: 25)')
parser.add_argument('--max-frames', type=int, help='Maximum number of frames to process')
parser.add_argument('--show-trajectory', action='store_true', help='Show trajectory lines')
parser.add_argument('--compare', type=str, help='Second tracking.json for side-by-side comparison')
parser.add_argument('--labels', type=str, nargs=2, default=['Method 1', 'Method 2'],
help='Labels for comparison mode')
parser.add_argument('--class-id', type=int, default=None,
help='Filter by class ID (0=Car, 1=Pedestrian, 2=Cyclist, etc.). If not specified, all classes are shown.')
parser.add_argument('--track-ids', type=int, nargs='+', default=None,
help='Specific track IDs to visualize (e.g., --track-ids 1 2 3). If not specified, all tracks are shown.')
parser.add_argument('--align-by-frame-id', action='store_true',
help='Align tracking frames to the raw input video by frameId and export images only.')
args = parser.parse_args()
# 检查输入文件
tracking_path = Path(args.tracking)
if not tracking_path.exists():
print(f"Error: Tracking file not found: {tracking_path}")
return
image_source = Path(args.images)
if not image_source.exists():
print(f"Error: Image source not found: {image_source}")
return
# 加载跟踪数据
print(f"Loading tracking data from {tracking_path}")
with open(tracking_path, 'r') as f:
tracking_data = json.load(f)
print(f"Loaded {len(tracking_data)} frames")
# 对比模式
if args.compare:
compare_path = Path(args.compare)
if not compare_path.exists():
print(f"Error: Comparison file not found: {compare_path}")
return
output_path = Path(args.output)
if not str(output_path).endswith('.mp4'):
output_path = output_path.with_suffix('.mp4')
create_side_by_side_comparison(
tracking_path, compare_path, image_source,
output_path, args.labels, args.max_frames or 100
)
else:
# 单文件可视化
# 准备过滤参数
track_ids_set = set(args.track_ids) if args.track_ids else None
# 打印过滤信息
if args.class_id is not None or args.track_ids:
print(f"\nFiltering settings:")
if args.class_id is not None:
print(f" Class ID: {args.class_id}")
if args.track_ids:
print(f" Track IDs: {args.track_ids}")
visualize_tracking_on_images(
tracking_data,
image_source,
args.output,
save_video=args.save_video,
fps=args.fps,
max_frames=args.max_frames,
show_trajectory=args.show_trajectory,
class_id_filter=args.class_id,
track_ids_filter=track_ids_set,
align_by_frame_id=args.align_by_frame_id,
)
if __name__ == '__main__':
main()