Files
yolov26_3d/tools/pdcl_inference/run_latest_two_aeb_rawids.py
2026-06-24 09:35:46 +08:00

182 lines
5.9 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import argparse
import glob
import json
import subprocess
import sys
from pathlib import Path
from typing import Any
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from tools.pdcl_inference.run_batch_two_roi_infer import (
DEFAULT_OUTPUT_ROOT,
parse_rawid_tasks,
)
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Select the latest N raw_ids for each scenario from an AEB manifest and invoke run_batch_two_roi_infer.py."
)
parser.add_argument(
"--aeb-json",
type=str,
default=str(FILE.parent / "aeb*.json"),
help="AEB manifest JSON path or glob pattern",
)
parser.add_argument(
"--output-root",
type=str,
default=str(DEFAULT_OUTPUT_ROOT / "latest_two_rawids_per_scenario"),
help="Output root forwarded to run_batch_two_roi_infer.py",
)
parser.add_argument(
"--num-rawids-per-scenario",
type=int,
default=2,
help="Keep the latest N raw_ids for each scenario",
)
parser.add_argument(
"batch_args",
nargs=argparse.REMAINDER,
help="Extra args forwarded to run_batch_two_roi_infer.py; prefix them with '--'",
)
return parser.parse_args(argv)
def resolve_aeb_json(path_or_glob: str) -> Path:
candidate = Path(path_or_glob)
if candidate.is_file():
return candidate.resolve()
matches = sorted(Path(path) for path in glob.glob(path_or_glob))
if not matches:
raise FileNotFoundError(f"No AEB manifest matched: {path_or_glob}")
return max(matches, key=lambda path: path.stat().st_mtime).resolve()
def load_manifest(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as file:
payload = json.load(file)
if not isinstance(payload, dict):
raise ValueError(f"{path} 顶层必须是 JSON object实际: {type(payload).__name__}")
return payload
def filter_manifest_by_rawids(payload: dict[str, Any], selected_rawids: set[str]) -> dict[str, Any]:
scenarios = payload.get("scenarios", payload)
if not isinstance(scenarios, dict):
raise ValueError("AEB manifest 的 scenarios 字段必须是 dict")
filtered_scenarios = {}
for scenario_key, records in scenarios.items():
if not isinstance(records, list):
continue
filtered_records = []
for record in records:
if not isinstance(record, dict):
continue
raw_id = str(record.get("rawid", "")).strip()
if raw_id and raw_id in selected_rawids:
filtered_records.append(record)
if filtered_records:
filtered_scenarios[scenario_key] = filtered_records
summary = {
"total_scenarios": len(filtered_scenarios),
"total_rawids": sum(len(records) for records in filtered_scenarios.values()),
"scenario_total_rawids": {
scenario_key: len(records)
for scenario_key, records in filtered_scenarios.items()
},
}
return {
"summary": summary,
"scenarios": filtered_scenarios,
}
def normalize_batch_args(batch_args: list[str]) -> list[str]:
if batch_args and batch_args[0] == "--":
return batch_args[1:]
return batch_args
def select_latest_rawids_per_scenario(tasks, num_rawids_per_scenario: int):
selected_tasks = []
grouped_tasks = {}
for task in tasks:
grouped_tasks.setdefault(task.scenario_key, []).append(task)
for scenario_key in sorted(grouped_tasks):
scenario_tasks = sorted(
grouped_tasks[scenario_key],
key=lambda task: (task.cve_data or "", task.raw_id),
reverse=True,
)[:num_rawids_per_scenario]
selected_tasks.extend(scenario_tasks)
return selected_tasks
def main(argv: list[str] | None = None) -> None:
args = parse_args(argv)
batch_args = normalize_batch_args(args.batch_args)
manifest_path = resolve_aeb_json(args.aeb_json)
tasks = parse_rawid_tasks(str(manifest_path))
if not tasks:
raise ValueError(f"{manifest_path} 中没有可用于推理的 raw_id + clips 记录")
if args.num_rawids_per_scenario <= 0:
raise ValueError("--num-rawids-per-scenario 必须大于 0")
selected_tasks = select_latest_rawids_per_scenario(
tasks,
num_rawids_per_scenario=args.num_rawids_per_scenario,
)
selected_rawids = {task.raw_id for task in selected_tasks}
if not selected_rawids:
raise ValueError("没有选出任何 raw_id")
output_root = Path(args.output_root).resolve()
status_dir = output_root / "_status"
status_dir.mkdir(parents=True, exist_ok=True)
selected_manifest = filter_manifest_by_rawids(load_manifest(manifest_path), selected_rawids)
selected_manifest_path = status_dir / "latest_two_rawids_per_scenario.json"
with selected_manifest_path.open("w", encoding="utf-8") as file:
json.dump(selected_manifest, file, indent=2, ensure_ascii=False)
print(f"Selected manifest: {manifest_path}")
print(f"Selected latest {args.num_rawids_per_scenario} raw_ids for each scenario:")
for task in sorted(selected_tasks, key=lambda task: (task.scenario_key, task.cve_data or "", task.raw_id)):
print(
f" - scenario={task.scenario_key} raw_id={task.raw_id} "
f"cve={task.cve_data or 'n/a'} clips={len(task.clips)}"
)
print(f"Filtered manifest saved to: {selected_manifest_path}")
batch_script = FILE.parent / "run_batch_two_roi_infer.py"
cmd = [
sys.executable,
str(batch_script),
"--rawid-json",
str(selected_manifest_path),
"--output-root",
str(output_root),
*batch_args,
]
print("Running:")
print(" ".join(cmd))
subprocess.run(cmd, check=True)
if __name__ == "__main__":
main()