210 lines
7.6 KiB
Python
Executable File
210 lines
7.6 KiB
Python
Executable File
"""
|
||
批量可视化脚本
|
||
|
||
用途:批量读取多帧图像和对应的真值标签文件,生成可视化结果
|
||
|
||
使用方法:
|
||
python scripts_for_gt/visualize_batch.py \
|
||
--image-dir /path/to/images \
|
||
--label-dir /path/to/labels \
|
||
--output output_dir \
|
||
--roi 704 352 \
|
||
--virtual-fx 500 \
|
||
--max-samples 100
|
||
"""
|
||
|
||
import argparse
|
||
from pathlib import Path
|
||
import sys
|
||
|
||
# Add project root to path
|
||
FILE = Path(__file__).resolve()
|
||
ROOT = FILE.parents[1]
|
||
if str(ROOT) not in sys.path:
|
||
sys.path.append(str(ROOT))
|
||
|
||
from visualize_single_frame import (
|
||
DEFAULT_CLASS_NAMES,
|
||
format_target_classes,
|
||
resolve_target_class_ids,
|
||
visualize_single_frame,
|
||
)
|
||
|
||
|
||
def resolve_label_path(label_dir, image_stem, preferred_extension=None):
|
||
"""按同名文件自动查找标签,支持 .txt / .json。"""
|
||
candidate_extensions = []
|
||
if preferred_extension:
|
||
normalized_extension = preferred_extension if preferred_extension.startswith(".") else f".{preferred_extension}"
|
||
candidate_extensions.append(normalized_extension)
|
||
|
||
for extension in (".txt", ".json"):
|
||
if extension not in candidate_extensions:
|
||
candidate_extensions.append(extension)
|
||
|
||
for extension in candidate_extensions:
|
||
candidate_path = label_dir / f"{image_stem}{extension}"
|
||
if candidate_path.exists():
|
||
return candidate_path
|
||
|
||
return None
|
||
|
||
|
||
def visualize_batch(
|
||
image_dir,
|
||
label_dir,
|
||
output_dir,
|
||
calib_dir=None,
|
||
roi_size=None,
|
||
virtual_fx=None,
|
||
ori_img_size=None,
|
||
max_samples=None,
|
||
sample_interval=1,
|
||
file_extension=".jpg",
|
||
label_extension=None,
|
||
target_class_ids=None,
|
||
):
|
||
"""批量可视化图像
|
||
|
||
Args:
|
||
image_dir: 图像目录
|
||
label_dir: 标签目录
|
||
output_dir: 输出目录
|
||
calib_dir: 标定文件目录(包含L2_calib/camera4.json)
|
||
roi_size: ROI尺寸 (width, height)
|
||
virtual_fx: 虚拟焦距
|
||
ori_img_size: 原始图像尺寸 (width, height)
|
||
max_samples: 最大处理样本数
|
||
sample_interval: 采样间隔,每隔 N 帧处理一帧
|
||
file_extension: 图像文件扩展名
|
||
label_extension: 标签文件扩展名,默认自动在 .txt/.json 中查找
|
||
target_class_ids: 目标类别ID集合,None表示不过滤
|
||
"""
|
||
image_dir = Path(image_dir)
|
||
label_dir = Path(label_dir)
|
||
output_dir = Path(output_dir)
|
||
|
||
# 获取所有图像文件
|
||
image_files = sorted(image_dir.glob(f"*{file_extension}"))
|
||
|
||
if sample_interval > 1:
|
||
image_files = image_files[::sample_interval]
|
||
|
||
if max_samples is not None:
|
||
image_files = image_files[:max_samples]
|
||
|
||
print(f"找到 {len(image_files)} 个图像文件(采样间隔={sample_interval})")
|
||
|
||
# 类别名称
|
||
names = DEFAULT_CLASS_NAMES.copy()
|
||
|
||
if target_class_ids is not None:
|
||
print(f"类别过滤: {format_target_classes(target_class_ids, names=names)}")
|
||
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
for i, image_path in enumerate(image_files):
|
||
print(f"\n[{i+1}/{len(image_files)}] 处理: {image_path.name}")
|
||
|
||
# 查找对应的标签文件
|
||
label_path = resolve_label_path(label_dir, image_path.stem, preferred_extension=label_extension)
|
||
if label_path is None:
|
||
expected_extensions = [label_extension] if label_extension else [".txt", ".json"]
|
||
print(f" 警告:未找到标签文件 {image_path.stem},已检查扩展名: {expected_extensions}")
|
||
error_count += 1
|
||
continue
|
||
|
||
# 查找对应的标定文件
|
||
if calib_dir is not None:
|
||
# 兼容 L2_calib 子目录和直接放在 calib/ 下两种结构
|
||
calib_path = Path(calib_dir) / "L2_calib" / "camera4.json"
|
||
if not calib_path.exists():
|
||
calib_path = Path(calib_dir) / "camera4.json"
|
||
else:
|
||
# 自动推断:依次尝试两种常见目录结构
|
||
base = image_path.parent.parent
|
||
calib_path = base / "calib" / "L2_calib" / "camera4.json"
|
||
if not calib_path.exists():
|
||
calib_path = base / "calib" / "camera4.json"
|
||
|
||
if not calib_path.exists():
|
||
calib_path = None
|
||
print(f" 警告:未找到标定文件,将仅进行2D可视化")
|
||
|
||
try:
|
||
# 可视化单帧
|
||
sample_output_dir = output_dir
|
||
visualize_single_frame(
|
||
image_path=str(image_path),
|
||
label_path=str(label_path),
|
||
output_dir=str(sample_output_dir),
|
||
calib_path=str(calib_path) if calib_path else None,
|
||
roi_size=roi_size,
|
||
virtual_fx=virtual_fx,
|
||
ori_img_size=ori_img_size,
|
||
names=names,
|
||
save_types='combined', # 只保存合成图
|
||
target_class_ids=target_class_ids,
|
||
)
|
||
success_count += 1
|
||
except Exception as e:
|
||
print(f" 错误:{e}")
|
||
error_count += 1
|
||
|
||
print(f"\n处理完成!")
|
||
print(f" 成功: {success_count}")
|
||
print(f" 失败: {error_count}")
|
||
print(f" 输出目录: {output_dir}")
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="批量可视化真值标签")
|
||
parser.add_argument("--image-dir", type=str, required=True, help="图像目录")
|
||
parser.add_argument("--label-dir", type=str, required=True, help="标签目录")
|
||
parser.add_argument("--output", type=str, default="./gt_visualization_batch", help="输出目录")
|
||
parser.add_argument("--calib-dir", type=str, default=None,
|
||
help="标定文件目录(包含L2_calib/camera4.json)")
|
||
parser.add_argument("--roi", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"),
|
||
help="ROI尺寸(宽 高),如: --roi 704 352")
|
||
parser.add_argument("--virtual-fx", type=float, default=None,
|
||
help="虚拟焦距,用于深度归一化")
|
||
parser.add_argument("--ori-img-size", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"),
|
||
help="原始图像尺寸(宽 高)")
|
||
parser.add_argument("--max-samples", type=int, default=None,
|
||
help="最大处理样本数")
|
||
parser.add_argument("--sample-interval", type=int, default=1,
|
||
help="采样间隔,每隔 N 帧处理一帧(默认1,即处理全部)")
|
||
parser.add_argument("--file-extension", type=str, default=".jpg",
|
||
help="图像文件扩展名")
|
||
parser.add_argument("--label-extension", type=str, default=None,
|
||
help="标签文件扩展名,默认自动在 .txt/.json 中查找")
|
||
parser.add_argument("--classes", type=str, nargs="+", default=None,
|
||
help="仅可视化指定类别,支持类别名/类别ID,可传多个值或逗号分隔")
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
target_class_ids = resolve_target_class_ids(args.classes, names=DEFAULT_CLASS_NAMES)
|
||
|
||
visualize_batch(
|
||
image_dir=args.image_dir,
|
||
label_dir=args.label_dir,
|
||
output_dir=args.output,
|
||
calib_dir=args.calib_dir,
|
||
roi_size=tuple(args.roi) if args.roi else None,
|
||
virtual_fx=args.virtual_fx,
|
||
ori_img_size=tuple(args.ori_img_size) if args.ori_img_size else None,
|
||
max_samples=args.max_samples,
|
||
sample_interval=args.sample_interval,
|
||
file_extension=args.file_extension,
|
||
label_extension=args.label_extension,
|
||
target_class_ids=target_class_ids,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|