266 lines
12 KiB
Python
Executable File
266 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Ground 2D Detection Training Script for YOLO26
|
|
Supports single-GPU and multi-GPU distributed training with difficulty-based loss weighting.
|
|
|
|
Usage:
|
|
# Single GPU
|
|
python train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100 --device 0
|
|
|
|
# Multi-GPU (DDP)
|
|
python -m torch.distributed.run --nproc_per_node 4 train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100
|
|
|
|
# Multi-GPU with specific devices
|
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node 4 train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add ultralytics to path if needed
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[0]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT))
|
|
|
|
from ultralytics import settings
|
|
from ultralytics.models.yolo.detect import GroundDetectionTrainer
|
|
from ultralytics.utils import LOGGER, colorstr
|
|
from ultralytics.utils.dist import get_distributed_run_timestamp
|
|
|
|
# Enable TensorBoard logging
|
|
settings.update({'tensorboard': True})
|
|
|
|
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
|
|
"""Build a timestamped run name that stays consistent across distributed ranks."""
|
|
if resume:
|
|
resume_path = Path(resume)
|
|
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
|
|
return resume_path.parents[1].name
|
|
return base_name
|
|
|
|
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
|
|
return f"{base_name}_{timestamp}"
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground 2D Detection Dataset")
|
|
|
|
# Model and data
|
|
parser.add_argument("--model", type=str, default="yolo26s.pt", help="Model path or name (yolo26n.pt, yolo26s.pt, etc.)")
|
|
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono2d_ground.yaml", help="Dataset YAML path")
|
|
|
|
# Training hyperparameters
|
|
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
|
parser.add_argument("--batch", type=int, default=16, help="Global batch size across all ranks")
|
|
parser.add_argument("--imgsz", type=int, default=704, help="Image size")
|
|
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
|
|
|
|
# Optimizer
|
|
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
|
|
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
|
|
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
|
|
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
|
|
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
|
|
|
|
# Loss weights
|
|
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
|
|
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
|
|
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain")
|
|
|
|
# Augmentation
|
|
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
|
|
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
|
|
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
|
|
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
|
|
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
|
|
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
|
|
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
|
|
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
|
|
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
|
|
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
|
|
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
|
|
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
|
|
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
|
|
|
|
# Training settings
|
|
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience")
|
|
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (-1 to disable)")
|
|
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images")
|
|
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
|
|
parser.add_argument("--project", type=str, default=None, help="Project directory")
|
|
parser.add_argument("--name", type=str, default="train_mono", help="Experiment name")
|
|
parser.add_argument(
|
|
"--exp_dir",
|
|
type=str,
|
|
default="",
|
|
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
|
|
)
|
|
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
|
|
parser.add_argument("--pretrained", action="store_true", help="Use pretrained model")
|
|
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
|
|
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
|
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use")
|
|
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
|
|
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
|
|
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
|
|
|
|
# Validation
|
|
parser.add_argument("--val", action="store_true", default=True, help="Validate during training")
|
|
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training")
|
|
|
|
# Advanced
|
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
|
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
|
|
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset")
|
|
parser.add_argument("--rect", action="store_true", help="Rectangular training")
|
|
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
|
|
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
|
|
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing epsilon")
|
|
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
|
|
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
|
|
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
|
|
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
"""Main training function."""
|
|
args = parse_args()
|
|
|
|
# Print banner
|
|
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "="*80))
|
|
LOGGER.info(colorstr("bright_blue", "bold", "Ground 2D Detection Training with YOLO26"))
|
|
LOGGER.info(colorstr("bright_blue", "bold", "="*80 + "\n"))
|
|
|
|
# Confirm TensorBoard is enabled
|
|
LOGGER.info(colorstr("bright_green", "✓ TensorBoard logging enabled"))
|
|
|
|
# Check for distributed training
|
|
world_size = int(os.getenv("WORLD_SIZE", 1))
|
|
rank = int(os.getenv("RANK", -1))
|
|
local_rank = int(os.getenv("LOCAL_RANK", -1))
|
|
|
|
if world_size > 1:
|
|
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
|
else:
|
|
LOGGER.info("Single GPU training")
|
|
|
|
if args.resume and args.exp_dir:
|
|
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
|
|
|
|
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
|
|
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank)
|
|
|
|
# Build overrides dictionary
|
|
overrides = {
|
|
"model": args.model,
|
|
"data": args.data,
|
|
"epochs": args.epochs,
|
|
"batch": args.batch,
|
|
"imgsz": args.imgsz,
|
|
"device": args.device,
|
|
"optimizer": args.optimizer,
|
|
"lr0": args.lr0,
|
|
"lrf": args.lrf,
|
|
"momentum": args.momentum,
|
|
"weight_decay": args.weight_decay,
|
|
"box": args.box,
|
|
"cls": args.cls,
|
|
"dfl": args.dfl,
|
|
"hsv_h": args.hsv_h,
|
|
"hsv_s": args.hsv_s,
|
|
"hsv_v": args.hsv_v,
|
|
"degrees": args.degrees,
|
|
"translate": args.translate,
|
|
"scale": args.scale,
|
|
"shear": args.shear,
|
|
"perspective": args.perspective,
|
|
"flipud": args.flipud,
|
|
"fliplr": args.fliplr,
|
|
"mosaic": args.mosaic,
|
|
"mixup": args.mixup,
|
|
"copy_paste": args.copy_paste,
|
|
"patience": args.patience,
|
|
"save_period": args.save_period,
|
|
"cache": args.cache if args.cache else False,
|
|
"workers": args.workers,
|
|
"project": args.project,
|
|
"name": run_name,
|
|
"exist_ok": args.exist_ok,
|
|
"pretrained": args.pretrained,
|
|
"amp": args.amp,
|
|
"fraction": args.fraction,
|
|
"profile": args.profile,
|
|
"multi_scale": args.multi_scale,
|
|
"val": args.val,
|
|
"plots": args.plots,
|
|
"seed": args.seed,
|
|
"deterministic": args.deterministic,
|
|
"single_cls": args.single_cls,
|
|
"rect": args.rect,
|
|
"cos_lr": args.cos_lr,
|
|
"close_mosaic": args.close_mosaic,
|
|
"label_smoothing": args.label_smoothing,
|
|
"nbs": args.nbs,
|
|
"overlap_mask": args.overlap_mask,
|
|
"mask_ratio": args.mask_ratio,
|
|
"dropout": args.dropout,
|
|
}
|
|
|
|
# Add freeze if specified
|
|
if args.freeze is not None:
|
|
overrides["freeze"] = args.freeze
|
|
|
|
# Add resume if specified
|
|
if args.resume:
|
|
overrides["resume"] = args.resume
|
|
if exp_dir:
|
|
overrides["save_dir"] = exp_dir
|
|
|
|
# Print configuration
|
|
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
|
|
LOGGER.info(f" Model: {args.model}")
|
|
LOGGER.info(f" Data: {args.data}")
|
|
LOGGER.info(f" Epochs: {args.epochs}")
|
|
LOGGER.info(f" Global batch size: {args.batch}")
|
|
LOGGER.info(f" Image size: {args.imgsz}")
|
|
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
|
|
LOGGER.info(f" Workers: {args.workers}")
|
|
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
|
|
LOGGER.info(f" Run name: {run_name}")
|
|
if exp_dir:
|
|
LOGGER.info(f" Experiment dir: {exp_dir}")
|
|
LOGGER.info(f" AMP: {args.amp}")
|
|
LOGGER.info(f" Optimizer: {args.optimizer}")
|
|
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
|
|
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
|
|
LOGGER.info("")
|
|
|
|
# Initialize trainer
|
|
LOGGER.info(colorstr("bright_green", "Initializing GroundDetectionTrainer..."))
|
|
trainer = GroundDetectionTrainer(overrides=overrides)
|
|
|
|
# Start training
|
|
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
|
|
trainer.train()
|
|
|
|
# Print completion message
|
|
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "="*80))
|
|
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
|
|
LOGGER.info(colorstr("bright_blue", "bold", "="*80 + "\n"))
|
|
|
|
# Print results location
|
|
if hasattr(trainer, "save_dir"):
|
|
LOGGER.info(f"Results saved to: {trainer.save_dir}")
|
|
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
|
|
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|