182 lines
5.9 KiB
Python
Executable File
182 lines
5.9 KiB
Python
Executable File
from __future__ import annotations
|
||
|
||
import argparse
|
||
import glob
|
||
import json
|
||
import subprocess
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
FILE = Path(__file__).resolve()
|
||
ROOT = FILE.parents[2]
|
||
if str(ROOT) not in sys.path:
|
||
sys.path.append(str(ROOT))
|
||
|
||
from tools.pdcl_inference.run_batch_two_roi_infer import (
|
||
DEFAULT_OUTPUT_ROOT,
|
||
parse_rawid_tasks,
|
||
)
|
||
|
||
|
||
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(
|
||
description="Select the latest N raw_ids for each scenario from an AEB manifest and invoke run_batch_two_roi_infer.py."
|
||
)
|
||
parser.add_argument(
|
||
"--aeb-json",
|
||
type=str,
|
||
default=str(FILE.parent / "aeb*.json"),
|
||
help="AEB manifest JSON path or glob pattern",
|
||
)
|
||
parser.add_argument(
|
||
"--output-root",
|
||
type=str,
|
||
default=str(DEFAULT_OUTPUT_ROOT / "latest_two_rawids_per_scenario"),
|
||
help="Output root forwarded to run_batch_two_roi_infer.py",
|
||
)
|
||
parser.add_argument(
|
||
"--num-rawids-per-scenario",
|
||
type=int,
|
||
default=2,
|
||
help="Keep the latest N raw_ids for each scenario",
|
||
)
|
||
parser.add_argument(
|
||
"batch_args",
|
||
nargs=argparse.REMAINDER,
|
||
help="Extra args forwarded to run_batch_two_roi_infer.py; prefix them with '--'",
|
||
)
|
||
return parser.parse_args(argv)
|
||
|
||
|
||
def resolve_aeb_json(path_or_glob: str) -> Path:
|
||
candidate = Path(path_or_glob)
|
||
if candidate.is_file():
|
||
return candidate.resolve()
|
||
|
||
matches = sorted(Path(path) for path in glob.glob(path_or_glob))
|
||
if not matches:
|
||
raise FileNotFoundError(f"No AEB manifest matched: {path_or_glob}")
|
||
return max(matches, key=lambda path: path.stat().st_mtime).resolve()
|
||
|
||
|
||
def load_manifest(path: Path) -> dict[str, Any]:
|
||
with path.open("r", encoding="utf-8") as file:
|
||
payload = json.load(file)
|
||
if not isinstance(payload, dict):
|
||
raise ValueError(f"{path} 顶层必须是 JSON object,实际: {type(payload).__name__}")
|
||
return payload
|
||
|
||
|
||
def filter_manifest_by_rawids(payload: dict[str, Any], selected_rawids: set[str]) -> dict[str, Any]:
|
||
scenarios = payload.get("scenarios", payload)
|
||
if not isinstance(scenarios, dict):
|
||
raise ValueError("AEB manifest 的 scenarios 字段必须是 dict")
|
||
|
||
filtered_scenarios = {}
|
||
for scenario_key, records in scenarios.items():
|
||
if not isinstance(records, list):
|
||
continue
|
||
filtered_records = []
|
||
for record in records:
|
||
if not isinstance(record, dict):
|
||
continue
|
||
raw_id = str(record.get("rawid", "")).strip()
|
||
if raw_id and raw_id in selected_rawids:
|
||
filtered_records.append(record)
|
||
if filtered_records:
|
||
filtered_scenarios[scenario_key] = filtered_records
|
||
|
||
summary = {
|
||
"total_scenarios": len(filtered_scenarios),
|
||
"total_rawids": sum(len(records) for records in filtered_scenarios.values()),
|
||
"scenario_total_rawids": {
|
||
scenario_key: len(records)
|
||
for scenario_key, records in filtered_scenarios.items()
|
||
},
|
||
}
|
||
return {
|
||
"summary": summary,
|
||
"scenarios": filtered_scenarios,
|
||
}
|
||
|
||
|
||
def normalize_batch_args(batch_args: list[str]) -> list[str]:
|
||
if batch_args and batch_args[0] == "--":
|
||
return batch_args[1:]
|
||
return batch_args
|
||
|
||
|
||
def select_latest_rawids_per_scenario(tasks, num_rawids_per_scenario: int):
|
||
selected_tasks = []
|
||
grouped_tasks = {}
|
||
for task in tasks:
|
||
grouped_tasks.setdefault(task.scenario_key, []).append(task)
|
||
|
||
for scenario_key in sorted(grouped_tasks):
|
||
scenario_tasks = sorted(
|
||
grouped_tasks[scenario_key],
|
||
key=lambda task: (task.cve_data or "", task.raw_id),
|
||
reverse=True,
|
||
)[:num_rawids_per_scenario]
|
||
selected_tasks.extend(scenario_tasks)
|
||
|
||
return selected_tasks
|
||
|
||
|
||
def main(argv: list[str] | None = None) -> None:
|
||
args = parse_args(argv)
|
||
batch_args = normalize_batch_args(args.batch_args)
|
||
|
||
manifest_path = resolve_aeb_json(args.aeb_json)
|
||
tasks = parse_rawid_tasks(str(manifest_path))
|
||
if not tasks:
|
||
raise ValueError(f"{manifest_path} 中没有可用于推理的 raw_id + clips 记录")
|
||
|
||
if args.num_rawids_per_scenario <= 0:
|
||
raise ValueError("--num-rawids-per-scenario 必须大于 0")
|
||
|
||
selected_tasks = select_latest_rawids_per_scenario(
|
||
tasks,
|
||
num_rawids_per_scenario=args.num_rawids_per_scenario,
|
||
)
|
||
selected_rawids = {task.raw_id for task in selected_tasks}
|
||
if not selected_rawids:
|
||
raise ValueError("没有选出任何 raw_id")
|
||
|
||
output_root = Path(args.output_root).resolve()
|
||
status_dir = output_root / "_status"
|
||
status_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
selected_manifest = filter_manifest_by_rawids(load_manifest(manifest_path), selected_rawids)
|
||
selected_manifest_path = status_dir / "latest_two_rawids_per_scenario.json"
|
||
with selected_manifest_path.open("w", encoding="utf-8") as file:
|
||
json.dump(selected_manifest, file, indent=2, ensure_ascii=False)
|
||
|
||
print(f"Selected manifest: {manifest_path}")
|
||
print(f"Selected latest {args.num_rawids_per_scenario} raw_ids for each scenario:")
|
||
for task in sorted(selected_tasks, key=lambda task: (task.scenario_key, task.cve_data or "", task.raw_id)):
|
||
print(
|
||
f" - scenario={task.scenario_key} raw_id={task.raw_id} "
|
||
f"cve={task.cve_data or 'n/a'} clips={len(task.clips)}"
|
||
)
|
||
print(f"Filtered manifest saved to: {selected_manifest_path}")
|
||
|
||
batch_script = FILE.parent / "run_batch_two_roi_infer.py"
|
||
cmd = [
|
||
sys.executable,
|
||
str(batch_script),
|
||
"--rawid-json",
|
||
str(selected_manifest_path),
|
||
"--output-root",
|
||
str(output_root),
|
||
*batch_args,
|
||
]
|
||
print("Running:")
|
||
print(" ".join(cmd))
|
||
subprocess.run(cmd, check=True)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|