Files
yolov26_3d/train_mono3d.py
2026-06-24 09:35:46 +08:00

470 lines
25 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Ground Joint 2D+3D Detection Training Script for YOLO26
Supports single-GPU and multi-GPU distributed training with 3D loss ramping.
Usage:
# Single GPU with pretrained 2D weights
python train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100 --device 0
# Multi-GPU (DDP) with pretrained 2D weights
python -m torch.distributed.run --nproc_per_node 4 train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100
# Train from scratch (no pretrained weights)
python train_mono3d.py --epochs 100 --device 0
# Resume from checkpoint
python train_mono3d.py --resume runs/detect/train_mono3d/weights/last.pt
"""
import argparse
import os
import re
import sys
import tempfile
from pathlib import Path
from typing import Any, Optional, Tuple
from uuid import uuid4
from torch.distributed.elastic.multiprocessing.errors import record
os.environ['TORCH_CACHE_DIR'] = './torch_cache'
os.makedirs(os.environ['TORCH_CACHE_DIR'], exist_ok=True)
# Add ultralytics to path if needed
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from ultralytics import settings
from ultralytics.models.yolo.detect.train import Ground3DDetectionTrainer
from ultralytics.utils import LOGGER, YAML, colorstr
from ultralytics.utils.dist import get_distributed_run_timestamp
# Enable TensorBoard logging
settings.update({"tensorboard": True})
GROUND3D_REQUIRED_CONFIG_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")
GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"}
GROUND3D_MODEL_VARIANTS = {f"yolo26{scale}": scale for scale in "nslmx"}
GROUND3D_DEFAULT_MODEL_CFG = "ultralytics/cfg/models/26/yolo26-3d.yaml"
def _validate_ground3d_pair(cfg: dict[str, Any], key: str, context: str) -> None:
"""Validate that a Ground3D size-like config field is an integer pair."""
value = cfg.get(key)
if not isinstance(value, (list, tuple)) or len(value) != 2:
raise ValueError(f"{context} must define '{key}' as a 2-item list/tuple, but got: {value!r}")
if any(int(v) <= 0 for v in value):
raise ValueError(f"{context} must define positive values for '{key}', but got: {value!r}")
def _validate_ground3d_config(cfg: dict[str, Any], context: str) -> None:
"""Validate required Ground3D config fields before training starts."""
missing = [key for key in GROUND3D_REQUIRED_CONFIG_KEYS if key not in cfg or cfg[key] is None]
if missing:
missing_str = ", ".join(missing)
raise ValueError(f"{context} is missing required Ground3D config field(s): {missing_str}.")
if not str(cfg["path"]).strip():
raise ValueError(f"{context} must define a non-empty 'path'.")
_validate_ground3d_pair(cfg, "ori_img_size", context)
_validate_ground3d_pair(cfg, "roi", context)
if float(cfg["virtual_fx"]) <= 0:
raise ValueError(f"{context} must define a positive 'virtual_fx', but got: {cfg['virtual_fx']!r}.")
try:
float(cfg["virtual_camera_prob"])
except (TypeError, ValueError) as e:
raise ValueError(
f"{context} must define 'virtual_camera_prob' as a numeric value, but got: {cfg['virtual_camera_prob']!r}."
) from e
crop_center_mode = cfg["crop_center_mode"]
if crop_center_mode not in GROUND3D_VALID_CROP_CENTER_MODES:
valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES))
raise ValueError(
f"{context} must define 'crop_center_mode' as one of {{{valid_modes}}}, but got: {crop_center_mode!r}."
)
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
"""Build a timestamped run name that stays consistent across distributed ranks."""
if resume:
resume_path = Path(resume)
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
return resume_path.parents[1].name
return base_name
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
return f"{base_name}_{timestamp}"
def resolve_ground3d_model(model_arg: str) -> Tuple[str, Optional[str]]:
"""Resolve shorthand variants like 'yolo26s' to the shared 3D YAML plus an explicit scale."""
model_arg = model_arg.strip()
if re.fullmatch(r"yolo26[nslmx]", model_arg):
return GROUND3D_DEFAULT_MODEL_CFG, GROUND3D_MODEL_VARIANTS[model_arg]
return str(Path(model_arg).expanduser()), None
def resolve_data_yaml_for_roi(data_path: str, roi_name: Optional[str]) -> Tuple[str, Optional[str]]:
"""Resolve a dataset YAML to a temporary file with the selected ROI preset applied."""
data_path = str(Path(data_path).expanduser())
if Path(data_path).suffix not in {".yaml", ".yml"}:
if roi_name:
raise ValueError("--roi requires --data to point to a dataset YAML file.")
return data_path, None
data_cfg = YAML.load(data_path)
roi_configs = data_cfg.get("roi_configs")
default_roi = data_cfg.get("default_roi")
selected_roi = roi_name or default_roi
if selected_roi is None:
return data_path, None
if not isinstance(roi_configs, dict) or not roi_configs:
raise ValueError(f"Dataset YAML '{data_path}' does not define any 'roi_configs' presets.")
if selected_roi not in roi_configs:
available = ", ".join(sorted(roi_configs))
raise ValueError(
f"Unknown ROI preset '{selected_roi}' for dataset '{data_path}'. Available presets: {available}."
)
roi_overrides = roi_configs[selected_roi]
if not isinstance(roi_overrides, dict):
raise ValueError(f"ROI preset '{selected_roi}' in '{data_path}' must be a mapping.")
resolved_cfg: dict[str, Any] = dict(data_cfg)
resolved_cfg.pop("default_roi", None)
resolved_cfg.pop("roi_configs", None)
resolved_cfg.update(roi_overrides)
_validate_ground3d_config(resolved_cfg, f"Resolved dataset YAML '{data_path}' with ROI preset '{selected_roi}'")
source_path = Path(data_path)
resolved_path = (
Path(tempfile.gettempdir())
/ "ultralytics_roi_configs"
/ f"{source_path.stem}.{selected_roi}.{os.getpid()}.{uuid4().hex}{source_path.suffix}"
)
header = (
"# Auto-generated by train_mono3d.py\n"
f"# Source dataset YAML: {source_path.resolve()}\n"
f"# Selected ROI preset: {selected_roi}\n\n"
)
YAML.save(file=resolved_path, data=resolved_cfg, header=header)
return str(resolved_path), selected_roi
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground Joint 2D+3D Detection Dataset") # description 是帮助信息标题。
# Model and data
parser.add_argument(
"--model", type=str, default=GROUND3D_DEFAULT_MODEL_CFG,
help="Model config YAML path or shorthand variant (yolo26n/yolo26s/yolo26m/yolo26l/yolo26x).",
) # --model ultralytics/cfg/models/26/yolo26-3d.yaml 或者 --model yolo26s
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono3d_ground.yaml", help="Dataset YAML path")
parser.add_argument(
"--roi",
type=str,
default=None,
help="Named ROI preset from the dataset YAML 'roi_configs' section (for example: roi0 or roi1)",
)
# Training hyperparameters
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--batch", type=int, default=16, help="Batch size per GPU")
parser.add_argument("--imgsz", type=str, default="704,352", help="Image size (int or w,h e.g. 704,352)")
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
# Optimizer
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
parser.add_argument("--warmup_epochs", type=float, default=3.0, help="Warmup epochs (2D-only stage if loss_3d_warmup_epochs is not set)")
parser.add_argument("--warmup_momentum", type=float, default=0.8, help="Initial momentum during warmup")
parser.add_argument("--warmup_bias_lr", type=float, default=0.1, help="Initial bias learning rate during warmup")
parser.add_argument("--e2e_o2m_start", type=float, default=0.8, help="Initial one-to-many loss weight for end-to-end training")
parser.add_argument("--e2e_o2m_final", type=float, default=0.1, help="Final one-to-many loss weight after decay")
parser.add_argument(
"--e2e_o2m_decay_epochs",
type=float,
default=None,
help="Epochs used to decay one-to-many weight to the final value; defaults to epochs - 1",
)
parser.add_argument("--loss_3d_warmup_epochs", type=float, default=None, help="Epochs before enabling 3D loss; defaults to warmup_epochs")
parser.add_argument("--loss_3d_ramp_epochs", type=float, default=10.0, help="Epochs used to ramp 3D loss weight to its max value")
parser.add_argument("--loss_3d_weight_max", type=float, default=0.1, help="Maximum 3D loss weight multiplier")
parser.add_argument(
"--face_visibility_score_thresh",
type=float,
default=None,
help="Override visible-face score threshold used by 3D face supervision and visible-face metrics",
)
# Loss weights
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain") # DFL loss 权重。 注意当前 yolo26-3d.yaml 里 reg_max=1DFL 实际弱化/关闭,但参数仍保留兼容。
# Augmentation
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
# Training settings
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience") # early stopping 的 patience。如果验证指标长期不提升最多等 100 个 epoch。
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (default: every epoch, -1 to disable)")
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images") # 是否缓存图片
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
parser.add_argument("--project", type=str, default=None, help="Project directory") # 训练结果保存的 project 目录
parser.add_argument("--name", type=str, default="train_mono3d", help="Experiment name")
parser.add_argument("--exp_dir", type=str, default="",
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
) # 如果传这个,就直接用该目录保存结果,不再走默认 project/name 逻辑
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
parser.add_argument("--pretrained", type=str, default="", help="Pretrained 2D weights path (e.g. yolo26s-pretrain.pt). Scale extracted from filename.")
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
parser.add_argument(
"--strict-backbone-pretrained",
dest="strict_backbone_pretrained",
action=argparse.BooleanOptionalAction,
default=True,
help="Fail startup if any backbone parameter tensor can not be transferred from --pretrained.",
)
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use") # 只用 10% 数据,适合快速调试
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
# Validation
parser.add_argument("--val", action="store_true", default=True, help="Validate during training") # 默认已经是 True传 --val 也还是 True。也就是说这段代码没有提供一个直接关闭 val 的 --no-val
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training") # 是否保存训练/验证图
parser.add_argument("--roi_metrics_only", action="store_true", help="Compute validation metrics using ROI samples only") # 当 batch 里有 camera_mode 时validation 可以跳过 virtual-camera 样本,只统计 ROI 样本的 metrics
# Advanced
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset") # 是否按单类别训练,把所有类别合并成一个类
parser.add_argument("--rect", action="store_true", help="Rectangular training")
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
return parser.parse_args()
@record # 在 DDP / 分布式训练时,如果子进程报错,它能把异常栈记录得更完整,方便定位哪个 rank 出错
def main():
"""Main training function. 解析命令行参数 -> 处理模型/数据/ROI 配置 -> 组装 Ultralytics trainer 参数 -> 启动 Ground3DDetectionTrainer 训练"""
args = parse_args()
# Print banner
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
LOGGER.info(colorstr("bright_blue", "bold", "Ground Joint 2D+3D Detection Training with YOLO26"))
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
# Confirm TensorBoard is enabled
LOGGER.info(colorstr("bright_green", "TensorBoard logging enabled"))
# Check for distributed training
world_size = int(os.getenv("WORLD_SIZE", 1))
rank = int(os.getenv("RANK", -1))
local_rank = int(os.getenv("LOCAL_RANK", -1))
if world_size > 1:
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
else:
LOGGER.info("Single GPU training")
# For 3D training: --model can be the 3D config yaml or a shorthand variant like "yolo26s".
pretrained = args.pretrained if args.pretrained else False
resolved_model_path, explicit_model_scale = resolve_ground3d_model(args.model) # 如果是 yolo26sresolve_ground3d_model() 会返回: resolved_model_path = "ultralytics/cfg/models/26/yolo26-3d.yaml" + explicit_model_scale = "s"
if pretrained and explicit_model_scale is None and not re.fullmatch(r"yolo26[nslmx].*", Path(str(pretrained)).stem): # 如果用了 pretrained但是 --model 没有显式写 yolo26n/s/l/m/x并且 pretrained 文件名里也看不出 scale就发 warning。
LOGGER.warning(
"Pretrained checkpoint name does not encode a YOLO26 scale. Pass '--model yolo26n/yolo26s/...' "
"to lock the 3D architecture scale explicitly."
)
# Parse imgsz: "704,352" -> [704, 352], "640" -> 640
if isinstance(args.imgsz, str) and "," in args.imgsz:
imgsz = [int(x) for x in args.imgsz.split(",")] # --imgsz 704,352 --> imgsz = [704, 352]
else:
imgsz = int(args.imgsz)
if args.resume and args.exp_dir: # 如果 --resume 断点恢复,训练目录已经由 checkpoint 决定;此时不能再手动指定 --exp_dir。
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank) # 如果指定了 exp_dir实验名就是目录名。 否则调用 build_run_name(),通常会加时间戳,并保证 DDP 多个 rank 用同一个名字。
resolved_data_path, selected_roi = resolve_data_yaml_for_roi(args.data, args.roi) # --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 ==> --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 selected_roi = "roi1"
# Build overrides dictionary
overrides = {
"model": resolved_model_path,
"data": resolved_data_path,
"epochs": args.epochs,
"batch": args.batch,
"imgsz": imgsz,
"device": args.device,
"optimizer": args.optimizer,
"lr0": args.lr0,
"lrf": args.lrf,
"momentum": args.momentum,
"weight_decay": args.weight_decay,
"warmup_epochs": args.warmup_epochs,
"warmup_momentum": args.warmup_momentum,
"warmup_bias_lr": args.warmup_bias_lr,
"e2e_o2m_start": args.e2e_o2m_start, # YOLO26 end-to-end 训练里 one-to-many loss 的权重控制。0.8-->0.1
"e2e_o2m_final": args.e2e_o2m_final,
"loss_3d_ramp_epochs": args.loss_3d_ramp_epochs, # 3D loss 不一定一开始就满权重参与,而是逐渐从 0 增长到最大值。10
"loss_3d_weight_max": args.loss_3d_weight_max, # 0.1
"box": args.box,
"cls": args.cls,
"dfl": args.dfl,
"hsv_h": args.hsv_h,
"hsv_s": args.hsv_s,
"hsv_v": args.hsv_v,
"degrees": args.degrees,
"translate": args.translate,
"scale": args.scale,
"shear": args.shear,
"perspective": args.perspective,
"flipud": args.flipud,
"fliplr": args.fliplr,
"mosaic": args.mosaic,
"mixup": args.mixup,
"copy_paste": args.copy_paste,
"patience": args.patience, # early stopping 等待轮数。
"save_period": args.save_period, # 每隔多少 epoch 保存 checkpoint。
"cache": args.cache if args.cache else False,
"workers": args.workers,
"project": args.project,
"name": run_name,
"exist_ok": args.exist_ok,
"pretrained": pretrained,
"amp": args.amp,
"fraction": args.fraction,
"profile": args.profile,
"multi_scale": args.multi_scale,
"val": args.val,
"plots": args.plots,
"roi_metrics_only": args.roi_metrics_only,
"seed": args.seed,
"deterministic": args.deterministic,
"single_cls": args.single_cls,
"rect": args.rect,
"cos_lr": args.cos_lr,
"close_mosaic": args.close_mosaic,
"nbs": args.nbs,
"overlap_mask": args.overlap_mask,
"mask_ratio": args.mask_ratio,
"dropout": args.dropout,
}
# Add freeze if specified
if args.freeze is not None: # 果用户传了 --freeze加入冻结层数配置。
overrides["freeze"] = args.freeze
if args.loss_3d_warmup_epochs is not None:
overrides["loss_3d_warmup_epochs"] = args.loss_3d_warmup_epochs
if args.e2e_o2m_decay_epochs is not None:
overrides["e2e_o2m_decay_epochs"] = args.e2e_o2m_decay_epochs
if args.face_visibility_score_thresh is not None:
overrides["face_visibility_score_thresh"] = args.face_visibility_score_thresh
# Add resume if specified
if args.resume:
overrides["resume"] = args.resume
if exp_dir:
overrides["save_dir"] = exp_dir
# Print configuration
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
LOGGER.info(f" Model arg: {args.model}")
LOGGER.info(f" Model config: {resolved_model_path}")
LOGGER.info(
" Model scale: "
f"{explicit_model_scale if explicit_model_scale else 'infer from pretrained filename or YAML default'}"
)
LOGGER.info(f" Pretrained weights: {pretrained if pretrained else 'None'}")
LOGGER.info(f" Strict pretrained backbone load: {args.strict_backbone_pretrained}")
LOGGER.info(f" Data: {args.data}")
if selected_roi:
LOGGER.info(f" ROI preset: {selected_roi}")
LOGGER.info(f" Resolved data YAML: {resolved_data_path}")
LOGGER.info(f" Epochs: {args.epochs}")
LOGGER.info(f" Batch size: {args.batch}")
LOGGER.info(f" Image size: {args.imgsz}")
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
LOGGER.info(f" Workers: {args.workers}")
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
LOGGER.info(f" Run name: {run_name}")
LOGGER.info(f" Checkpoints: every {args.save_period} epoch(s)" if args.save_period > 0 else " Checkpoints: last.pt and best.pt only")
if exp_dir:
LOGGER.info(f" Experiment dir: {exp_dir}")
LOGGER.info(f" AMP: {args.amp}")
LOGGER.info(f" Optimizer: {args.optimizer}")
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
LOGGER.info(f" Warmup: epochs={args.warmup_epochs}, momentum={args.warmup_momentum}, bias_lr={args.warmup_bias_lr}")
LOGGER.info(
" E2E one-to-many weight: "
f"start={args.e2e_o2m_start}, final={args.e2e_o2m_final}, "
f"decay_epochs={args.e2e_o2m_decay_epochs if args.e2e_o2m_decay_epochs is not None else 'epochs-1'}"
)
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
loss_3d_warmup = args.loss_3d_warmup_epochs if args.loss_3d_warmup_epochs is not None else args.warmup_epochs
LOGGER.info(
f" 3D loss: warmup={loss_3d_warmup} epochs, ramp={args.loss_3d_ramp_epochs} epochs, max={args.loss_3d_weight_max}"
)
LOGGER.info(
" Face visibility score threshold: "
f"{args.face_visibility_score_thresh if args.face_visibility_score_thresh is not None else 'default.yaml'}"
)
LOGGER.info("")
# Initialize trainer
LOGGER.info(colorstr("bright_green", "Initializing Ground3DDetectionTrainer..."))
trainer = Ground3DDetectionTrainer(overrides=overrides) # 进入 Ultralytics 训练框架,做很多初始化
trainer.explicit_model_scale = explicit_model_scale # 把显式模型 scale 挂到 trainer 上。后面 Ground3DDetectionTrainer.get_model() 会用它来决定模型规模,比如 n/s/m/l/x。
trainer.strict_backbone_pretrained = args.strict_backbone_pretrained # 把严格预训练加载检查开关挂到 trainer 上。后面加载 pretrained 时会用。
# Start training
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
trainer.train()
# Print completion message
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
# Print results location
if hasattr(trainer, "save_dir"):
LOGGER.info(f"Results saved to: {trainer.save_dir}")
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
if __name__ == "__main__":
main()