Files
HSAP/datasets/dms/scripts/task_registry.py

167 lines
5.6 KiB
Python
Raw Normal View History

"""DMS 任务注册表domain 分组、multi 任务(前向 detect+classify、旧 ID 别名。"""
from __future__ import annotations
from typing import Any
DOMAIN_LABELS = {
"dms": "舱内 DMS",
"forward": "前向 ADAS",
}
# 报表 / 旧目录名 -> (task, mode)
REPORT_TASK_ALIASES: dict[str, tuple[str, str | None]] = {
"isa": ("forward", "detect"),
"isa_detect": ("forward", "detect"),
"isa_class": ("forward", "classify"),
"isa_class_0116": ("forward", "classify"),
"dam_0417": ("dam", "batch_0417"),
}
LEGACY_TASK_ALIASES: dict[str, tuple[str, str | None]] = {
"isa": ("forward", "detect"),
"isa_class": ("forward", "classify"),
"dam_0417": ("dam", "batch_0417"),
}
def load_registry(reg: dict) -> dict[str, Any]:
return reg.get("tasks") or {}
def resolve_task_id(task: str, mode: str | None = None) -> tuple[str, str | None]:
"""用户/历史 task ID -> (canonical_task, mode)。"""
if task in LEGACY_TASK_ALIASES:
t, m = LEGACY_TASK_ALIASES[task]
return t, mode or m
return task, mode
def report_task_key(task: str, mode: str | None = None) -> str:
"""catalog 报表 CSV 中的任务列名。"""
t, m = resolve_task_id(task, mode)
if t == "forward" and m == "detect":
return "isa"
if t == "forward" and m == "classify":
return "isa_class"
if t == "dam" and m == "batch_0516":
return "dam"
if t == "dam" and m == "batch_0417":
return "dam_0417"
return task
def train_yaml_key(task: str, mode: str | None, reg: dict) -> str:
"""manifests/yaml_active 与 train.sh 使用的文件名(不含 .yaml"""
task, mode = resolve_task_id(task, mode)
tcfg = load_registry(reg)[task]
if tcfg.get("type") == "multi":
if not mode:
raise ValueError(f"任务 {task} 需指定 modedetect / classify")
return f"{task}__{mode}"
return task
def get_mode_config(task: str, mode: str | None, reg: dict) -> dict[str, Any]:
task, mode = resolve_task_id(task, mode)
tcfg = load_registry(reg)[task]
if tcfg.get("type") != "multi":
return {**tcfg, "task": task, "mode": None}
modes = tcfg.get("modes") or {}
if not mode:
raise ValueError(f"任务 {task} 需指定 mode")
if mode not in modes:
raise ValueError(f"未知 mode: {task}/{mode}")
mcfg = dict(modes[mode])
mcfg["task"] = task
mcfg["mode"] = mode
mcfg["task_dir"] = f"{tcfg.get('task_dir', task)}/{mcfg.get('subdir', mode)}"
mcfg["domain"] = tcfg.get("domain")
mcfg["label"] = mcfg.get("label") or tcfg.get("label")
return mcfg
def task_data_dir(pack_dir, task: str, mode: str | None, reg: dict):
from pathlib import Path
mcfg = get_mode_config(task, mode, reg)
return Path(pack_dir) / mcfg["task_dir"]
def inbox_dir(root, task: str, mode: str | None, reg: dict):
from pathlib import Path
task, mode = resolve_task_id(task, mode)
tcfg = load_registry(reg)[task]
if tcfg.get("type") == "multi":
mcfg = tcfg["modes"][mode or ""]
rel = mcfg.get("inbox") or f"inbox/{task}/{mode}"
return Path(root) / rel
return Path(root) / (tcfg.get("inbox") or f"inbox/{task}")
def iter_catalog_tasks(reg: dict) -> list[tuple[str, dict[str, Any]]]:
"""catalog 顶层任务列表。"""
out: list[tuple[str, dict[str, Any]]] = []
for task, tcfg in load_registry(reg).items():
entry = {
"domain": tcfg.get("domain", "dms"),
"domain_label": DOMAIN_LABELS.get(tcfg.get("domain", "dms"), tcfg.get("domain", "dms")),
"label": tcfg.get("label", task),
"type": tcfg.get("type"),
}
if tcfg.get("type") == "multi":
entry["modes"] = {}
for mode, mcfg in (tcfg.get("modes") or {}).items():
entry["modes"][mode] = {
"label": mcfg.get("label", mode),
"type": mcfg.get("type"),
"nc": mcfg.get("nc"),
"names": mcfg.get("names"),
"packs": [],
"class_counts": {},
}
else:
entry["nc"] = tcfg.get("nc")
entry["names"] = tcfg.get("names")
entry["packs"] = []
entry["class_counts"] = {}
out.append((task, entry))
return out
def map_report_task(report_name: str) -> tuple[str, str | None]:
if report_name in REPORT_TASK_ALIASES:
return REPORT_TASK_ALIASES[report_name]
return report_name, None
def task_defs_for_pending(reg: dict) -> dict[str, Any]:
"""平台 pending API 的 task_defs。"""
defs: dict[str, Any] = {}
for task, tcfg in load_registry(reg).items():
if tcfg.get("type") == "multi":
defs[task] = {
"type": "multi",
"domain": tcfg.get("domain", "dms"),
"label": tcfg.get("label", task),
"modes": {
m: {
"type": mc.get("type"),
"nc": mc.get("nc"),
"names": mc.get("names"),
"task_dir": f"{tcfg.get('task_dir', task)}/{mc.get('subdir', m)}",
}
for m, mc in (tcfg.get("modes") or {}).items()
},
}
else:
defs[task] = {
"type": tcfg.get("type"),
"domain": tcfg.get("domain", "dms"),
"label": tcfg.get("label", task),
"nc": tcfg.get("nc"),
"names": tcfg.get("names"),
"task_dir": tcfg.get("task_dir", task),
}
return defs