Files
yolov26_3d/docs/yolo26_training_flow.md
2026-06-24 09:35:46 +08:00

194 lines
5.1 KiB
Markdown
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# YOLO26 / Ground3D 训练流程
## Overview
本文描述当前仓库里训练主干的真实流程,并额外指出 Ground3D 分支相对通用检测训练的差异。对应代码主要在:
- `ultralytics/engine/trainer.py`
- `ultralytics/models/yolo/detect/train.py`
- `ultralytics/data/dataset.py`
- `ultralytics/utils/loss.py`
- `train_mono3d.py`
## 通用主干流程
### 1. Trainer 初始化
`DetectionTrainer` / `Ground3DDetectionTrainer` 最终都会先走 `BaseTrainer.__init__()`
1. `get_cfg()` 合并默认配置与 overrides
2. 处理 `resume`
3. 选择 device
4. 初始化随机种子
5. 创建 `save_dir` / `weights` 目录
6. 初始化 callbacks
7. 调用 `get_dataset()` 读取数据集元信息
### 2. 进入训练
`trainer.train()` 的逻辑是:
- 多卡且非已分布式启动时:生成 DDP 命令并拉起子进程
- 否则:直接进入 `_do_train()`
### 3. `_setup_train()`
训练正式开始前会做这些事:
1. `setup_model()`
2. `model.to(device)`
3. `set_model_attributes()`
4. `attempt_compile()`,若开启 `compile`
5. 冻结指定层
6. 检查 AMP
7. DDP 包装
8. `check_imgsz()`
9.`batch < 1` 且支持 AutoBatch则估计 batch size
10. `_build_train_pipeline()`
11. `get_validator()`
12. 初始化 EMA
13. 可选画 label 图
14. 初始化 early stopping
15. `resume_training()`
### 4. `_build_train_pipeline()`
当前代码里的真实行为是:
1. 构建 train dataloader
2. 构建 val dataloader
- 大多数任务用 `batch_size // 4`
- `obb` 例外,沿用 train batch size
3. 根据 `nbs` 计算梯度累积步数
4. `build_optimizer()`
- `optimizer=auto` 时,`iterations > 10000``MuSGD`,否则选 `AdamW`
5. `_setup_scheduler()`
- `cos_lr=True` 用 cosine
- 否则用 linear decay
### 5. 主训练循环
每个 epoch 的主流程:
1. `scheduler.step()`
2. `model.train()`
3. 若到 `close_mosaic` 阶段,调用 `_close_dataloader_mosaic()`
4. 遍历 train loader
5. warmup 阶段动态调 lr / momentum / accumulate
6. `preprocess_batch(batch)`
7. `preds = model(batch["img"])`
8. `loss, loss_items = model.loss(batch, preds)`
9. 反向传播
10. 满足累积步数后 `optimizer_step()`
11. 记录进度条和 callbacks
epoch 末尾会:
1. 跑 validation
2.`results.csv`
3. early stopping 检查
4. 保存 `last.pt` / `best.pt`
5. 到训练结束后做 `final_eval()`
## Ground3D 分支差异
### 1. 模型与 loss
Ground3D 使用的是:
- model: `Ground3DDetectionModel`
- trainer: `Ground3DDetectionTrainer`
- validator: `Ground3DDetectionValidator`
- loss:
- `end2end=True` 时为 `E2EGround3DLoss`
- 否则为 `v8Detection3DLoss`
这和通用检测的 `E2ELoss / v8DetectionLoss` 不是一回事。
### 2. 数据集
Ground3D 训练不是普通 `YOLODataset`,而是 `YOLOGround3DDataset`
它的关键特征:
- 样本来自 GT list 文件,不是直接扫图片目录
- `get_image_and_label()` 内部显式执行 ROI crop 或 virtual-camera 变换
- 训练时会携带:
- `labels_3d`
- `calib`
- `camera_mode`
- `difficulty_levels`
- 预计算 edge GT
### 3. 增强
Ground3D 当前不走标准几何增强。
`YOLOGround3DDataset.build_transforms()` 只追加:
1. `RandomHSV`
2. `Format`
因此:
- mosaic / mixup / random perspective / random flip / letterbox 不参与 Ground3D 训练增强
- `close_mosaic` 仍然存在于 trainer 主干里,但对 Ground3D 数据集本身基本没有实际几何增强可关闭
### 4. 预处理
`GroundDetectionTrainer.preprocess_batch()` / `Ground3DDetectionTrainer.preprocess_batch()` 的真实行为是:
- 先把 tensor 移到 device
- 图像做 `/ 256`
- 若启用 `multi_scale`,对 batch 图像做双线性 resize
注意这里不是 `/ 255`
### 5. batch size
Ground3D 显式禁用了 AutoBatch
- `Ground3DDetectionTrainer.auto_batch()` 直接报错
- 必须手动传 `--batch`
### 6. 2D / 3D / difficulty loss 的关系
Ground3D loss 由三块组成:
1. 2D detection loss
- box
- cls
- dfl
2. difficulty loss
3. 3D loss
其中:
- `camera_mode == "virtual"` 的样本仍参与 `cls`
- 但会被从 2D `box/dfl` 中 mask 掉
- `difficulty loss` 单独相加
- 3D 项受 `loss_3d_weight` 控制
所以 `loss_3d_weight=0` 的 warmup 阶段并不是“只训练 2D box/cls/dfl”而是“训练 2D + difficulty不训练 3D 项”。
### 7. 3D 权重调度
`Ground3DDetectionTrainer.preprocess_batch()` 会在每个 batch 前更新当前 loss 权重:
- `epoch < loss_3d_warmup_epochs``loss_3d_weight = 0`
- 之后在 `loss_3d_ramp_epochs` 内线性拉升
- 最大到 `loss_3d_weight_max`
### 8. validation
Ground3D validation 在通用 2D metrics 之外,还会:
1.`one2one` 头抓取 `preds_3d_selected` / `preds_edge_selected`
2.`batch["calib"]``depth_scale` 恢复预测深度
3. 计算 3D metrics
4.`roi_metrics_only=True`,可跳过 virtual-camera 样本,只统计 ROI 样本
## 一句话总结
当前仓库的训练主干仍然是标准 `BaseTrainer -> _setup_train -> _do_train` 框架,但 Ground3D 在数据集、loss、validation 和 ROI/virtual-camera 预处理上都已经是独立分支,不能再用“普通 YOLO detect 训练流”去概括。