686 lines
24 KiB
Python
Executable File
686 lines
24 KiB
Python
Executable File
"""
|
||
Visualize tracking results by drawing 2D boxes on images.
|
||
|
||
Usage:
|
||
# 从图像文件夹可视化
|
||
python visualize_tracking_boxes.py \
|
||
--tracking runs/val_viz/exp/tracking.json \
|
||
--images path/to/images \
|
||
--output runs/val_viz/exp/tracked_images
|
||
|
||
# 生成视频
|
||
python visualize_tracking_boxes.py \
|
||
--tracking runs/val_viz/exp/tracking.json \
|
||
--images path/to/images \
|
||
--output runs/val_viz/exp/tracked.mp4 \
|
||
--save-video
|
||
|
||
# 只可视化前100帧
|
||
python visualize_tracking_boxes.py \
|
||
--tracking tracking.json \
|
||
--images images/ \
|
||
--output output/ \
|
||
--max-frames 100
|
||
|
||
# 按 frameId 对齐原始 camera4.bin,并导出单帧图片
|
||
python visualize_tracking_boxes.py \
|
||
--tracking merge.json \
|
||
--images /path/to/case/sigmastar.1/camera4.bin \
|
||
--output tracking_vis_raw \
|
||
--align-by-frame-id
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
from pathlib import Path
|
||
import re
|
||
import sys
|
||
import cv2
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
import colorsys
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||
if str(PROJECT_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from tools.model_inference.adapters.video_dir_inference_utils import (
|
||
iter_video_case_frames,
|
||
read_video_frame_index,
|
||
)
|
||
|
||
|
||
def generate_colors(n):
|
||
"""生成n个区分度高的颜色"""
|
||
colors = []
|
||
for i in range(n):
|
||
hue = i / n
|
||
saturation = 0.9
|
||
value = 0.9
|
||
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
|
||
colors.append(tuple(int(x * 255) for x in rgb))
|
||
return colors
|
||
|
||
|
||
def get_color_for_track(track_id, color_map):
|
||
"""为track_id分配一个持久的颜色"""
|
||
if track_id not in color_map:
|
||
# 使用track_id作为种子生成颜色
|
||
np.random.seed(track_id)
|
||
color = (
|
||
np.random.randint(50, 255),
|
||
np.random.randint(50, 255),
|
||
np.random.randint(50, 255)
|
||
)
|
||
color_map[track_id] = color
|
||
return color_map[track_id]
|
||
|
||
|
||
def should_visualize_detection(det, class_id_filter=None, track_ids_filter=None):
|
||
"""判断检测是否应该被可视化
|
||
|
||
Args:
|
||
det: 检测结果字典
|
||
class_id_filter: 指定的类别ID,None表示所有类别
|
||
track_ids_filter: 指定的track ID列表,None表示所有ID
|
||
|
||
Returns:
|
||
bool: 是否应该可视化该检测
|
||
"""
|
||
# 检查class_id过滤
|
||
if class_id_filter is not None and det.get('class_id') != class_id_filter:
|
||
return False
|
||
|
||
# 检查track_id过滤
|
||
if track_ids_filter is not None:
|
||
track_id = det.get('track_id')
|
||
if track_id not in track_ids_filter:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def safe_int(value):
|
||
"""Best-effort integer conversion."""
|
||
if value is None:
|
||
return None
|
||
try:
|
||
return int(value)
|
||
except (TypeError, ValueError):
|
||
return None
|
||
|
||
|
||
def extract_frame_id_from_image_name(image_name):
|
||
"""Best-effort frame-id extraction from tracking image_name."""
|
||
stem = Path(str(image_name or "")).stem
|
||
match = re.search(r"_(\d+)_(\d+)(?:_merged)?$", stem)
|
||
if match is not None:
|
||
return int(match.group(1))
|
||
|
||
match = re.search(r"_(\d+)(?:_merged)?$", stem)
|
||
if match is not None:
|
||
return int(match.group(1))
|
||
|
||
match = re.search(r"(\d+)", stem)
|
||
if match is not None:
|
||
return int(match.group(1))
|
||
return None
|
||
|
||
|
||
def extract_tracking_frame_id(frame_data, fallback_idx=None):
|
||
"""Resolve one frame's frameId from detections, frame_info, or image_name."""
|
||
for det in frame_data.get("detections", []):
|
||
for key in ("frameId", "frame_id"):
|
||
frame_id = safe_int(det.get(key))
|
||
if frame_id is not None:
|
||
return frame_id
|
||
|
||
frame_info = frame_data.get("frame_info")
|
||
if isinstance(frame_info, dict):
|
||
for key in ("original_frame_id", "frame_id", "frameId"):
|
||
frame_id = safe_int(frame_info.get(key))
|
||
if frame_id is not None:
|
||
return frame_id
|
||
|
||
frame_id = extract_frame_id_from_image_name(frame_data.get("image_name"))
|
||
if frame_id is not None:
|
||
return frame_id
|
||
return fallback_idx
|
||
|
||
|
||
def resolve_detection_class_name(det, class_id, class_names):
|
||
"""Prefer detection-provided type_name before falling back to static names."""
|
||
type_name = str(det.get("type_name") or "").strip()
|
||
if type_name:
|
||
return type_name
|
||
return class_names.get(class_id, f"Class{class_id}")
|
||
|
||
|
||
def build_output_frame_name(frame_data, fallback_idx):
|
||
"""Build a stable output filename for one visualized frame."""
|
||
frame_name = frame_data.get("image_name", f"frame_{fallback_idx:06d}")
|
||
stem = Path(str(frame_name)).stem
|
||
return f"{stem}.jpg"
|
||
|
||
|
||
def draw_box_with_label(img, bbox, track_id, class_id, confidence, color, class_names, class_name_override=None):
|
||
"""在图像上绘制边界框和标签"""
|
||
x1, y1, x2, y2 = map(int, bbox)
|
||
|
||
# 绘制边界框
|
||
thickness = 2
|
||
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
|
||
|
||
# 准备标签文本
|
||
class_name = class_name_override or class_names.get(class_id, f'Class{class_id}')
|
||
label = f'ID:{track_id} {class_name} {confidence:.2f}'
|
||
|
||
# 计算标签背景大小
|
||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||
font_scale = 0.6
|
||
font_thickness = 2
|
||
(text_width, text_height), baseline = cv2.getTextSize(label, font, font_scale, font_thickness)
|
||
|
||
# 绘制标签背景
|
||
label_y1 = max(y1 - text_height - 10, 0)
|
||
label_y2 = y1
|
||
cv2.rectangle(img, (x1, label_y1), (x1 + text_width + 10, label_y2), color, -1)
|
||
|
||
# 绘制标签文本
|
||
text_x = x1 + 5
|
||
text_y = label_y1 + text_height + 5
|
||
cv2.putText(img, label, (text_x, text_y), font, font_scale, (255, 255, 255), font_thickness)
|
||
|
||
return img
|
||
|
||
|
||
def render_tracking_frame(
|
||
img,
|
||
frame_data,
|
||
frame_idx,
|
||
total_frames,
|
||
color_map,
|
||
class_names,
|
||
trajectory_history=None,
|
||
class_id_filter=None,
|
||
track_ids_filter=None,
|
||
):
|
||
"""Render one tracking frame on top of an image."""
|
||
detections = frame_data.get('detections', [])
|
||
visualized_count = 0
|
||
for det in detections:
|
||
track_id = det.get('track_id')
|
||
if track_id is None:
|
||
continue
|
||
|
||
if not should_visualize_detection(det, class_id_filter, track_ids_filter):
|
||
continue
|
||
|
||
visualized_count += 1
|
||
bbox = det['bbox']
|
||
class_id = det['class_id']
|
||
confidence = det['confidence']
|
||
color = get_color_for_track(track_id, color_map)
|
||
class_name = resolve_detection_class_name(det, class_id, class_names)
|
||
|
||
img = draw_box_with_label(
|
||
img,
|
||
bbox,
|
||
track_id,
|
||
class_id,
|
||
confidence,
|
||
color,
|
||
class_names,
|
||
class_name_override=class_name,
|
||
)
|
||
|
||
if trajectory_history is not None:
|
||
center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
|
||
if track_id not in trajectory_history:
|
||
trajectory_history[track_id] = []
|
||
trajectory_history[track_id].append(center)
|
||
|
||
if len(trajectory_history[track_id]) > 1:
|
||
points = np.array(trajectory_history[track_id], dtype=np.int32)
|
||
for i in range(len(points) - 1):
|
||
cv2.line(img, tuple(points[i]), tuple(points[i + 1]), color, 2)
|
||
|
||
if class_id_filter is not None or track_ids_filter is not None:
|
||
info_text = f"Frame: {frame_idx + 1}/{total_frames} | Showing: {visualized_count}/{len(detections)}"
|
||
else:
|
||
info_text = f"Frame: {frame_idx + 1}/{total_frames} | Tracks: {len(detections)}"
|
||
cv2.putText(img, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
|
||
0.8, (0, 255, 0), 2)
|
||
return img, visualized_count
|
||
|
||
|
||
def visualize_tracking_on_raw_video_by_frame_id(
|
||
tracking_data,
|
||
video_path,
|
||
output_path,
|
||
max_frames=None,
|
||
show_trajectory=False,
|
||
class_id_filter=None,
|
||
track_ids_filter=None,
|
||
):
|
||
"""Visualize tracking results on raw camera4.bin frames using frameId alignment."""
|
||
video_path = Path(video_path)
|
||
output_path = Path(output_path)
|
||
output_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
class_names = {
|
||
0: 'Car',
|
||
1: 'Pedestrian',
|
||
2: 'Tricycle',
|
||
3: 'Cyclist'
|
||
}
|
||
color_map = {}
|
||
trajectory_history = {} if show_trajectory else None
|
||
|
||
selected_tracking_data = tracking_data[:max_frames] if max_frames else tracking_data
|
||
frame_id_to_tracking = {}
|
||
ordered_frame_ids = []
|
||
skipped_duplicates = []
|
||
unresolved_frames = []
|
||
for idx, frame_data in enumerate(selected_tracking_data):
|
||
frame_id = extract_tracking_frame_id(frame_data, fallback_idx=idx)
|
||
if frame_id is None:
|
||
unresolved_frames.append(frame_data.get("image_name", f"frame_{idx:06d}"))
|
||
continue
|
||
if frame_id in frame_id_to_tracking:
|
||
skipped_duplicates.append(frame_id)
|
||
continue
|
||
frame_id_to_tracking[frame_id] = (idx, frame_data)
|
||
ordered_frame_ids.append(frame_id)
|
||
|
||
if not frame_id_to_tracking:
|
||
print("Error: no frameIds could be resolved from tracking data")
|
||
return
|
||
|
||
frame_index_payload = read_video_frame_index(video_path)
|
||
total_targets = len(ordered_frame_ids)
|
||
matched_frame_ids = set()
|
||
|
||
print(f"Reading raw video with frameId alignment: {video_path}")
|
||
print(f"Saving aligned images to: {output_path}")
|
||
if unresolved_frames:
|
||
print(f"Warning: skipped {len(unresolved_frames)} frame(s) without resolvable frameId")
|
||
if skipped_duplicates:
|
||
print(f"Warning: skipped duplicate frameId(s): {sorted(set(skipped_duplicates))[:10]}")
|
||
|
||
progress = tqdm(total=total_targets, desc="Visualizing")
|
||
for read_frame_index, frame_bgr, _, frame_info in iter_video_case_frames(
|
||
video_path,
|
||
frame_index_payload=frame_index_payload,
|
||
frame_stride=1,
|
||
max_frames=0,
|
||
):
|
||
aligned_frame_id = safe_int(None if frame_info is None else frame_info.get("frame_id"))
|
||
if aligned_frame_id is None:
|
||
aligned_frame_id = read_frame_index
|
||
|
||
if aligned_frame_id not in frame_id_to_tracking:
|
||
continue
|
||
|
||
tracking_idx, frame_data = frame_id_to_tracking[aligned_frame_id]
|
||
rendered, _ = render_tracking_frame(
|
||
frame_bgr.copy(),
|
||
frame_data,
|
||
tracking_idx,
|
||
total_targets,
|
||
color_map,
|
||
class_names,
|
||
trajectory_history=trajectory_history,
|
||
class_id_filter=class_id_filter,
|
||
track_ids_filter=track_ids_filter,
|
||
)
|
||
output_file = output_path / build_output_frame_name(frame_data, tracking_idx)
|
||
cv2.imwrite(str(output_file), rendered)
|
||
matched_frame_ids.add(aligned_frame_id)
|
||
progress.update(1)
|
||
|
||
if len(matched_frame_ids) >= total_targets:
|
||
break
|
||
|
||
progress.close()
|
||
|
||
missing_frame_ids = [frame_id for frame_id in ordered_frame_ids if frame_id not in matched_frame_ids]
|
||
print(f"\n✓ Visualization complete!")
|
||
print(f" Images saved to: {output_path}")
|
||
print(f" Matched frames: {len(matched_frame_ids)}/{total_targets}")
|
||
print(f" Total tracks visualized: {len(color_map)}")
|
||
if missing_frame_ids:
|
||
preview = ", ".join(str(frame_id) for frame_id in missing_frame_ids[:10])
|
||
suffix = "..." if len(missing_frame_ids) > 10 else ""
|
||
print(f" Missing frameIds: {preview}{suffix}")
|
||
|
||
|
||
def visualize_tracking_on_images(tracking_data, image_source, output_path,
|
||
save_video=False, fps=25, max_frames=None,
|
||
show_trajectory=False, class_id_filter=None,
|
||
track_ids_filter=None, align_by_frame_id=False):
|
||
"""将跟踪结果可视化到图像上
|
||
|
||
Args:
|
||
tracking_data: 跟踪结果数据(从tracking.json加载)
|
||
image_source: 图像来源(文件夹路径或视频文件路径)
|
||
output_path: 输出路径(文件夹或视频文件)
|
||
save_video: 是否保存为视频
|
||
fps: 视频帧率
|
||
max_frames: 最大处理帧数
|
||
show_trajectory: 是否显示轨迹线
|
||
class_id_filter: 指定类别ID,None表示所有类别
|
||
track_ids_filter: 指定track ID集合,None表示所有ID
|
||
align_by_frame_id: 是否按 frameId 对齐原始视频帧
|
||
"""
|
||
image_source = Path(image_source)
|
||
output_path = Path(output_path)
|
||
|
||
if align_by_frame_id:
|
||
if save_video:
|
||
print("Error: --align-by-frame-id only supports image-directory output")
|
||
return
|
||
if not image_source.is_file():
|
||
print("Error: --align-by-frame-id requires a video file such as camera4.bin")
|
||
return
|
||
visualize_tracking_on_raw_video_by_frame_id(
|
||
tracking_data,
|
||
image_source,
|
||
output_path,
|
||
max_frames=max_frames,
|
||
show_trajectory=show_trajectory,
|
||
class_id_filter=class_id_filter,
|
||
track_ids_filter=track_ids_filter,
|
||
)
|
||
return
|
||
|
||
# 类别名称
|
||
class_names = {
|
||
0: 'Car',
|
||
1: 'Pedestrian',
|
||
2: 'Tricycle',
|
||
3: 'Cyclist'
|
||
}
|
||
|
||
# 颜色映射
|
||
color_map = {}
|
||
|
||
# 轨迹历史(用于绘制轨迹线)
|
||
trajectory_history = {} if show_trajectory else None
|
||
|
||
# 检查图像源类型
|
||
VIDEO_EXTENSIONS = {'.mp4', '.avi', '.mov', '.mkv', '.bin'}
|
||
if image_source.is_file() and image_source.suffix.lower() in VIDEO_EXTENSIONS:
|
||
# 从视频读取
|
||
cap = cv2.VideoCapture(str(image_source))
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
print(f"Reading from video: {image_source}")
|
||
print(f"Video info: {frame_width}x{frame_height}, {total_frames} frames")
|
||
elif image_source.is_dir():
|
||
# 从图像文件夹读取
|
||
cap = None
|
||
image_files = sorted(list(image_source.glob('*.jpg')) +
|
||
list(image_source.glob('*.png')) +
|
||
list(image_source.glob('*.jpeg')))
|
||
if not image_files:
|
||
print(f"Error: No images found in {image_source}")
|
||
return
|
||
|
||
# 读取第一张图像获取尺寸
|
||
first_img = cv2.imread(str(image_files[0]))
|
||
if first_img is None:
|
||
print(f"Error: Cannot read image {image_files[0]}")
|
||
return
|
||
frame_height, frame_width = first_img.shape[:2]
|
||
total_frames = len(image_files)
|
||
print(f"Reading from image folder: {image_source}")
|
||
print(f"Found {total_frames} images, size: {frame_width}x{frame_height}")
|
||
else:
|
||
print(f"Error: Invalid image source: {image_source}")
|
||
return
|
||
|
||
# 限制处理帧数
|
||
if max_frames:
|
||
total_frames = min(total_frames, max_frames)
|
||
tracking_data = tracking_data[:total_frames]
|
||
|
||
# 准备输出
|
||
if save_video:
|
||
if not str(output_path).endswith(('.mp4', '.avi')):
|
||
output_path = output_path.with_suffix('.mp4')
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
|
||
print(f"Saving video to: {output_path}")
|
||
else:
|
||
output_path.mkdir(parents=True, exist_ok=True)
|
||
video_writer = None
|
||
print(f"Saving images to: {output_path}")
|
||
|
||
# 处理每一帧
|
||
print(f"\nProcessing {len(tracking_data)} frames...")
|
||
for frame_idx, frame_data in enumerate(tqdm(tracking_data, desc="Visualizing")):
|
||
# 读取图像
|
||
if cap is not None:
|
||
ret, img = cap.read()
|
||
if not ret:
|
||
print(f"Warning: Cannot read frame {frame_idx}")
|
||
break
|
||
else:
|
||
if frame_idx >= len(image_files):
|
||
break
|
||
img = cv2.imread(str(image_files[frame_idx]))
|
||
if img is None:
|
||
print(f"Warning: Cannot read image {image_files[frame_idx]}")
|
||
continue
|
||
|
||
img, _ = render_tracking_frame(
|
||
img,
|
||
frame_data,
|
||
frame_idx,
|
||
len(tracking_data),
|
||
color_map,
|
||
class_names,
|
||
trajectory_history=trajectory_history,
|
||
class_id_filter=class_id_filter,
|
||
track_ids_filter=track_ids_filter,
|
||
)
|
||
|
||
# 保存或写入视频
|
||
if video_writer:
|
||
video_writer.write(img)
|
||
else:
|
||
frame_name = frame_data.get('image_name', f'frame_{frame_idx:06d}.jpg')
|
||
if not frame_name.endswith(('.jpg', '.png', '.jpeg')):
|
||
frame_name = f'{Path(frame_name).stem}.jpg'
|
||
output_file = output_path / frame_name.replace('.png', '.jpg')
|
||
cv2.imwrite(str(output_file), img)
|
||
|
||
# 清理
|
||
if cap is not None:
|
||
cap.release()
|
||
if video_writer is not None:
|
||
video_writer.release()
|
||
|
||
print(f"\n✓ Visualization complete!")
|
||
if save_video:
|
||
print(f" Video saved to: {output_path}")
|
||
else:
|
||
print(f" Images saved to: {output_path}")
|
||
print(f" Total tracks visualized: {len(color_map)}")
|
||
|
||
|
||
def create_side_by_side_comparison(tracking_file1, tracking_file2, image_source,
|
||
output_path, labels=None, max_frames=100):
|
||
"""创建并排对比可视化
|
||
|
||
Args:
|
||
tracking_file1, tracking_file2: 两个跟踪结果文件
|
||
image_source: 图像源
|
||
output_path: 输出视频路径
|
||
labels: 两个方法的标签
|
||
max_frames: 最大处理帧数
|
||
"""
|
||
if labels is None:
|
||
labels = ['Method 1', 'Method 2']
|
||
|
||
# 加载跟踪数据
|
||
with open(tracking_file1, 'r') as f:
|
||
data1 = json.load(f)
|
||
with open(tracking_file2, 'r') as f:
|
||
data2 = json.load(f)
|
||
|
||
image_source = Path(image_source)
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
class_names = {0: 'Car', 1: 'Pedestrian', 2: 'Tricycle', 3: 'Cyclist'}
|
||
color_map1, color_map2 = {}, {}
|
||
|
||
# 读取图像
|
||
if image_source.is_dir():
|
||
image_files = sorted(list(image_source.glob('*.jpg')) +
|
||
list(image_source.glob('*.png')))
|
||
first_img = cv2.imread(str(image_files[0]))
|
||
h, w = first_img.shape[:2]
|
||
else:
|
||
cap = cv2.VideoCapture(str(image_source))
|
||
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
cap.release()
|
||
|
||
# 创建视频写入器(并排宽度加倍)
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
video_writer = cv2.VideoWriter(str(output_path), fourcc, 25, (w * 2, h))
|
||
|
||
print(f"Creating side-by-side comparison video...")
|
||
frames_to_process = min(len(data1), len(data2), max_frames)
|
||
|
||
for frame_idx in tqdm(range(frames_to_process), desc="Processing"):
|
||
# 读取原始图像
|
||
if image_source.is_dir():
|
||
img = cv2.imread(str(image_files[frame_idx]))
|
||
else:
|
||
cap = cv2.VideoCapture(str(image_source))
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||
ret, img = cap.read()
|
||
cap.release()
|
||
if not ret:
|
||
break
|
||
|
||
img1 = img.copy()
|
||
img2 = img.copy()
|
||
|
||
# 绘制第一个跟踪结果
|
||
for det in data1[frame_idx].get('detections', []):
|
||
if 'track_id' in det:
|
||
color = get_color_for_track(det['track_id'], color_map1)
|
||
img1 = draw_box_with_label(img1, det['bbox'], det['track_id'],
|
||
det['class_id'], det['confidence'], color, class_names)
|
||
|
||
# 绘制第二个跟踪结果
|
||
for det in data2[frame_idx].get('detections', []):
|
||
if 'track_id' in det:
|
||
color = get_color_for_track(det['track_id'], color_map2)
|
||
img2 = draw_box_with_label(img2, det['bbox'], det['track_id'],
|
||
det['class_id'], det['confidence'], color, class_names)
|
||
|
||
# 添加标签
|
||
cv2.putText(img1, labels[0], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||
cv2.putText(img2, labels[1], (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
||
|
||
# 并排拼接
|
||
combined = np.hstack([img1, img2])
|
||
video_writer.write(combined)
|
||
|
||
video_writer.release()
|
||
print(f"✓ Comparison video saved to: {output_path}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='Visualize tracking results on images')
|
||
parser.add_argument('--tracking', type=str, required=True, help='Path to tracking.json file')
|
||
parser.add_argument('--images', type=str, required=True, help='Path to images folder or video file')
|
||
parser.add_argument('--output', type=str, required=True, help='Output path (folder or video file)')
|
||
parser.add_argument('--save-video', action='store_true', help='Save as video instead of images')
|
||
parser.add_argument('--fps', type=int, default=25, help='Video FPS (default: 25)')
|
||
parser.add_argument('--max-frames', type=int, help='Maximum number of frames to process')
|
||
parser.add_argument('--show-trajectory', action='store_true', help='Show trajectory lines')
|
||
parser.add_argument('--compare', type=str, help='Second tracking.json for side-by-side comparison')
|
||
parser.add_argument('--labels', type=str, nargs=2, default=['Method 1', 'Method 2'],
|
||
help='Labels for comparison mode')
|
||
parser.add_argument('--class-id', type=int, default=None,
|
||
help='Filter by class ID (0=Car, 1=Pedestrian, 2=Cyclist, etc.). If not specified, all classes are shown.')
|
||
parser.add_argument('--track-ids', type=int, nargs='+', default=None,
|
||
help='Specific track IDs to visualize (e.g., --track-ids 1 2 3). If not specified, all tracks are shown.')
|
||
parser.add_argument('--align-by-frame-id', action='store_true',
|
||
help='Align tracking frames to the raw input video by frameId and export images only.')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 检查输入文件
|
||
tracking_path = Path(args.tracking)
|
||
if not tracking_path.exists():
|
||
print(f"Error: Tracking file not found: {tracking_path}")
|
||
return
|
||
|
||
image_source = Path(args.images)
|
||
if not image_source.exists():
|
||
print(f"Error: Image source not found: {image_source}")
|
||
return
|
||
|
||
# 加载跟踪数据
|
||
print(f"Loading tracking data from {tracking_path}")
|
||
with open(tracking_path, 'r') as f:
|
||
tracking_data = json.load(f)
|
||
|
||
print(f"Loaded {len(tracking_data)} frames")
|
||
|
||
# 对比模式
|
||
if args.compare:
|
||
compare_path = Path(args.compare)
|
||
if not compare_path.exists():
|
||
print(f"Error: Comparison file not found: {compare_path}")
|
||
return
|
||
|
||
output_path = Path(args.output)
|
||
if not str(output_path).endswith('.mp4'):
|
||
output_path = output_path.with_suffix('.mp4')
|
||
|
||
create_side_by_side_comparison(
|
||
tracking_path, compare_path, image_source,
|
||
output_path, args.labels, args.max_frames or 100
|
||
)
|
||
else:
|
||
# 单文件可视化
|
||
# 准备过滤参数
|
||
track_ids_set = set(args.track_ids) if args.track_ids else None
|
||
|
||
# 打印过滤信息
|
||
if args.class_id is not None or args.track_ids:
|
||
print(f"\nFiltering settings:")
|
||
if args.class_id is not None:
|
||
print(f" Class ID: {args.class_id}")
|
||
if args.track_ids:
|
||
print(f" Track IDs: {args.track_ids}")
|
||
|
||
visualize_tracking_on_images(
|
||
tracking_data,
|
||
image_source,
|
||
args.output,
|
||
save_video=args.save_video,
|
||
fps=args.fps,
|
||
max_frames=args.max_frames,
|
||
show_trajectory=args.show_trajectory,
|
||
class_id_filter=args.class_id,
|
||
track_ids_filter=track_ids_set,
|
||
align_by_frame_id=args.align_by_frame_id,
|
||
)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|