409 lines
17 KiB
Python
Executable File
409 lines
17 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[2]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT))
|
|
|
|
from ultralytics.nn.modules import Detect, RTDETRDecoder
|
|
from ultralytics.nn.tasks import load_checkpoint
|
|
from ultralytics.utils import LOGGER
|
|
from ultralytics.utils.patches import arange_patch, onnx_export_patch
|
|
|
|
|
|
DEFAULT_MODEL = ROOT / "runs" / "detect" / "train_mono3d_roi0_202603291330_epoch61.pt"
|
|
DEFAULT_SAVE_DIR = ROOT / "runs" / "export" / "train_mono3d_single_roi0_202603291330"
|
|
DEFAULT_IMGSZ_WH = (768, 352)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Export one yolo26 single-ROI Detect3D checkpoint.")
|
|
parser.add_argument("--model-path", type=str, default=str(DEFAULT_MODEL), help="Path to ROI checkpoint")
|
|
parser.add_argument("--save-dir", type=str, default=str(DEFAULT_SAVE_DIR), help="Directory used to store exported files")
|
|
parser.add_argument(
|
|
"--imgsz",
|
|
nargs=2,
|
|
type=int,
|
|
default=list(DEFAULT_IMGSZ_WH),
|
|
metavar=("W", "H"),
|
|
help="Example input size as width height, e.g. --imgsz 768 352",
|
|
)
|
|
parser.add_argument("--opset", type=int, default=15, help="ONNX opset version")
|
|
parser.add_argument("--max-det", type=int, default=300, help="Top-k detections kept by the Detect3D head")
|
|
parser.add_argument("--dynamic", action="store_true", help="Export with dynamic batch dimension")
|
|
parser.add_argument("--simplify", action="store_true", help="Try to simplify the exported ONNX graph with onnxslim")
|
|
export_group = parser.add_mutually_exclusive_group()
|
|
export_group.add_argument(
|
|
"--detections-only",
|
|
action="store_true",
|
|
help="Export only postprocessed 2D detections.",
|
|
)
|
|
export_group.add_argument(
|
|
"--hybrid-outputs",
|
|
action="store_true",
|
|
help="Export postprocessed 2D detections together with raw 2D/3D/edge head outputs.",
|
|
)
|
|
export_group.add_argument(
|
|
"--denorm-branch-outputs",
|
|
action="store_true",
|
|
help="Export one2one branch outputs after Detect3D branch denormalization, but before top-k postprocessing.",
|
|
)
|
|
export_group.add_argument(
|
|
"--postprocessed-outputs",
|
|
action="store_true",
|
|
help="Export postprocessed selected outputs instead of raw head tensors.",
|
|
)
|
|
parser.add_argument(
|
|
"--include-anchor-stride",
|
|
action="store_true",
|
|
help="Also export anchors and strides. Only used with --postprocessed-outputs.",
|
|
)
|
|
parser.add_argument("--skip-torchscript", action="store_true", help="Skip TorchScript export")
|
|
parser.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export")
|
|
parser.add_argument("--no-fuse", action="store_true", help="Disable Conv-BN fusion before export")
|
|
return parser.parse_args()
|
|
|
|
|
|
def _make_example_input(imgsz_wh: tuple[int, int]) -> torch.Tensor:
|
|
width, height = imgsz_wh
|
|
return torch.zeros(1, 3, height, width, dtype=torch.float32)
|
|
|
|
|
|
def _prepare_model_for_export(model: nn.Module, max_det: int, dynamic: bool, fuse: bool) -> nn.Module:
|
|
model = deepcopy(model).cpu()
|
|
for parameter in model.parameters():
|
|
parameter.requires_grad_(False)
|
|
model.eval()
|
|
model.float()
|
|
if fuse and hasattr(model, "fuse"):
|
|
model = model.fuse()
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, (Detect, RTDETRDecoder)):
|
|
if hasattr(module, "dynamic"):
|
|
module.dynamic = dynamic
|
|
if hasattr(module, "max_det"):
|
|
module.max_det = int(max_det)
|
|
if hasattr(module, "shape"):
|
|
module.shape = None
|
|
return model
|
|
|
|
|
|
def _export_mode(args: argparse.Namespace) -> str:
|
|
if args.postprocessed_outputs:
|
|
return "postprocessed_outputs"
|
|
if args.detections_only:
|
|
return "detections_only"
|
|
if args.hybrid_outputs:
|
|
return "hybrid_outputs"
|
|
if args.denorm_branch_outputs:
|
|
return "denorm_branch_outputs"
|
|
return "raw_head_outputs"
|
|
|
|
|
|
def _forward_model_to_head_inputs(model: nn.Module, x: torch.Tensor) -> tuple[nn.Module, Any]:
|
|
y = []
|
|
for module in model.model[:-1]:
|
|
if module.f != -1:
|
|
x = y[module.f] if isinstance(module.f, int) else [x if j == -1 else y[j] for j in module.f]
|
|
x = module(x)
|
|
y.append(x if module.i in model.save else None)
|
|
|
|
head = model.model[-1]
|
|
if head.f != -1:
|
|
x = y[head.f] if isinstance(head.f, int) else [x if j == -1 else y[j] for j in head.f]
|
|
return head, x
|
|
|
|
|
|
def _flatten_detect3d_outputs(outputs: Any, include_anchor_stride: bool = False) -> tuple[torch.Tensor, ...]:
|
|
if not isinstance(outputs, (tuple, list)) or len(outputs) < 2:
|
|
raise RuntimeError("Model forward output must be `(detections, raw_preds)` for Detect3D models.")
|
|
|
|
detections, raw_preds = outputs[0], outputs[1]
|
|
if not isinstance(raw_preds, dict):
|
|
raise RuntimeError("Raw prediction payload is missing.")
|
|
|
|
one2one = raw_preds.get("one2one", raw_preds)
|
|
if not isinstance(one2one, dict):
|
|
raise RuntimeError("one2one prediction payload is missing.")
|
|
|
|
required_keys = ["preds_3d_selected", "preds_edge_selected"]
|
|
if include_anchor_stride:
|
|
required_keys += ["anchors_selected", "strides_selected"]
|
|
missing = [key for key in required_keys if one2one.get(key) is None]
|
|
if missing:
|
|
available = sorted(one2one.keys())
|
|
raise RuntimeError(f"Missing Detect3D export tensors {missing}. Available keys: {available}.")
|
|
|
|
result = (detections, one2one["preds_3d_selected"], one2one["preds_edge_selected"])
|
|
if include_anchor_stride:
|
|
result = result + (one2one["anchors_selected"], one2one["strides_selected"])
|
|
return result
|
|
|
|
|
|
def _denorm_detect3d_branch_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
head, head_inputs = _forward_model_to_head_inputs(model, x)
|
|
branch_inputs = [tensor.detach() for tensor in head_inputs] if getattr(head, "end2end", False) else head_inputs
|
|
branch = head.one2one if getattr(head, "end2end", False) else head.one2many
|
|
raw_preds = head.forward_head(branch_inputs, **branch)
|
|
if not isinstance(raw_preds, dict):
|
|
raise RuntimeError("Raw branch payload is missing.")
|
|
|
|
required_keys = ("boxes", "scores", "preds_3d", "preds_edge")
|
|
missing = [key for key in required_keys if raw_preds.get(key) is None]
|
|
if missing:
|
|
available = sorted(raw_preds.keys())
|
|
raise RuntimeError(f"Raw branch tensors {missing} are missing. Available keys: {available}.")
|
|
|
|
return raw_preds["boxes"], raw_preds["scores"], raw_preds["preds_3d"], raw_preds["preds_edge"]
|
|
|
|
|
|
def _collect_raw_detect3d_head_outputs(head: nn.Module, head_inputs: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
|
|
if not getattr(head, "end2end", False):
|
|
raise RuntimeError("Raw-head export currently expects an end2end Detect3D head.")
|
|
|
|
branch = head.one2one
|
|
box_head = branch.get("box_head")
|
|
cls_head = branch.get("cls_head")
|
|
head_3d = branch.get("head_3d")
|
|
edge_head = branch.get("edge_head")
|
|
if any(module is None for module in (box_head, cls_head, head_3d, edge_head)):
|
|
raise RuntimeError("Raw-head export could not find complete one2one head modules.")
|
|
|
|
bs = head_inputs[0].shape[0]
|
|
raw_boxes = torch.cat([box_head[i](head_inputs[i]).view(bs, 4 * head.reg_max, -1) for i in range(head.nl)], dim=-1)
|
|
raw_scores = torch.cat([cls_head[i](head_inputs[i]).view(bs, head.nc, -1) for i in range(head.nl)], dim=-1)
|
|
raw_3d = torch.cat([head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1)
|
|
raw_edge = torch.cat([edge_head[i](head_inputs[i]).view(bs, head.no_edge, -1) for i in range(head.nl)], dim=-1)
|
|
return raw_boxes, raw_scores, raw_3d, raw_edge
|
|
|
|
|
|
def _raw_detect3d_head_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
head, head_inputs = _forward_model_to_head_inputs(model, x)
|
|
return _collect_raw_detect3d_head_outputs(head, head_inputs)
|
|
|
|
|
|
def _hybrid_detect2d_raw3d_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
head, head_inputs = _forward_model_to_head_inputs(model, x)
|
|
raw_boxes, raw_scores, raw_3d, raw_edge = _collect_raw_detect3d_head_outputs(head, head_inputs)
|
|
preds = {"boxes": raw_boxes, "scores": raw_scores, "feats": head_inputs}
|
|
detections = head.postprocess(head._inference(preds).permute(0, 2, 1))
|
|
return detections, raw_boxes, raw_scores, raw_3d, raw_edge
|
|
|
|
|
|
def _detections_only_outputs(model: nn.Module, x: torch.Tensor) -> tuple[torch.Tensor]:
|
|
outputs = model(x)
|
|
if not isinstance(outputs, (tuple, list)) or len(outputs) < 1:
|
|
raise RuntimeError("Model forward output must contain detections for detections-only export.")
|
|
return (outputs[0],)
|
|
|
|
|
|
class SingleROIExportWrapper(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
include_anchor_stride: bool = False,
|
|
detections_only: bool = False,
|
|
hybrid_outputs: bool = False,
|
|
denorm_branch_outputs: bool = False,
|
|
postprocessed_outputs: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.model = model
|
|
self.include_anchor_stride = include_anchor_stride
|
|
self.detections_only = detections_only
|
|
self.hybrid_outputs = hybrid_outputs
|
|
self.denorm_branch_outputs = denorm_branch_outputs
|
|
self.postprocessed_outputs = postprocessed_outputs
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
if self.postprocessed_outputs:
|
|
return _flatten_detect3d_outputs(self.model(x), include_anchor_stride=self.include_anchor_stride)
|
|
if self.detections_only:
|
|
return _detections_only_outputs(self.model, x)
|
|
if self.hybrid_outputs:
|
|
return _hybrid_detect2d_raw3d_outputs(self.model, x)
|
|
if self.denorm_branch_outputs:
|
|
return _denorm_detect3d_branch_outputs(self.model, x)
|
|
return _raw_detect3d_head_outputs(self.model, x)
|
|
|
|
|
|
def _output_names(
|
|
include_anchor_stride: bool = False,
|
|
detections_only: bool = False,
|
|
hybrid_outputs: bool = False,
|
|
denorm_branch_outputs: bool = False,
|
|
postprocessed_outputs: bool = False,
|
|
) -> list[str]:
|
|
if postprocessed_outputs:
|
|
names = ["detections", "preds_3d", "preds_edge"]
|
|
if include_anchor_stride:
|
|
names += ["anchors", "strides"]
|
|
return names
|
|
if detections_only:
|
|
return ["detections"]
|
|
if hybrid_outputs:
|
|
return ["detections", "boxes_head_raw", "scores_head_raw", "preds_3d_head_raw", "preds_edge_head_raw"]
|
|
if denorm_branch_outputs:
|
|
return ["boxes_raw", "scores_raw", "preds_3d_raw", "preds_edge_raw"]
|
|
return ["boxes_head_raw", "scores_head_raw", "preds_3d_head_raw", "preds_edge_head_raw"]
|
|
|
|
|
|
def _describe_outputs(outputs: tuple[torch.Tensor, ...]) -> list[list[int]]:
|
|
return [list(output.shape) for output in outputs]
|
|
|
|
|
|
def _dynamic_axes(input_names: list[str], output_names: list[str], enabled: bool) -> dict[str, dict[int, str]] | None:
|
|
if not enabled:
|
|
return None
|
|
dynamic_axes = {name: {0: "batch"} for name in input_names}
|
|
dynamic_axes.update({name: {0: "batch"} for name in output_names})
|
|
return dynamic_axes
|
|
|
|
|
|
def _write_manifest(path: Path, payload: dict[str, Any]) -> None:
|
|
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
|
|
|
|
|
def _base_manifest(args: argparse.Namespace, output_names: list[str], output_shapes: list[list[int]]) -> dict[str, Any]:
|
|
return {
|
|
"artifact_name": "single_roi_model",
|
|
"model_path": str(Path(args.model_path).resolve()),
|
|
"input_names": ["images"],
|
|
"input_sizes_wh": {"images": [int(args.imgsz[0]), int(args.imgsz[1])]},
|
|
"output_names": output_names,
|
|
"output_shapes": output_shapes,
|
|
"dynamic_batch": bool(args.dynamic),
|
|
"max_det": int(args.max_det),
|
|
"opset": int(args.opset),
|
|
"fuse": not bool(args.no_fuse),
|
|
"include_anchor_stride": bool(args.include_anchor_stride),
|
|
"detections_only": bool(args.detections_only),
|
|
"hybrid_outputs": bool(args.hybrid_outputs),
|
|
"denorm_branch_outputs": bool(args.denorm_branch_outputs),
|
|
"postprocessed_outputs": bool(args.postprocessed_outputs),
|
|
"export_mode": _export_mode(args),
|
|
}
|
|
|
|
|
|
def _export_torchscript(model: nn.Module, example_input: torch.Tensor, save_path: Path, manifest: dict[str, Any]) -> None:
|
|
LOGGER.info(f"Exporting TorchScript to {save_path}")
|
|
traced = torch.jit.trace(model, example_input, strict=False)
|
|
traced.save(str(save_path), _extra_files={"config.txt": json.dumps(manifest, sort_keys=True)})
|
|
LOGGER.info(f"TorchScript export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
|
|
|
|
|
|
def _save_onnx_metadata(onnx_path: Path, manifest: dict[str, Any], simplify: bool) -> None:
|
|
import onnx
|
|
|
|
model_onnx = onnx.load(str(onnx_path))
|
|
for key, value in manifest.items():
|
|
meta = model_onnx.metadata_props.add()
|
|
meta.key, meta.value = key, json.dumps(value, ensure_ascii=True) if isinstance(value, (dict, list)) else str(value)
|
|
if simplify:
|
|
import onnxslim
|
|
|
|
LOGGER.info("Simplifying ONNX graph with onnxslim")
|
|
model_onnx = onnxslim.slim(model_onnx)
|
|
onnx.save(model_onnx, str(onnx_path))
|
|
|
|
|
|
def _export_onnx(
|
|
model: nn.Module,
|
|
example_input: torch.Tensor,
|
|
save_path: Path,
|
|
output_names: list[str],
|
|
manifest: dict[str, Any],
|
|
opset: int,
|
|
dynamic: bool,
|
|
simplify: bool,
|
|
) -> None:
|
|
LOGGER.info(f"Exporting ONNX to {save_path} with opset={opset}")
|
|
patch_args = SimpleNamespace(dynamic=dynamic, half=False, format="onnx")
|
|
with onnx_export_patch():
|
|
with arange_patch(patch_args):
|
|
torch.onnx.export(
|
|
model,
|
|
example_input,
|
|
str(save_path),
|
|
verbose=False,
|
|
opset_version=opset,
|
|
do_constant_folding=True,
|
|
input_names=["images"],
|
|
output_names=output_names,
|
|
dynamic_axes=_dynamic_axes(["images"], output_names, dynamic),
|
|
)
|
|
try:
|
|
_save_onnx_metadata(save_path, manifest, simplify=simplify)
|
|
except ImportError as error:
|
|
LOGGER.warning(f"Skipping ONNX metadata/simplify step because a package is missing: {error}")
|
|
except Exception as error:
|
|
LOGGER.warning(f"ONNX post-processing failed: {error}")
|
|
LOGGER.info(f"ONNX export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
if args.include_anchor_stride and not args.postprocessed_outputs:
|
|
LOGGER.warning("--include-anchor-stride is ignored unless --postprocessed-outputs is also set.")
|
|
LOGGER.info(f"Export mode: {_export_mode(args)}")
|
|
|
|
save_dir = Path(args.save_dir)
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model, _ = load_checkpoint(args.model_path, device="cpu", fuse=False)
|
|
prepared_model = _prepare_model_for_export(model, max_det=args.max_det, dynamic=args.dynamic, fuse=not args.no_fuse)
|
|
export_model = SingleROIExportWrapper(
|
|
prepared_model,
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
detections_only=args.detections_only,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
).eval()
|
|
|
|
example_input = _make_example_input((int(args.imgsz[0]), int(args.imgsz[1])))
|
|
with torch.inference_mode():
|
|
outputs = export_model(example_input)
|
|
|
|
output_names = _output_names(
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
detections_only=args.detections_only,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
)
|
|
LOGGER.info(f"Single ROI model output shapes: {_describe_outputs(outputs)}")
|
|
|
|
manifest = _base_manifest(args=args, output_names=output_names, output_shapes=_describe_outputs(outputs))
|
|
_write_manifest(save_dir / "single_roi_model.export.json", manifest)
|
|
|
|
if not args.skip_torchscript:
|
|
_export_torchscript(export_model, example_input, save_dir / "single_roi_model.torchscript", manifest)
|
|
if not args.skip_onnx:
|
|
_export_onnx(
|
|
export_model,
|
|
example_input,
|
|
save_dir / "single_roi_model.onnx",
|
|
output_names=output_names,
|
|
manifest=manifest,
|
|
opset=args.opset,
|
|
dynamic=args.dynamic,
|
|
simplify=args.simplify,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|