单目3D初始代码

This commit is contained in:
zhao.zhu
2026-06-24 09:35:46 +08:00
commit 04a5895b6b
1153 changed files with 340700 additions and 0 deletions

265
train_mono2d.py Executable file
View File

@@ -0,0 +1,265 @@
#!/usr/bin/env python3
"""
Ground 2D Detection Training Script for YOLO26
Supports single-GPU and multi-GPU distributed training with difficulty-based loss weighting.
Usage:
# Single GPU
python train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100 --device 0
# Multi-GPU (DDP)
python -m torch.distributed.run --nproc_per_node 4 train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100
# Multi-GPU with specific devices
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node 4 train_mono.py --model yolo26n.pt --data ultralytics/cfg/datasets/mono2d_ground.yaml --epochs 100
"""
import argparse
import os
import sys
from pathlib import Path
# Add ultralytics to path if needed
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from ultralytics import settings
from ultralytics.models.yolo.detect import GroundDetectionTrainer
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.dist import get_distributed_run_timestamp
# Enable TensorBoard logging
settings.update({'tensorboard': True})
def build_run_name(base_name: str, resume: str, world_size: int, rank: int) -> str:
"""Build a timestamped run name that stays consistent across distributed ranks."""
if resume:
resume_path = Path(resume)
if resume_path.suffix == ".pt" and len(resume_path.parents) >= 2:
return resume_path.parents[1].name
return base_name
timestamp = get_distributed_run_timestamp(base_name, world_size, rank)
return f"{base_name}_{timestamp}"
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Train YOLO26 on Ground 2D Detection Dataset")
# Model and data
parser.add_argument("--model", type=str, default="yolo26s.pt", help="Model path or name (yolo26n.pt, yolo26s.pt, etc.)")
parser.add_argument("--data", type=str, default="ultralytics/cfg/datasets/mono2d_ground.yaml", help="Dataset YAML path")
# Training hyperparameters
parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
parser.add_argument("--batch", type=int, default=16, help="Global batch size across all ranks")
parser.add_argument("--imgsz", type=int, default=704, help="Image size")
parser.add_argument("--device", type=str, default="0,1,2,3", help="Device to use (e.g., 0 or 0,1,2,3 or cpu)")
# Optimizer
parser.add_argument("--optimizer", type=str, default="auto", choices=["SGD", "Adam", "AdamW", "auto"], help="Optimizer")
parser.add_argument("--lr0", type=float, default=0.01, help="Initial learning rate")
parser.add_argument("--lrf", type=float, default=0.01, help="Final learning rate factor")
parser.add_argument("--momentum", type=float, default=0.937, help="SGD momentum/Adam beta1")
parser.add_argument("--weight_decay", type=float, default=0.0005, help="Optimizer weight decay")
# Loss weights
parser.add_argument("--box", type=float, default=7.5, help="Box loss gain")
parser.add_argument("--cls", type=float, default=0.5, help="Class loss gain")
parser.add_argument("--dfl", type=float, default=1.5, help="DFL loss gain")
# Augmentation
parser.add_argument("--hsv_h", type=float, default=0.015, help="HSV-Hue augmentation")
parser.add_argument("--hsv_s", type=float, default=0.7, help="HSV-Saturation augmentation")
parser.add_argument("--hsv_v", type=float, default=0.4, help="HSV-Value augmentation")
parser.add_argument("--degrees", type=float, default=0.0, help="Rotation augmentation (degrees)")
parser.add_argument("--translate", type=float, default=0.1, help="Translation augmentation")
parser.add_argument("--scale", type=float, default=0.5, help="Scale augmentation")
parser.add_argument("--shear", type=float, default=0.0, help="Shear augmentation (degrees)")
parser.add_argument("--perspective", type=float, default=0.0, help="Perspective augmentation")
parser.add_argument("--flipud", type=float, default=0.0, help="Vertical flip probability")
parser.add_argument("--fliplr", type=float, default=0.5, help="Horizontal flip probability")
parser.add_argument("--mosaic", type=float, default=1.0, help="Mosaic augmentation probability")
parser.add_argument("--mixup", type=float, default=0.0, help="Mixup augmentation probability")
parser.add_argument("--copy_paste", type=float, default=0.0, help="Copy-paste augmentation probability")
# Training settings
parser.add_argument("--patience", type=int, default=100, help="Early stopping patience")
parser.add_argument("--save_period", type=int, default=1, help="Save checkpoint every x epochs (-1 to disable)")
parser.add_argument("--cache", type=str, default="", choices=["", "ram", "disk"], help="Cache images")
parser.add_argument("--workers", type=int, default=16, help="Number of dataloader workers")
parser.add_argument("--project", type=str, default=None, help="Project directory")
parser.add_argument("--name", type=str, default="train_mono", help="Experiment name")
parser.add_argument(
"--exp_dir",
type=str,
default="",
help="Exact experiment directory to use. Bypasses distributed run-name timestamp sync.",
)
parser.add_argument("--exist_ok", action="store_true", help="Overwrite existing experiment")
parser.add_argument("--pretrained", action="store_true", help="Use pretrained model")
parser.add_argument("--resume", type=str, default="", help="Resume training from checkpoint")
parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
parser.add_argument("--fraction", type=float, default=1.0, help="Dataset fraction to use")
parser.add_argument("--profile", action="store_true", help="Profile ONNX and TensorRT speeds")
parser.add_argument("--freeze", type=int, default=None, help="Freeze first n layers")
parser.add_argument("--multi_scale", type=float, default=0.0, help="Multi-scale training range")
# Validation
parser.add_argument("--val", action="store_true", default=True, help="Validate during training")
parser.add_argument("--plots", action="store_true", default=True, help="Save plots during training")
# Advanced
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--deterministic", action="store_true", help="Use deterministic algorithms")
parser.add_argument("--single_cls", action="store_true", help="Train as single-class dataset")
parser.add_argument("--rect", action="store_true", help="Rectangular training")
parser.add_argument("--cos_lr", action="store_true", help="Use cosine learning rate scheduler")
parser.add_argument("--close_mosaic", type=int, default=10, help="Disable mosaic augmentation for final epochs")
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing epsilon")
parser.add_argument("--nbs", type=int, default=64, help="Nominal batch size")
parser.add_argument("--overlap_mask", action="store_true", default=True, help="Masks should overlap during training")
parser.add_argument("--mask_ratio", type=int, default=4, help="Mask downsample ratio")
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate")
return parser.parse_args()
def main():
"""Main training function."""
args = parse_args()
# Print banner
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "="*80))
LOGGER.info(colorstr("bright_blue", "bold", "Ground 2D Detection Training with YOLO26"))
LOGGER.info(colorstr("bright_blue", "bold", "="*80 + "\n"))
# Confirm TensorBoard is enabled
LOGGER.info(colorstr("bright_green", "✓ TensorBoard logging enabled"))
# Check for distributed training
world_size = int(os.getenv("WORLD_SIZE", 1))
rank = int(os.getenv("RANK", -1))
local_rank = int(os.getenv("LOCAL_RANK", -1))
if world_size > 1:
LOGGER.info(f"Distributed training: world_size={world_size}, rank={rank}, local_rank={local_rank}")
else:
LOGGER.info("Single GPU training")
if args.resume and args.exp_dir:
raise ValueError("--exp_dir cannot be used with --resume because resume already determines the run directory.")
exp_dir = str(Path(args.exp_dir).expanduser()) if args.exp_dir else ""
run_name = Path(exp_dir).name if exp_dir else build_run_name(args.name, args.resume, world_size, rank)
# Build overrides dictionary
overrides = {
"model": args.model,
"data": args.data,
"epochs": args.epochs,
"batch": args.batch,
"imgsz": args.imgsz,
"device": args.device,
"optimizer": args.optimizer,
"lr0": args.lr0,
"lrf": args.lrf,
"momentum": args.momentum,
"weight_decay": args.weight_decay,
"box": args.box,
"cls": args.cls,
"dfl": args.dfl,
"hsv_h": args.hsv_h,
"hsv_s": args.hsv_s,
"hsv_v": args.hsv_v,
"degrees": args.degrees,
"translate": args.translate,
"scale": args.scale,
"shear": args.shear,
"perspective": args.perspective,
"flipud": args.flipud,
"fliplr": args.fliplr,
"mosaic": args.mosaic,
"mixup": args.mixup,
"copy_paste": args.copy_paste,
"patience": args.patience,
"save_period": args.save_period,
"cache": args.cache if args.cache else False,
"workers": args.workers,
"project": args.project,
"name": run_name,
"exist_ok": args.exist_ok,
"pretrained": args.pretrained,
"amp": args.amp,
"fraction": args.fraction,
"profile": args.profile,
"multi_scale": args.multi_scale,
"val": args.val,
"plots": args.plots,
"seed": args.seed,
"deterministic": args.deterministic,
"single_cls": args.single_cls,
"rect": args.rect,
"cos_lr": args.cos_lr,
"close_mosaic": args.close_mosaic,
"label_smoothing": args.label_smoothing,
"nbs": args.nbs,
"overlap_mask": args.overlap_mask,
"mask_ratio": args.mask_ratio,
"dropout": args.dropout,
}
# Add freeze if specified
if args.freeze is not None:
overrides["freeze"] = args.freeze
# Add resume if specified
if args.resume:
overrides["resume"] = args.resume
if exp_dir:
overrides["save_dir"] = exp_dir
# Print configuration
LOGGER.info(colorstr("bright_cyan", "\nTraining Configuration:"))
LOGGER.info(f" Model: {args.model}")
LOGGER.info(f" Data: {args.data}")
LOGGER.info(f" Epochs: {args.epochs}")
LOGGER.info(f" Global batch size: {args.batch}")
LOGGER.info(f" Image size: {args.imgsz}")
LOGGER.info(f" Device: {args.device if args.device else 'auto'}")
LOGGER.info(f" Workers: {args.workers}")
LOGGER.info(f" Cache: {args.cache if args.cache else 'False'}")
LOGGER.info(f" Run name: {run_name}")
if exp_dir:
LOGGER.info(f" Experiment dir: {exp_dir}")
LOGGER.info(f" AMP: {args.amp}")
LOGGER.info(f" Optimizer: {args.optimizer}")
LOGGER.info(f" Learning rate: {args.lr0} -> {args.lr0 * args.lrf}")
LOGGER.info(f" Loss weights: box={args.box}, cls={args.cls}, dfl={args.dfl}")
LOGGER.info("")
# Initialize trainer
LOGGER.info(colorstr("bright_green", "Initializing GroundDetectionTrainer..."))
trainer = GroundDetectionTrainer(overrides=overrides)
# Start training
LOGGER.info(colorstr("bright_green", "Starting training...\n"))
trainer.train()
# Print completion message
LOGGER.info(colorstr("bright_blue", "bold", "\n" + "="*80))
LOGGER.info(colorstr("bright_blue", "bold", "Training completed!"))
LOGGER.info(colorstr("bright_blue", "bold", "="*80 + "\n"))
# Print results location
if hasattr(trainer, "save_dir"):
LOGGER.info(f"Results saved to: {trainer.save_dir}")
LOGGER.info(f"Best weights: {trainer.save_dir / 'weights' / 'best.pt'}")
LOGGER.info(f"Last weights: {trainer.save_dir / 'weights' / 'last.pt'}")
if __name__ == "__main__":
main()