Files
yolov26_3d/tools/convert_gt_to_label/count_visualization_batches.py

278 lines
8.4 KiB
Python
Raw Permalink Normal View History

2026-06-24 09:35:46 +08:00
#!/usr/bin/env python3
# coding: utf-8
import argparse
import json
import os
from datetime import datetime
from pathlib import Path
DEFAULT_SHOWPATH_ROOT = "/data1/dongying/Mono3d/D4Q2/data_visualization_for_check_camera2"
DEFAULT_BATCH_GLOB = "batch_*"
DEFAULT_IMAGE_EXTS = ".jpg,.jpeg,.png"
def parse_args():
parser = argparse.ArgumentParser(
description="统计可视化结果根目录下每个 batch 的 2D/3D 图像数量。"
)
parser.add_argument(
"--showpath-root",
default=DEFAULT_SHOWPATH_ROOT,
help="可视化结果根目录,目录下应包含 batch_* 子目录。",
)
parser.add_argument(
"--batch-glob",
default=DEFAULT_BATCH_GLOB,
help="batch 目录匹配模式,默认 batch_*。",
)
parser.add_argument(
"--image-exts",
default=DEFAULT_IMAGE_EXTS,
help="参与统计的图片扩展名,逗号分隔,默认 .jpg,.jpeg,.png。",
)
parser.add_argument(
"--check-pairs",
action="store_true",
help="额外校验 2D/3D 下的相对图片路径是否一一对应。",
)
parser.add_argument(
"--sample-mismatch-limit",
type=int,
default=10,
help="配对校验失败时,最多展示多少条 only_in_2d/only_in_3d 样例。",
)
parser.add_argument(
"--output-json",
default=None,
help="可选,将统计结果写入 JSON 文件。",
)
return parser.parse_args()
def normalize_image_exts(raw_exts):
image_exts = []
for ext in raw_exts.split(","):
ext = ext.strip().lower()
if not ext:
continue
if not ext.startswith("."):
ext = f".{ext}"
image_exts.append(ext)
if not image_exts:
raise ValueError("image-exts 不能为空。")
return tuple(sorted(set(image_exts)))
def collect_image_stats(image_dir, image_exts, collect_relpaths=False):
count = 0
relpaths = set() if collect_relpaths else None
if not image_dir.is_dir():
return count, relpaths
for current_root, _, filenames in os.walk(image_dir):
current_root_path = Path(current_root)
for filename in filenames:
if Path(filename).suffix.lower() not in image_exts:
continue
count += 1
if relpaths is not None:
file_path = current_root_path / filename
relpaths.add(file_path.relative_to(image_dir).as_posix())
return count, relpaths
def build_status(batch_summary, check_pairs):
issues = []
if not batch_summary["has_2d_dir"]:
issues.append("MISSING_2D")
if not batch_summary["has_3d_dir"]:
issues.append("MISSING_3D")
if batch_summary["count_diff"] != 0:
issues.append("COUNT_DIFF")
if check_pairs and (
batch_summary["only_in_2d_count"] > 0 or batch_summary["only_in_3d_count"] > 0
):
issues.append("PAIR_DIFF")
return "OK" if not issues else "+".join(issues)
def analyze_batch(batch_dir, image_exts, check_pairs=False, sample_mismatch_limit=10):
dir_2d = batch_dir / "2D"
dir_3d = batch_dir / "3D"
count_2d, relpaths_2d = collect_image_stats(
dir_2d, image_exts, collect_relpaths=check_pairs
)
count_3d, relpaths_3d = collect_image_stats(
dir_3d, image_exts, collect_relpaths=check_pairs
)
summary = {
"batch_name": batch_dir.name,
"batch_dir": str(batch_dir),
"has_2d_dir": dir_2d.is_dir(),
"has_3d_dir": dir_3d.is_dir(),
"count_2d": count_2d,
"count_3d": count_3d,
"count_diff": count_2d - count_3d,
}
if check_pairs:
only_in_2d = sorted(relpaths_2d - relpaths_3d)
only_in_3d = sorted(relpaths_3d - relpaths_2d)
summary.update(
{
"only_in_2d_count": len(only_in_2d),
"only_in_3d_count": len(only_in_3d),
"only_in_2d_samples": only_in_2d[:sample_mismatch_limit],
"only_in_3d_samples": only_in_3d[:sample_mismatch_limit],
}
)
summary["status"] = build_status(summary, check_pairs)
return summary
def build_total_row(batch_summaries, check_pairs):
total_row = {
"batch_name": "TOTAL",
"count_2d": sum(item["count_2d"] for item in batch_summaries),
"count_3d": sum(item["count_3d"] for item in batch_summaries),
}
total_row["count_diff"] = total_row["count_2d"] - total_row["count_3d"]
if check_pairs:
total_row["only_in_2d_count"] = sum(
item["only_in_2d_count"] for item in batch_summaries
)
total_row["only_in_3d_count"] = sum(
item["only_in_3d_count"] for item in batch_summaries
)
total_row["status"] = "OK"
if any(item["status"] != "OK" for item in batch_summaries):
total_row["status"] = "HAS_ISSUES"
return total_row
def print_summary_table(batch_summaries, total_row, check_pairs):
columns = [
("batch_name", "batch"),
("count_2d", "2D"),
("count_3d", "3D"),
("count_diff", "diff"),
]
if check_pairs:
columns.extend(
[
("only_in_2d_count", "only_2d"),
("only_in_3d_count", "only_3d"),
]
)
columns.append(("status", "status"))
table_rows = batch_summaries + [total_row]
widths = {}
for key, title in columns:
widths[key] = max(
len(title),
max(len(str(row.get(key, ""))) for row in table_rows),
)
header = " ".join(title.ljust(widths[key]) for key, title in columns)
separator = " ".join("-" * widths[key] for key, _ in columns)
print(header)
print(separator)
for row in table_rows:
print(
" ".join(str(row.get(key, "")).ljust(widths[key]) for key, _ in columns)
)
def build_report(args, batch_summaries, total_row, image_exts):
return {
"generated_at": datetime.now().isoformat(timespec="seconds"),
"showpath_root": str(Path(args.showpath_root).resolve()),
"batch_glob": args.batch_glob,
"image_exts": list(image_exts),
"check_pairs": args.check_pairs,
"sample_mismatch_limit": args.sample_mismatch_limit,
"total_batches": len(batch_summaries),
"total_summary": total_row,
"batches": batch_summaries,
}
def print_pair_mismatch_details(batch_summaries):
mismatched_batches = [
item
for item in batch_summaries
if item["only_in_2d_count"] > 0 or item["only_in_3d_count"] > 0
]
if not mismatched_batches:
return
print("\nPair mismatch details:")
for item in mismatched_batches:
print(
f"- {item['batch_name']}: only_in_2d={item['only_in_2d_count']}, "
f"only_in_3d={item['only_in_3d_count']}"
)
if item["only_in_2d_samples"]:
print(f" only_in_2d samples: {item['only_in_2d_samples']}")
if item["only_in_3d_samples"]:
print(f" only_in_3d samples: {item['only_in_3d_samples']}")
def main():
args = parse_args()
image_exts = normalize_image_exts(args.image_exts)
showpath_root = Path(args.showpath_root)
if not showpath_root.is_dir():
print(f"可视化结果根目录不存在: {showpath_root}")
return 1
batch_dirs = sorted(
path for path in showpath_root.glob(args.batch_glob) if path.is_dir()
)
if not batch_dirs:
print(
f"未找到 batch 目录root={showpath_root}, batch_glob={args.batch_glob}"
)
return 1
batch_summaries = [
analyze_batch(
batch_dir,
image_exts,
check_pairs=args.check_pairs,
sample_mismatch_limit=args.sample_mismatch_limit,
)
for batch_dir in batch_dirs
]
total_row = build_total_row(batch_summaries, args.check_pairs)
print_summary_table(batch_summaries, total_row, args.check_pairs)
if args.check_pairs:
print_pair_mismatch_details(batch_summaries)
report = build_report(args, batch_summaries, total_row, image_exts)
if args.output_json:
output_json = Path(args.output_json)
output_json.parent.mkdir(parents=True, exist_ok=True)
with open(output_json, "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"\n统计结果已写入: {output_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())