470 lines
25 KiB
Python
Executable File
470 lines
25 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
Ground Joint 2D+3D Detection Training Script for YOLO26
|
||
Supports single-GPU and multi-GPU distributed training with 3D loss ramping.
|
||
|
||
Usage:
|
||
# Single GPU with pretrained 2D weights
|
||
python train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100 --device 0
|
||
|
||
# Multi-GPU (DDP) with pretrained 2D weights
|
||
python -m torch.distributed.run --nproc_per_node 4 train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100
|
||
|
||
# Train from scratch (no pretrained weights)
|
||
python train_mono3d.py --epochs 100 --device 0
|
||
|
||
# Resume from checkpoint
|
||
python train_mono3d.py --resume runs/detect/train_mono3d/weights/last.pt
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import re
|
||
import sys
|
||
import tempfile
|
||
from pathlib import Path
|
||
from typing import Any, Optional, Tuple
|
||
from uuid import uuid4
|
||
|
||
from torch.distributed.elastic.multiprocessing.errors import record
|
||
|
||
os.environ['TORCH_CACHE_DIR'] = './torch_cache'
|
||
os.makedirs(os.environ['TORCH_CACHE_DIR'], exist_ok=True)
|
||
|
||
# Add ultralytics to path if needed
|
||
FILE = Path(__file__).resolve()
|
||
ROOT = FILE.parents[0]
|
||
if str(ROOT) not in sys.path:
|
||
sys.path.append(str(ROOT))
|
||
|
||
from ultralytics import settings
|
||
from ultralytics.models.yolo.detect.train import Ground3DDetectionTrainer
|
||
from ultralytics.utils import LOGGER, YAML, colorstr
|
||
from ultralytics.utils.dist import get_distributed_run_timestamp
|
||
|
||
# Enable TensorBoard logging
|
||
settings.update({"tensorboard": True})
|
||
|
||
GROUND3D_REQUIRED_CONFIG_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")
|
||
GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"}
|
||
GROUND3D_MODEL_VARIANTS = {f"yolo26{scale}": scale for scale in "nslmx"}
|
||
GROUND3D_DEFAULT_MODEL_CFG = "ultralytics/cfg/models/26/yolo26-3d.yaml"
|
||
|
||
|
||
def _validate_ground3d_pair(cfg: dict[str, Any], key: str, context: str) -> None:
|
||
"""Validate that a Ground3D size-like config field is an integer pair."""
|
||
value = cfg.get(key)
|
||
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
||
raise ValueError(f"{context} must define '{key}' as a 2-item list/tuple, but got: {value!r}")
|
||
if any(int(v) <= 0 for v in value):
|
||
raise ValueError(f"{context} must define positive values for '{key}', but got: {value!r}")
|
||
|
||
|
||
def _validate_ground3d_config(cfg: dict[str, Any], context: str) -> None:
|
||
"""Validate required Ground3D config fields before training starts."""
|
||
missing = [key for key in GROUND3D_REQUIRED_CONFIG_KEYS if key not in cfg or cfg[key] is None]
|
||
if missing:
|
||
missing_str = ", ".join(missing)
|
||
raise ValueError(f"{context} is missing required Ground3D config field(s): {missing_str}.")
|
||
|
||
if not str(cfg["path"]).strip():
|
||
raise ValueError(f"{context} must define a non-empty 'path'.")
|
||
|
||
_validate_ground3d_pair(cfg, "ori_img_size", context)
|
||
_validate_ground3d_pair(cfg, "roi", context)
|
||
|
||
if float(cfg["virtual_fx"]) <= 0:
|
||
raise ValueError(f"{context} must define a positive 'virtual_fx', but got: {cfg['virtual_fx']!r}.")
|
||
try:
|
||
float(cfg["virtual_camera_prob"])
|
||
except (TypeError, ValueError) as e:
|
||
raise ValueError(
|
||
f"{context} must define 'virtual_camera_prob' as a numeric value, but got: {cfg['virtual_camera_prob']!r}."
|
||
) from e
|
||
|
||
crop_center_mode = cfg["crop_center_mode"]
|
||
if crop_center_mode not in GROUND3D_VALID_CROP_CENTER_MODES:
|
||
valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES))
|
||
raise ValueError(
|
||
f"{context} must define 'crop_center_mode' as one of {{{valid_modes}}}, but got: {crop_center_mode!r}."
|
||
)
|
||
|
||
|
||
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
|
||
"""Build a timestamped run name that stays consistent across distributed ranks."""
|
||
if resume:
|
||
resume_path = Path(resume)
|
||
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
|
||
return resume_path.parents[1].name
|
||
return base_name
|
||
|
||
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
|
||
return f"{base_name}_{timestamp}"
|
||
|
||
|
||
def resolve_ground3d_model(model_arg: str) -> Tuple[str, Optional[str]]:
|
||
"""Resolve shorthand variants like 'yolo26s' to the shared 3D YAML plus an explicit scale."""
|
||
model_arg = model_arg.strip()
|
||
if re.fullmatch(r"yolo26[nslmx]", model_arg):
|
||
return GROUND3D_DEFAULT_MODEL_CFG, GROUND3D_MODEL_VARIANTS[model_arg]
|
||
return str(Path(model_arg).expanduser()), None
|
||
|
||
|
||
def resolve_data_yaml_for_roi(data_path: str, roi_name: Optional[str]) -> Tuple[str, Optional[str]]:
|
||
"""Resolve a dataset YAML to a temporary file with the selected ROI preset applied."""
|
||
data_path = str(Path(data_path).expanduser())
|
||
if Path(data_path).suffix not in {".yaml", ".yml"}:
|
||
if roi_name:
|
||
raise ValueError("--roi requires --data to point to a dataset YAML file.")
|
||
return data_path, None
|
||
|
||
data_cfg = YAML.load(data_path)
|
||
roi_configs = data_cfg.get("roi_configs")
|
||
default_roi = data_cfg.get("default_roi")
|
||
selected_roi = roi_name or default_roi
|
||
|
||
if selected_roi is None:
|
||
return data_path, None
|
||
if not isinstance(roi_configs, dict) or not roi_configs:
|
||
raise ValueError(f"Dataset YAML '{data_path}' does not define any 'roi_configs' presets.")
|
||
if selected_roi not in roi_configs:
|
||
available = ", ".join(sorted(roi_configs))
|
||
raise ValueError(
|
||
f"Unknown ROI preset '{selected_roi}' for dataset '{data_path}'. Available presets: {available}."
|
||
)
|
||
|
||
roi_overrides = roi_configs[selected_roi]
|
||
if not isinstance(roi_overrides, dict):
|
||
raise ValueError(f"ROI preset '{selected_roi}' in '{data_path}' must be a mapping.")
|
||
|
||
resolved_cfg: dict[str, Any] = dict(data_cfg)
|
||
resolved_cfg.pop("default_roi", None)
|
||
resolved_cfg.pop("roi_configs", None)
|
||
resolved_cfg.update(roi_overrides)
|
||
_validate_ground3d_config(resolved_cfg, f"Resolved dataset YAML '{data_path}' with ROI preset '{selected_roi}'")
|
||
|
||
source_path = Path(data_path)
|
||
resolved_path = (
|
||
Path(tempfile.gettempdir())
|
||
/ "ultralytics_roi_configs"
|
||
/ f"{source_path.stem}.{selected_roi}.{os.getpid()}.{uuid4().hex}{source_path.suffix}"
|
||
)
|
||
header = (
|
||
"# Auto-generated by train_mono3d.py\n"
|
||
f"# Source dataset YAML: {source_path.resolve()}\n"
|
||
f"# Selected ROI preset: {selected_roi}\n\n"
|
||
)
|
||
YAML.save(file=resolved_path, data=resolved_cfg, header=header)
|
||
return str(resolved_path), selected_roi
|
||
|
||
|
||
def parse_args():
|
||
"""Parse command line arguments."""
|
||
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground Joint 2D+3D Detection Dataset") # description 是帮助信息标题。
|
||
|
||
# Model and data
|
||
parser.add_argument(
|
||
"--model", type=str, default=GROUND3D_DEFAULT_MODEL_CFG,
|
||
help="Model config YAML path or shorthand variant (yolo26n/yolo26s/yolo26m/yolo26l/yolo26x).",
|
||
) # --model ultralytics/cfg/models/26/yolo26-3d.yaml 或者 --model yolo26s
|
||
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono3d_ground.yaml", help="Dataset YAML path")
|
||
parser.add_argument(
|
||
"--roi",
|
||
type=str,
|
||
default=None,
|
||
help="Named ROI preset from the dataset YAML 'roi_configs' section (for example: roi0 or roi1)",
|
||
)
|
||
|
||
# Training hyperparameters
|
||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||
parser.add_argument("--batch", type=int, default=16, help="Batch size per GPU")
|
||
parser.add_argument("--imgsz", type=str, default="704,352", help="Image size (int or w,h e.g. 704,352)")
|
||
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
|
||
|
||
# Optimizer
|
||
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
|
||
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
|
||
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
|
||
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
|
||
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
|
||
parser.add_argument("--warmup_epochs", type=float, default=3.0, help="Warmup epochs (2D-only stage if loss_3d_warmup_epochs is not set)")
|
||
parser.add_argument("--warmup_momentum", type=float, default=0.8, help="Initial momentum during warmup")
|
||
parser.add_argument("--warmup_bias_lr", type=float, default=0.1, help="Initial bias learning rate during warmup")
|
||
parser.add_argument("--e2e_o2m_start", type=float, default=0.8, help="Initial one-to-many loss weight for end-to-end training")
|
||
parser.add_argument("--e2e_o2m_final", type=float, default=0.1, help="Final one-to-many loss weight after decay")
|
||
parser.add_argument(
|
||
"--e2e_o2m_decay_epochs",
|
||
type=float,
|
||
default=None,
|
||
help="Epochs used to decay one-to-many weight to the final value; defaults to epochs - 1",
|
||
)
|
||
parser.add_argument("--loss_3d_warmup_epochs", type=float, default=None, help="Epochs before enabling 3D loss; defaults to warmup_epochs")
|
||
parser.add_argument("--loss_3d_ramp_epochs", type=float, default=10.0, help="Epochs used to ramp 3D loss weight to its max value")
|
||
parser.add_argument("--loss_3d_weight_max", type=float, default=0.1, help="Maximum 3D loss weight multiplier")
|
||
parser.add_argument(
|
||
"--face_visibility_score_thresh",
|
||
type=float,
|
||
default=None,
|
||
help="Override visible-face score threshold used by 3D face supervision and visible-face metrics",
|
||
)
|
||
|
||
# Loss weights
|
||
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
|
||
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
|
||
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain") # DFL loss 权重。 注意当前 yolo26-3d.yaml 里 reg_max=1,DFL 实际弱化/关闭,但参数仍保留兼容。
|
||
|
||
# Augmentation
|
||
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
|
||
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
|
||
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
|
||
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
|
||
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
|
||
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
|
||
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
|
||
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
|
||
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
|
||
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
|
||
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
|
||
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
|
||
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
|
||
|
||
# Training settings
|
||
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience") # early stopping 的 patience。如果验证指标长期不提升,最多等 100 个 epoch。
|
||
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (default: every epoch, -1 to disable)")
|
||
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images") # 是否缓存图片
|
||
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
|
||
parser.add_argument("--project", type=str, default=None, help="Project directory") # 训练结果保存的 project 目录
|
||
parser.add_argument("--name", type=str, default="train_mono3d", help="Experiment name")
|
||
parser.add_argument("--exp_dir", type=str, default="",
|
||
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
|
||
) # 如果传这个,就直接用该目录保存结果,不再走默认 project/name 逻辑
|
||
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
|
||
parser.add_argument("--pretrained", type=str, default="", help="Pretrained 2D weights path (e.g. yolo26s-pretrain.pt). Scale extracted from filename.")
|
||
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
|
||
parser.add_argument(
|
||
"--strict-backbone-pretrained",
|
||
dest="strict_backbone_pretrained",
|
||
action=argparse.BooleanOptionalAction,
|
||
default=True,
|
||
help="Fail startup if any backbone parameter tensor can not be transferred from --pretrained.",
|
||
)
|
||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
||
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use") # 只用 10% 数据,适合快速调试
|
||
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
|
||
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
|
||
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
|
||
|
||
# Validation
|
||
parser.add_argument("--val", action="store_true", default=True, help="Validate during training") # 默认已经是 True,传 --val 也还是 True。也就是说这段代码没有提供一个直接关闭 val 的 --no-val
|
||
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training") # 是否保存训练/验证图
|
||
parser.add_argument("--roi_metrics_only", action="store_true", help="Compute validation metrics using ROI samples only") # 当 batch 里有 camera_mode 时,validation 可以跳过 virtual-camera 样本,只统计 ROI 样本的 metrics
|
||
|
||
# Advanced
|
||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
|
||
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset") # 是否按单类别训练,把所有类别合并成一个类
|
||
parser.add_argument("--rect", action="store_true", help="Rectangular training")
|
||
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
|
||
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
|
||
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
|
||
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
|
||
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
|
||
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
@record # 在 DDP / 分布式训练时,如果子进程报错,它能把异常栈记录得更完整,方便定位哪个 rank 出错
|
||
def main():
|
||
"""Main training function. 解析命令行参数 -> 处理模型/数据/ROI 配置 -> 组装 Ultralytics trainer 参数 -> 启动 Ground3DDetectionTrainer 训练"""
|
||
args = parse_args()
|
||
|
||
# Print banner
|
||
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
|
||
LOGGER.info(colorstr("bright_blue", "bold", "Ground Joint 2D+3D Detection Training with YOLO26"))
|
||
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
|
||
|
||
# Confirm TensorBoard is enabled
|
||
LOGGER.info(colorstr("bright_green", "TensorBoard logging enabled"))
|
||
|
||
# Check for distributed training
|
||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||
rank = int(os.getenv("RANK", -1))
|
||
local_rank = int(os.getenv("LOCAL_RANK", -1))
|
||
|
||
if world_size > 1:
|
||
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||
else:
|
||
LOGGER.info("Single GPU training")
|
||
|
||
# For 3D training: --model can be the 3D config yaml or a shorthand variant like "yolo26s".
|
||
pretrained = args.pretrained if args.pretrained else False
|
||
resolved_model_path, explicit_model_scale = resolve_ground3d_model(args.model) # 如果是 yolo26s,resolve_ground3d_model() 会返回: resolved_model_path = "ultralytics/cfg/models/26/yolo26-3d.yaml" + explicit_model_scale = "s"
|
||
if pretrained and explicit_model_scale is None and not re.fullmatch(r"yolo26[nslmx].*", Path(str(pretrained)).stem): # 如果用了 pretrained,但是 --model 没有显式写 yolo26n/s/l/m/x,并且 pretrained 文件名里也看不出 scale,就发 warning。
|
||
LOGGER.warning(
|
||
"Pretrained checkpoint name does not encode a YOLO26 scale. Pass '--model yolo26n/yolo26s/...' "
|
||
"to lock the 3D architecture scale explicitly."
|
||
)
|
||
|
||
# Parse imgsz: "704,352" -> [704, 352], "640" -> 640
|
||
if isinstance(args.imgsz, str) and "," in args.imgsz:
|
||
imgsz = [int(x) for x in args.imgsz.split(",")] # --imgsz 704,352 --> imgsz = [704, 352]
|
||
else:
|
||
imgsz = int(args.imgsz)
|
||
|
||
if args.resume and args.exp_dir: # 如果 --resume 断点恢复,训练目录已经由 checkpoint 决定;此时不能再手动指定 --exp_dir。
|
||
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
|
||
|
||
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
|
||
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank) # 如果指定了 exp_dir,实验名就是目录名。 否则调用 build_run_name(),通常会加时间戳,并保证 DDP 多个 rank 用同一个名字。
|
||
resolved_data_path, selected_roi = resolve_data_yaml_for_roi(args.data, args.roi) # --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 ==> --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 selected_roi = "roi1"
|
||
|
||
# Build overrides dictionary
|
||
overrides = {
|
||
"model": resolved_model_path,
|
||
"data": resolved_data_path,
|
||
"epochs": args.epochs,
|
||
"batch": args.batch,
|
||
"imgsz": imgsz,
|
||
"device": args.device,
|
||
"optimizer": args.optimizer,
|
||
"lr0": args.lr0,
|
||
"lrf": args.lrf,
|
||
"momentum": args.momentum,
|
||
"weight_decay": args.weight_decay,
|
||
"warmup_epochs": args.warmup_epochs,
|
||
"warmup_momentum": args.warmup_momentum,
|
||
"warmup_bias_lr": args.warmup_bias_lr,
|
||
"e2e_o2m_start": args.e2e_o2m_start, # YOLO26 end-to-end 训练里 one-to-many loss 的权重控制。0.8-->0.1
|
||
"e2e_o2m_final": args.e2e_o2m_final,
|
||
"loss_3d_ramp_epochs": args.loss_3d_ramp_epochs, # 3D loss 不一定一开始就满权重参与,而是逐渐从 0 增长到最大值。10
|
||
"loss_3d_weight_max": args.loss_3d_weight_max, # 0.1
|
||
"box": args.box,
|
||
"cls": args.cls,
|
||
"dfl": args.dfl,
|
||
"hsv_h": args.hsv_h,
|
||
"hsv_s": args.hsv_s,
|
||
"hsv_v": args.hsv_v,
|
||
"degrees": args.degrees,
|
||
"translate": args.translate,
|
||
"scale": args.scale,
|
||
"shear": args.shear,
|
||
"perspective": args.perspective,
|
||
"flipud": args.flipud,
|
||
"fliplr": args.fliplr,
|
||
"mosaic": args.mosaic,
|
||
"mixup": args.mixup,
|
||
"copy_paste": args.copy_paste,
|
||
"patience": args.patience, # early stopping 等待轮数。
|
||
"save_period": args.save_period, # 每隔多少 epoch 保存 checkpoint。
|
||
"cache": args.cache if args.cache else False,
|
||
"workers": args.workers,
|
||
"project": args.project,
|
||
"name": run_name,
|
||
"exist_ok": args.exist_ok,
|
||
"pretrained": pretrained,
|
||
"amp": args.amp,
|
||
"fraction": args.fraction,
|
||
"profile": args.profile,
|
||
"multi_scale": args.multi_scale,
|
||
"val": args.val,
|
||
"plots": args.plots,
|
||
"roi_metrics_only": args.roi_metrics_only,
|
||
"seed": args.seed,
|
||
"deterministic": args.deterministic,
|
||
"single_cls": args.single_cls,
|
||
"rect": args.rect,
|
||
"cos_lr": args.cos_lr,
|
||
"close_mosaic": args.close_mosaic,
|
||
"nbs": args.nbs,
|
||
"overlap_mask": args.overlap_mask,
|
||
"mask_ratio": args.mask_ratio,
|
||
"dropout": args.dropout,
|
||
}
|
||
|
||
# Add freeze if specified
|
||
if args.freeze is not None: # 果用户传了 --freeze,加入冻结层数配置。
|
||
overrides["freeze"] = args.freeze
|
||
|
||
if args.loss_3d_warmup_epochs is not None:
|
||
overrides["loss_3d_warmup_epochs"] = args.loss_3d_warmup_epochs
|
||
if args.e2e_o2m_decay_epochs is not None:
|
||
overrides["e2e_o2m_decay_epochs"] = args.e2e_o2m_decay_epochs
|
||
if args.face_visibility_score_thresh is not None:
|
||
overrides["face_visibility_score_thresh"] = args.face_visibility_score_thresh
|
||
|
||
# Add resume if specified
|
||
if args.resume:
|
||
overrides["resume"] = args.resume
|
||
if exp_dir:
|
||
overrides["save_dir"] = exp_dir
|
||
|
||
# Print configuration
|
||
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
|
||
LOGGER.info(f" Model arg: {args.model}")
|
||
LOGGER.info(f" Model config: {resolved_model_path}")
|
||
LOGGER.info(
|
||
" Model scale: "
|
||
f"{explicit_model_scale if explicit_model_scale else 'infer from pretrained filename or YAML default'}"
|
||
)
|
||
LOGGER.info(f" Pretrained weights: {pretrained if pretrained else 'None'}")
|
||
LOGGER.info(f" Strict pretrained backbone load: {args.strict_backbone_pretrained}")
|
||
LOGGER.info(f" Data: {args.data}")
|
||
if selected_roi:
|
||
LOGGER.info(f" ROI preset: {selected_roi}")
|
||
LOGGER.info(f" Resolved data YAML: {resolved_data_path}")
|
||
LOGGER.info(f" Epochs: {args.epochs}")
|
||
LOGGER.info(f" Batch size: {args.batch}")
|
||
LOGGER.info(f" Image size: {args.imgsz}")
|
||
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
|
||
LOGGER.info(f" Workers: {args.workers}")
|
||
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
|
||
LOGGER.info(f" Run name: {run_name}")
|
||
LOGGER.info(f" Checkpoints: every {args.save_period} epoch(s)" if args.save_period > 0 else " Checkpoints: last.pt and best.pt only")
|
||
if exp_dir:
|
||
LOGGER.info(f" Experiment dir: {exp_dir}")
|
||
LOGGER.info(f" AMP: {args.amp}")
|
||
LOGGER.info(f" Optimizer: {args.optimizer}")
|
||
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
|
||
LOGGER.info(f" Warmup: epochs={args.warmup_epochs}, momentum={args.warmup_momentum}, bias_lr={args.warmup_bias_lr}")
|
||
LOGGER.info(
|
||
" E2E one-to-many weight: "
|
||
f"start={args.e2e_o2m_start}, final={args.e2e_o2m_final}, "
|
||
f"decay_epochs={args.e2e_o2m_decay_epochs if args.e2e_o2m_decay_epochs is not None else 'epochs-1'}"
|
||
)
|
||
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
|
||
loss_3d_warmup = args.loss_3d_warmup_epochs if args.loss_3d_warmup_epochs is not None else args.warmup_epochs
|
||
LOGGER.info(
|
||
f" 3D loss: warmup={loss_3d_warmup} epochs, ramp={args.loss_3d_ramp_epochs} epochs, max={args.loss_3d_weight_max}"
|
||
)
|
||
LOGGER.info(
|
||
" Face visibility score threshold: "
|
||
f"{args.face_visibility_score_thresh if args.face_visibility_score_thresh is not None else 'default.yaml'}"
|
||
)
|
||
LOGGER.info("")
|
||
|
||
# Initialize trainer
|
||
LOGGER.info(colorstr("bright_green", "Initializing Ground3DDetectionTrainer..."))
|
||
trainer = Ground3DDetectionTrainer(overrides=overrides) # 进入 Ultralytics 训练框架,做很多初始化
|
||
trainer.explicit_model_scale = explicit_model_scale # 把显式模型 scale 挂到 trainer 上。后面 Ground3DDetectionTrainer.get_model() 会用它来决定模型规模,比如 n/s/m/l/x。
|
||
trainer.strict_backbone_pretrained = args.strict_backbone_pretrained # 把严格预训练加载检查开关挂到 trainer 上。后面加载 pretrained 时会用。
|
||
|
||
# Start training
|
||
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
|
||
trainer.train()
|
||
|
||
# Print completion message
|
||
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
|
||
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
|
||
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
|
||
|
||
# Print results location
|
||
if hasattr(trainer, "save_dir"):
|
||
LOGGER.info(f"Results saved to: {trainer.save_dir}")
|
||
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
|
||
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|