470 lines
25 KiB
Python
470 lines
25 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
Ground Joint 2D+3D Detection Training Script for YOLO26
|
|||
|
|
Supports single-GPU and multi-GPU distributed training with 3D loss ramping.
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
# Single GPU with pretrained 2D weights
|
|||
|
|
python train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100 --device 0
|
|||
|
|
|
|||
|
|
# Multi-GPU (DDP) with pretrained 2D weights
|
|||
|
|
python -m torch.distributed.run --nproc_per_node 4 train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100
|
|||
|
|
|
|||
|
|
# Train from scratch (no pretrained weights)
|
|||
|
|
python train_mono3d.py --epochs 100 --device 0
|
|||
|
|
|
|||
|
|
# Resume from checkpoint
|
|||
|
|
python train_mono3d.py --resume runs/detect/train_mono3d/weights/last.pt
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import os
|
|||
|
|
import re
|
|||
|
|
import sys
|
|||
|
|
import tempfile
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any, Optional, Tuple
|
|||
|
|
from uuid import uuid4
|
|||
|
|
|
|||
|
|
from torch.distributed.elastic.multiprocessing.errors import record
|
|||
|
|
|
|||
|
|
os.environ['TORCH_CACHE_DIR'] = './torch_cache'
|
|||
|
|
os.makedirs(os.environ['TORCH_CACHE_DIR'], exist_ok=True)
|
|||
|
|
|
|||
|
|
# Add ultralytics to path if needed
|
|||
|
|
FILE = Path(__file__).resolve()
|
|||
|
|
ROOT = FILE.parents[0]
|
|||
|
|
if str(ROOT) not in sys.path:
|
|||
|
|
sys.path.append(str(ROOT))
|
|||
|
|
|
|||
|
|
from ultralytics import settings
|
|||
|
|
from ultralytics.models.yolo.detect.train import Ground3DDetectionTrainer
|
|||
|
|
from ultralytics.utils import LOGGER, YAML, colorstr
|
|||
|
|
from ultralytics.utils.dist import get_distributed_run_timestamp
|
|||
|
|
|
|||
|
|
# Enable TensorBoard logging
|
|||
|
|
settings.update({"tensorboard": True})
|
|||
|
|
|
|||
|
|
GROUND3D_REQUIRED_CONFIG_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")
|
|||
|
|
GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"}
|
|||
|
|
GROUND3D_MODEL_VARIANTS = {f"yolo26{scale}": scale for scale in "nslmx"}
|
|||
|
|
GROUND3D_DEFAULT_MODEL_CFG = "ultralytics/cfg/models/26/yolo26-3d.yaml"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _validate_ground3d_pair(cfg: dict[str, Any], key: str, context: str) -> None:
|
|||
|
|
"""Validate that a Ground3D size-like config field is an integer pair."""
|
|||
|
|
value = cfg.get(key)
|
|||
|
|
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
|||
|
|
raise ValueError(f"{context} must define '{key}' as a 2-item list/tuple, but got: {value!r}")
|
|||
|
|
if any(int(v) <= 0 for v in value):
|
|||
|
|
raise ValueError(f"{context} must define positive values for '{key}', but got: {value!r}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _validate_ground3d_config(cfg: dict[str, Any], context: str) -> None:
|
|||
|
|
"""Validate required Ground3D config fields before training starts."""
|
|||
|
|
missing = [key for key in GROUND3D_REQUIRED_CONFIG_KEYS if key not in cfg or cfg[key] is None]
|
|||
|
|
if missing:
|
|||
|
|
missing_str = ", ".join(missing)
|
|||
|
|
raise ValueError(f"{context} is missing required Ground3D config field(s): {missing_str}.")
|
|||
|
|
|
|||
|
|
if not str(cfg["path"]).strip():
|
|||
|
|
raise ValueError(f"{context} must define a non-empty 'path'.")
|
|||
|
|
|
|||
|
|
_validate_ground3d_pair(cfg, "ori_img_size", context)
|
|||
|
|
_validate_ground3d_pair(cfg, "roi", context)
|
|||
|
|
|
|||
|
|
if float(cfg["virtual_fx"]) <= 0:
|
|||
|
|
raise ValueError(f"{context} must define a positive 'virtual_fx', but got: {cfg['virtual_fx']!r}.")
|
|||
|
|
try:
|
|||
|
|
float(cfg["virtual_camera_prob"])
|
|||
|
|
except (TypeError, ValueError) as e:
|
|||
|
|
raise ValueError(
|
|||
|
|
f"{context} must define 'virtual_camera_prob' as a numeric value, but got: {cfg['virtual_camera_prob']!r}."
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
crop_center_mode = cfg["crop_center_mode"]
|
|||
|
|
if crop_center_mode not in GROUND3D_VALID_CROP_CENTER_MODES:
|
|||
|
|
valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES))
|
|||
|
|
raise ValueError(
|
|||
|
|
f"{context} must define 'crop_center_mode' as one of {{{valid_modes}}}, but got: {crop_center_mode!r}."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
|
|||
|
|
"""Build a timestamped run name that stays consistent across distributed ranks."""
|
|||
|
|
if resume:
|
|||
|
|
resume_path = Path(resume)
|
|||
|
|
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
|
|||
|
|
return resume_path.parents[1].name
|
|||
|
|
return base_name
|
|||
|
|
|
|||
|
|
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
|
|||
|
|
return f"{base_name}_{timestamp}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_ground3d_model(model_arg: str) -> Tuple[str, Optional[str]]:
|
|||
|
|
"""Resolve shorthand variants like 'yolo26s' to the shared 3D YAML plus an explicit scale."""
|
|||
|
|
model_arg = model_arg.strip()
|
|||
|
|
if re.fullmatch(r"yolo26[nslmx]", model_arg):
|
|||
|
|
return GROUND3D_DEFAULT_MODEL_CFG, GROUND3D_MODEL_VARIANTS[model_arg]
|
|||
|
|
return str(Path(model_arg).expanduser()), None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_data_yaml_for_roi(data_path: str, roi_name: Optional[str]) -> Tuple[str, Optional[str]]:
|
|||
|
|
"""Resolve a dataset YAML to a temporary file with the selected ROI preset applied."""
|
|||
|
|
data_path = str(Path(data_path).expanduser())
|
|||
|
|
if Path(data_path).suffix not in {".yaml", ".yml"}:
|
|||
|
|
if roi_name:
|
|||
|
|
raise ValueError("--roi requires --data to point to a dataset YAML file.")
|
|||
|
|
return data_path, None
|
|||
|
|
|
|||
|
|
data_cfg = YAML.load(data_path)
|
|||
|
|
roi_configs = data_cfg.get("roi_configs")
|
|||
|
|
default_roi = data_cfg.get("default_roi")
|
|||
|
|
selected_roi = roi_name or default_roi
|
|||
|
|
|
|||
|
|
if selected_roi is None:
|
|||
|
|
return data_path, None
|
|||
|
|
if not isinstance(roi_configs, dict) or not roi_configs:
|
|||
|
|
raise ValueError(f"Dataset YAML '{data_path}' does not define any 'roi_configs' presets.")
|
|||
|
|
if selected_roi not in roi_configs:
|
|||
|
|
available = ", ".join(sorted(roi_configs))
|
|||
|
|
raise ValueError(
|
|||
|
|
f"Unknown ROI preset '{selected_roi}' for dataset '{data_path}'. Available presets: {available}."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
roi_overrides = roi_configs[selected_roi]
|
|||
|
|
if not isinstance(roi_overrides, dict):
|
|||
|
|
raise ValueError(f"ROI preset '{selected_roi}' in '{data_path}' must be a mapping.")
|
|||
|
|
|
|||
|
|
resolved_cfg: dict[str, Any] = dict(data_cfg)
|
|||
|
|
resolved_cfg.pop("default_roi", None)
|
|||
|
|
resolved_cfg.pop("roi_configs", None)
|
|||
|
|
resolved_cfg.update(roi_overrides)
|
|||
|
|
_validate_ground3d_config(resolved_cfg, f"Resolved dataset YAML '{data_path}' with ROI preset '{selected_roi}'")
|
|||
|
|
|
|||
|
|
source_path = Path(data_path)
|
|||
|
|
resolved_path = (
|
|||
|
|
Path(tempfile.gettempdir())
|
|||
|
|
/ "ultralytics_roi_configs"
|
|||
|
|
/ f"{source_path.stem}.{selected_roi}.{os.getpid()}.{uuid4().hex}{source_path.suffix}"
|
|||
|
|
)
|
|||
|
|
header = (
|
|||
|
|
"# Auto-generated by train_mono3d.py\n"
|
|||
|
|
f"# Source dataset YAML: {source_path.resolve()}\n"
|
|||
|
|
f"# Selected ROI preset: {selected_roi}\n\n"
|
|||
|
|
)
|
|||
|
|
YAML.save(file=resolved_path, data=resolved_cfg, header=header)
|
|||
|
|
return str(resolved_path), selected_roi
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_args():
|
|||
|
|
"""Parse command line arguments."""
|
|||
|
|
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground Joint 2D+3D Detection Dataset") # description 是帮助信息标题。
|
|||
|
|
|
|||
|
|
# Model and data
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--model", type=str, default=GROUND3D_DEFAULT_MODEL_CFG,
|
|||
|
|
help="Model config YAML path or shorthand variant (yolo26n/yolo26s/yolo26m/yolo26l/yolo26x).",
|
|||
|
|
) # --model ultralytics/cfg/models/26/yolo26-3d.yaml 或者 --model yolo26s
|
|||
|
|
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono3d_ground.yaml", help="Dataset YAML path")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--roi",
|
|||
|
|
type=str,
|
|||
|
|
default=None,
|
|||
|
|
help="Named ROI preset from the dataset YAML 'roi_configs' section (for example: roi0 or roi1)",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Training hyperparameters
|
|||
|
|
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
|||
|
|
parser.add_argument("--batch", type=int, default=16, help="Batch size per GPU")
|
|||
|
|
parser.add_argument("--imgsz", type=str, default="704,352", help="Image size (int or w,h e.g. 704,352)")
|
|||
|
|
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
|
|||
|
|
|
|||
|
|
# Optimizer
|
|||
|
|
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
|
|||
|
|
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
|
|||
|
|
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
|
|||
|
|
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
|
|||
|
|
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
|
|||
|
|
parser.add_argument("--warmup_epochs", type=float, default=3.0, help="Warmup epochs (2D-only stage if loss_3d_warmup_epochs is not set)")
|
|||
|
|
parser.add_argument("--warmup_momentum", type=float, default=0.8, help="Initial momentum during warmup")
|
|||
|
|
parser.add_argument("--warmup_bias_lr", type=float, default=0.1, help="Initial bias learning rate during warmup")
|
|||
|
|
parser.add_argument("--e2e_o2m_start", type=float, default=0.8, help="Initial one-to-many loss weight for end-to-end training")
|
|||
|
|
parser.add_argument("--e2e_o2m_final", type=float, default=0.1, help="Final one-to-many loss weight after decay")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--e2e_o2m_decay_epochs",
|
|||
|
|
type=float,
|
|||
|
|
default=None,
|
|||
|
|
help="Epochs used to decay one-to-many weight to the final value; defaults to epochs - 1",
|
|||
|
|
)
|
|||
|
|
parser.add_argument("--loss_3d_warmup_epochs", type=float, default=None, help="Epochs before enabling 3D loss; defaults to warmup_epochs")
|
|||
|
|
parser.add_argument("--loss_3d_ramp_epochs", type=float, default=10.0, help="Epochs used to ramp 3D loss weight to its max value")
|
|||
|
|
parser.add_argument("--loss_3d_weight_max", type=float, default=0.1, help="Maximum 3D loss weight multiplier")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--face_visibility_score_thresh",
|
|||
|
|
type=float,
|
|||
|
|
default=None,
|
|||
|
|
help="Override visible-face score threshold used by 3D face supervision and visible-face metrics",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Loss weights
|
|||
|
|
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
|
|||
|
|
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
|
|||
|
|
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain") # DFL loss 权重。 注意当前 yolo26-3d.yaml 里 reg_max=1,DFL 实际弱化/关闭,但参数仍保留兼容。
|
|||
|
|
|
|||
|
|
# Augmentation
|
|||
|
|
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
|
|||
|
|
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
|
|||
|
|
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
|
|||
|
|
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
|
|||
|
|
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
|
|||
|
|
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
|
|||
|
|
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
|
|||
|
|
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
|
|||
|
|
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
|
|||
|
|
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
|
|||
|
|
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
|
|||
|
|
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
|
|||
|
|
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
|
|||
|
|
|
|||
|
|
# Training settings
|
|||
|
|
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience") # early stopping 的 patience。如果验证指标长期不提升,最多等 100 个 epoch。
|
|||
|
|
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (default: every epoch, -1 to disable)")
|
|||
|
|
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images") # 是否缓存图片
|
|||
|
|
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
|
|||
|
|
parser.add_argument("--project", type=str, default=None, help="Project directory") # 训练结果保存的 project 目录
|
|||
|
|
parser.add_argument("--name", type=str, default="train_mono3d", help="Experiment name")
|
|||
|
|
parser.add_argument("--exp_dir", type=str, default="",
|
|||
|
|
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
|
|||
|
|
) # 如果传这个,就直接用该目录保存结果,不再走默认 project/name 逻辑
|
|||
|
|
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
|
|||
|
|
parser.add_argument("--pretrained", type=str, default="", help="Pretrained 2D weights path (e.g. yolo26s-pretrain.pt). Scale extracted from filename.")
|
|||
|
|
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--strict-backbone-pretrained",
|
|||
|
|
dest="strict_backbone_pretrained",
|
|||
|
|
action=argparse.BooleanOptionalAction,
|
|||
|
|
default=True,
|
|||
|
|
help="Fail startup if any backbone parameter tensor can not be transferred from --pretrained.",
|
|||
|
|
)
|
|||
|
|
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
|||
|
|
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use") # 只用 10% 数据,适合快速调试
|
|||
|
|
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
|
|||
|
|
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
|
|||
|
|
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
|
|||
|
|
|
|||
|
|
# Validation
|
|||
|
|
parser.add_argument("--val", action="store_true", default=True, help="Validate during training") # 默认已经是 True,传 --val 也还是 True。也就是说这段代码没有提供一个直接关闭 val 的 --no-val
|
|||
|
|
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training") # 是否保存训练/验证图
|
|||
|
|
parser.add_argument("--roi_metrics_only", action="store_true", help="Compute validation metrics using ROI samples only") # 当 batch 里有 camera_mode 时,validation 可以跳过 virtual-camera 样本,只统计 ROI 样本的 metrics
|
|||
|
|
|
|||
|
|
# Advanced
|
|||
|
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
|||
|
|
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
|
|||
|
|
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset") # 是否按单类别训练,把所有类别合并成一个类
|
|||
|
|
parser.add_argument("--rect", action="store_true", help="Rectangular training")
|
|||
|
|
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
|
|||
|
|
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
|
|||
|
|
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
|
|||
|
|
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
|
|||
|
|
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
|
|||
|
|
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
|
|||
|
|
|
|||
|
|
return parser.parse_args()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@record # 在 DDP / 分布式训练时,如果子进程报错,它能把异常栈记录得更完整,方便定位哪个 rank 出错
|
|||
|
|
def main():
|
|||
|
|
"""Main training function. 解析命令行参数 -> 处理模型/数据/ROI 配置 -> 组装 Ultralytics trainer 参数 -> 启动 Ground3DDetectionTrainer 训练"""
|
|||
|
|
args = parse_args()
|
|||
|
|
|
|||
|
|
# Print banner
|
|||
|
|
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
|
|||
|
|
LOGGER.info(colorstr("bright_blue", "bold", "Ground Joint 2D+3D Detection Training with YOLO26"))
|
|||
|
|
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
|
|||
|
|
|
|||
|
|
# Confirm TensorBoard is enabled
|
|||
|
|
LOGGER.info(colorstr("bright_green", "TensorBoard logging enabled"))
|
|||
|
|
|
|||
|
|
# Check for distributed training
|
|||
|
|
world_size = int(os.getenv("WORLD_SIZE", 1))
|
|||
|
|
rank = int(os.getenv("RANK", -1))
|
|||
|
|
local_rank = int(os.getenv("LOCAL_RANK", -1))
|
|||
|
|
|
|||
|
|
if world_size > 1:
|
|||
|
|
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
|||
|
|
else:
|
|||
|
|
LOGGER.info("Single GPU training")
|
|||
|
|
|
|||
|
|
# For 3D training: --model can be the 3D config yaml or a shorthand variant like "yolo26s".
|
|||
|
|
pretrained = args.pretrained if args.pretrained else False
|
|||
|
|
resolved_model_path, explicit_model_scale = resolve_ground3d_model(args.model) # 如果是 yolo26s,resolve_ground3d_model() 会返回: resolved_model_path = "ultralytics/cfg/models/26/yolo26-3d.yaml" + explicit_model_scale = "s"
|
|||
|
|
if pretrained and explicit_model_scale is None and not re.fullmatch(r"yolo26[nslmx].*", Path(str(pretrained)).stem): # 如果用了 pretrained,但是 --model 没有显式写 yolo26n/s/l/m/x,并且 pretrained 文件名里也看不出 scale,就发 warning。
|
|||
|
|
LOGGER.warning(
|
|||
|
|
"Pretrained checkpoint name does not encode a YOLO26 scale. Pass '--model yolo26n/yolo26s/...' "
|
|||
|
|
"to lock the 3D architecture scale explicitly."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Parse imgsz: "704,352" -> [704, 352], "640" -> 640
|
|||
|
|
if isinstance(args.imgsz, str) and "," in args.imgsz:
|
|||
|
|
imgsz = [int(x) for x in args.imgsz.split(",")] # --imgsz 704,352 --> imgsz = [704, 352]
|
|||
|
|
else:
|
|||
|
|
imgsz = int(args.imgsz)
|
|||
|
|
|
|||
|
|
if args.resume and args.exp_dir: # 如果 --resume 断点恢复,训练目录已经由 checkpoint 决定;此时不能再手动指定 --exp_dir。
|
|||
|
|
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
|
|||
|
|
|
|||
|
|
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
|
|||
|
|
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank) # 如果指定了 exp_dir,实验名就是目录名。 否则调用 build_run_name(),通常会加时间戳,并保证 DDP 多个 rank 用同一个名字。
|
|||
|
|
resolved_data_path, selected_roi = resolve_data_yaml_for_roi(args.data, args.roi) # --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 ==> --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 selected_roi = "roi1"
|
|||
|
|
|
|||
|
|
# Build overrides dictionary
|
|||
|
|
overrides = {
|
|||
|
|
"model": resolved_model_path,
|
|||
|
|
"data": resolved_data_path,
|
|||
|
|
"epochs": args.epochs,
|
|||
|
|
"batch": args.batch,
|
|||
|
|
"imgsz": imgsz,
|
|||
|
|
"device": args.device,
|
|||
|
|
"optimizer": args.optimizer,
|
|||
|
|
"lr0": args.lr0,
|
|||
|
|
"lrf": args.lrf,
|
|||
|
|
"momentum": args.momentum,
|
|||
|
|
"weight_decay": args.weight_decay,
|
|||
|
|
"warmup_epochs": args.warmup_epochs,
|
|||
|
|
"warmup_momentum": args.warmup_momentum,
|
|||
|
|
"warmup_bias_lr": args.warmup_bias_lr,
|
|||
|
|
"e2e_o2m_start": args.e2e_o2m_start, # YOLO26 end-to-end 训练里 one-to-many loss 的权重控制。0.8-->0.1
|
|||
|
|
"e2e_o2m_final": args.e2e_o2m_final,
|
|||
|
|
"loss_3d_ramp_epochs": args.loss_3d_ramp_epochs, # 3D loss 不一定一开始就满权重参与,而是逐渐从 0 增长到最大值。10
|
|||
|
|
"loss_3d_weight_max": args.loss_3d_weight_max, # 0.1
|
|||
|
|
"box": args.box,
|
|||
|
|
"cls": args.cls,
|
|||
|
|
"dfl": args.dfl,
|
|||
|
|
"hsv_h": args.hsv_h,
|
|||
|
|
"hsv_s": args.hsv_s,
|
|||
|
|
"hsv_v": args.hsv_v,
|
|||
|
|
"degrees": args.degrees,
|
|||
|
|
"translate": args.translate,
|
|||
|
|
"scale": args.scale,
|
|||
|
|
"shear": args.shear,
|
|||
|
|
"perspective": args.perspective,
|
|||
|
|
"flipud": args.flipud,
|
|||
|
|
"fliplr": args.fliplr,
|
|||
|
|
"mosaic": args.mosaic,
|
|||
|
|
"mixup": args.mixup,
|
|||
|
|
"copy_paste": args.copy_paste,
|
|||
|
|
"patience": args.patience, # early stopping 等待轮数。
|
|||
|
|
"save_period": args.save_period, # 每隔多少 epoch 保存 checkpoint。
|
|||
|
|
"cache": args.cache if args.cache else False,
|
|||
|
|
"workers": args.workers,
|
|||
|
|
"project": args.project,
|
|||
|
|
"name": run_name,
|
|||
|
|
"exist_ok": args.exist_ok,
|
|||
|
|
"pretrained": pretrained,
|
|||
|
|
"amp": args.amp,
|
|||
|
|
"fraction": args.fraction,
|
|||
|
|
"profile": args.profile,
|
|||
|
|
"multi_scale": args.multi_scale,
|
|||
|
|
"val": args.val,
|
|||
|
|
"plots": args.plots,
|
|||
|
|
"roi_metrics_only": args.roi_metrics_only,
|
|||
|
|
"seed": args.seed,
|
|||
|
|
"deterministic": args.deterministic,
|
|||
|
|
"single_cls": args.single_cls,
|
|||
|
|
"rect": args.rect,
|
|||
|
|
"cos_lr": args.cos_lr,
|
|||
|
|
"close_mosaic": args.close_mosaic,
|
|||
|
|
"nbs": args.nbs,
|
|||
|
|
"overlap_mask": args.overlap_mask,
|
|||
|
|
"mask_ratio": args.mask_ratio,
|
|||
|
|
"dropout": args.dropout,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Add freeze if specified
|
|||
|
|
if args.freeze is not None: # 果用户传了 --freeze,加入冻结层数配置。
|
|||
|
|
overrides["freeze"] = args.freeze
|
|||
|
|
|
|||
|
|
if args.loss_3d_warmup_epochs is not None:
|
|||
|
|
overrides["loss_3d_warmup_epochs"] = args.loss_3d_warmup_epochs
|
|||
|
|
if args.e2e_o2m_decay_epochs is not None:
|
|||
|
|
overrides["e2e_o2m_decay_epochs"] = args.e2e_o2m_decay_epochs
|
|||
|
|
if args.face_visibility_score_thresh is not None:
|
|||
|
|
overrides["face_visibility_score_thresh"] = args.face_visibility_score_thresh
|
|||
|
|
|
|||
|
|
# Add resume if specified
|
|||
|
|
if args.resume:
|
|||
|
|
overrides["resume"] = args.resume
|
|||
|
|
if exp_dir:
|
|||
|
|
overrides["save_dir"] = exp_dir
|
|||
|
|
|
|||
|
|
# Print configuration
|
|||
|
|
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
|
|||
|
|
LOGGER.info(f" Model arg: {args.model}")
|
|||
|
|
LOGGER.info(f" Model config: {resolved_model_path}")
|
|||
|
|
LOGGER.info(
|
|||
|
|
" Model scale: "
|
|||
|
|
f"{explicit_model_scale if explicit_model_scale else 'infer from pretrained filename or YAML default'}"
|
|||
|
|
)
|
|||
|
|
LOGGER.info(f" Pretrained weights: {pretrained if pretrained else 'None'}")
|
|||
|
|
LOGGER.info(f" Strict pretrained backbone load: {args.strict_backbone_pretrained}")
|
|||
|
|
LOGGER.info(f" Data: {args.data}")
|
|||
|
|
if selected_roi:
|
|||
|
|
LOGGER.info(f" ROI preset: {selected_roi}")
|
|||
|
|
LOGGER.info(f" Resolved data YAML: {resolved_data_path}")
|
|||
|
|
LOGGER.info(f" Epochs: {args.epochs}")
|
|||
|
|
LOGGER.info(f" Batch size: {args.batch}")
|
|||
|
|
LOGGER.info(f" Image size: {args.imgsz}")
|
|||
|
|
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
|
|||
|
|
LOGGER.info(f" Workers: {args.workers}")
|
|||
|
|
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
|
|||
|
|
LOGGER.info(f" Run name: {run_name}")
|
|||
|
|
LOGGER.info(f" Checkpoints: every {args.save_period} epoch(s)" if args.save_period > 0 else " Checkpoints: last.pt and best.pt only")
|
|||
|
|
if exp_dir:
|
|||
|
|
LOGGER.info(f" Experiment dir: {exp_dir}")
|
|||
|
|
LOGGER.info(f" AMP: {args.amp}")
|
|||
|
|
LOGGER.info(f" Optimizer: {args.optimizer}")
|
|||
|
|
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
|
|||
|
|
LOGGER.info(f" Warmup: epochs={args.warmup_epochs}, momentum={args.warmup_momentum}, bias_lr={args.warmup_bias_lr}")
|
|||
|
|
LOGGER.info(
|
|||
|
|
" E2E one-to-many weight: "
|
|||
|
|
f"start={args.e2e_o2m_start}, final={args.e2e_o2m_final}, "
|
|||
|
|
f"decay_epochs={args.e2e_o2m_decay_epochs if args.e2e_o2m_decay_epochs is not None else 'epochs-1'}"
|
|||
|
|
)
|
|||
|
|
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
|
|||
|
|
loss_3d_warmup = args.loss_3d_warmup_epochs if args.loss_3d_warmup_epochs is not None else args.warmup_epochs
|
|||
|
|
LOGGER.info(
|
|||
|
|
f" 3D loss: warmup={loss_3d_warmup} epochs, ramp={args.loss_3d_ramp_epochs} epochs, max={args.loss_3d_weight_max}"
|
|||
|
|
)
|
|||
|
|
LOGGER.info(
|
|||
|
|
" Face visibility score threshold: "
|
|||
|
|
f"{args.face_visibility_score_thresh if args.face_visibility_score_thresh is not None else 'default.yaml'}"
|
|||
|
|
)
|
|||
|
|
LOGGER.info("")
|
|||
|
|
|
|||
|
|
# Initialize trainer
|
|||
|
|
LOGGER.info(colorstr("bright_green", "Initializing Ground3DDetectionTrainer..."))
|
|||
|
|
trainer = Ground3DDetectionTrainer(overrides=overrides) # 进入 Ultralytics 训练框架,做很多初始化
|
|||
|
|
trainer.explicit_model_scale = explicit_model_scale # 把显式模型 scale 挂到 trainer 上。后面 Ground3DDetectionTrainer.get_model() 会用它来决定模型规模,比如 n/s/m/l/x。
|
|||
|
|
trainer.strict_backbone_pretrained = args.strict_backbone_pretrained # 把严格预训练加载检查开关挂到 trainer 上。后面加载 pretrained 时会用。
|
|||
|
|
|
|||
|
|
# Start training
|
|||
|
|
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
|
|||
|
|
trainer.train()
|
|||
|
|
|
|||
|
|
# Print completion message
|
|||
|
|
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
|
|||
|
|
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
|
|||
|
|
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
|
|||
|
|
|
|||
|
|
# Print results location
|
|||
|
|
if hasattr(trainer, "save_dir"):
|
|||
|
|
LOGGER.info(f"Results saved to: {trainer.save_dir}")
|
|||
|
|
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
|
|||
|
|
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|