Files
yolov26_3d/train_mono3d.py

470 lines
25 KiB
Python
Raw Normal View History

2026-06-24 09:35:46 +08:00
#!/usr/bin/env python3
"""
Ground Joint 2D+3D Detection Training Script for YOLO26
Supports single-GPU and multi-GPU distributed training with 3D loss ramping.
Usage:
# Single GPU with pretrained 2D weights
python train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100 --device 0
# Multi-GPU (DDP) with pretrained 2D weights
python -m torch.distributed.run --nproc_per_node 4 train_mono3d.py --model yolo26s --pretrained yolo26s-pretrain.pt --epochs 100
# Train from scratch (no pretrained weights)
python train_mono3d.py --epochs 100 --device 0
# Resume from checkpoint
python train_mono3d.py --resume runs/detect/train_mono3d/weights/last.pt
"""
import argparse
import os
import re
import sys
import tempfile
from pathlib import Path
from typing import Any, Optional, Tuple
from uuid import uuid4
from torch.distributed.elastic.multiprocessing.errors import record
os.environ['TORCH_CACHE_DIR'] = './torch_cache'
os.makedirs(os.environ['TORCH_CACHE_DIR'], exist_ok=True)
# Add ultralytics to path if needed
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from ultralytics import settings
from ultralytics.models.yolo.detect.train import Ground3DDetectionTrainer
from ultralytics.utils import LOGGER, YAML, colorstr
from ultralytics.utils.dist import get_distributed_run_timestamp
# Enable TensorBoard logging
settings.update({"tensorboard": True})
GROUND3D_REQUIRED_CONFIG_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")
GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"}
GROUND3D_MODEL_VARIANTS = {f"yolo26{scale}": scale for scale in "nslmx"}
GROUND3D_DEFAULT_MODEL_CFG = "ultralytics/cfg/models/26/yolo26-3d.yaml"
def _validate_ground3d_pair(cfg: dict[str, Any], key: str, context: str) -> None:
"""Validate that a Ground3D size-like config field is an integer pair."""
value = cfg.get(key)
if not isinstance(value, (list, tuple)) or len(value) != 2:
raise ValueError(f"{context} must define '{key}' as a 2-item list/tuple, but got: {value!r}")
if any(int(v) <= 0 for v in value):
raise ValueError(f"{context} must define positive values for '{key}', but got: {value!r}")
def _validate_ground3d_config(cfg: dict[str, Any], context: str) -> None:
"""Validate required Ground3D config fields before training starts."""
missing = [key for key in GROUND3D_REQUIRED_CONFIG_KEYS if key not in cfg or cfg[key] is None]
if missing:
missing_str = ", ".join(missing)
raise ValueError(f"{context} is missing required Ground3D config field(s): {missing_str}.")
if not str(cfg["path"]).strip():
raise ValueError(f"{context} must define a non-empty 'path'.")
_validate_ground3d_pair(cfg, "ori_img_size", context)
_validate_ground3d_pair(cfg, "roi", context)
if float(cfg["virtual_fx"]) <= 0:
raise ValueError(f"{context} must define a positive 'virtual_fx', but got: {cfg['virtual_fx']!r}.")
try:
float(cfg["virtual_camera_prob"])
except (TypeError, ValueError) as e:
raise ValueError(
f"{context} must define 'virtual_camera_prob' as a numeric value, but got: {cfg['virtual_camera_prob']!r}."
) from e
crop_center_mode = cfg["crop_center_mode"]
if crop_center_mode not in GROUND3D_VALID_CROP_CENTER_MODES:
valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES))
raise ValueError(
f"{context} must define 'crop_center_mode' as one of {{{valid_modes}}}, but got: {crop_center_mode!r}."
)
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
"""Build a timestamped run name that stays consistent across distributed ranks."""
if resume:
resume_path = Path(resume)
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
return resume_path.parents[1].name
return base_name
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
return f"{base_name}_{timestamp}"
def resolve_ground3d_model(model_arg: str) -> Tuple[str, Optional[str]]:
"""Resolve shorthand variants like 'yolo26s' to the shared 3D YAML plus an explicit scale."""
model_arg = model_arg.strip()
if re.fullmatch(r"yolo26[nslmx]", model_arg):
return GROUND3D_DEFAULT_MODEL_CFG, GROUND3D_MODEL_VARIANTS[model_arg]
return str(Path(model_arg).expanduser()), None
def resolve_data_yaml_for_roi(data_path: str, roi_name: Optional[str]) -> Tuple[str, Optional[str]]:
"""Resolve a dataset YAML to a temporary file with the selected ROI preset applied."""
data_path = str(Path(data_path).expanduser())
if Path(data_path).suffix not in {".yaml", ".yml"}:
if roi_name:
raise ValueError("--roi requires --data to point to a dataset YAML file.")
return data_path, None
data_cfg = YAML.load(data_path)
roi_configs = data_cfg.get("roi_configs")
default_roi = data_cfg.get("default_roi")
selected_roi = roi_name or default_roi
if selected_roi is None:
return data_path, None
if not isinstance(roi_configs, dict) or not roi_configs:
raise ValueError(f"Dataset YAML '{data_path}' does not define any 'roi_configs' presets.")
if selected_roi not in roi_configs:
available = ", ".join(sorted(roi_configs))
raise ValueError(
f"Unknown ROI preset '{selected_roi}' for dataset '{data_path}'. Available presets: {available}."
)
roi_overrides = roi_configs[selected_roi]
if not isinstance(roi_overrides, dict):
raise ValueError(f"ROI preset '{selected_roi}' in '{data_path}' must be a mapping.")
resolved_cfg: dict[str, Any] = dict(data_cfg)
resolved_cfg.pop("default_roi", None)
resolved_cfg.pop("roi_configs", None)
resolved_cfg.update(roi_overrides)
_validate_ground3d_config(resolved_cfg, f"Resolved dataset YAML '{data_path}' with ROI preset '{selected_roi}'")
source_path = Path(data_path)
resolved_path = (
Path(tempfile.gettempdir())
/ "ultralytics_roi_configs"
/ f"{source_path.stem}.{selected_roi}.{os.getpid()}.{uuid4().hex}{source_path.suffix}"
)
header = (
"# Auto-generated by train_mono3d.py\n"
f"# Source dataset YAML: {source_path.resolve()}\n"
f"# Selected ROI preset: {selected_roi}\n\n"
)
YAML.save(file=resolved_path, data=resolved_cfg, header=header)
return str(resolved_path), selected_roi
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground Joint 2D+3D Detection Dataset") # description 是帮助信息标题。
# Model and data
parser.add_argument(
"--model", type=str, default=GROUND3D_DEFAULT_MODEL_CFG,
help="Model config YAML path or shorthand variant (yolo26n/yolo26s/yolo26m/yolo26l/yolo26x).",
) # --model ultralytics/cfg/models/26/yolo26-3d.yaml 或者 --model yolo26s
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono3d_ground.yaml", help="Dataset YAML path")
parser.add_argument(
"--roi",
type=str,
default=None,
help="Named ROI preset from the dataset YAML 'roi_configs' section (for example: roi0 or roi1)",
)
# Training hyperparameters
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--batch", type=int, default=16, help="Batch size per GPU")
parser.add_argument("--imgsz", type=str, default="704,352", help="Image size (int or w,h e.g. 704,352)")
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
# Optimizer
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
parser.add_argument("--warmup_epochs", type=float, default=3.0, help="Warmup epochs (2D-only stage if loss_3d_warmup_epochs is not set)")
parser.add_argument("--warmup_momentum", type=float, default=0.8, help="Initial momentum during warmup")
parser.add_argument("--warmup_bias_lr", type=float, default=0.1, help="Initial bias learning rate during warmup")
parser.add_argument("--e2e_o2m_start", type=float, default=0.8, help="Initial one-to-many loss weight for end-to-end training")
parser.add_argument("--e2e_o2m_final", type=float, default=0.1, help="Final one-to-many loss weight after decay")
parser.add_argument(
"--e2e_o2m_decay_epochs",
type=float,
default=None,
help="Epochs used to decay one-to-many weight to the final value; defaults to epochs - 1",
)
parser.add_argument("--loss_3d_warmup_epochs", type=float, default=None, help="Epochs before enabling 3D loss; defaults to warmup_epochs")
parser.add_argument("--loss_3d_ramp_epochs", type=float, default=10.0, help="Epochs used to ramp 3D loss weight to its max value")
parser.add_argument("--loss_3d_weight_max", type=float, default=0.1, help="Maximum 3D loss weight multiplier")
parser.add_argument(
"--face_visibility_score_thresh",
type=float,
default=None,
help="Override visible-face score threshold used by 3D face supervision and visible-face metrics",
)
# Loss weights
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain") # DFL loss 权重。 注意当前 yolo26-3d.yaml 里 reg_max=1DFL 实际弱化/关闭,但参数仍保留兼容。
# Augmentation
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
# Training settings
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience") # early stopping 的 patience。如果验证指标长期不提升最多等 100 个 epoch。
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (default: every epoch, -1 to disable)")
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images") # 是否缓存图片
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
parser.add_argument("--project", type=str, default=None, help="Project directory") # 训练结果保存的 project 目录
parser.add_argument("--name", type=str, default="train_mono3d", help="Experiment name")
parser.add_argument("--exp_dir", type=str, default="",
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
) # 如果传这个,就直接用该目录保存结果,不再走默认 project/name 逻辑
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
parser.add_argument("--pretrained", type=str, default="", help="Pretrained 2D weights path (e.g. yolo26s-pretrain.pt). Scale extracted from filename.")
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
parser.add_argument(
"--strict-backbone-pretrained",
dest="strict_backbone_pretrained",
action=argparse.BooleanOptionalAction,
default=True,
help="Fail startup if any backbone parameter tensor can not be transferred from --pretrained.",
)
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use") # 只用 10% 数据,适合快速调试
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
# Validation
parser.add_argument("--val", action="store_true", default=True, help="Validate during training") # 默认已经是 True传 --val 也还是 True。也就是说这段代码没有提供一个直接关闭 val 的 --no-val
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training") # 是否保存训练/验证图
parser.add_argument("--roi_metrics_only", action="store_true", help="Compute validation metrics using ROI samples only") # 当 batch 里有 camera_mode 时validation 可以跳过 virtual-camera 样本,只统计 ROI 样本的 metrics
# Advanced
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset") # 是否按单类别训练,把所有类别合并成一个类
parser.add_argument("--rect", action="store_true", help="Rectangular training")
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
return parser.parse_args()
@record # 在 DDP / 分布式训练时,如果子进程报错,它能把异常栈记录得更完整,方便定位哪个 rank 出错
def main():
"""Main training function. 解析命令行参数 -> 处理模型/数据/ROI 配置 -> 组装 Ultralytics trainer 参数 -> 启动 Ground3DDetectionTrainer 训练"""
args = parse_args()
# Print banner
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
LOGGER.info(colorstr("bright_blue", "bold", "Ground Joint 2D+3D Detection Training with YOLO26"))
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
# Confirm TensorBoard is enabled
LOGGER.info(colorstr("bright_green", "TensorBoard logging enabled"))
# Check for distributed training
world_size = int(os.getenv("WORLD_SIZE", 1))
rank = int(os.getenv("RANK", -1))
local_rank = int(os.getenv("LOCAL_RANK", -1))
if world_size > 1:
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
else:
LOGGER.info("Single GPU training")
# For 3D training: --model can be the 3D config yaml or a shorthand variant like "yolo26s".
pretrained = args.pretrained if args.pretrained else False
resolved_model_path, explicit_model_scale = resolve_ground3d_model(args.model) # 如果是 yolo26sresolve_ground3d_model() 会返回: resolved_model_path = "ultralytics/cfg/models/26/yolo26-3d.yaml" + explicit_model_scale = "s"
if pretrained and explicit_model_scale is None and not re.fullmatch(r"yolo26[nslmx].*", Path(str(pretrained)).stem): # 如果用了 pretrained但是 --model 没有显式写 yolo26n/s/l/m/x并且 pretrained 文件名里也看不出 scale就发 warning。
LOGGER.warning(
"Pretrained checkpoint name does not encode a YOLO26 scale. Pass '--model yolo26n/yolo26s/...' "
"to lock the 3D architecture scale explicitly."
)
# Parse imgsz: "704,352" -> [704, 352], "640" -> 640
if isinstance(args.imgsz, str) and "," in args.imgsz:
imgsz = [int(x) for x in args.imgsz.split(",")] # --imgsz 704,352 --> imgsz = [704, 352]
else:
imgsz = int(args.imgsz)
if args.resume and args.exp_dir: # 如果 --resume 断点恢复,训练目录已经由 checkpoint 决定;此时不能再手动指定 --exp_dir。
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank) # 如果指定了 exp_dir实验名就是目录名。 否则调用 build_run_name(),通常会加时间戳,并保证 DDP 多个 rank 用同一个名字。
resolved_data_path, selected_roi = resolve_data_yaml_for_roi(args.data, args.roi) # --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 ==> --data ultralytics/cfg/datasets/mono3d_ground.yaml --roi roi1 selected_roi = "roi1"
# Build overrides dictionary
overrides = {
"model": resolved_model_path,
"data": resolved_data_path,
"epochs": args.epochs,
"batch": args.batch,
"imgsz": imgsz,
"device": args.device,
"optimizer": args.optimizer,
"lr0": args.lr0,
"lrf": args.lrf,
"momentum": args.momentum,
"weight_decay": args.weight_decay,
"warmup_epochs": args.warmup_epochs,
"warmup_momentum": args.warmup_momentum,
"warmup_bias_lr": args.warmup_bias_lr,
"e2e_o2m_start": args.e2e_o2m_start, # YOLO26 end-to-end 训练里 one-to-many loss 的权重控制。0.8-->0.1
"e2e_o2m_final": args.e2e_o2m_final,
"loss_3d_ramp_epochs": args.loss_3d_ramp_epochs, # 3D loss 不一定一开始就满权重参与,而是逐渐从 0 增长到最大值。10
"loss_3d_weight_max": args.loss_3d_weight_max, # 0.1
"box": args.box,
"cls": args.cls,
"dfl": args.dfl,
"hsv_h": args.hsv_h,
"hsv_s": args.hsv_s,
"hsv_v": args.hsv_v,
"degrees": args.degrees,
"translate": args.translate,
"scale": args.scale,
"shear": args.shear,
"perspective": args.perspective,
"flipud": args.flipud,
"fliplr": args.fliplr,
"mosaic": args.mosaic,
"mixup": args.mixup,
"copy_paste": args.copy_paste,
"patience": args.patience, # early stopping 等待轮数。
"save_period": args.save_period, # 每隔多少 epoch 保存 checkpoint。
"cache": args.cache if args.cache else False,
"workers": args.workers,
"project": args.project,
"name": run_name,
"exist_ok": args.exist_ok,
"pretrained": pretrained,
"amp": args.amp,
"fraction": args.fraction,
"profile": args.profile,
"multi_scale": args.multi_scale,
"val": args.val,
"plots": args.plots,
"roi_metrics_only": args.roi_metrics_only,
"seed": args.seed,
"deterministic": args.deterministic,
"single_cls": args.single_cls,
"rect": args.rect,
"cos_lr": args.cos_lr,
"close_mosaic": args.close_mosaic,
"nbs": args.nbs,
"overlap_mask": args.overlap_mask,
"mask_ratio": args.mask_ratio,
"dropout": args.dropout,
}
# Add freeze if specified
if args.freeze is not None: # 果用户传了 --freeze加入冻结层数配置。
overrides["freeze"] = args.freeze
if args.loss_3d_warmup_epochs is not None:
overrides["loss_3d_warmup_epochs"] = args.loss_3d_warmup_epochs
if args.e2e_o2m_decay_epochs is not None:
overrides["e2e_o2m_decay_epochs"] = args.e2e_o2m_decay_epochs
if args.face_visibility_score_thresh is not None:
overrides["face_visibility_score_thresh"] = args.face_visibility_score_thresh
# Add resume if specified
if args.resume:
overrides["resume"] = args.resume
if exp_dir:
overrides["save_dir"] = exp_dir
# Print configuration
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
LOGGER.info(f" Model arg: {args.model}")
LOGGER.info(f" Model config: {resolved_model_path}")
LOGGER.info(
" Model scale: "
f"{explicit_model_scale if explicit_model_scale else 'infer from pretrained filename or YAML default'}"
)
LOGGER.info(f" Pretrained weights: {pretrained if pretrained else 'None'}")
LOGGER.info(f" Strict pretrained backbone load: {args.strict_backbone_pretrained}")
LOGGER.info(f" Data: {args.data}")
if selected_roi:
LOGGER.info(f" ROI preset: {selected_roi}")
LOGGER.info(f" Resolved data YAML: {resolved_data_path}")
LOGGER.info(f" Epochs: {args.epochs}")
LOGGER.info(f" Batch size: {args.batch}")
LOGGER.info(f" Image size: {args.imgsz}")
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
LOGGER.info(f" Workers: {args.workers}")
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
LOGGER.info(f" Run name: {run_name}")
LOGGER.info(f" Checkpoints: every {args.save_period} epoch(s)" if args.save_period > 0 else " Checkpoints: last.pt and best.pt only")
if exp_dir:
LOGGER.info(f" Experiment dir: {exp_dir}")
LOGGER.info(f" AMP: {args.amp}")
LOGGER.info(f" Optimizer: {args.optimizer}")
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
LOGGER.info(f" Warmup: epochs={args.warmup_epochs}, momentum={args.warmup_momentum}, bias_lr={args.warmup_bias_lr}")
LOGGER.info(
" E2E one-to-many weight: "
f"start={args.e2e_o2m_start}, final={args.e2e_o2m_final}, "
f"decay_epochs={args.e2e_o2m_decay_epochs if args.e2e_o2m_decay_epochs is not None else 'epochs-1'}"
)
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
loss_3d_warmup = args.loss_3d_warmup_epochs if args.loss_3d_warmup_epochs is not None else args.warmup_epochs
LOGGER.info(
f" 3D loss: warmup={loss_3d_warmup} epochs, ramp={args.loss_3d_ramp_epochs} epochs, max={args.loss_3d_weight_max}"
)
LOGGER.info(
" Face visibility score threshold: "
f"{args.face_visibility_score_thresh if args.face_visibility_score_thresh is not None else 'default.yaml'}"
)
LOGGER.info("")
# Initialize trainer
LOGGER.info(colorstr("bright_green", "Initializing Ground3DDetectionTrainer..."))
trainer = Ground3DDetectionTrainer(overrides=overrides) # 进入 Ultralytics 训练框架,做很多初始化
trainer.explicit_model_scale = explicit_model_scale # 把显式模型 scale 挂到 trainer 上。后面 Ground3DDetectionTrainer.get_model() 会用它来决定模型规模,比如 n/s/m/l/x。
trainer.strict_backbone_pretrained = args.strict_backbone_pretrained # 把严格预训练加载检查开关挂到 trainer 上。后面加载 pretrained 时会用。
# Start training
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
trainer.train()
# Print completion message
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "=" * 80))
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
LOGGER.info(colorstr("bright_blue", "bold", "=" * 80 + "\n"))
# Print results location
if hasattr(trainer, "save_dir"):
LOGGER.info(f"Results saved to: {trainer.save_dir}")
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
if __name__ == "__main__":
main()