Files
yolov26_3d/tools/convert_gt_to_label/rewrite_visualization_archives.py
2026-06-24 09:35:46 +08:00

499 lines
16 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# coding: utf-8
import argparse
import copy
import hashlib
import json
import tarfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path, PurePosixPath
MODALITIES = ("2D", "3D")
DEFAULT_ARCHIVE_GLOB = "part_*.tar.gz"
DEFAULT_REPORT_NAME = "rewrite_archives_report.json"
def parse_args():
parser = argparse.ArgumentParser(
description=(
"将已有可视化归档中的成员路径改写为 part_xxxx/images/<filename>"
"并输出到新的归档目录。"
)
)
parser.add_argument(
"source_path",
help="输入路径,可为 archive root、单个 modality 目录,或单个 tar/tar.gz 文件。",
)
parser.add_argument(
"output_root",
nargs="?",
default=None,
help="输出目录。默认在 source_path 同级生成 <name>_images_prefixed。",
)
parser.add_argument(
"--modalities",
default="2D,3D",
help="当 source_path 为 archive root 时要处理的模态,逗号分隔,默认 2D,3D。",
)
parser.add_argument(
"--archive-glob",
default=DEFAULT_ARCHIVE_GLOB,
help="归档匹配模式,默认 part_*.tar.gz。",
)
parser.add_argument(
"--parts",
default=None,
help=(
"仅处理指定 part逗号分隔既支持 part_0001也支持 part_0001.tar.gz。"
),
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="最多处理多少个归档,便于抽样验证。",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="并发处理归档数,默认 1。",
)
parser.add_argument(
"--gzip-compresslevel",
type=int,
default=1,
help="输出 tar.gz 的 gzip 压缩级别,范围 0-9默认 1。",
)
parser.add_argument(
"--checksum",
action="store_true",
help="输出归档后计算 sha256。",
)
parser.add_argument(
"--verify",
action="store_true",
default=True,
help="输出归档后校验成员数及路径前缀,默认开启。",
)
parser.add_argument(
"--no-verify",
dest="verify",
action="store_false",
help="关闭输出归档校验。",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="允许覆盖已有输出归档。",
)
parser.add_argument(
"--report-name",
default=DEFAULT_REPORT_NAME,
help=f"输出报告文件名,默认 {DEFAULT_REPORT_NAME}",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="仅输出转换计划,不实际生成归档。",
)
return parser.parse_args()
def now_str():
return datetime.now().isoformat(timespec="seconds")
def normalize_modalities(raw_modalities):
modalities = []
for item in raw_modalities.split(","):
modality = item.strip()
if not modality:
continue
if modality not in MODALITIES:
raise ValueError(f"unsupported modality: {modality}")
modalities.append(modality)
if not modalities:
raise ValueError("modalities 不能为空。")
return tuple(dict.fromkeys(modalities))
def archive_suffix(path):
name = path.name
if name.endswith(".tar.gz"):
return ".tar.gz"
if name.endswith(".tar"):
return ".tar"
raise ValueError(f"unsupported archive suffix: {path}")
def strip_archive_suffix(name):
if name.endswith(".tar.gz"):
return name[: -len(".tar.gz")]
if name.endswith(".tar"):
return name[: -len(".tar")]
return name
def parse_parts_filter(raw_parts):
if not raw_parts:
return None
normalized = []
for item in raw_parts.split(","):
value = item.strip()
if not value:
continue
normalized.append(strip_archive_suffix(value))
if not normalized:
return None
return set(normalized)
def is_archive_file(path):
return path.is_file() and path.name.endswith((".tar", ".tar.gz"))
def default_output_root(source_path, resolve_mode):
source_path = source_path.resolve()
if resolve_mode == "single_archive":
base_name = strip_archive_suffix(source_path.name)
else:
base_name = source_path.name
return source_path.parent / f"{base_name}_images_prefixed"
def resolve_archives(source_path, archive_glob, modalities):
source_path = source_path.resolve()
if source_path.is_file():
if not is_archive_file(source_path):
raise ValueError(f"source_path is not a supported archive file: {source_path}")
return "single_archive", [source_path], source_path.parent
if not source_path.is_dir():
raise FileNotFoundError(f"source_path does not exist or is not a directory: {source_path}")
direct_archives = sorted(path for path in source_path.glob(archive_glob) if is_archive_file(path))
if source_path.name in MODALITIES and direct_archives:
return "modality_dir", direct_archives, source_path.parent
archive_paths = []
for modality in modalities:
modality_dir = source_path / modality
if not modality_dir.is_dir():
continue
archive_paths.extend(
sorted(path for path in modality_dir.glob(archive_glob) if is_archive_file(path))
)
if archive_paths:
return "archive_root", archive_paths, source_path
raise FileNotFoundError(
f"no archives found under {source_path} with archive_glob={archive_glob}"
)
def filter_archives(archive_paths, parts_filter, limit):
selected = archive_paths
if parts_filter:
selected = [
path for path in selected if strip_archive_suffix(path.name) in parts_filter
]
if limit is not None:
selected = selected[:limit]
if not selected:
raise FileNotFoundError("没有匹配到待处理归档,请检查 parts 或 archive_glob。")
return selected
def build_output_archive_path(archive_path, source_base_dir, output_root):
return output_root / archive_path.relative_to(source_base_dir)
def build_archive_member_dir(archive_path):
return f"{strip_archive_suffix(archive_path.name)}/images"
def build_output_member_name(member_name, output_archive):
return f"{build_archive_member_dir(output_archive)}/{PurePosixPath(member_name).name}"
def open_output_tarfile(output_archive, archive_format, gzip_compresslevel):
suffix = archive_format or archive_suffix(output_archive)
if suffix == ".tar.gz":
return tarfile.open(
output_archive,
mode="w:gz",
compresslevel=gzip_compresslevel,
)
return tarfile.open(output_archive, mode="w")
def compute_sha256(file_path):
digest = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
chunk = f.read(1024 * 1024)
if not chunk:
break
digest.update(chunk)
return digest.hexdigest()
def verify_rewritten_archive(archive_path, expected_member_count):
member_count = 0
samples = []
expected_prefix = f"{build_archive_member_dir(archive_path)}/"
with tarfile.open(archive_path, mode="r:*") as tar_obj:
for member in tar_obj:
if not member.isfile():
raise ValueError(
f"rewritten archive contains non-regular member: {archive_path} -> {member.name}"
)
if not member.name.startswith(expected_prefix):
raise ValueError(
f"rewritten archive member missing expected prefix {expected_prefix}: "
f"{archive_path} -> {member.name}"
)
member_count += 1
if len(samples) < 5:
samples.append(member.name)
if member_count != expected_member_count:
raise ValueError(
f"rewritten archive member count mismatch: {archive_path}, "
f"expected={expected_member_count}, actual={member_count}"
)
return {
"member_count": member_count,
"sample_members": samples,
}
def rewrite_archive(
source_archive,
output_archive,
gzip_compresslevel,
verify,
checksum,
overwrite,
dry_run,
):
source_archive = source_archive.resolve()
output_archive = output_archive.resolve()
if output_archive.exists() and not overwrite:
raise FileExistsError(f"output archive already exists: {output_archive}")
samples_before = []
samples_after = []
if dry_run:
return {
"status": "dry_run",
"source_archive": str(source_archive),
"output_archive": str(output_archive),
"member_count": None,
"sample_members_before": samples_before,
"sample_members_after": samples_after,
"output_sha256": None,
"verification": None,
}
output_archive.parent.mkdir(parents=True, exist_ok=True)
partial_archive = output_archive.with_name(f"{output_archive.name}.partial")
if partial_archive.exists():
partial_archive.unlink()
seen_output_names = {}
member_count = 0
output_archive_format = archive_suffix(output_archive)
try:
with tarfile.open(source_archive, mode="r:*") as input_tar:
with open_output_tarfile(
partial_archive,
output_archive_format,
gzip_compresslevel,
) as output_tar:
for member in input_tar:
if not member.isfile():
raise ValueError(
f"only regular file members are supported: "
f"{source_archive} -> {member.name}"
)
output_name = build_output_member_name(member.name, output_archive)
if output_name in seen_output_names:
raise ValueError(
"duplicated output member name after rewrite: "
f"{output_name}, first={seen_output_names[output_name]}, "
f"duplicate={member.name}"
)
seen_output_names[output_name] = member.name
if len(samples_before) < 5:
samples_before.append(member.name)
if len(samples_after) < 5:
samples_after.append(output_name)
input_file = input_tar.extractfile(member)
if input_file is None:
raise ValueError(
f"failed to extract file object from member: {source_archive} -> {member.name}"
)
output_member = copy.copy(member)
output_member.name = output_name
if output_member.pax_headers:
output_member.pax_headers = dict(output_member.pax_headers)
try:
output_tar.addfile(output_member, input_file)
finally:
input_file.close()
member_count += 1
partial_archive.replace(output_archive)
finally:
if partial_archive.exists():
partial_archive.unlink()
verification = verify_rewritten_archive(output_archive, member_count) if verify else None
output_sha256 = compute_sha256(output_archive) if checksum else None
return {
"status": "rewritten",
"source_archive": str(source_archive),
"output_archive": str(output_archive),
"member_count": member_count,
"sample_members_before": samples_before,
"sample_members_after": samples_after,
"output_sha256": output_sha256,
"verification": verification,
}
def rewrite_archives(archive_paths, source_base_dir, output_root, args):
tasks = [
(archive_path, build_output_archive_path(archive_path, source_base_dir, output_root))
for archive_path in archive_paths
]
if args.workers <= 1 or len(tasks) <= 1:
return [
rewrite_archive(
source_archive,
output_archive,
args.gzip_compresslevel,
args.verify,
args.checksum,
args.overwrite,
args.dry_run,
)
for source_archive, output_archive in tasks
]
results = [None] * len(tasks)
with ThreadPoolExecutor(max_workers=min(args.workers, len(tasks))) as executor:
future_map = {
executor.submit(
rewrite_archive,
source_archive,
output_archive,
args.gzip_compresslevel,
args.verify,
args.checksum,
args.overwrite,
args.dry_run,
): index
for index, (source_archive, output_archive) in enumerate(tasks)
}
for future in as_completed(future_map):
index = future_map[future]
results[index] = future.result()
return results
def build_report(args, resolve_mode, source_path, output_root, archive_paths, results):
return {
"generated_at": now_str(),
"source_path": str(Path(source_path).resolve()),
"output_root": str(Path(output_root).resolve()),
"resolve_mode": resolve_mode,
"modalities": list(normalize_modalities(args.modalities)),
"archive_glob": args.archive_glob,
"parts": sorted(parse_parts_filter(args.parts) or []),
"limit": args.limit,
"workers": args.workers,
"gzip_compresslevel": args.gzip_compresslevel,
"checksum": args.checksum,
"verify": args.verify,
"overwrite": args.overwrite,
"dry_run": args.dry_run,
"selected_archive_count": len(archive_paths),
"results": results,
}
def write_json(path, data):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def main():
args = parse_args()
if args.workers <= 0:
raise ValueError("workers 必须大于 0")
if not 0 <= args.gzip_compresslevel <= 9:
raise ValueError("gzip-compresslevel 必须在 0 到 9 之间")
source_path = Path(args.source_path).resolve()
modalities = normalize_modalities(args.modalities)
parts_filter = parse_parts_filter(args.parts)
resolve_mode, archive_paths, source_base_dir = resolve_archives(
source_path,
args.archive_glob,
modalities,
)
archive_paths = filter_archives(archive_paths, parts_filter, args.limit)
output_root = (
Path(args.output_root).resolve()
if args.output_root
else default_output_root(source_path, resolve_mode)
)
print("")
print("######################################################################")
print("# Rewrite visualization archives")
print("######################################################################")
print(f"Source path : {source_path}")
print(f"Output root : {output_root}")
print(f"Resolve mode : {resolve_mode}")
print(f"Selected count : {len(archive_paths)}")
print(f"Workers : {args.workers}")
print(f"Dry run : {args.dry_run}")
results = rewrite_archives(archive_paths, source_base_dir, output_root, args)
report = build_report(args, resolve_mode, source_path, output_root, archive_paths, results)
report_path = output_root / args.report_name
write_json(report_path, report)
print(f"Report : {report_path}")
for item in results:
print(
f"[{item['status']}] {item['source_archive']} -> {item['output_archive']}"
)
if item["sample_members_before"]:
print(f" before: {item['sample_members_before'][:3]}")
print(f" after : {item['sample_members_after'][:3]}")
return 0
if __name__ == "__main__":
raise SystemExit(main())