794 lines
30 KiB
Python
Executable File
794 lines
30 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import random
|
|
import shutil
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
from convert_txt_to_json import convert_txt_to_json
|
|
|
|
try:
|
|
import yaml # type: ignore
|
|
except ImportError: # pragma: no cover - optional dependency fallback
|
|
yaml = None
|
|
|
|
try:
|
|
from ruamel.yaml import YAML # type: ignore
|
|
except ImportError: # pragma: no cover - optional dependency fallback
|
|
YAML = None
|
|
|
|
from ultralytics.data.ground3d_augment import parse_ground_3d_label_file
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FrameRecord:
|
|
"""Lightweight frame summary used for balanced subset mining."""
|
|
|
|
source_index: int
|
|
list_root: str
|
|
label_entry: str
|
|
manifest_entry: str
|
|
label_path: str
|
|
image_path: str
|
|
present_classes: tuple[int, ...]
|
|
present_class_names: tuple[str, ...]
|
|
instance_count_per_class: dict[int, int]
|
|
num_objects: int
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Mine a class-balanced Ground3D evaluation subset from a train/val GT list."
|
|
)
|
|
parser.add_argument("--data", type=str, required=True, help="Dataset yaml, e.g. ultralytics/cfg/datasets/mono3d_ground.yaml")
|
|
parser.add_argument("--split", type=str, default="val", help="Split key in data yaml, e.g. train or val.")
|
|
parser.add_argument("--target-frames", type=int, default=300, help="Target number of frames to select.")
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
default="",
|
|
help="Directory used for mined outputs. Defaults to <data-yaml-dir>/mined_subsets.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-prefix",
|
|
type=str,
|
|
default="",
|
|
help="Optional filename prefix. Defaults to balanced_<split>_<selected_frames>.",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=20260401, help="Random seed used for tie-breaking.")
|
|
parser.add_argument(
|
|
"--classes",
|
|
type=int,
|
|
nargs="*",
|
|
default=None,
|
|
help="Optional mapped class ids to balance. Defaults to all classes defined by class_map.",
|
|
)
|
|
parser.add_argument(
|
|
"--include-empty",
|
|
action="store_true",
|
|
help="Allow frames without any selected classes into the candidate pool.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-entries",
|
|
type=int,
|
|
default=0,
|
|
help="Optional cap on scanned split entries for debugging. 0 scans the whole split.",
|
|
)
|
|
parser.add_argument(
|
|
"--sample-selection",
|
|
choices=("head", "random"),
|
|
default="head",
|
|
help="How to subsample split entries when --max-entries is set.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-export-assets",
|
|
action="store_true",
|
|
help="Do not copy selected image and label files into the output directory.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-json-convert",
|
|
action="store_true",
|
|
help="Do not convert exported label txt files to JSON format.",
|
|
)
|
|
parser.add_argument(
|
|
"--image-width",
|
|
type=int,
|
|
default=1920,
|
|
help="Image width used for label coordinate denormalization when converting to JSON (default: 1920).",
|
|
)
|
|
parser.add_argument(
|
|
"--image-height",
|
|
type=int,
|
|
default=1080,
|
|
help="Image height used for label coordinate denormalization when converting to JSON (default: 1080).",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_yaml(path: str | Path) -> dict[str, Any]:
|
|
path = Path(path)
|
|
if yaml is not None:
|
|
with path.open("r", encoding="utf-8") as file:
|
|
return yaml.safe_load(file) or {}
|
|
if YAML is not None:
|
|
yaml_loader = YAML(typ="safe")
|
|
with path.open("r", encoding="utf-8") as file:
|
|
return yaml_loader.load(file) or {}
|
|
raise RuntimeError("Neither PyYAML nor ruamel.yaml is available. Please install one YAML parser.")
|
|
|
|
|
|
def resolve_dataset_root(data_yaml: Path, dataset_root: str | None) -> Path | None:
|
|
if not dataset_root:
|
|
return None
|
|
root_path = Path(str(dataset_root))
|
|
if root_path.is_absolute():
|
|
return root_path.resolve()
|
|
candidates = [(data_yaml.parent / root_path).resolve(), root_path.resolve()]
|
|
for candidate in candidates:
|
|
if candidate.exists():
|
|
return candidate
|
|
return candidates[0]
|
|
|
|
|
|
def resolve_data_paths(data_yaml: Path, dataset_root: Path | None, value: Any) -> list[Path]:
|
|
if value is None:
|
|
raise ValueError(f"Missing split path in {data_yaml}")
|
|
if isinstance(value, (list, tuple)):
|
|
paths: list[Path] = []
|
|
for item in value:
|
|
paths.extend(resolve_data_paths(data_yaml, dataset_root, item))
|
|
return paths
|
|
|
|
path = Path(str(value))
|
|
if path.is_absolute():
|
|
return [path.resolve()]
|
|
|
|
candidates = []
|
|
if dataset_root is not None:
|
|
candidates.append((dataset_root / path).resolve())
|
|
candidates.append((data_yaml.parent / path).resolve())
|
|
for candidate in candidates:
|
|
if candidate.exists():
|
|
return [candidate]
|
|
return [candidates[0] if candidates else path.resolve()]
|
|
|
|
|
|
def load_split_entries(
|
|
split_files: list[Path],
|
|
max_entries: int = 0,
|
|
sample_selection: str = "head",
|
|
sample_seed: int = 20260401,
|
|
) -> list[tuple[str, str]]:
|
|
entries: list[tuple[str, str]] = []
|
|
for split_file in split_files:
|
|
with split_file.open("r", encoding="utf-8") as file:
|
|
for line in file:
|
|
entry = line.strip()
|
|
if not entry or entry.lstrip().startswith("#"):
|
|
continue
|
|
if Path(entry).suffix.lower() != ".txt":
|
|
raise ValueError(f"Ground3D split entries must point to label .txt files, but got: {entry}")
|
|
entries.append((str(split_file.parent.resolve()), entry))
|
|
|
|
if not entries:
|
|
raise FileNotFoundError(f"No usable entries found in split files: {', '.join(str(p) for p in split_files)}")
|
|
|
|
if max_entries > 0 and len(entries) > max_entries:
|
|
if sample_selection == "random":
|
|
rng = random.Random(int(sample_seed))
|
|
indices = sorted(rng.sample(range(len(entries)), int(max_entries)))
|
|
entries = [entries[idx] for idx in indices]
|
|
else:
|
|
entries = entries[:max_entries]
|
|
return entries
|
|
|
|
|
|
def label_rel_to_image_rel(label_path: Path) -> Path:
|
|
parts = list(label_path.parts)
|
|
if "labels" in parts:
|
|
parts[len(parts) - 1 - parts[::-1].index("labels")] = "images"
|
|
return Path(*parts).with_suffix(".png")
|
|
|
|
|
|
def entry_to_label_file(entry: tuple[str, str]) -> Path:
|
|
list_root, label_entry = entry
|
|
path = Path(label_entry)
|
|
return path.resolve() if path.is_absolute() else (Path(list_root) / path).resolve()
|
|
|
|
|
|
def entry_to_image_file(entry: tuple[str, str], image_root: Path | None) -> Path:
|
|
_, label_entry = entry
|
|
label_path = Path(label_entry)
|
|
rel_image = label_rel_to_image_rel(label_path)
|
|
if rel_image.is_absolute():
|
|
return rel_image.resolve()
|
|
if image_root is None:
|
|
raise ValueError("Dataset image root is required to resolve relative Ground3D image paths.")
|
|
return (image_root / rel_image).resolve()
|
|
|
|
|
|
def infer_class_name_map(class_map: dict[str, int]) -> dict[int, str]:
|
|
names: dict[int, str] = {}
|
|
for raw_name, class_id in class_map.items():
|
|
names.setdefault(int(class_id), str(raw_name))
|
|
return {class_id: names[class_id] for class_id in sorted(names)}
|
|
|
|
|
|
def build_frame_records(
|
|
entries: list[tuple[str, str]],
|
|
image_root: Path | None,
|
|
class_map: dict[str, int],
|
|
difficulty_weights: list[float],
|
|
face_3d_classes: set[int],
|
|
complete_3d_classes: set[int],
|
|
class_name_map: dict[int, str],
|
|
requested_classes: set[int],
|
|
include_empty: bool,
|
|
) -> tuple[list[FrameRecord], dict[str, int]]:
|
|
records: list[FrameRecord] = []
|
|
stats = {
|
|
"total_entries": len(entries),
|
|
"missing_labels": 0,
|
|
"missing_images": 0,
|
|
"empty_frames": 0,
|
|
"kept_frames": 0,
|
|
}
|
|
|
|
for source_index, entry in enumerate(entries):
|
|
label_path = entry_to_label_file(entry)
|
|
image_path = entry_to_image_file(entry, image_root)
|
|
|
|
if not label_path.is_file():
|
|
stats["missing_labels"] += 1
|
|
continue
|
|
if not image_path.is_file():
|
|
stats["missing_images"] += 1
|
|
continue
|
|
|
|
lb_2d, _ = parse_ground_3d_label_file(
|
|
str(label_path),
|
|
class_map,
|
|
difficulty_weights,
|
|
face_3d_classes,
|
|
complete_3d_classes,
|
|
)
|
|
|
|
raw_classes = [int(class_id) for class_id in lb_2d["cls"].reshape(-1).tolist()]
|
|
instance_counts = Counter(class_id for class_id in raw_classes if class_id in requested_classes)
|
|
present_classes = tuple(sorted(instance_counts))
|
|
|
|
if not present_classes and not include_empty:
|
|
stats["empty_frames"] += 1
|
|
continue
|
|
|
|
present_class_names = tuple(class_name_map.get(class_id, str(class_id)) for class_id in present_classes)
|
|
records.append(
|
|
FrameRecord(
|
|
source_index=source_index,
|
|
list_root=entry[0],
|
|
label_entry=entry[1],
|
|
manifest_entry=str(label_path),
|
|
label_path=str(label_path),
|
|
image_path=str(image_path),
|
|
present_classes=present_classes,
|
|
present_class_names=present_class_names,
|
|
instance_count_per_class=dict(instance_counts),
|
|
num_objects=int(sum(instance_counts.values())),
|
|
)
|
|
)
|
|
|
|
stats["kept_frames"] = len(records)
|
|
return records, stats
|
|
|
|
|
|
def summarize_records(records: list[FrameRecord]) -> tuple[Counter[int], Counter[int]]:
|
|
frame_counts: Counter[int] = Counter()
|
|
instance_counts: Counter[int] = Counter()
|
|
for record in records:
|
|
for class_id in record.present_classes:
|
|
frame_counts[class_id] += 1
|
|
instance_counts[class_id] += int(record.instance_count_per_class.get(class_id, 0))
|
|
return frame_counts, instance_counts
|
|
|
|
|
|
def build_candidate_rankings(
|
|
records: list[FrameRecord],
|
|
active_classes: list[int],
|
|
available_frame_counts: Counter[int],
|
|
seed: int,
|
|
) -> tuple[dict[int, list[int]], list[int]]:
|
|
rng = random.Random(seed)
|
|
shuffled_indices = list(range(len(records)))
|
|
rng.shuffle(shuffled_indices)
|
|
tie_rank = {index: rank for rank, index in enumerate(shuffled_indices)}
|
|
|
|
rarity_score: dict[int, float] = {}
|
|
active_span: dict[int, int] = {}
|
|
overall_score: dict[int, tuple[float, int, int, int]] = {}
|
|
candidates_by_class: dict[int, list[int]] = defaultdict(list)
|
|
|
|
active_class_set = set(active_classes)
|
|
for index, record in enumerate(records):
|
|
covered_classes = [class_id for class_id in record.present_classes if class_id in active_class_set]
|
|
active_span[index] = len(covered_classes)
|
|
rarity_score[index] = sum(1.0 / max(int(available_frame_counts[class_id]), 1) for class_id in covered_classes)
|
|
overall_score[index] = (
|
|
rarity_score[index],
|
|
active_span[index],
|
|
record.num_objects,
|
|
-tie_rank[index],
|
|
)
|
|
for class_id in covered_classes:
|
|
candidates_by_class[class_id].append(index)
|
|
|
|
for class_id, indices in candidates_by_class.items():
|
|
indices.sort(
|
|
key=lambda index: (
|
|
-overall_score[index][0],
|
|
-overall_score[index][1],
|
|
-overall_score[index][2],
|
|
tie_rank[index],
|
|
)
|
|
)
|
|
|
|
fill_order = list(range(len(records)))
|
|
fill_order.sort(
|
|
key=lambda index: (
|
|
-overall_score[index][0],
|
|
-overall_score[index][1],
|
|
-overall_score[index][2],
|
|
tie_rank[index],
|
|
)
|
|
)
|
|
return candidates_by_class, fill_order
|
|
|
|
|
|
def select_balanced_subset(
|
|
records: list[FrameRecord],
|
|
target_frames: int,
|
|
active_classes: list[int],
|
|
seed: int,
|
|
) -> tuple[list[FrameRecord], Counter[int]]:
|
|
if target_frames <= 0 or not records:
|
|
return [], Counter()
|
|
|
|
if target_frames >= len(records):
|
|
selected_frame_counts, _ = summarize_records(records)
|
|
return list(records), selected_frame_counts
|
|
|
|
available_frame_counts, _ = summarize_records(records)
|
|
candidates_by_class, fill_order = build_candidate_rankings(records, active_classes, available_frame_counts, seed)
|
|
selected_indices: list[int] = []
|
|
selected_set: set[int] = set()
|
|
selected_frame_counts: Counter[int] = Counter()
|
|
next_ptr = {class_id: 0 for class_id in active_classes}
|
|
exhausted_classes: set[int] = set()
|
|
|
|
while len(selected_indices) < target_frames and len(exhausted_classes) < len(active_classes):
|
|
focus_class = min(
|
|
(class_id for class_id in active_classes if class_id not in exhausted_classes),
|
|
key=lambda class_id: (
|
|
selected_frame_counts[class_id],
|
|
available_frame_counts[class_id],
|
|
class_id,
|
|
),
|
|
)
|
|
|
|
candidate_list = candidates_by_class.get(focus_class, [])
|
|
pointer = next_ptr[focus_class]
|
|
while pointer < len(candidate_list) and candidate_list[pointer] in selected_set:
|
|
pointer += 1
|
|
next_ptr[focus_class] = pointer
|
|
|
|
if pointer >= len(candidate_list):
|
|
exhausted_classes.add(focus_class)
|
|
continue
|
|
|
|
chosen_index = candidate_list[pointer]
|
|
next_ptr[focus_class] = pointer + 1
|
|
selected_indices.append(chosen_index)
|
|
selected_set.add(chosen_index)
|
|
for class_id in records[chosen_index].present_classes:
|
|
selected_frame_counts[class_id] += 1
|
|
|
|
if len(selected_indices) < target_frames:
|
|
for index in fill_order:
|
|
if index in selected_set:
|
|
continue
|
|
selected_indices.append(index)
|
|
selected_set.add(index)
|
|
for class_id in records[index].present_classes:
|
|
selected_frame_counts[class_id] += 1
|
|
if len(selected_indices) >= target_frames:
|
|
break
|
|
|
|
selected_records = [records[index] for index in selected_indices]
|
|
return selected_records, selected_frame_counts
|
|
|
|
|
|
def write_manifest(path: Path, records: list[FrameRecord], meta: dict[str, Any]) -> None:
|
|
with path.open("w", encoding="utf-8") as file:
|
|
file.write(f"# generated_at: {meta['generated_at']}\n")
|
|
file.write(f"# data_yaml: {meta['data_yaml']}\n")
|
|
file.write(f"# split: {meta['split']}\n")
|
|
file.write(f"# target_frames: {meta['target_frames']}\n")
|
|
file.write(f"# selected_frames: {meta['selected_frames']}\n")
|
|
manifest_entries = meta.get("manifest_entries")
|
|
if manifest_entries is None:
|
|
manifest_entries = [record.manifest_entry for record in records]
|
|
for manifest_entry in manifest_entries:
|
|
file.write(f"{manifest_entry}\n")
|
|
|
|
|
|
def infer_export_rel_label_path(record: FrameRecord) -> Path:
|
|
label_entry_path = Path(record.label_entry)
|
|
if not label_entry_path.is_absolute():
|
|
return label_entry_path
|
|
|
|
label_path = Path(record.label_path)
|
|
parts = list(label_path.parts)
|
|
if "labels" in parts:
|
|
label_idx = len(parts) - 1 - parts[::-1].index("labels")
|
|
return Path(*parts[label_idx:])
|
|
return Path(label_path.name)
|
|
|
|
|
|
def normalize_export_subpath(path: Path, anchor: str) -> Path:
|
|
if path.parts and path.parts[0] == anchor:
|
|
return path.relative_to(anchor)
|
|
return path
|
|
|
|
|
|
def infer_calib_dir_tasks(image_path: Path, rel_image_path: Path, image_root: Path | None, calib_root: Path) -> list[tuple[Path, Path]]:
|
|
"""Return (src_calib_dir, dst_calib_dir) pairs for the case that owns *image_path*.
|
|
|
|
The source data layout has a ``calib/`` folder at the case level, i.e. the
|
|
directory that also contains the ``images/`` sub-folder. We locate that
|
|
folder by walking up the resolved image path until we find a sibling
|
|
``calib/`` directory. The destination mirrors the same relative structure
|
|
under *calib_root*.
|
|
"""
|
|
tasks: list[tuple[Path, Path]] = []
|
|
|
|
# Walk up the image path to find the case directory that has a sibling calib/ folder.
|
|
resolved = image_path.resolve()
|
|
candidate_dir = resolved.parent
|
|
while candidate_dir != candidate_dir.parent:
|
|
src_calib_dir = candidate_dir / "calib"
|
|
if src_calib_dir.is_dir():
|
|
# Determine the relative destination path.
|
|
# Prefer a path relative to image_root when possible.
|
|
if image_root is not None:
|
|
try:
|
|
rel_case = candidate_dir.relative_to(image_root.resolve())
|
|
dst_calib_dir = calib_root / rel_case / "calib"
|
|
tasks.append((src_calib_dir, dst_calib_dir))
|
|
return tasks
|
|
except ValueError:
|
|
pass
|
|
# Fallback: use the rel_image_path to infer the case directory depth.
|
|
rel_parts = list(rel_image_path.parts)
|
|
if "images" in rel_parts:
|
|
case_parts = rel_parts[: len(rel_parts) - 1 - rel_parts[::-1].index("images")]
|
|
dst_calib_dir = calib_root / Path(*case_parts) / "calib" if case_parts else calib_root / "calib"
|
|
else:
|
|
dst_calib_dir = calib_root / candidate_dir.name / "calib"
|
|
tasks.append((src_calib_dir, dst_calib_dir))
|
|
return tasks
|
|
candidate_dir = candidate_dir.parent
|
|
|
|
return tasks
|
|
|
|
|
|
def export_selected_assets(
|
|
output_dir: Path,
|
|
records: list[FrameRecord],
|
|
image_root: Path | None,
|
|
skip_json_convert: bool = False,
|
|
image_width: int = 1920,
|
|
image_height: int = 1080,
|
|
) -> dict[str, Any]:
|
|
# Output layout: output_dir/<case_path>/labels|images|calib|labels_json/...
|
|
exported_manifest_entries: list[str] = []
|
|
copied_calib_targets: set[Path] = set()
|
|
copied_calib_count = 0
|
|
converted_json_count = 0
|
|
|
|
for record in records:
|
|
rel_label_path = infer_export_rel_label_path(record)
|
|
rel_image_path = label_rel_to_image_rel(rel_label_path)
|
|
dst_label_path = output_dir / normalize_export_subpath(rel_label_path, "labels")
|
|
dst_image_path = output_dir / normalize_export_subpath(rel_image_path, "images")
|
|
|
|
dst_label_path.parent.mkdir(parents=True, exist_ok=True)
|
|
dst_image_path.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy2(record.label_path, dst_label_path)
|
|
shutil.copy2(record.image_path, dst_image_path)
|
|
|
|
if not skip_json_convert:
|
|
rel_parts = list(normalize_export_subpath(rel_label_path, "labels").parts)
|
|
if "labels" in rel_parts:
|
|
idx = len(rel_parts) - 1 - rel_parts[::-1].index("labels")
|
|
rel_parts[idx] = "labels_json"
|
|
dst_json_path = (output_dir / Path(*rel_parts)).with_suffix(".json")
|
|
dst_json_path.parent.mkdir(parents=True, exist_ok=True)
|
|
convert_txt_to_json(str(dst_label_path), str(dst_json_path), image_width, image_height)
|
|
converted_json_count += 1
|
|
|
|
calib_tasks = infer_calib_dir_tasks(Path(record.image_path), rel_image_path, image_root, output_dir)
|
|
for src_calib_dir, dst_calib_dir in calib_tasks:
|
|
if dst_calib_dir in copied_calib_targets:
|
|
continue
|
|
copied_calib_targets.add(dst_calib_dir)
|
|
if dst_calib_dir.exists():
|
|
# Already copied by a previous frame from the same case.
|
|
continue
|
|
dst_calib_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copytree(src_calib_dir, dst_calib_dir)
|
|
copied_calib_count += 1
|
|
|
|
exported_manifest_entries.append(str(dst_label_path.relative_to(output_dir)))
|
|
|
|
return {
|
|
"output_root": str(output_dir),
|
|
"labels_root": str(output_dir),
|
|
"images_root": str(output_dir),
|
|
"calib_root": str(output_dir),
|
|
"labels_json_root": str(output_dir) if not skip_json_convert else None,
|
|
"calib_files": copied_calib_count,
|
|
"json_files": converted_json_count,
|
|
"manifest_entries": exported_manifest_entries,
|
|
}
|
|
|
|
|
|
def write_csv(path: Path, records: list[FrameRecord]) -> None:
|
|
with path.open("w", encoding="utf-8", newline="") as file:
|
|
writer = csv.DictWriter(
|
|
file,
|
|
fieldnames=[
|
|
"selected_rank",
|
|
"source_index",
|
|
"label_entry",
|
|
"label_path",
|
|
"image_path",
|
|
"num_objects",
|
|
"present_class_ids",
|
|
"present_class_names",
|
|
"instance_count_per_class",
|
|
],
|
|
)
|
|
writer.writeheader()
|
|
for rank, record in enumerate(records, start=1):
|
|
writer.writerow(
|
|
{
|
|
"selected_rank": rank,
|
|
"source_index": record.source_index,
|
|
"label_entry": record.label_entry,
|
|
"label_path": record.label_path,
|
|
"image_path": record.image_path,
|
|
"num_objects": record.num_objects,
|
|
"present_class_ids": " ".join(str(class_id) for class_id in record.present_classes),
|
|
"present_class_names": " ".join(record.present_class_names),
|
|
"instance_count_per_class": json.dumps(record.instance_count_per_class, ensure_ascii=False, sort_keys=True),
|
|
}
|
|
)
|
|
|
|
|
|
def build_stats_payload(
|
|
args: argparse.Namespace,
|
|
data_yaml: Path,
|
|
split_files: list[Path],
|
|
image_root: Path | None,
|
|
class_name_map: dict[int, str],
|
|
requested_classes: list[int],
|
|
scan_stats: dict[str, int],
|
|
candidate_records: list[FrameRecord],
|
|
selected_records: list[FrameRecord],
|
|
manifest_path: Path,
|
|
csv_path: Path,
|
|
stats_path: Path,
|
|
exported_assets: dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
candidate_frame_counts, candidate_instance_counts = summarize_records(candidate_records)
|
|
selected_frame_counts, selected_instance_counts = summarize_records(selected_records)
|
|
active_classes = [class_id for class_id in requested_classes if candidate_frame_counts[class_id] > 0]
|
|
|
|
return {
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"data_yaml": str(data_yaml),
|
|
"split": str(args.split),
|
|
"split_files": [str(path) for path in split_files],
|
|
"image_root": str(image_root) if image_root is not None else None,
|
|
"target_frames": int(args.target_frames),
|
|
"selected_frames": len(selected_records),
|
|
"candidate_frames": len(candidate_records),
|
|
"seed": int(args.seed),
|
|
"include_empty": bool(args.include_empty),
|
|
"requested_classes": requested_classes,
|
|
"active_classes": active_classes,
|
|
"scan_stats": scan_stats,
|
|
"outputs": {
|
|
"manifest": str(manifest_path),
|
|
"frame_csv": str(csv_path),
|
|
"stats_json": str(stats_path),
|
|
"images_dir": exported_assets.get("images_root") if exported_assets else None,
|
|
"labels_dir": exported_assets.get("labels_root") if exported_assets else None,
|
|
"labels_json_dir": exported_assets.get("labels_json_root") if exported_assets else None,
|
|
"calib_dir": exported_assets.get("calib_root") if exported_assets else None,
|
|
"calib_files": exported_assets.get("calib_files") if exported_assets else 0,
|
|
"json_files": exported_assets.get("json_files") if exported_assets else 0,
|
|
},
|
|
"per_class": {
|
|
str(class_id): {
|
|
"name": class_name_map.get(class_id, str(class_id)),
|
|
"candidate_frame_count": int(candidate_frame_counts[class_id]),
|
|
"selected_frame_count": int(selected_frame_counts[class_id]),
|
|
"candidate_instance_count": int(candidate_instance_counts[class_id]),
|
|
"selected_instance_count": int(selected_instance_counts[class_id]),
|
|
}
|
|
for class_id in requested_classes
|
|
},
|
|
}
|
|
|
|
|
|
def print_summary(
|
|
requested_classes: list[int],
|
|
class_name_map: dict[int, str],
|
|
candidate_records: list[FrameRecord],
|
|
selected_records: list[FrameRecord],
|
|
stats_payload: dict[str, Any],
|
|
) -> None:
|
|
candidate_frame_counts, _ = summarize_records(candidate_records)
|
|
selected_frame_counts, _ = summarize_records(selected_records)
|
|
print(f"Scanned candidate frames: {len(candidate_records)}")
|
|
print(f"Selected frames: {len(selected_records)} / target {stats_payload['target_frames']}")
|
|
print(
|
|
"Scan stats: "
|
|
f"total_entries={stats_payload['scan_stats']['total_entries']} "
|
|
f"missing_labels={stats_payload['scan_stats']['missing_labels']} "
|
|
f"missing_images={stats_payload['scan_stats']['missing_images']} "
|
|
f"empty_frames={stats_payload['scan_stats']['empty_frames']}"
|
|
)
|
|
print("Per-class selected frame counts:")
|
|
for class_id in requested_classes:
|
|
name = class_name_map.get(class_id, str(class_id))
|
|
print(
|
|
f" class {class_id:>2} ({name:<20}) "
|
|
f"candidate={candidate_frame_counts[class_id]:>5} selected={selected_frame_counts[class_id]:>4}"
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
data_yaml = Path(args.data).resolve()
|
|
data_cfg = load_yaml(data_yaml)
|
|
|
|
dataset_root = resolve_dataset_root(data_yaml, data_cfg.get("path"))
|
|
split_files = resolve_data_paths(data_yaml, dataset_root, data_cfg.get(args.split))
|
|
image_root = dataset_root
|
|
|
|
class_map = {str(key): int(value) for key, value in (data_cfg.get("class_map") or {}).items()}
|
|
if not class_map:
|
|
raise ValueError(f"`class_map` is required in {data_yaml} for Ground3D subset mining.")
|
|
class_name_map = infer_class_name_map(class_map)
|
|
all_class_ids = sorted(set(class_name_map))
|
|
requested_classes = sorted(set(int(class_id) for class_id in (args.classes if args.classes is not None else all_class_ids)))
|
|
unknown_classes = [class_id for class_id in requested_classes if class_id not in class_name_map]
|
|
if unknown_classes:
|
|
raise ValueError(f"Requested classes are not present in class_map: {unknown_classes}")
|
|
|
|
difficulty_weights = [float(value) for value in data_cfg.get("difficulty_weights", [1.0, 1.0, 1.0, 1.0])]
|
|
face_3d_classes = set(int(value) for value in data_cfg.get("face_3d_classes", []))
|
|
complete_3d_classes = set(int(value) for value in data_cfg.get("complete_3d_classes", []))
|
|
|
|
entries = load_split_entries(
|
|
split_files=split_files,
|
|
max_entries=int(args.max_entries),
|
|
sample_selection=str(args.sample_selection),
|
|
sample_seed=int(args.seed),
|
|
)
|
|
candidate_records, scan_stats = build_frame_records(
|
|
entries=entries,
|
|
image_root=image_root,
|
|
class_map=class_map,
|
|
difficulty_weights=difficulty_weights,
|
|
face_3d_classes=face_3d_classes,
|
|
complete_3d_classes=complete_3d_classes,
|
|
class_name_map=class_name_map,
|
|
requested_classes=set(requested_classes),
|
|
include_empty=bool(args.include_empty),
|
|
)
|
|
if not candidate_records:
|
|
raise RuntimeError("No candidate frames remain after filtering missing assets and empty labels.")
|
|
|
|
active_classes = [class_id for class_id in requested_classes if any(class_id in record.present_classes for record in candidate_records)]
|
|
if not active_classes and not args.include_empty:
|
|
raise RuntimeError("No requested classes were found in the candidate split.")
|
|
|
|
selected_records, _ = select_balanced_subset(
|
|
records=candidate_records,
|
|
target_frames=min(int(args.target_frames), len(candidate_records)),
|
|
active_classes=active_classes,
|
|
seed=int(args.seed),
|
|
)
|
|
|
|
output_dir = Path(args.output_dir).resolve() if args.output_dir else (data_yaml.parent / "mined_subsets").resolve()
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
output_prefix = args.output_prefix or f"balanced_{args.split}_{len(selected_records)}"
|
|
manifest_path = output_dir / f"{output_prefix}.txt"
|
|
csv_path = output_dir / f"{output_prefix}_frames.csv"
|
|
stats_path = output_dir / f"{output_prefix}_stats.json"
|
|
exported_assets = None
|
|
if not args.skip_export_assets:
|
|
exported_assets = export_selected_assets(
|
|
output_dir,
|
|
selected_records,
|
|
image_root,
|
|
skip_json_convert=bool(args.skip_json_convert),
|
|
image_width=int(args.image_width),
|
|
image_height=int(args.image_height),
|
|
)
|
|
|
|
stats_payload = build_stats_payload(
|
|
args=args,
|
|
data_yaml=data_yaml,
|
|
split_files=split_files,
|
|
image_root=image_root,
|
|
class_name_map=class_name_map,
|
|
requested_classes=requested_classes,
|
|
scan_stats=scan_stats,
|
|
candidate_records=candidate_records,
|
|
selected_records=selected_records,
|
|
manifest_path=manifest_path,
|
|
csv_path=csv_path,
|
|
stats_path=stats_path,
|
|
exported_assets=exported_assets,
|
|
)
|
|
write_manifest(
|
|
manifest_path,
|
|
selected_records,
|
|
meta={
|
|
"generated_at": stats_payload["generated_at"],
|
|
"data_yaml": str(data_yaml),
|
|
"split": args.split,
|
|
"target_frames": int(args.target_frames),
|
|
"selected_frames": len(selected_records),
|
|
"manifest_entries": exported_assets.get("manifest_entries") if exported_assets else None,
|
|
},
|
|
)
|
|
write_csv(csv_path, selected_records)
|
|
with stats_path.open("w", encoding="utf-8") as file:
|
|
json.dump(stats_payload, file, ensure_ascii=False, indent=2, sort_keys=False)
|
|
file.write("\n")
|
|
|
|
print_summary(
|
|
requested_classes=requested_classes,
|
|
class_name_map=class_name_map,
|
|
candidate_records=candidate_records,
|
|
selected_records=selected_records,
|
|
stats_payload=stats_payload,
|
|
)
|
|
print(f"Manifest written to: {manifest_path}")
|
|
print(f"Frame csv written to: {csv_path}")
|
|
print(f"Stats written to: {stats_path}")
|
|
if exported_assets:
|
|
output_root = exported_assets["output_root"]
|
|
print(f"Assets exported to: {output_root}")
|
|
print(f" images/labels/calib/labels_json nested under each case directory")
|
|
if exported_assets.get("labels_json_root"):
|
|
print(f" Labels JSON converted: {exported_assets['json_files']} files")
|
|
print(f" Calib dirs copied: {exported_assets['calib_files']}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|