Files
yolov26_3d/tools/scripts_for_gt/visualization/visualize_batch.py

210 lines
7.6 KiB
Python
Raw Normal View History

2026-06-24 09:35:46 +08:00
"""
批量可视化脚本
用途批量读取多帧图像和对应的真值标签文件生成可视化结果
使用方法
python scripts_for_gt/visualize_batch.py \
--image-dir /path/to/images \
--label-dir /path/to/labels \
--output output_dir \
--roi 704 352 \
--virtual-fx 500 \
--max-samples 100
"""
import argparse
from pathlib import Path
import sys
# Add project root to path
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from visualize_single_frame import (
DEFAULT_CLASS_NAMES,
format_target_classes,
resolve_target_class_ids,
visualize_single_frame,
)
def resolve_label_path(label_dir, image_stem, preferred_extension=None):
"""按同名文件自动查找标签,支持 .txt / .json。"""
candidate_extensions = []
if preferred_extension:
normalized_extension = preferred_extension if preferred_extension.startswith(".") else f".{preferred_extension}"
candidate_extensions.append(normalized_extension)
for extension in (".txt", ".json"):
if extension not in candidate_extensions:
candidate_extensions.append(extension)
for extension in candidate_extensions:
candidate_path = label_dir / f"{image_stem}{extension}"
if candidate_path.exists():
return candidate_path
return None
def visualize_batch(
image_dir,
label_dir,
output_dir,
calib_dir=None,
roi_size=None,
virtual_fx=None,
ori_img_size=None,
max_samples=None,
sample_interval=1,
file_extension=".jpg",
label_extension=None,
target_class_ids=None,
):
"""批量可视化图像
Args:
image_dir: 图像目录
label_dir: 标签目录
output_dir: 输出目录
calib_dir: 标定文件目录包含L2_calib/camera4.json
roi_size: ROI尺寸 (width, height)
virtual_fx: 虚拟焦距
ori_img_size: 原始图像尺寸 (width, height)
max_samples: 最大处理样本数
sample_interval: 采样间隔每隔 N 帧处理一帧
file_extension: 图像文件扩展名
label_extension: 标签文件扩展名默认自动在 .txt/.json 中查找
target_class_ids: 目标类别ID集合None表示不过滤
"""
image_dir = Path(image_dir)
label_dir = Path(label_dir)
output_dir = Path(output_dir)
# 获取所有图像文件
image_files = sorted(image_dir.glob(f"*{file_extension}"))
if sample_interval > 1:
image_files = image_files[::sample_interval]
if max_samples is not None:
image_files = image_files[:max_samples]
print(f"找到 {len(image_files)} 个图像文件(采样间隔={sample_interval}")
# 类别名称
names = DEFAULT_CLASS_NAMES.copy()
if target_class_ids is not None:
print(f"类别过滤: {format_target_classes(target_class_ids, names=names)}")
success_count = 0
error_count = 0
for i, image_path in enumerate(image_files):
print(f"\n[{i+1}/{len(image_files)}] 处理: {image_path.name}")
# 查找对应的标签文件
label_path = resolve_label_path(label_dir, image_path.stem, preferred_extension=label_extension)
if label_path is None:
expected_extensions = [label_extension] if label_extension else [".txt", ".json"]
print(f" 警告:未找到标签文件 {image_path.stem},已检查扩展名: {expected_extensions}")
error_count += 1
continue
# 查找对应的标定文件
if calib_dir is not None:
# 兼容 L2_calib 子目录和直接放在 calib/ 下两种结构
calib_path = Path(calib_dir) / "L2_calib" / "camera4.json"
if not calib_path.exists():
calib_path = Path(calib_dir) / "camera4.json"
else:
# 自动推断:依次尝试两种常见目录结构
base = image_path.parent.parent
calib_path = base / "calib" / "L2_calib" / "camera4.json"
if not calib_path.exists():
calib_path = base / "calib" / "camera4.json"
if not calib_path.exists():
calib_path = None
print(f" 警告未找到标定文件将仅进行2D可视化")
try:
# 可视化单帧
sample_output_dir = output_dir
visualize_single_frame(
image_path=str(image_path),
label_path=str(label_path),
output_dir=str(sample_output_dir),
calib_path=str(calib_path) if calib_path else None,
roi_size=roi_size,
virtual_fx=virtual_fx,
ori_img_size=ori_img_size,
names=names,
save_types='combined', # 只保存合成图
target_class_ids=target_class_ids,
)
success_count += 1
except Exception as e:
print(f" 错误:{e}")
error_count += 1
print(f"\n处理完成!")
print(f" 成功: {success_count}")
print(f" 失败: {error_count}")
print(f" 输出目录: {output_dir}")
def parse_args():
parser = argparse.ArgumentParser(description="批量可视化真值标签")
parser.add_argument("--image-dir", type=str, required=True, help="图像目录")
parser.add_argument("--label-dir", type=str, required=True, help="标签目录")
parser.add_argument("--output", type=str, default="./gt_visualization_batch", help="输出目录")
parser.add_argument("--calib-dir", type=str, default=None,
help="标定文件目录包含L2_calib/camera4.json")
parser.add_argument("--roi", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"),
help="ROI尺寸宽 高),如: --roi 704 352")
parser.add_argument("--virtual-fx", type=float, default=None,
help="虚拟焦距,用于深度归一化")
parser.add_argument("--ori-img-size", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"),
help="原始图像尺寸(宽 高)")
parser.add_argument("--max-samples", type=int, default=None,
help="最大处理样本数")
parser.add_argument("--sample-interval", type=int, default=1,
help="采样间隔,每隔 N 帧处理一帧默认1即处理全部")
parser.add_argument("--file-extension", type=str, default=".jpg",
help="图像文件扩展名")
parser.add_argument("--label-extension", type=str, default=None,
help="标签文件扩展名,默认自动在 .txt/.json 中查找")
parser.add_argument("--classes", type=str, nargs="+", default=None,
help="仅可视化指定类别,支持类别名/类别ID可传多个值或逗号分隔")
return parser.parse_args()
def main():
args = parse_args()
target_class_ids = resolve_target_class_ids(args.classes, names=DEFAULT_CLASS_NAMES)
visualize_batch(
image_dir=args.image_dir,
label_dir=args.label_dir,
output_dir=args.output,
calib_dir=args.calib_dir,
roi_size=tuple(args.roi) if args.roi else None,
virtual_fx=args.virtual_fx,
ori_img_size=tuple(args.ori_img_size) if args.ori_img_size else None,
max_samples=args.max_samples,
sample_interval=args.sample_interval,
file_extension=args.file_extension,
label_extension=args.label_extension,
target_class_ids=target_class_ids,
)
if __name__ == "__main__":
main()