167 lines
5.6 KiB
Python
167 lines
5.6 KiB
Python
|
|
"""DMS 任务注册表:domain 分组、multi 任务(前向 detect+classify)、旧 ID 别名。"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
DOMAIN_LABELS = {
|
|||
|
|
"dms": "舱内 DMS",
|
|||
|
|
"forward": "前向 ADAS",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 报表 / 旧目录名 -> (task, mode)
|
|||
|
|
REPORT_TASK_ALIASES: dict[str, tuple[str, str | None]] = {
|
|||
|
|
"isa": ("forward", "detect"),
|
|||
|
|
"isa_detect": ("forward", "detect"),
|
|||
|
|
"isa_class": ("forward", "classify"),
|
|||
|
|
"isa_class_0116": ("forward", "classify"),
|
|||
|
|
"dam_0417": ("dam", "batch_0417"),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
LEGACY_TASK_ALIASES: dict[str, tuple[str, str | None]] = {
|
|||
|
|
"isa": ("forward", "detect"),
|
|||
|
|
"isa_class": ("forward", "classify"),
|
|||
|
|
"dam_0417": ("dam", "batch_0417"),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_registry(reg: dict) -> dict[str, Any]:
|
|||
|
|
return reg.get("tasks") or {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_task_id(task: str, mode: str | None = None) -> tuple[str, str | None]:
|
|||
|
|
"""用户/历史 task ID -> (canonical_task, mode)。"""
|
|||
|
|
if task in LEGACY_TASK_ALIASES:
|
|||
|
|
t, m = LEGACY_TASK_ALIASES[task]
|
|||
|
|
return t, mode or m
|
|||
|
|
return task, mode
|
|||
|
|
|
|||
|
|
|
|||
|
|
def report_task_key(task: str, mode: str | None = None) -> str:
|
|||
|
|
"""catalog 报表 CSV 中的任务列名。"""
|
|||
|
|
t, m = resolve_task_id(task, mode)
|
|||
|
|
if t == "forward" and m == "detect":
|
|||
|
|
return "isa"
|
|||
|
|
if t == "forward" and m == "classify":
|
|||
|
|
return "isa_class"
|
|||
|
|
if t == "dam" and m == "batch_0516":
|
|||
|
|
return "dam"
|
|||
|
|
if t == "dam" and m == "batch_0417":
|
|||
|
|
return "dam_0417"
|
|||
|
|
return task
|
|||
|
|
|
|||
|
|
|
|||
|
|
def train_yaml_key(task: str, mode: str | None, reg: dict) -> str:
|
|||
|
|
"""manifests/yaml_active 与 train.sh 使用的文件名(不含 .yaml)。"""
|
|||
|
|
task, mode = resolve_task_id(task, mode)
|
|||
|
|
tcfg = load_registry(reg)[task]
|
|||
|
|
if tcfg.get("type") == "multi":
|
|||
|
|
if not mode:
|
|||
|
|
raise ValueError(f"任务 {task} 需指定 mode(detect / classify)")
|
|||
|
|
return f"{task}__{mode}"
|
|||
|
|
return task
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_mode_config(task: str, mode: str | None, reg: dict) -> dict[str, Any]:
|
|||
|
|
task, mode = resolve_task_id(task, mode)
|
|||
|
|
tcfg = load_registry(reg)[task]
|
|||
|
|
if tcfg.get("type") != "multi":
|
|||
|
|
return {**tcfg, "task": task, "mode": None}
|
|||
|
|
modes = tcfg.get("modes") or {}
|
|||
|
|
if not mode:
|
|||
|
|
raise ValueError(f"任务 {task} 需指定 mode")
|
|||
|
|
if mode not in modes:
|
|||
|
|
raise ValueError(f"未知 mode: {task}/{mode}")
|
|||
|
|
mcfg = dict(modes[mode])
|
|||
|
|
mcfg["task"] = task
|
|||
|
|
mcfg["mode"] = mode
|
|||
|
|
mcfg["task_dir"] = f"{tcfg.get('task_dir', task)}/{mcfg.get('subdir', mode)}"
|
|||
|
|
mcfg["domain"] = tcfg.get("domain")
|
|||
|
|
mcfg["label"] = mcfg.get("label") or tcfg.get("label")
|
|||
|
|
return mcfg
|
|||
|
|
|
|||
|
|
|
|||
|
|
def task_data_dir(pack_dir, task: str, mode: str | None, reg: dict):
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
mcfg = get_mode_config(task, mode, reg)
|
|||
|
|
return Path(pack_dir) / mcfg["task_dir"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def inbox_dir(root, task: str, mode: str | None, reg: dict):
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
task, mode = resolve_task_id(task, mode)
|
|||
|
|
tcfg = load_registry(reg)[task]
|
|||
|
|
if tcfg.get("type") == "multi":
|
|||
|
|
mcfg = tcfg["modes"][mode or ""]
|
|||
|
|
rel = mcfg.get("inbox") or f"inbox/{task}/{mode}"
|
|||
|
|
return Path(root) / rel
|
|||
|
|
return Path(root) / (tcfg.get("inbox") or f"inbox/{task}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def iter_catalog_tasks(reg: dict) -> list[tuple[str, dict[str, Any]]]:
|
|||
|
|
"""catalog 顶层任务列表。"""
|
|||
|
|
out: list[tuple[str, dict[str, Any]]] = []
|
|||
|
|
for task, tcfg in load_registry(reg).items():
|
|||
|
|
entry = {
|
|||
|
|
"domain": tcfg.get("domain", "dms"),
|
|||
|
|
"domain_label": DOMAIN_LABELS.get(tcfg.get("domain", "dms"), tcfg.get("domain", "dms")),
|
|||
|
|
"label": tcfg.get("label", task),
|
|||
|
|
"type": tcfg.get("type"),
|
|||
|
|
}
|
|||
|
|
if tcfg.get("type") == "multi":
|
|||
|
|
entry["modes"] = {}
|
|||
|
|
for mode, mcfg in (tcfg.get("modes") or {}).items():
|
|||
|
|
entry["modes"][mode] = {
|
|||
|
|
"label": mcfg.get("label", mode),
|
|||
|
|
"type": mcfg.get("type"),
|
|||
|
|
"nc": mcfg.get("nc"),
|
|||
|
|
"names": mcfg.get("names"),
|
|||
|
|
"packs": [],
|
|||
|
|
"class_counts": {},
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
entry["nc"] = tcfg.get("nc")
|
|||
|
|
entry["names"] = tcfg.get("names")
|
|||
|
|
entry["packs"] = []
|
|||
|
|
entry["class_counts"] = {}
|
|||
|
|
out.append((task, entry))
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
def map_report_task(report_name: str) -> tuple[str, str | None]:
|
|||
|
|
if report_name in REPORT_TASK_ALIASES:
|
|||
|
|
return REPORT_TASK_ALIASES[report_name]
|
|||
|
|
return report_name, None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def task_defs_for_pending(reg: dict) -> dict[str, Any]:
|
|||
|
|
"""平台 pending API 的 task_defs。"""
|
|||
|
|
defs: dict[str, Any] = {}
|
|||
|
|
for task, tcfg in load_registry(reg).items():
|
|||
|
|
if tcfg.get("type") == "multi":
|
|||
|
|
defs[task] = {
|
|||
|
|
"type": "multi",
|
|||
|
|
"domain": tcfg.get("domain", "dms"),
|
|||
|
|
"label": tcfg.get("label", task),
|
|||
|
|
"modes": {
|
|||
|
|
m: {
|
|||
|
|
"type": mc.get("type"),
|
|||
|
|
"nc": mc.get("nc"),
|
|||
|
|
"names": mc.get("names"),
|
|||
|
|
"task_dir": f"{tcfg.get('task_dir', task)}/{mc.get('subdir', m)}",
|
|||
|
|
}
|
|||
|
|
for m, mc in (tcfg.get("modes") or {}).items()
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
defs[task] = {
|
|||
|
|
"type": tcfg.get("type"),
|
|||
|
|
"domain": tcfg.get("domain", "dms"),
|
|||
|
|
"label": tcfg.get("label", task),
|
|||
|
|
"nc": tcfg.get("nc"),
|
|||
|
|
"names": tcfg.get("names"),
|
|||
|
|
"task_dir": tcfg.get("task_dir", task),
|
|||
|
|
}
|
|||
|
|
return defs
|