Files
yolov26_3d/tools/convert_gt_to_label/rewrite_visualization_archives.py

499 lines
16 KiB
Python
Raw Permalink Normal View History

2026-06-24 09:35:46 +08:00
#!/usr/bin/env python3
# coding: utf-8
import argparse
import copy
import hashlib
import json
import tarfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path, PurePosixPath
MODALITIES = ("2D", "3D")
DEFAULT_ARCHIVE_GLOB = "part_*.tar.gz"
DEFAULT_REPORT_NAME = "rewrite_archives_report.json"
def parse_args():
parser = argparse.ArgumentParser(
description=(
"将已有可视化归档中的成员路径改写为 part_xxxx/images/<filename>"
"并输出到新的归档目录。"
)
)
parser.add_argument(
"source_path",
help="输入路径,可为 archive root、单个 modality 目录,或单个 tar/tar.gz 文件。",
)
parser.add_argument(
"output_root",
nargs="?",
default=None,
help="输出目录。默认在 source_path 同级生成 <name>_images_prefixed。",
)
parser.add_argument(
"--modalities",
default="2D,3D",
help="当 source_path 为 archive root 时要处理的模态,逗号分隔,默认 2D,3D。",
)
parser.add_argument(
"--archive-glob",
default=DEFAULT_ARCHIVE_GLOB,
help="归档匹配模式,默认 part_*.tar.gz。",
)
parser.add_argument(
"--parts",
default=None,
help=(
"仅处理指定 part逗号分隔既支持 part_0001也支持 part_0001.tar.gz。"
),
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="最多处理多少个归档,便于抽样验证。",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="并发处理归档数,默认 1。",
)
parser.add_argument(
"--gzip-compresslevel",
type=int,
default=1,
help="输出 tar.gz 的 gzip 压缩级别,范围 0-9默认 1。",
)
parser.add_argument(
"--checksum",
action="store_true",
help="输出归档后计算 sha256。",
)
parser.add_argument(
"--verify",
action="store_true",
default=True,
help="输出归档后校验成员数及路径前缀,默认开启。",
)
parser.add_argument(
"--no-verify",
dest="verify",
action="store_false",
help="关闭输出归档校验。",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="允许覆盖已有输出归档。",
)
parser.add_argument(
"--report-name",
default=DEFAULT_REPORT_NAME,
help=f"输出报告文件名,默认 {DEFAULT_REPORT_NAME}",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="仅输出转换计划,不实际生成归档。",
)
return parser.parse_args()
def now_str():
return datetime.now().isoformat(timespec="seconds")
def normalize_modalities(raw_modalities):
modalities = []
for item in raw_modalities.split(","):
modality = item.strip()
if not modality:
continue
if modality not in MODALITIES:
raise ValueError(f"unsupported modality: {modality}")
modalities.append(modality)
if not modalities:
raise ValueError("modalities 不能为空。")
return tuple(dict.fromkeys(modalities))
def archive_suffix(path):
name = path.name
if name.endswith(".tar.gz"):
return ".tar.gz"
if name.endswith(".tar"):
return ".tar"
raise ValueError(f"unsupported archive suffix: {path}")
def strip_archive_suffix(name):
if name.endswith(".tar.gz"):
return name[: -len(".tar.gz")]
if name.endswith(".tar"):
return name[: -len(".tar")]
return name
def parse_parts_filter(raw_parts):
if not raw_parts:
return None
normalized = []
for item in raw_parts.split(","):
value = item.strip()
if not value:
continue
normalized.append(strip_archive_suffix(value))
if not normalized:
return None
return set(normalized)
def is_archive_file(path):
return path.is_file() and path.name.endswith((".tar", ".tar.gz"))
def default_output_root(source_path, resolve_mode):
source_path = source_path.resolve()
if resolve_mode == "single_archive":
base_name = strip_archive_suffix(source_path.name)
else:
base_name = source_path.name
return source_path.parent / f"{base_name}_images_prefixed"
def resolve_archives(source_path, archive_glob, modalities):
source_path = source_path.resolve()
if source_path.is_file():
if not is_archive_file(source_path):
raise ValueError(f"source_path is not a supported archive file: {source_path}")
return "single_archive", [source_path], source_path.parent
if not source_path.is_dir():
raise FileNotFoundError(f"source_path does not exist or is not a directory: {source_path}")
direct_archives = sorted(path for path in source_path.glob(archive_glob) if is_archive_file(path))
if source_path.name in MODALITIES and direct_archives:
return "modality_dir", direct_archives, source_path.parent
archive_paths = []
for modality in modalities:
modality_dir = source_path / modality
if not modality_dir.is_dir():
continue
archive_paths.extend(
sorted(path for path in modality_dir.glob(archive_glob) if is_archive_file(path))
)
if archive_paths:
return "archive_root", archive_paths, source_path
raise FileNotFoundError(
f"no archives found under {source_path} with archive_glob={archive_glob}"
)
def filter_archives(archive_paths, parts_filter, limit):
selected = archive_paths
if parts_filter:
selected = [
path for path in selected if strip_archive_suffix(path.name) in parts_filter
]
if limit is not None:
selected = selected[:limit]
if not selected:
raise FileNotFoundError("没有匹配到待处理归档,请检查 parts 或 archive_glob。")
return selected
def build_output_archive_path(archive_path, source_base_dir, output_root):
return output_root / archive_path.relative_to(source_base_dir)
def build_archive_member_dir(archive_path):
return f"{strip_archive_suffix(archive_path.name)}/images"
def build_output_member_name(member_name, output_archive):
return f"{build_archive_member_dir(output_archive)}/{PurePosixPath(member_name).name}"
def open_output_tarfile(output_archive, archive_format, gzip_compresslevel):
suffix = archive_format or archive_suffix(output_archive)
if suffix == ".tar.gz":
return tarfile.open(
output_archive,
mode="w:gz",
compresslevel=gzip_compresslevel,
)
return tarfile.open(output_archive, mode="w")
def compute_sha256(file_path):
digest = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
digest.update(chunk)
return digest.hexdigest()
def verify_rewritten_archive(archive_path, expected_member_count):
member_count = 0
samples = []
expected_prefix = f"{build_archive_member_dir(archive_path)}/"
with tarfile.open(archive_path, mode="r:*") as tar_obj:
for member in tar_obj:
if not member.isfile():
raise ValueError(
f"rewritten archive contains non-regular member: {archive_path} -> {member.name}"
)
if not member.name.startswith(expected_prefix):
raise ValueError(
f"rewritten archive member missing expected prefix {expected_prefix}: "
f"{archive_path} -> {member.name}"
)
member_count += 1
if len(samples) < 5:
samples.append(member.name)
if member_count != expected_member_count:
raise ValueError(
f"rewritten archive member count mismatch: {archive_path}, "
f"expected={expected_member_count}, actual={member_count}"
)
return {
"member_count": member_count,
"sample_members": samples,
}
def rewrite_archive(
source_archive,
output_archive,
gzip_compresslevel,
verify,
checksum,
overwrite,
dry_run,
):
source_archive = source_archive.resolve()
output_archive = output_archive.resolve()
if output_archive.exists() and not overwrite:
raise FileExistsError(f"output archive already exists: {output_archive}")
samples_before = []
samples_after = []
if dry_run:
return {
"status": "dry_run",
"source_archive": str(source_archive),
"output_archive": str(output_archive),
"member_count": None,
"sample_members_before": samples_before,
"sample_members_after": samples_after,
"output_sha256": None,
"verification": None,
}
output_archive.parent.mkdir(parents=True, exist_ok=True)
partial_archive = output_archive.with_name(f"{output_archive.name}.partial")
if partial_archive.exists():
partial_archive.unlink()
seen_output_names = {}
member_count = 0
output_archive_format = archive_suffix(output_archive)
try:
with tarfile.open(source_archive, mode="r:*") as input_tar:
with open_output_tarfile(
partial_archive,
output_archive_format,
gzip_compresslevel,
) as output_tar:
for member in input_tar:
if not member.isfile():
raise ValueError(
f"only regular file members are supported: "
f"{source_archive} -> {member.name}"
)
output_name = build_output_member_name(member.name, output_archive)
if output_name in seen_output_names:
raise ValueError(
"duplicated output member name after rewrite: "
f"{output_name}, first={seen_output_names[output_name]}, "
f"duplicate={member.name}"
)
seen_output_names[output_name] = member.name
if len(samples_before) < 5:
samples_before.append(member.name)
if len(samples_after) < 5:
samples_after.append(output_name)
input_file = input_tar.extractfile(member)
if input_file is None:
raise ValueError(
f"failed to extract file object from member: {source_archive} -> {member.name}"
)
output_member = copy.copy(member)
output_member.name = output_name
if output_member.pax_headers:
output_member.pax_headers = dict(output_member.pax_headers)
try:
output_tar.addfile(output_member, input_file)
finally:
input_file.close()
member_count += 1
partial_archive.replace(output_archive)
finally:
if partial_archive.exists():
partial_archive.unlink()
verification = verify_rewritten_archive(output_archive, member_count) if verify else None
output_sha256 = compute_sha256(output_archive) if checksum else None
return {
"status": "rewritten",
"source_archive": str(source_archive),
"output_archive": str(output_archive),
"member_count": member_count,
"sample_members_before": samples_before,
"sample_members_after": samples_after,
"output_sha256": output_sha256,
"verification": verification,
}
def rewrite_archives(archive_paths, source_base_dir, output_root, args):
tasks = [
(archive_path, build_output_archive_path(archive_path, source_base_dir, output_root))
for archive_path in archive_paths
]
if args.workers <= 1 or len(tasks) <= 1:
return [
rewrite_archive(
source_archive,
output_archive,
args.gzip_compresslevel,
args.verify,
args.checksum,
args.overwrite,
args.dry_run,
)
for source_archive, output_archive in tasks
]
results = [None] * len(tasks)
with ThreadPoolExecutor(max_workers=min(args.workers, len(tasks))) as executor:
future_map = {
executor.submit(
rewrite_archive,
source_archive,
output_archive,
args.gzip_compresslevel,
args.verify,
args.checksum,
args.overwrite,
args.dry_run,
): index
for index, (source_archive, output_archive) in enumerate(tasks)
}
for future in as_completed(future_map):
index = future_map[future]
results[index] = future.result()
return results
def build_report(args, resolve_mode, source_path, output_root, archive_paths, results):
return {
"generated_at": now_str(),
"source_path": str(Path(source_path).resolve()),
"output_root": str(Path(output_root).resolve()),
"resolve_mode": resolve_mode,
"modalities": list(normalize_modalities(args.modalities)),
"archive_glob": args.archive_glob,
"parts": sorted(parse_parts_filter(args.parts) or []),
"limit": args.limit,
"workers": args.workers,
"gzip_compresslevel": args.gzip_compresslevel,
"checksum": args.checksum,
"verify": args.verify,
"overwrite": args.overwrite,
"dry_run": args.dry_run,
"selected_archive_count": len(archive_paths),
"results": results,
}
def write_json(path, data):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def main():
args = parse_args()
if args.workers <= 0:
raise ValueError("workers 必须大于 0")
if not 0 <= args.gzip_compresslevel <= 9:
raise ValueError("gzip-compresslevel 必须在 0 到 9 之间")
source_path = Path(args.source_path).resolve()
modalities = normalize_modalities(args.modalities)
parts_filter = parse_parts_filter(args.parts)
resolve_mode, archive_paths, source_base_dir = resolve_archives(
source_path,
args.archive_glob,
modalities,
)
archive_paths = filter_archives(archive_paths, parts_filter, args.limit)
output_root = (
Path(args.output_root).resolve()
if args.output_root
else default_output_root(source_path, resolve_mode)
)
print("")
print("######################################################################")
print("# Rewrite visualization archives")
print("######################################################################")
print(f"Source path : {source_path}")
print(f"Output root : {output_root}")
print(f"Resolve mode : {resolve_mode}")
print(f"Selected count : {len(archive_paths)}")
print(f"Workers : {args.workers}")
print(f"Dry run : {args.dry_run}")
results = rewrite_archives(archive_paths, source_base_dir, output_root, args)
report = build_report(args, resolve_mode, source_path, output_root, archive_paths, results)
report_path = output_root / args.report_name
write_json(report_path, report)
print(f"Report : {report_path}")
for item in results:
print(
f"[{item['status']}] {item['source_archive']} -> {item['output_archive']}"
)
if item["sample_members_before"]:
print(f" before: {item['sample_members_before'][:3]}")
print(f" after : {item['sample_members_after'][:3]}")
return 0
if __name__ == "__main__":
raise SystemExit(main())