Files

132 lines
4.6 KiB
Python
Raw Permalink Normal View History

"""DMS YOLO 引擎local 研发轨 / platform 平台轨。"""
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from typing import Any
import yaml
WORKSPACE = Path(__file__).resolve().parents[2]
YOLO_ROOT = WORKSPACE / "algorithms/dms_yolo/code"
DATASET_ROOT = WORKSPACE / "datasets/dms"
def _resolve_yaml_key(task: str, submode: str | None = None) -> tuple[str, str]:
import sys
scripts = DATASET_ROOT / "scripts"
if str(scripts) not in sys.path:
sys.path.insert(0, str(scripts))
from task_registry import get_mode_config, resolve_task_id, train_yaml_key
reg = yaml.safe_load((DATASET_ROOT / "datasets.registry.yaml").read_text(encoding="utf-8"))
task, submode = resolve_task_id(task, submode)
mcfg = get_mode_config(task, submode, reg)
key = train_yaml_key(task, submode, reg)
return key, mcfg["type"]
def _latest_run(task: str, submode: str | None = None) -> tuple[str | None, str | None]:
yaml_key, typ = _resolve_yaml_key(task, submode)
mode_yolo = {"detect": "detect", "pose": "pose", "classify": "classify"}[typ]
runs = sorted((YOLO_ROOT / "runs" / mode_yolo).glob(f"{yaml_key}_*"), key=lambda p: p.stat().st_mtime, reverse=True)
if not runs:
return None, None
run_dir = runs[0].resolve()
best = run_dir / "weights" / "best.pt"
return str(run_dir), str(best) if best.is_file() else None
def train_local(task: str, mode: str = "full", config_overrides: dict | None = None, submode: str | None = None) -> dict[str, Any]:
"""研发轨train.sh不 refresh yaml_active不更新 candidate。"""
env = {**os.environ}
if submode:
env["SUBMODE"] = submode
proc = subprocess.run(
[str(DATASET_ROOT / "scripts" / "train.sh"), task, mode],
cwd=str(DATASET_ROOT),
capture_output=True,
text=True,
env=env,
)
if proc.returncode != 0:
raise RuntimeError(proc.stderr or proc.stdout)
run_dir, best_weights = _latest_run(task, submode)
return {
"ok": True,
"stdout": proc.stdout,
"stderr": proc.stderr,
"track": "local",
"command": f"{DATASET_ROOT / 'scripts' / 'train.sh'} {task} {mode}",
"run_dir": run_dir,
"best_weights": best_weights,
"note": "local 轨不写入 train_versions candidate",
}
def train_platform(task: str, mode: str = "full", submode: str | None = None) -> dict[str, Any]:
"""平台轨refresh yaml_active + train + 更新 candidate。"""
yaml_key, _ = _resolve_yaml_key(task, submode)
subprocess.check_call(
[sys.executable, str(DATASET_ROOT / "scripts/refresh_yaml.py"), "--task", task],
cwd=str(DATASET_ROOT),
)
env = {**os.environ}
if submode:
env["SUBMODE"] = submode
proc = subprocess.run(
[str(DATASET_ROOT / "scripts" / "train.sh"), task, mode],
cwd=str(DATASET_ROOT),
capture_output=True,
text=True,
env=env,
)
if proc.returncode != 0:
raise RuntimeError(proc.stderr or proc.stdout)
run_dir, best_weights = _latest_run(task, submode)
candidate = None
if best_weights:
candidate = best_weights
versions_path = DATASET_ROOT / "manifests/train_versions.yaml"
versions = yaml.safe_load(versions_path.read_text(encoding="utf-8"))
versions[yaml_key]["candidate"] = candidate
versions_path.write_text(yaml.dump(versions, allow_unicode=True, sort_keys=False), encoding="utf-8")
return {
"ok": True,
"stdout": proc.stdout,
"stderr": proc.stderr,
"track": "platform",
"command": f"{DATASET_ROOT / 'scripts' / 'train.sh'} {task} {mode}",
"run_dir": run_dir,
"best_weights": best_weights,
"candidate": candidate,
}
def eval_task(task: str, weights: Path | None = None) -> dict[str, Any]:
argv = [sys.executable, str(WORKSPACE / "as.py"), "eval", "dms", task, "--save-candidate"]
if weights:
argv.extend(["--weights", str(weights)])
proc = subprocess.run(argv, cwd=str(WORKSPACE), capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(proc.stderr or proc.stdout)
return {
"ok": True,
"stdout": proc.stdout,
"stderr": proc.stderr,
"weights": str(weights) if weights else None,
"command": " ".join(argv),
}
def visualize_task(task: str, weights: Path | None = None) -> dict[str, Any]:
"""复用 eval 流程产出可视化结果。"""
out = eval_task(task, weights=weights)
out["note"] = "可视化结果请查看 DMS 评估输出目录与 runs/detect 结果。"
return out