Files
yolov26_3d/tools/model_merging/export_single_roi_yolo26.py

409 lines
17 KiB
Python
Raw Normal View History

2026-06-24 09:35:46 +08:00
from __future__ import annotations
import argparse
import json
import sys
from copy import deepcopy
from pathlib import Path
from types import SimpleNamespace
from typing import Any
import torch
import torch.nn as nn
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from ultralytics.nn.modules import Detect, RTDETRDecoder
from ultralytics.nn.tasks import load_checkpoint
from ultralytics.utils import LOGGER
from ultralytics.utils.patches import arange_patch, onnx_export_patch
DEFAULT_MODEL = ROOT / "runs" / "detect" / "train_mono3d_roi0_202603291330_epoch61.pt"
DEFAULT_SAVE_DIR = ROOT / "runs" / "export" / "train_mono3d_single_roi0_202603291330"
DEFAULT_IMGSZ_WH = (768, 352)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Export one yolo26 single-ROI Detect3D checkpoint.")
parser.add_argument("--model-path", type=str, default=str(DEFAULT_MODEL), help="Path to ROI checkpoint")
parser.add_argument("--save-dir", type=str, default=str(DEFAULT_SAVE_DIR), help="Directory used to store exported files")
parser.add_argument(
"--imgsz",
nargs=2,
type=int,
default=list(DEFAULT_IMGSZ_WH),
metavar=("W", "H"),
help="Example input size as width height, e.g. --imgsz 768 352",
)
parser.add_argument("--opset", type=int, default=15, help="ONNX opset version")
parser.add_argument("--max-det", type=int, default=300, help="Top-k detections kept by the Detect3D head")
parser.add_argument("--dynamic", action="store_true", help="Export with dynamic batch dimension")
parser.add_argument("--simplify", action="store_true", help="Try to simplify the exported ONNX graph with onnxslim")
export_group = parser.add_mutually_exclusive_group()
export_group.add_argument(
"--detections-only",
action="store_true",
help="Export only postprocessed 2D detections.",
)
export_group.add_argument(
"--hybrid-outputs",
action="store_true",
help="Export postprocessed 2D detections together with raw 2D/3D/edge head outputs.",
)
export_group.add_argument(
"--denorm-branch-outputs",
action="store_true",
help="Export one2one branch outputs after Detect3D branch denormalization, but before top-k postprocessing.",
)
export_group.add_argument(
"--postprocessed-outputs",
action="store_true",
help="Export postprocessed selected outputs instead of raw head tensors.",
)
parser.add_argument(
"--include-anchor-stride",
action="store_true",
help="Also export anchors and strides. Only used with --postprocessed-outputs.",
)
parser.add_argument("--skip-torchscript", action="store_true", help="Skip TorchScript export")
parser.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export")
parser.add_argument("--no-fuse", action="store_true", help="Disable Conv-BN fusion before export")
return parser.parse_args()
def _make_example_input(imgsz_wh: tuple[int, int]) -> torch.Tensor:
width, height = imgsz_wh
return torch.zeros(1, 3, height, width, dtype=torch.float32)
def _prepare_model_for_export(model: nn.Module, max_det: int, dynamic: bool, fuse: bool) -> nn.Module:
model = deepcopy(model).cpu()
for parameter in model.parameters():
parameter.requires_grad_(False)
model.eval()
model.float()
if fuse and hasattr(model, "fuse"):
model = model.fuse()
for module in model.modules():
if isinstance(module, (Detect, RTDETRDecoder)):
if hasattr(module, "dynamic"):
module.dynamic = dynamic
if hasattr(module, "max_det"):
module.max_det = int(max_det)
if hasattr(module, "shape"):
module.shape = None
return model
def _export_mode(args: argparse.Namespace) -> str:
if args.postprocessed_outputs:
return "postprocessed_outputs"
if args.detections_only:
return "detections_only"
if args.hybrid_outputs:
return "hybrid_outputs"
if args.denorm_branch_outputs:
return "denorm_branch_outputs"
return "raw_head_outputs"
def _forward_model_to_head_inputs(model: nn.Module, x: torch.Tensor) -> tuple[nn.Module, Any]:
y = []
for module in model.model[:-1]:
if module.f != -1:
x = y[module.f] if isinstance(module.f, int) else [x if j == -1 else y[j] for j in module.f]
x = module(x)
y.append(x if module.i in model.save else None)
head = model.model[-1]
if head.f != -1:
x = y[head.f] if isinstance(head.f, int) else [x if j == -1 else y[j] for j in head.f]
return head, x
def _flatten_detect3d_outputs(outputs: Any, include_anchor_stride: bool = False) -> tuple[torch.Tensor, ...]:
if not isinstance(outputs, (tuple, list)) or len(outputs) < 2:
raise RuntimeError("Model forward output must be `(detections, raw_preds)` for Detect3D models.")
detections, raw_preds = outputs[0], outputs[1]
if not isinstance(raw_preds, dict):
raise RuntimeError("Raw prediction payload is missing.")
one2one = raw_preds.get("one2one", raw_preds)
if not isinstance(one2one, dict):
raise RuntimeError("one2one prediction payload is missing.")
required_keys = ["preds_3d_selected", "preds_edge_selected"]
if include_anchor_stride:
required_keys += ["anchors_selected", "strides_selected"]
missing = [key for key in required_keys if one2one.get(key) is None]
if missing:
available = sorted(one2one.keys())
raise RuntimeError(f"Missing Detect3D export tensors {missing}. Available keys: {available}.")
result = (detections, one2one["preds_3d_selected"], one2one["preds_edge_selected"])
if include_anchor_stride:
result = result + (one2one["anchors_selected"], one2one["strides_selected"])
return result
def _denorm_detect3d_branch_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
head, head_inputs = _forward_model_to_head_inputs(model, x)
branch_inputs = [tensor.detach() for tensor in head_inputs] if getattr(head, "end2end", False) else head_inputs
branch = head.one2one if getattr(head, "end2end", False) else head.one2many
raw_preds = head.forward_head(branch_inputs, **branch)
if not isinstance(raw_preds, dict):
raise RuntimeError("Raw branch payload is missing.")
required_keys = ("boxes", "scores", "preds_3d", "preds_edge")
missing = [key for key in required_keys if raw_preds.get(key) is None]
if missing:
available = sorted(raw_preds.keys())
raise RuntimeError(f"Raw branch tensors {missing} are missing. Available keys: {available}.")
return raw_preds["boxes"], raw_preds["scores"], raw_preds["preds_3d"], raw_preds["preds_edge"]
def _collect_raw_detect3d_head_outputs(head: nn.Module, head_inputs: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
if not getattr(head, "end2end", False):
raise RuntimeError("Raw-head export currently expects an end2end Detect3D head.")
branch = head.one2one
box_head = branch.get("box_head")
cls_head = branch.get("cls_head")
head_3d = branch.get("head_3d")
edge_head = branch.get("edge_head")
if any(module is None for module in (box_head, cls_head, head_3d, edge_head)):
raise RuntimeError("Raw-head export could not find complete one2one head modules.")
bs = head_inputs[0].shape[0]
raw_boxes = torch.cat([box_head[i](head_inputs[i]).view(bs, 4 * head.reg_max, -1) for i in range(head.nl)], dim=-1)
raw_scores = torch.cat([cls_head[i](head_inputs[i]).view(bs, head.nc, -1) for i in range(head.nl)], dim=-1)
raw_3d = torch.cat([head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1)
raw_edge = torch.cat([edge_head[i](head_inputs[i]).view(bs, head.no_edge, -1) for i in range(head.nl)], dim=-1)
return raw_boxes, raw_scores, raw_3d, raw_edge
def _raw_detect3d_head_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
head, head_inputs = _forward_model_to_head_inputs(model, x)
return _collect_raw_detect3d_head_outputs(head, head_inputs)
def _hybrid_detect2d_raw3d_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
head, head_inputs = _forward_model_to_head_inputs(model, x)
raw_boxes, raw_scores, raw_3d, raw_edge = _collect_raw_detect3d_head_outputs(head, head_inputs)
preds = {"boxes": raw_boxes, "scores": raw_scores, "feats": head_inputs}
detections = head.postprocess(head._inference(preds).permute(0, 2, 1))
return detections, raw_boxes, raw_scores, raw_3d, raw_edge
def _detections_only_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor]:
outputs = model(x)
if not isinstance(outputs, (tuple, list)) or len(outputs) < 1:
raise RuntimeError("Model forward output must contain detections for detections-only export.")
return (outputs[0],)
class SingleROIExportWrapper(nn.Module):
def __init__(
self,
model: nn.Module,
include_anchor_stride: bool = False,
detections_only: bool = False,
hybrid_outputs: bool = False,
denorm_branch_outputs: bool = False,
postprocessed_outputs: bool = False,
):
super().__init__()
self.model = model
self.include_anchor_stride = include_anchor_stride
self.detections_only = detections_only
self.hybrid_outputs = hybrid_outputs
self.denorm_branch_outputs = denorm_branch_outputs
self.postprocessed_outputs = postprocessed_outputs
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
if self.postprocessed_outputs:
return _flatten_detect3d_outputs(self.model(x), include_anchor_stride=self.include_anchor_stride)
if self.detections_only:
return _detections_only_outputs(self.model, x)
if self.hybrid_outputs:
return _hybrid_detect2d_raw3d_outputs(self.model, x)
if self.denorm_branch_outputs:
return _denorm_detect3d_branch_outputs(self.model, x)
return _raw_detect3d_head_outputs(self.model, x)
def _output_names(
include_anchor_stride: bool = False,
detections_only: bool = False,
hybrid_outputs: bool = False,
denorm_branch_outputs: bool = False,
postprocessed_outputs: bool = False,
) -> list[str]:
if postprocessed_outputs:
names = ["detections", "preds_3d", "preds_edge"]
if include_anchor_stride:
names += ["anchors", "strides"]
return names
if detections_only:
return ["detections"]
if hybrid_outputs:
return ["detections", "boxes_head_raw", "scores_head_raw", "preds_3d_head_raw", "preds_edge_head_raw"]
if denorm_branch_outputs:
return ["boxes_raw", "scores_raw", "preds_3d_raw", "preds_edge_raw"]
return ["boxes_head_raw", "scores_head_raw", "preds_3d_head_raw", "preds_edge_head_raw"]
def _describe_outputs(outputs: tuple[torch.Tensor, ...]) -> list[list[int]]:
return [list(output.shape) for output in outputs]
def _dynamic_axes(input_names: list[str], output_names: list[str], enabled: bool) -> dict[str, dict[int, str]] | None:
if not enabled:
return None
dynamic_axes = {name: {0: "batch"} for name in input_names}
dynamic_axes.update({name: {0: "batch"} for name in output_names})
return dynamic_axes
def _write_manifest(path: Path, payload: dict[str, Any]) -> None:
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def _base_manifest(args: argparse.Namespace, output_names: list[str], output_shapes: list[list[int]]) -> dict[str, Any]:
return {
"artifact_name": "single_roi_model",
"model_path": str(Path(args.model_path).resolve()),
"input_names": ["images"],
"input_sizes_wh": {"images": [int(args.imgsz[0]), int(args.imgsz[1])]},
"output_names": output_names,
"output_shapes": output_shapes,
"dynamic_batch": bool(args.dynamic),
"max_det": int(args.max_det),
"opset": int(args.opset),
"fuse": not bool(args.no_fuse),
"include_anchor_stride": bool(args.include_anchor_stride),
"detections_only": bool(args.detections_only),
"hybrid_outputs": bool(args.hybrid_outputs),
"denorm_branch_outputs": bool(args.denorm_branch_outputs),
"postprocessed_outputs": bool(args.postprocessed_outputs),
"export_mode": _export_mode(args),
}
def _export_torchscript(model: nn.Module, example_input: torch.Tensor, save_path: Path, manifest: dict[str, Any]) -> None:
LOGGER.info(f"Exporting TorchScript to {save_path}")
traced = torch.jit.trace(model, example_input, strict=False)
traced.save(str(save_path), _extra_files={"config.txt": json.dumps(manifest, sort_keys=True)})
LOGGER.info(f"TorchScript export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
def _save_onnx_metadata(onnx_path: Path, manifest: dict[str, Any], simplify: bool) -> None:
import onnx
model_onnx = onnx.load(str(onnx_path))
for key, value in manifest.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = key, json.dumps(value, ensure_ascii=True) if isinstance(value, (dict, list)) else str(value)
if simplify:
import onnxslim
LOGGER.info("Simplifying ONNX graph with onnxslim")
model_onnx = onnxslim.slim(model_onnx)
onnx.save(model_onnx, str(onnx_path))
def _export_onnx(
model: nn.Module,
example_input: torch.Tensor,
save_path: Path,
output_names: list[str],
manifest: dict[str, Any],
opset: int,
dynamic: bool,
simplify: bool,
) -> None:
LOGGER.info(f"Exporting ONNX to {save_path} with opset={opset}")
patch_args = SimpleNamespace(dynamic=dynamic, half=False, format="onnx")
with onnx_export_patch():
with arange_patch(patch_args):
torch.onnx.export(
model,
example_input,
str(save_path),
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=["images"],
output_names=output_names,
dynamic_axes=_dynamic_axes(["images"], output_names, dynamic),
)
try:
_save_onnx_metadata(save_path, manifest, simplify=simplify)
except ImportError as error:
LOGGER.warning(f"Skipping ONNX metadata/simplify step because a package is missing: {error}")
except Exception as error:
LOGGER.warning(f"ONNX post-processing failed: {error}")
LOGGER.info(f"ONNX export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
def main() -> None:
args = parse_args()
if args.include_anchor_stride and not args.postprocessed_outputs:
LOGGER.warning("--include-anchor-stride is ignored unless --postprocessed-outputs is also set.")
LOGGER.info(f"Export mode: {_export_mode(args)}")
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
model, _ = load_checkpoint(args.model_path, device="cpu", fuse=False)
prepared_model = _prepare_model_for_export(model, max_det=args.max_det, dynamic=args.dynamic, fuse=not args.no_fuse)
export_model = SingleROIExportWrapper(
prepared_model,
include_anchor_stride=args.include_anchor_stride,
detections_only=args.detections_only,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
).eval()
example_input = _make_example_input((int(args.imgsz[0]), int(args.imgsz[1])))
with torch.inference_mode():
outputs = export_model(example_input)
output_names = _output_names(
include_anchor_stride=args.include_anchor_stride,
detections_only=args.detections_only,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
)
LOGGER.info(f"Single ROI model output shapes: {_describe_outputs(outputs)}")
manifest = _base_manifest(args=args, output_names=output_names, output_shapes=_describe_outputs(outputs))
_write_manifest(save_dir / "single_roi_model.export.json", manifest)
if not args.skip_torchscript:
_export_torchscript(export_model, example_input, save_dir / "single_roi_model.torchscript", manifest)
if not args.skip_onnx:
_export_onnx(
export_model,
example_input,
save_dir / "single_roi_model.onnx",
output_names=output_names,
manifest=manifest,
opset=args.opset,
dynamic=args.dynamic,
simplify=args.simplify,
)
if __name__ == "__main__":
main()