Files
HSAP/datasets/dms/scripts/train.sh

104 lines
3.6 KiB
Bash
Raw Normal View History

#!/usr/bin/env bash
# train.sh <task> [full|continue] — 读 datasets.registry.yaml
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
DATASET_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
YOLO26_ROOT="${YOLO26_ROOT:-$(cd "$DATASET_ROOT/../Code/yolo26_rknn_ultralytics-main" 2>/dev/null && pwd || echo "")}"
# 优先使用 dms_yolo26 环境
if [[ -z "${CONDA_DEFAULT_ENV:-}" || "${CONDA_DEFAULT_ENV}" != "dms_yolo26" ]]; then
if [[ -f "${HOME}/miniconda3/etc/profile.d/conda.sh" ]]; then
source "${HOME}/miniconda3/etc/profile.d/conda.sh"
conda activate dms_yolo26 2>/dev/null || true
fi
fi
TASK="${1:?用法: $0 <task> [full|continue]}"
TRAIN_MODE="${2:-full}"
REG="$DATASET_ROOT/datasets.registry.yaml"
VERSIONS="$DATASET_ROOT/manifests/train_versions.yaml"
SUBMODE="${SUBMODE:-}"
read -r YAML_KEY TYPE MODE MODEL EPOCHS LR0 IMGSZ RUN_SUFFIX <<< "$(python3 - <<PY
import sys
from pathlib import Path
import yaml
sys.path.insert(0, str(Path("$DATASET_ROOT/scripts")))
from task_registry import get_mode_config, resolve_task_id, train_yaml_key
reg = yaml.safe_load(Path("$REG").read_text())
task, sub = resolve_task_id("$TASK", "$SUBMODE" or None)
mcfg = get_mode_config(task, sub, reg)
typ = mcfg["type"]
yaml_key = train_yaml_key(task, sub, reg)
print(yaml_key, end=" ")
train_mode = "$TRAIN_MODE" if "$TRAIN_MODE" in ("full", "continue") else reg.get("train", {}).get("mode", "full")
t = reg.get("train", {}).get(typ, reg.get("train_defaults", {}).get(typ, {}))
if train_mode == "continue":
model = t.get("warm_start") or "null"
epochs = t.get("epochs_continue", t.get("epochs_increment", 50))
lr0 = t.get("lr0_continue", t.get("lr0", 0.001))
suffix = "continue"
else:
model = t.get("model", "yolo26n.pt")
epochs = t.get("epochs", 100)
lr0 = t.get("lr0", 0.01)
suffix = "full"
imgsz = t.get("imgsz", 224 if typ == "classify" else 640)
mode = {"detect": "detect", "pose": "pose", "classify": "classify"}.get(typ, "detect")
print(typ, mode, model, epochs, lr0, imgsz, suffix)
PY
)"
YAML="$DATASET_ROOT/manifests/yaml_active/${YAML_KEY}.yaml"
if [[ ! -f "$YAML" ]]; then
echo "找不到 yaml: $YAML(请先 refresh_yaml.py"
exit 1
fi
# continue 模式warm_start 为空则读 train_versions.yaml
if [[ "$TRAIN_MODE" == "continue" && ( "$MODEL" == "null" || "$MODEL" == "None" || -z "$MODEL" ) ]]; then
MODEL=$(python3 - <<PY 2>/dev/null || true
import yaml
from pathlib import Path
p = Path("$VERSIONS")
if p.is_file():
v = yaml.safe_load(p.read_text()) or {}
c = v.get("$YAML_KEY", {}).get("current")
if c: print(c)
PY
)
fi
if [[ "$TRAIN_MODE" == "continue" && ( -z "$MODEL" || "$MODEL" == "null" ) ]]; then
echo "continue 模式需要 registry.train.<type>.warm_start 或 manifests/train_versions.yaml 中的 current"
exit 1
fi
RUN_NAME="${YAML_KEY}_${RUN_SUFFIX}_$(date +%Y%m%d)"
echo "task=$TASK submode=$SUBMODE yaml_key=$YAML_KEY type=$TYPE yolo_mode=$MODE train_mode=$TRAIN_MODE"
echo "data=$YAML"
echo "model=$MODEL epochs=$EPOCHS lr0=$LR0 imgsz=$IMGSZ name=$RUN_NAME"
if [[ -z "$YOLO26_ROOT" || ! -d "$YOLO26_ROOT" ]]; then
echo "请设置 YOLO26_ROOT 或安装到 ../Code/yolo26_rknn_ultralytics-main"
echo " cd \$YOLO26_ROOT"
echo " yolo $MODE train data=$YAML model=$MODEL epochs=$EPOCHS lr0=$LR0 imgsz=$IMGSZ project=runs/${MODE} name=$RUN_NAME"
exit 0
fi
cd "$YOLO26_ROOT"
yolo "$MODE" train \
data="$YAML" \
model="$MODEL" \
epochs="$EPOCHS" \
lr0="$LR0" \
imgsz="$IMGSZ" \
project="runs/${MODE}" \
name="$RUN_NAME"
BEST="runs/${MODE}/${RUN_NAME}/weights/best.pt"
echo "完成: $BEST"
echo "请更新 manifests/train_versions.yaml 中 $YAML_KEY.current = $BEST"