Files
yolov26_3d/tools/pdcl_inference/run_latest_two_aeb_rawids.py

182 lines
5.9 KiB
Python
Raw Normal View History

2026-06-24 09:35:46 +08:00
from __future__ import annotations
import argparse
import glob
import json
import subprocess
import sys
from pathlib import Path
from typing import Any
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from tools.pdcl_inference.run_batch_two_roi_infer import (
DEFAULT_OUTPUT_ROOT,
parse_rawid_tasks,
)
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Select the latest N raw_ids for each scenario from an AEB manifest and invoke run_batch_two_roi_infer.py."
)
parser.add_argument(
"--aeb-json",
type=str,
default=str(FILE.parent / "aeb*.json"),
help="AEB manifest JSON path or glob pattern",
)
parser.add_argument(
"--output-root",
type=str,
default=str(DEFAULT_OUTPUT_ROOT / "latest_two_rawids_per_scenario"),
help="Output root forwarded to run_batch_two_roi_infer.py",
)
parser.add_argument(
"--num-rawids-per-scenario",
type=int,
default=2,
help="Keep the latest N raw_ids for each scenario",
)
parser.add_argument(
"batch_args",
nargs=argparse.REMAINDER,
help="Extra args forwarded to run_batch_two_roi_infer.py; prefix them with '--'",
)
return parser.parse_args(argv)
def resolve_aeb_json(path_or_glob: str) -> Path:
candidate = Path(path_or_glob)
if candidate.is_file():
return candidate.resolve()
matches = sorted(Path(path) for path in glob.glob(path_or_glob))
if not matches:
raise FileNotFoundError(f"No AEB manifest matched: {path_or_glob}")
return max(matches, key=lambda path: path.stat().st_mtime).resolve()
def load_manifest(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as file:
payload = json.load(file)
if not isinstance(payload, dict):
raise ValueError(f"{path} 顶层必须是 JSON object实际: {type(payload).__name__}")
return payload
def filter_manifest_by_rawids(payload: dict[str, Any], selected_rawids: set[str]) -> dict[str, Any]:
scenarios = payload.get("scenarios", payload)
if not isinstance(scenarios, dict):
raise ValueError("AEB manifest 的 scenarios 字段必须是 dict")
filtered_scenarios = {}
for scenario_key, records in scenarios.items():
if not isinstance(records, list):
continue
filtered_records = []
for record in records:
if not isinstance(record, dict):
continue
raw_id = str(record.get("rawid", "")).strip()
if raw_id and raw_id in selected_rawids:
filtered_records.append(record)
if filtered_records:
filtered_scenarios[scenario_key] = filtered_records
summary = {
"total_scenarios": len(filtered_scenarios),
"total_rawids": sum(len(records) for records in filtered_scenarios.values()),
"scenario_total_rawids": {
scenario_key: len(records)
for scenario_key, records in filtered_scenarios.items()
},
}
return {
"summary": summary,
"scenarios": filtered_scenarios,
}
def normalize_batch_args(batch_args: list[str]) -> list[str]:
if batch_args and batch_args[0] == "--":
return batch_args[1:]
return batch_args
def select_latest_rawids_per_scenario(tasks, num_rawids_per_scenario: int):
selected_tasks = []
grouped_tasks = {}
for task in tasks:
grouped_tasks.setdefault(task.scenario_key, []).append(task)
for scenario_key in sorted(grouped_tasks):
scenario_tasks = sorted(
grouped_tasks[scenario_key],
key=lambda task: (task.cve_data or "", task.raw_id),
reverse=True,
)[:num_rawids_per_scenario]
selected_tasks.extend(scenario_tasks)
return selected_tasks
def main(argv: list[str] | None = None) -> None:
args = parse_args(argv)
batch_args = normalize_batch_args(args.batch_args)
manifest_path = resolve_aeb_json(args.aeb_json)
tasks = parse_rawid_tasks(str(manifest_path))
if not tasks:
raise ValueError(f"{manifest_path} 中没有可用于推理的 raw_id + clips 记录")
if args.num_rawids_per_scenario <= 0:
raise ValueError("--num-rawids-per-scenario 必须大于 0")
selected_tasks = select_latest_rawids_per_scenario(
tasks,
num_rawids_per_scenario=args.num_rawids_per_scenario,
)
selected_rawids = {task.raw_id for task in selected_tasks}
if not selected_rawids:
raise ValueError("没有选出任何 raw_id")
output_root = Path(args.output_root).resolve()
status_dir = output_root / "_status"
status_dir.mkdir(parents=True, exist_ok=True)
selected_manifest = filter_manifest_by_rawids(load_manifest(manifest_path), selected_rawids)
selected_manifest_path = status_dir / "latest_two_rawids_per_scenario.json"
with selected_manifest_path.open("w", encoding="utf-8") as file:
json.dump(selected_manifest, file, indent=2, ensure_ascii=False)
print(f"Selected manifest: {manifest_path}")
print(f"Selected latest {args.num_rawids_per_scenario} raw_ids for each scenario:")
for task in sorted(selected_tasks, key=lambda task: (task.scenario_key, task.cve_data or "", task.raw_id)):
print(
f" - scenario={task.scenario_key} raw_id={task.raw_id} "
f"cve={task.cve_data or 'n/a'} clips={len(task.clips)}"
)
print(f"Filtered manifest saved to: {selected_manifest_path}")
batch_script = FILE.parent / "run_batch_two_roi_infer.py"
cmd = [
sys.executable,
str(batch_script),
"--rawid-json",
str(selected_manifest_path),
"--output-root",
str(output_root),
*batch_args,
]
print("Running:")
print(" ".join(cmd))
subprocess.run(cmd, check=True)
if __name__ == "__main__":
main()