from __future__ import annotations import argparse import glob import json import subprocess import sys from pathlib import Path from typing import Any FILE = Path(__file__).resolve() ROOT = FILE.parents[2] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from tools.pdcl_inference.run_batch_two_roi_infer import ( DEFAULT_OUTPUT_ROOT, parse_rawid_tasks, ) def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description="Select the latest N raw_ids for each scenario from an AEB manifest and invoke run_batch_two_roi_infer.py." ) parser.add_argument( "--aeb-json", type=str, default=str(FILE.parent / "aeb*.json"), help="AEB manifest JSON path or glob pattern", ) parser.add_argument( "--output-root", type=str, default=str(DEFAULT_OUTPUT_ROOT / "latest_two_rawids_per_scenario"), help="Output root forwarded to run_batch_two_roi_infer.py", ) parser.add_argument( "--num-rawids-per-scenario", type=int, default=2, help="Keep the latest N raw_ids for each scenario", ) parser.add_argument( "batch_args", nargs=argparse.REMAINDER, help="Extra args forwarded to run_batch_two_roi_infer.py; prefix them with '--'", ) return parser.parse_args(argv) def resolve_aeb_json(path_or_glob: str) -> Path: candidate = Path(path_or_glob) if candidate.is_file(): return candidate.resolve() matches = sorted(Path(path) for path in glob.glob(path_or_glob)) if not matches: raise FileNotFoundError(f"No AEB manifest matched: {path_or_glob}") return max(matches, key=lambda path: path.stat().st_mtime).resolve() def load_manifest(path: Path) -> dict[str, Any]: with path.open("r", encoding="utf-8") as file: payload = json.load(file) if not isinstance(payload, dict): raise ValueError(f"{path} 顶层必须是 JSON object,实际: {type(payload).__name__}") return payload def filter_manifest_by_rawids(payload: dict[str, Any], selected_rawids: set[str]) -> dict[str, Any]: scenarios = payload.get("scenarios", payload) if not isinstance(scenarios, dict): raise ValueError("AEB manifest 的 scenarios 字段必须是 dict") filtered_scenarios = {} for scenario_key, records in scenarios.items(): if not isinstance(records, list): continue filtered_records = [] for record in records: if not isinstance(record, dict): continue raw_id = str(record.get("rawid", "")).strip() if raw_id and raw_id in selected_rawids: filtered_records.append(record) if filtered_records: filtered_scenarios[scenario_key] = filtered_records summary = { "total_scenarios": len(filtered_scenarios), "total_rawids": sum(len(records) for records in filtered_scenarios.values()), "scenario_total_rawids": { scenario_key: len(records) for scenario_key, records in filtered_scenarios.items() }, } return { "summary": summary, "scenarios": filtered_scenarios, } def normalize_batch_args(batch_args: list[str]) -> list[str]: if batch_args and batch_args[0] == "--": return batch_args[1:] return batch_args def select_latest_rawids_per_scenario(tasks, num_rawids_per_scenario: int): selected_tasks = [] grouped_tasks = {} for task in tasks: grouped_tasks.setdefault(task.scenario_key, []).append(task) for scenario_key in sorted(grouped_tasks): scenario_tasks = sorted( grouped_tasks[scenario_key], key=lambda task: (task.cve_data or "", task.raw_id), reverse=True, )[:num_rawids_per_scenario] selected_tasks.extend(scenario_tasks) return selected_tasks def main(argv: list[str] | None = None) -> None: args = parse_args(argv) batch_args = normalize_batch_args(args.batch_args) manifest_path = resolve_aeb_json(args.aeb_json) tasks = parse_rawid_tasks(str(manifest_path)) if not tasks: raise ValueError(f"{manifest_path} 中没有可用于推理的 raw_id + clips 记录") if args.num_rawids_per_scenario <= 0: raise ValueError("--num-rawids-per-scenario 必须大于 0") selected_tasks = select_latest_rawids_per_scenario( tasks, num_rawids_per_scenario=args.num_rawids_per_scenario, ) selected_rawids = {task.raw_id for task in selected_tasks} if not selected_rawids: raise ValueError("没有选出任何 raw_id") output_root = Path(args.output_root).resolve() status_dir = output_root / "_status" status_dir.mkdir(parents=True, exist_ok=True) selected_manifest = filter_manifest_by_rawids(load_manifest(manifest_path), selected_rawids) selected_manifest_path = status_dir / "latest_two_rawids_per_scenario.json" with selected_manifest_path.open("w", encoding="utf-8") as file: json.dump(selected_manifest, file, indent=2, ensure_ascii=False) print(f"Selected manifest: {manifest_path}") print(f"Selected latest {args.num_rawids_per_scenario} raw_ids for each scenario:") for task in sorted(selected_tasks, key=lambda task: (task.scenario_key, task.cve_data or "", task.raw_id)): print( f" - scenario={task.scenario_key} raw_id={task.raw_id} " f"cve={task.cve_data or 'n/a'} clips={len(task.clips)}" ) print(f"Filtered manifest saved to: {selected_manifest_path}") batch_script = FILE.parent / "run_batch_two_roi_infer.py" cmd = [ sys.executable, str(batch_script), "--rawid-json", str(selected_manifest_path), "--output-root", str(output_root), *batch_args, ] print("Running:") print(" ".join(cmd)) subprocess.run(cmd, check=True) if __name__ == "__main__": main()