Files
yolov26_3d/tools/scripts_for_gt/visualization/visualize_batch.py
2026-06-24 09:35:46 +08:00

210 lines
7.6 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
批量可视化脚本
用途:批量读取多帧图像和对应的真值标签文件,生成可视化结果
使用方法:
python scripts_for_gt/visualize_batch.py \
--image-dir /path/to/images \
--label-dir /path/to/labels \
--output output_dir \
--roi 704 352 \
--virtual-fx 500 \
--max-samples 100
"""
import argparse
from pathlib import Path
import sys
# Add project root to path
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from visualize_single_frame import (
DEFAULT_CLASS_NAMES,
format_target_classes,
resolve_target_class_ids,
visualize_single_frame,
)
def resolve_label_path(label_dir, image_stem, preferred_extension=None):
"""按同名文件自动查找标签,支持 .txt / .json。"""
candidate_extensions = []
if preferred_extension:
normalized_extension = preferred_extension if preferred_extension.startswith(".") else f".{preferred_extension}"
candidate_extensions.append(normalized_extension)
for extension in (".txt", ".json"):
if extension not in candidate_extensions:
candidate_extensions.append(extension)
for extension in candidate_extensions:
candidate_path = label_dir / f"{image_stem}{extension}"
if candidate_path.exists():
return candidate_path
return None
def visualize_batch(
image_dir,
label_dir,
output_dir,
calib_dir=None,
roi_size=None,
virtual_fx=None,
ori_img_size=None,
max_samples=None,
sample_interval=1,
file_extension=".jpg",
label_extension=None,
target_class_ids=None,
):
"""批量可视化图像
Args:
image_dir: 图像目录
label_dir: 标签目录
output_dir: 输出目录
calib_dir: 标定文件目录包含L2_calib/camera4.json
roi_size: ROI尺寸 (width, height)
virtual_fx: 虚拟焦距
ori_img_size: 原始图像尺寸 (width, height)
max_samples: 最大处理样本数
sample_interval: 采样间隔,每隔 N 帧处理一帧
file_extension: 图像文件扩展名
label_extension: 标签文件扩展名,默认自动在 .txt/.json 中查找
target_class_ids: 目标类别ID集合None表示不过滤
"""
image_dir = Path(image_dir)
label_dir = Path(label_dir)
output_dir = Path(output_dir)
# 获取所有图像文件
image_files = sorted(image_dir.glob(f"*{file_extension}"))
if sample_interval > 1:
image_files = image_files[::sample_interval]
if max_samples is not None:
image_files = image_files[:max_samples]
print(f"找到 {len(image_files)} 个图像文件(采样间隔={sample_interval}")
# 类别名称
names = DEFAULT_CLASS_NAMES.copy()
if target_class_ids is not None:
print(f"类别过滤: {format_target_classes(target_class_ids, names=names)}")
success_count = 0
error_count = 0
for i, image_path in enumerate(image_files):
print(f"\n[{i+1}/{len(image_files)}] 处理: {image_path.name}")
# 查找对应的标签文件
label_path = resolve_label_path(label_dir, image_path.stem, preferred_extension=label_extension)
if label_path is None:
expected_extensions = [label_extension] if label_extension else [".txt", ".json"]
print(f" 警告:未找到标签文件 {image_path.stem},已检查扩展名: {expected_extensions}")
error_count += 1
continue
# 查找对应的标定文件
if calib_dir is not None:
# 兼容 L2_calib 子目录和直接放在 calib/ 下两种结构
calib_path = Path(calib_dir) / "L2_calib" / "camera4.json"
if not calib_path.exists():
calib_path = Path(calib_dir) / "camera4.json"
else:
# 自动推断:依次尝试两种常见目录结构
base = image_path.parent.parent
calib_path = base / "calib" / "L2_calib" / "camera4.json"
if not calib_path.exists():
calib_path = base / "calib" / "camera4.json"
if not calib_path.exists():
calib_path = None
print(f" 警告未找到标定文件将仅进行2D可视化")
try:
# 可视化单帧
sample_output_dir = output_dir
visualize_single_frame(
image_path=str(image_path),
label_path=str(label_path),
output_dir=str(sample_output_dir),
calib_path=str(calib_path) if calib_path else None,
roi_size=roi_size,
virtual_fx=virtual_fx,
ori_img_size=ori_img_size,
names=names,
save_types='combined', # 只保存合成图
target_class_ids=target_class_ids,
)
success_count += 1
except Exception as e:
print(f" 错误:{e}")
error_count += 1
print(f"\n处理完成!")
print(f" 成功: {success_count}")
print(f" 失败: {error_count}")
print(f" 输出目录: {output_dir}")
def parse_args():
parser = argparse.ArgumentParser(description="批量可视化真值标签")
parser.add_argument("--image-dir", type=str, required=True, help="图像目录")
parser.add_argument("--label-dir", type=str, required=True, help="标签目录")
parser.add_argument("--output", type=str, default="./gt_visualization_batch", help="输出目录")
parser.add_argument("--calib-dir", type=str, default=None,
help="标定文件目录包含L2_calib/camera4.json")
parser.add_argument("--roi", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"),
help="ROI尺寸宽 高),如: --roi 704 352")
parser.add_argument("--virtual-fx", type=float, default=None,
help="虚拟焦距,用于深度归一化")
parser.add_argument("--ori-img-size", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"),
help="原始图像尺寸(宽 高)")
parser.add_argument("--max-samples", type=int, default=None,
help="最大处理样本数")
parser.add_argument("--sample-interval", type=int, default=1,
help="采样间隔,每隔 N 帧处理一帧默认1即处理全部")
parser.add_argument("--file-extension", type=str, default=".jpg",
help="图像文件扩展名")
parser.add_argument("--label-extension", type=str, default=None,
help="标签文件扩展名,默认自动在 .txt/.json 中查找")
parser.add_argument("--classes", type=str, nargs="+", default=None,
help="仅可视化指定类别,支持类别名/类别ID可传多个值或逗号分隔")
return parser.parse_args()
def main():
args = parse_args()
target_class_ids = resolve_target_class_ids(args.classes, names=DEFAULT_CLASS_NAMES)
visualize_batch(
image_dir=args.image_dir,
label_dir=args.label_dir,
output_dir=args.output,
calib_dir=args.calib_dir,
roi_size=tuple(args.roi) if args.roi else None,
virtual_fx=args.virtual_fx,
ori_img_size=tuple(args.ori_img_size) if args.ori_img_size else None,
max_samples=args.max_samples,
sample_interval=args.sample_interval,
file_extension=args.file_extension,
label_extension=args.label_extension,
target_class_ids=target_class_ids,
)
if __name__ == "__main__":
main()