Files
yolov26_3d/tools/convert_gt_to_label/count_visualization_batches.py
2026-06-24 09:35:46 +08:00

278 lines
8.4 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# coding: utf-8
import argparse
import json
import os
from datetime import datetime
from pathlib import Path
DEFAULT_SHOWPATH_ROOT = "/data1/dongying/Mono3d/D4Q2/data_visualization_for_check_camera2"
DEFAULT_BATCH_GLOB = "batch_*"
DEFAULT_IMAGE_EXTS = ".jpg,.jpeg,.png"
def parse_args():
parser = argparse.ArgumentParser(
description="统计可视化结果根目录下每个 batch 的 2D/3D 图像数量。"
)
parser.add_argument(
"--showpath-root",
default=DEFAULT_SHOWPATH_ROOT,
help="可视化结果根目录,目录下应包含 batch_* 子目录。",
)
parser.add_argument(
"--batch-glob",
default=DEFAULT_BATCH_GLOB,
help="batch 目录匹配模式,默认 batch_*。",
)
parser.add_argument(
"--image-exts",
default=DEFAULT_IMAGE_EXTS,
help="参与统计的图片扩展名,逗号分隔,默认 .jpg,.jpeg,.png。",
)
parser.add_argument(
"--check-pairs",
action="store_true",
help="额外校验 2D/3D 下的相对图片路径是否一一对应。",
)
parser.add_argument(
"--sample-mismatch-limit",
type=int,
default=10,
help="配对校验失败时,最多展示多少条 only_in_2d/only_in_3d 样例。",
)
parser.add_argument(
"--output-json",
default=None,
help="可选,将统计结果写入 JSON 文件。",
)
return parser.parse_args()
def normalize_image_exts(raw_exts):
image_exts = []
for ext in raw_exts.split(","):
ext = ext.strip().lower()
if not ext:
continue
if not ext.startswith("."):
ext = f".{ext}"
image_exts.append(ext)
if not image_exts:
raise ValueError("image-exts 不能为空。")
return tuple(sorted(set(image_exts)))
def collect_image_stats(image_dir, image_exts, collect_relpaths=False):
count = 0
relpaths = set() if collect_relpaths else None
if not image_dir.is_dir():
return count, relpaths
for current_root, _, filenames in os.walk(image_dir):
current_root_path = Path(current_root)
for filename in filenames:
if Path(filename).suffix.lower() not in image_exts:
continue
count += 1
if relpaths is not None:
file_path = current_root_path / filename
relpaths.add(file_path.relative_to(image_dir).as_posix())
return count, relpaths
def build_status(batch_summary, check_pairs):
issues = []
if not batch_summary["has_2d_dir"]:
issues.append("MISSING_2D")
if not batch_summary["has_3d_dir"]:
issues.append("MISSING_3D")
if batch_summary["count_diff"] != 0:
issues.append("COUNT_DIFF")
if check_pairs and (
batch_summary["only_in_2d_count"] > 0 or batch_summary["only_in_3d_count"] > 0
):
issues.append("PAIR_DIFF")
return "OK" if not issues else "+".join(issues)
def analyze_batch(batch_dir, image_exts, check_pairs=False, sample_mismatch_limit=10):
dir_2d = batch_dir / "2D"
dir_3d = batch_dir / "3D"
count_2d, relpaths_2d = collect_image_stats(
dir_2d, image_exts, collect_relpaths=check_pairs
)
count_3d, relpaths_3d = collect_image_stats(
dir_3d, image_exts, collect_relpaths=check_pairs
)
summary = {
"batch_name": batch_dir.name,
"batch_dir": str(batch_dir),
"has_2d_dir": dir_2d.is_dir(),
"has_3d_dir": dir_3d.is_dir(),
"count_2d": count_2d,
"count_3d": count_3d,
"count_diff": count_2d - count_3d,
}
if check_pairs:
only_in_2d = sorted(relpaths_2d - relpaths_3d)
only_in_3d = sorted(relpaths_3d - relpaths_2d)
summary.update(
{
"only_in_2d_count": len(only_in_2d),
"only_in_3d_count": len(only_in_3d),
"only_in_2d_samples": only_in_2d[:sample_mismatch_limit],
"only_in_3d_samples": only_in_3d[:sample_mismatch_limit],
}
)
summary["status"] = build_status(summary, check_pairs)
return summary
def build_total_row(batch_summaries, check_pairs):
total_row = {
"batch_name": "TOTAL",
"count_2d": sum(item["count_2d"] for item in batch_summaries),
"count_3d": sum(item["count_3d"] for item in batch_summaries),
}
total_row["count_diff"] = total_row["count_2d"] - total_row["count_3d"]
if check_pairs:
total_row["only_in_2d_count"] = sum(
item["only_in_2d_count"] for item in batch_summaries
)
total_row["only_in_3d_count"] = sum(
item["only_in_3d_count"] for item in batch_summaries
)
total_row["status"] = "OK"
if any(item["status"] != "OK" for item in batch_summaries):
total_row["status"] = "HAS_ISSUES"
return total_row
def print_summary_table(batch_summaries, total_row, check_pairs):
columns = [
("batch_name", "batch"),
("count_2d", "2D"),
("count_3d", "3D"),
("count_diff", "diff"),
]
if check_pairs:
columns.extend(
[
("only_in_2d_count", "only_2d"),
("only_in_3d_count", "only_3d"),
]
)
columns.append(("status", "status"))
table_rows = batch_summaries + [total_row]
widths = {}
for key, title in columns:
widths[key] = max(
len(title),
max(len(str(row.get(key, ""))) for row in table_rows),
)
header = " ".join(title.ljust(widths[key]) for key, title in columns)
separator = " ".join("-" * widths[key] for key, _ in columns)
print(header)
print(separator)
for row in table_rows:
print(
" ".join(str(row.get(key, "")).ljust(widths[key]) for key, _ in columns)
)
def build_report(args, batch_summaries, total_row, image_exts):
return {
"generated_at": datetime.now().isoformat(timespec="seconds"),
"showpath_root": str(Path(args.showpath_root).resolve()),
"batch_glob": args.batch_glob,
"image_exts": list(image_exts),
"check_pairs": args.check_pairs,
"sample_mismatch_limit": args.sample_mismatch_limit,
"total_batches": len(batch_summaries),
"total_summary": total_row,
"batches": batch_summaries,
}
def print_pair_mismatch_details(batch_summaries):
mismatched_batches = [
item
for item in batch_summaries
if item["only_in_2d_count"] > 0 or item["only_in_3d_count"] > 0
]
if not mismatched_batches:
return
print("\nPair mismatch details:")
for item in mismatched_batches:
print(
f"- {item['batch_name']}: only_in_2d={item['only_in_2d_count']}, "
f"only_in_3d={item['only_in_3d_count']}"
)
if item["only_in_2d_samples"]:
print(f" only_in_2d samples: {item['only_in_2d_samples']}")
if item["only_in_3d_samples"]:
print(f" only_in_3d samples: {item['only_in_3d_samples']}")
def main():
args = parse_args()
image_exts = normalize_image_exts(args.image_exts)
showpath_root = Path(args.showpath_root)
if not showpath_root.is_dir():
print(f"可视化结果根目录不存在: {showpath_root}")
return 1
batch_dirs = sorted(
path for path in showpath_root.glob(args.batch_glob) if path.is_dir()
)
if not batch_dirs:
print(
f"未找到 batch 目录root={showpath_root}, batch_glob={args.batch_glob}"
)
return 1
batch_summaries = [
analyze_batch(
batch_dir,
image_exts,
check_pairs=args.check_pairs,
sample_mismatch_limit=args.sample_mismatch_limit,
)
for batch_dir in batch_dirs
]
total_row = build_total_row(batch_summaries, args.check_pairs)
print_summary_table(batch_summaries, total_row, args.check_pairs)
if args.check_pairs:
print_pair_mismatch_details(batch_summaries)
report = build_report(args, batch_summaries, total_row, image_exts)
if args.output_json:
output_json = Path(args.output_json)
output_json.parent.mkdir(parents=True, exist_ok=True)
with open(output_json, "w", encoding="utf-8") as f:
json.dump(report, f, ensure_ascii=False, indent=2)
print(f"\n统计结果已写入: {output_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())