单目3D初始代码
This commit is contained in:
265
train_mono2d.py
Executable file
265
train_mono2d.py
Executable file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Ground 2D Detection Training Script for YOLO26
|
||||
Supports single-GPU and multi-GPU distributed training with difficulty-based loss weighting.
|
||||
|
||||
Usage:
|
||||
# Single GPU
|
||||
python train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100 --device 0
|
||||
|
||||
# Multi-GPU (DDP)
|
||||
python -m torch.distributed.run --nproc_per_node 4 train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100
|
||||
|
||||
# Multi-GPU with specific devices
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node 4 train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add ultralytics to path if needed
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT))
|
||||
|
||||
from ultralytics import settings
|
||||
from ultralytics.models.yolo.detect import GroundDetectionTrainer
|
||||
from ultralytics.utils import LOGGER, colorstr
|
||||
from ultralytics.utils.dist import get_distributed_run_timestamp
|
||||
|
||||
# Enable TensorBoard logging
|
||||
settings.update({'tensorboard': True})
|
||||
|
||||
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
|
||||
"""Build a timestamped run name that stays consistent across distributed ranks."""
|
||||
if resume:
|
||||
resume_path = Path(resume)
|
||||
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
|
||||
return resume_path.parents[1].name
|
||||
return base_name
|
||||
|
||||
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
|
||||
return f"{base_name}_{timestamp}"
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground 2D Detection Dataset")
|
||||
|
||||
# Model and data
|
||||
parser.add_argument("--model", type=str, default="yolo26s.pt", help="Model path or name (yolo26n.pt, yolo26s.pt, etc.)")
|
||||
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono2d_ground.yaml", help="Dataset YAML path")
|
||||
|
||||
# Training hyperparameters
|
||||
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
|
||||
parser.add_argument("--batch", type=int, default=16, help="Global batch size across all ranks")
|
||||
parser.add_argument("--imgsz", type=int, default=704, help="Image size")
|
||||
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
|
||||
|
||||
# Optimizer
|
||||
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
|
||||
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
|
||||
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
|
||||
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
|
||||
|
||||
# Loss weights
|
||||
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
|
||||
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
|
||||
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain")
|
||||
|
||||
# Augmentation
|
||||
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
|
||||
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
|
||||
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
|
||||
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
|
||||
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
|
||||
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
|
||||
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
|
||||
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
|
||||
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
|
||||
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
|
||||
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
|
||||
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
|
||||
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
|
||||
|
||||
# Training settings
|
||||
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience")
|
||||
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (-1 to disable)")
|
||||
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images")
|
||||
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
|
||||
parser.add_argument("--project", type=str, default=None, help="Project directory")
|
||||
parser.add_argument("--name", type=str, default="train_mono", help="Experiment name")
|
||||
parser.add_argument(
|
||||
"--exp_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
|
||||
)
|
||||
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
|
||||
parser.add_argument("--pretrained", action="store_true", help="Use pretrained model")
|
||||
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
|
||||
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
|
||||
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use")
|
||||
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
|
||||
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
|
||||
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
|
||||
|
||||
# Validation
|
||||
parser.add_argument("--val", action="store_true", default=True, help="Validate during training")
|
||||
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training")
|
||||
|
||||
# Advanced
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
|
||||
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset")
|
||||
parser.add_argument("--rect", action="store_true", help="Rectangular training")
|
||||
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
|
||||
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing epsilon")
|
||||
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
|
||||
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
|
||||
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
|
||||
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
args = parse_args()
|
||||
|
||||
# Print banner
|
||||
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "="*80))
|
||||
LOGGER.info(colorstr("bright_blue", "bold", "Ground 2D Detection Training with YOLO26"))
|
||||
LOGGER.info(colorstr("bright_blue", "bold", "="*80 + "\n"))
|
||||
|
||||
# Confirm TensorBoard is enabled
|
||||
LOGGER.info(colorstr("bright_green", "✓ TensorBoard logging enabled"))
|
||||
|
||||
# Check for distributed training
|
||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
rank = int(os.getenv("RANK", -1))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", -1))
|
||||
|
||||
if world_size > 1:
|
||||
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||
else:
|
||||
LOGGER.info("Single GPU training")
|
||||
|
||||
if args.resume and args.exp_dir:
|
||||
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
|
||||
|
||||
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
|
||||
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank)
|
||||
|
||||
# Build overrides dictionary
|
||||
overrides = {
|
||||
"model": args.model,
|
||||
"data": args.data,
|
||||
"epochs": args.epochs,
|
||||
"batch": args.batch,
|
||||
"imgsz": args.imgsz,
|
||||
"device": args.device,
|
||||
"optimizer": args.optimizer,
|
||||
"lr0": args.lr0,
|
||||
"lrf": args.lrf,
|
||||
"momentum": args.momentum,
|
||||
"weight_decay": args.weight_decay,
|
||||
"box": args.box,
|
||||
"cls": args.cls,
|
||||
"dfl": args.dfl,
|
||||
"hsv_h": args.hsv_h,
|
||||
"hsv_s": args.hsv_s,
|
||||
"hsv_v": args.hsv_v,
|
||||
"degrees": args.degrees,
|
||||
"translate": args.translate,
|
||||
"scale": args.scale,
|
||||
"shear": args.shear,
|
||||
"perspective": args.perspective,
|
||||
"flipud": args.flipud,
|
||||
"fliplr": args.fliplr,
|
||||
"mosaic": args.mosaic,
|
||||
"mixup": args.mixup,
|
||||
"copy_paste": args.copy_paste,
|
||||
"patience": args.patience,
|
||||
"save_period": args.save_period,
|
||||
"cache": args.cache if args.cache else False,
|
||||
"workers": args.workers,
|
||||
"project": args.project,
|
||||
"name": run_name,
|
||||
"exist_ok": args.exist_ok,
|
||||
"pretrained": args.pretrained,
|
||||
"amp": args.amp,
|
||||
"fraction": args.fraction,
|
||||
"profile": args.profile,
|
||||
"multi_scale": args.multi_scale,
|
||||
"val": args.val,
|
||||
"plots": args.plots,
|
||||
"seed": args.seed,
|
||||
"deterministic": args.deterministic,
|
||||
"single_cls": args.single_cls,
|
||||
"rect": args.rect,
|
||||
"cos_lr": args.cos_lr,
|
||||
"close_mosaic": args.close_mosaic,
|
||||
"label_smoothing": args.label_smoothing,
|
||||
"nbs": args.nbs,
|
||||
"overlap_mask": args.overlap_mask,
|
||||
"mask_ratio": args.mask_ratio,
|
||||
"dropout": args.dropout,
|
||||
}
|
||||
|
||||
# Add freeze if specified
|
||||
if args.freeze is not None:
|
||||
overrides["freeze"] = args.freeze
|
||||
|
||||
# Add resume if specified
|
||||
if args.resume:
|
||||
overrides["resume"] = args.resume
|
||||
if exp_dir:
|
||||
overrides["save_dir"] = exp_dir
|
||||
|
||||
# Print configuration
|
||||
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
|
||||
LOGGER.info(f" Model: {args.model}")
|
||||
LOGGER.info(f" Data: {args.data}")
|
||||
LOGGER.info(f" Epochs: {args.epochs}")
|
||||
LOGGER.info(f" Global batch size: {args.batch}")
|
||||
LOGGER.info(f" Image size: {args.imgsz}")
|
||||
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
|
||||
LOGGER.info(f" Workers: {args.workers}")
|
||||
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
|
||||
LOGGER.info(f" Run name: {run_name}")
|
||||
if exp_dir:
|
||||
LOGGER.info(f" Experiment dir: {exp_dir}")
|
||||
LOGGER.info(f" AMP: {args.amp}")
|
||||
LOGGER.info(f" Optimizer: {args.optimizer}")
|
||||
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
|
||||
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
|
||||
LOGGER.info("")
|
||||
|
||||
# Initialize trainer
|
||||
LOGGER.info(colorstr("bright_green", "Initializing GroundDetectionTrainer..."))
|
||||
trainer = GroundDetectionTrainer(overrides=overrides)
|
||||
|
||||
# Start training
|
||||
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
|
||||
trainer.train()
|
||||
|
||||
# Print completion message
|
||||
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "="*80))
|
||||
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
|
||||
LOGGER.info(colorstr("bright_blue", "bold", "="*80 + "\n"))
|
||||
|
||||
# Print results location
|
||||
if hasattr(trainer, "save_dir"):
|
||||
LOGGER.info(f"Results saved to: {trainer.save_dir}")
|
||||
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
|
||||
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user