182 lines
5.9 KiB
Python
182 lines
5.9 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import glob
|
|||
|
|
import json
|
|||
|
|
import subprocess
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
FILE = Path(__file__).resolve()
|
|||
|
|
ROOT = FILE.parents[2]
|
|||
|
|
if str(ROOT) not in sys.path:
|
|||
|
|
sys.path.append(str(ROOT))
|
|||
|
|
|
|||
|
|
from tools.pdcl_inference.run_batch_two_roi_infer import (
|
|||
|
|
DEFAULT_OUTPUT_ROOT,
|
|||
|
|
parse_rawid_tasks,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
description="Select the latest N raw_ids for each scenario from an AEB manifest and invoke run_batch_two_roi_infer.py."
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--aeb-json",
|
|||
|
|
type=str,
|
|||
|
|
default=str(FILE.parent / "aeb*.json"),
|
|||
|
|
help="AEB manifest JSON path or glob pattern",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--output-root",
|
|||
|
|
type=str,
|
|||
|
|
default=str(DEFAULT_OUTPUT_ROOT / "latest_two_rawids_per_scenario"),
|
|||
|
|
help="Output root forwarded to run_batch_two_roi_infer.py",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--num-rawids-per-scenario",
|
|||
|
|
type=int,
|
|||
|
|
default=2,
|
|||
|
|
help="Keep the latest N raw_ids for each scenario",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"batch_args",
|
|||
|
|
nargs=argparse.REMAINDER,
|
|||
|
|
help="Extra args forwarded to run_batch_two_roi_infer.py; prefix them with '--'",
|
|||
|
|
)
|
|||
|
|
return parser.parse_args(argv)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_aeb_json(path_or_glob: str) -> Path:
|
|||
|
|
candidate = Path(path_or_glob)
|
|||
|
|
if candidate.is_file():
|
|||
|
|
return candidate.resolve()
|
|||
|
|
|
|||
|
|
matches = sorted(Path(path) for path in glob.glob(path_or_glob))
|
|||
|
|
if not matches:
|
|||
|
|
raise FileNotFoundError(f"No AEB manifest matched: {path_or_glob}")
|
|||
|
|
return max(matches, key=lambda path: path.stat().st_mtime).resolve()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_manifest(path: Path) -> dict[str, Any]:
|
|||
|
|
with path.open("r", encoding="utf-8") as file:
|
|||
|
|
payload = json.load(file)
|
|||
|
|
if not isinstance(payload, dict):
|
|||
|
|
raise ValueError(f"{path} 顶层必须是 JSON object,实际: {type(payload).__name__}")
|
|||
|
|
return payload
|
|||
|
|
|
|||
|
|
|
|||
|
|
def filter_manifest_by_rawids(payload: dict[str, Any], selected_rawids: set[str]) -> dict[str, Any]:
|
|||
|
|
scenarios = payload.get("scenarios", payload)
|
|||
|
|
if not isinstance(scenarios, dict):
|
|||
|
|
raise ValueError("AEB manifest 的 scenarios 字段必须是 dict")
|
|||
|
|
|
|||
|
|
filtered_scenarios = {}
|
|||
|
|
for scenario_key, records in scenarios.items():
|
|||
|
|
if not isinstance(records, list):
|
|||
|
|
continue
|
|||
|
|
filtered_records = []
|
|||
|
|
for record in records:
|
|||
|
|
if not isinstance(record, dict):
|
|||
|
|
continue
|
|||
|
|
raw_id = str(record.get("rawid", "")).strip()
|
|||
|
|
if raw_id and raw_id in selected_rawids:
|
|||
|
|
filtered_records.append(record)
|
|||
|
|
if filtered_records:
|
|||
|
|
filtered_scenarios[scenario_key] = filtered_records
|
|||
|
|
|
|||
|
|
summary = {
|
|||
|
|
"total_scenarios": len(filtered_scenarios),
|
|||
|
|
"total_rawids": sum(len(records) for records in filtered_scenarios.values()),
|
|||
|
|
"scenario_total_rawids": {
|
|||
|
|
scenario_key: len(records)
|
|||
|
|
for scenario_key, records in filtered_scenarios.items()
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
return {
|
|||
|
|
"summary": summary,
|
|||
|
|
"scenarios": filtered_scenarios,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def normalize_batch_args(batch_args: list[str]) -> list[str]:
|
|||
|
|
if batch_args and batch_args[0] == "--":
|
|||
|
|
return batch_args[1:]
|
|||
|
|
return batch_args
|
|||
|
|
|
|||
|
|
|
|||
|
|
def select_latest_rawids_per_scenario(tasks, num_rawids_per_scenario: int):
|
|||
|
|
selected_tasks = []
|
|||
|
|
grouped_tasks = {}
|
|||
|
|
for task in tasks:
|
|||
|
|
grouped_tasks.setdefault(task.scenario_key, []).append(task)
|
|||
|
|
|
|||
|
|
for scenario_key in sorted(grouped_tasks):
|
|||
|
|
scenario_tasks = sorted(
|
|||
|
|
grouped_tasks[scenario_key],
|
|||
|
|
key=lambda task: (task.cve_data or "", task.raw_id),
|
|||
|
|
reverse=True,
|
|||
|
|
)[:num_rawids_per_scenario]
|
|||
|
|
selected_tasks.extend(scenario_tasks)
|
|||
|
|
|
|||
|
|
return selected_tasks
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main(argv: list[str] | None = None) -> None:
|
|||
|
|
args = parse_args(argv)
|
|||
|
|
batch_args = normalize_batch_args(args.batch_args)
|
|||
|
|
|
|||
|
|
manifest_path = resolve_aeb_json(args.aeb_json)
|
|||
|
|
tasks = parse_rawid_tasks(str(manifest_path))
|
|||
|
|
if not tasks:
|
|||
|
|
raise ValueError(f"{manifest_path} 中没有可用于推理的 raw_id + clips 记录")
|
|||
|
|
|
|||
|
|
if args.num_rawids_per_scenario <= 0:
|
|||
|
|
raise ValueError("--num-rawids-per-scenario 必须大于 0")
|
|||
|
|
|
|||
|
|
selected_tasks = select_latest_rawids_per_scenario(
|
|||
|
|
tasks,
|
|||
|
|
num_rawids_per_scenario=args.num_rawids_per_scenario,
|
|||
|
|
)
|
|||
|
|
selected_rawids = {task.raw_id for task in selected_tasks}
|
|||
|
|
if not selected_rawids:
|
|||
|
|
raise ValueError("没有选出任何 raw_id")
|
|||
|
|
|
|||
|
|
output_root = Path(args.output_root).resolve()
|
|||
|
|
status_dir = output_root / "_status"
|
|||
|
|
status_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
selected_manifest = filter_manifest_by_rawids(load_manifest(manifest_path), selected_rawids)
|
|||
|
|
selected_manifest_path = status_dir / "latest_two_rawids_per_scenario.json"
|
|||
|
|
with selected_manifest_path.open("w", encoding="utf-8") as file:
|
|||
|
|
json.dump(selected_manifest, file, indent=2, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
print(f"Selected manifest: {manifest_path}")
|
|||
|
|
print(f"Selected latest {args.num_rawids_per_scenario} raw_ids for each scenario:")
|
|||
|
|
for task in sorted(selected_tasks, key=lambda task: (task.scenario_key, task.cve_data or "", task.raw_id)):
|
|||
|
|
print(
|
|||
|
|
f" - scenario={task.scenario_key} raw_id={task.raw_id} "
|
|||
|
|
f"cve={task.cve_data or 'n/a'} clips={len(task.clips)}"
|
|||
|
|
)
|
|||
|
|
print(f"Filtered manifest saved to: {selected_manifest_path}")
|
|||
|
|
|
|||
|
|
batch_script = FILE.parent / "run_batch_two_roi_infer.py"
|
|||
|
|
cmd = [
|
|||
|
|
sys.executable,
|
|||
|
|
str(batch_script),
|
|||
|
|
"--rawid-json",
|
|||
|
|
str(selected_manifest_path),
|
|||
|
|
"--output-root",
|
|||
|
|
str(output_root),
|
|||
|
|
*batch_args,
|
|||
|
|
]
|
|||
|
|
print("Running:")
|
|||
|
|
print(" ".join(cmd))
|
|||
|
|
subprocess.run(cmd, check=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|