""" 批量可视化脚本 用途:批量读取多帧图像和对应的真值标签文件,生成可视化结果 使用方法: python scripts_for_gt/visualize_batch.py \ --image-dir /path/to/images \ --label-dir /path/to/labels \ --output output_dir \ --roi 704 352 \ --virtual-fx 500 \ --max-samples 100 """ import argparse from pathlib import Path import sys # Add project root to path FILE = Path(__file__).resolve() ROOT = FILE.parents[1] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from visualize_single_frame import ( DEFAULT_CLASS_NAMES, format_target_classes, resolve_target_class_ids, visualize_single_frame, ) def resolve_label_path(label_dir, image_stem, preferred_extension=None): """按同名文件自动查找标签,支持 .txt / .json。""" candidate_extensions = [] if preferred_extension: normalized_extension = preferred_extension if preferred_extension.startswith(".") else f".{preferred_extension}" candidate_extensions.append(normalized_extension) for extension in (".txt", ".json"): if extension not in candidate_extensions: candidate_extensions.append(extension) for extension in candidate_extensions: candidate_path = label_dir / f"{image_stem}{extension}" if candidate_path.exists(): return candidate_path return None def visualize_batch( image_dir, label_dir, output_dir, calib_dir=None, roi_size=None, virtual_fx=None, ori_img_size=None, max_samples=None, sample_interval=1, file_extension=".jpg", label_extension=None, target_class_ids=None, ): """批量可视化图像 Args: image_dir: 图像目录 label_dir: 标签目录 output_dir: 输出目录 calib_dir: 标定文件目录(包含L2_calib/camera4.json) roi_size: ROI尺寸 (width, height) virtual_fx: 虚拟焦距 ori_img_size: 原始图像尺寸 (width, height) max_samples: 最大处理样本数 sample_interval: 采样间隔,每隔 N 帧处理一帧 file_extension: 图像文件扩展名 label_extension: 标签文件扩展名,默认自动在 .txt/.json 中查找 target_class_ids: 目标类别ID集合,None表示不过滤 """ image_dir = Path(image_dir) label_dir = Path(label_dir) output_dir = Path(output_dir) # 获取所有图像文件 image_files = sorted(image_dir.glob(f"*{file_extension}")) if sample_interval > 1: image_files = image_files[::sample_interval] if max_samples is not None: image_files = image_files[:max_samples] print(f"找到 {len(image_files)} 个图像文件(采样间隔={sample_interval})") # 类别名称 names = DEFAULT_CLASS_NAMES.copy() if target_class_ids is not None: print(f"类别过滤: {format_target_classes(target_class_ids, names=names)}") success_count = 0 error_count = 0 for i, image_path in enumerate(image_files): print(f"\n[{i+1}/{len(image_files)}] 处理: {image_path.name}") # 查找对应的标签文件 label_path = resolve_label_path(label_dir, image_path.stem, preferred_extension=label_extension) if label_path is None: expected_extensions = [label_extension] if label_extension else [".txt", ".json"] print(f" 警告:未找到标签文件 {image_path.stem},已检查扩展名: {expected_extensions}") error_count += 1 continue # 查找对应的标定文件 if calib_dir is not None: # 兼容 L2_calib 子目录和直接放在 calib/ 下两种结构 calib_path = Path(calib_dir) / "L2_calib" / "camera4.json" if not calib_path.exists(): calib_path = Path(calib_dir) / "camera4.json" else: # 自动推断:依次尝试两种常见目录结构 base = image_path.parent.parent calib_path = base / "calib" / "L2_calib" / "camera4.json" if not calib_path.exists(): calib_path = base / "calib" / "camera4.json" if not calib_path.exists(): calib_path = None print(f" 警告:未找到标定文件,将仅进行2D可视化") try: # 可视化单帧 sample_output_dir = output_dir visualize_single_frame( image_path=str(image_path), label_path=str(label_path), output_dir=str(sample_output_dir), calib_path=str(calib_path) if calib_path else None, roi_size=roi_size, virtual_fx=virtual_fx, ori_img_size=ori_img_size, names=names, save_types='combined', # 只保存合成图 target_class_ids=target_class_ids, ) success_count += 1 except Exception as e: print(f" 错误:{e}") error_count += 1 print(f"\n处理完成!") print(f" 成功: {success_count}") print(f" 失败: {error_count}") print(f" 输出目录: {output_dir}") def parse_args(): parser = argparse.ArgumentParser(description="批量可视化真值标签") parser.add_argument("--image-dir", type=str, required=True, help="图像目录") parser.add_argument("--label-dir", type=str, required=True, help="标签目录") parser.add_argument("--output", type=str, default="./gt_visualization_batch", help="输出目录") parser.add_argument("--calib-dir", type=str, default=None, help="标定文件目录(包含L2_calib/camera4.json)") parser.add_argument("--roi", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"), help="ROI尺寸(宽 高),如: --roi 704 352") parser.add_argument("--virtual-fx", type=float, default=None, help="虚拟焦距,用于深度归一化") parser.add_argument("--ori-img-size", type=int, nargs=2, default=None, metavar=("WIDTH", "HEIGHT"), help="原始图像尺寸(宽 高)") parser.add_argument("--max-samples", type=int, default=None, help="最大处理样本数") parser.add_argument("--sample-interval", type=int, default=1, help="采样间隔,每隔 N 帧处理一帧(默认1,即处理全部)") parser.add_argument("--file-extension", type=str, default=".jpg", help="图像文件扩展名") parser.add_argument("--label-extension", type=str, default=None, help="标签文件扩展名,默认自动在 .txt/.json 中查找") parser.add_argument("--classes", type=str, nargs="+", default=None, help="仅可视化指定类别,支持类别名/类别ID,可传多个值或逗号分隔") return parser.parse_args() def main(): args = parse_args() target_class_ids = resolve_target_class_ids(args.classes, names=DEFAULT_CLASS_NAMES) visualize_batch( image_dir=args.image_dir, label_dir=args.label_dir, output_dir=args.output, calib_dir=args.calib_dir, roi_size=tuple(args.roi) if args.roi else None, virtual_fx=args.virtual_fx, ori_img_size=tuple(args.ori_img_size) if args.ori_img_size else None, max_samples=args.max_samples, sample_interval=args.sample_interval, file_extension=args.file_extension, label_extension=args.label_extension, target_class_ids=target_class_ids, ) if __name__ == "__main__": main()