741 lines
27 KiB
Python
Executable File
741 lines
27 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[2]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT))
|
|
|
|
from ultralytics.nn.modules import Detect, RTDETRDecoder
|
|
from ultralytics.nn.tasks import load_checkpoint
|
|
from ultralytics.utils import LOGGER
|
|
from ultralytics.utils.patches import arange_patch, onnx_export_patch
|
|
|
|
|
|
DEFAULT_ROI0_MODEL = ROOT / "runs" / "detect" / "mono3d_roi0_20260506_epoch99.pt"
|
|
DEFAULT_ROI1_MODEL = ROOT / "runs" / "detect" / "mono3d_roi1_20260506_epoch99.pt"
|
|
DEFAULT_SAVE_DIR = ROOT / "runs" / "export" / "train_mono3d_two_roi_20260506"
|
|
DEFAULT_IMGSZ_WH = (768, 352)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Merge two yolo26 single-ROI Detect3D checkpoints and export them.")
|
|
parser.add_argument("--roi0-model-path", type=str, default=str(DEFAULT_ROI0_MODEL), help="Path to ROI0 checkpoint")
|
|
parser.add_argument("--roi1-model-path", type=str, default=str(DEFAULT_ROI1_MODEL), help="Path to ROI1 checkpoint")
|
|
parser.add_argument("--save-dir", type=str, default=str(DEFAULT_SAVE_DIR), help="Directory used to store exported files")
|
|
parser.add_argument(
|
|
"--imgsz",
|
|
nargs=2,
|
|
type=int,
|
|
default=list(DEFAULT_IMGSZ_WH),
|
|
metavar=("W", "H"),
|
|
help="Common example input size as width height, e.g. --imgsz 768 352",
|
|
)
|
|
parser.add_argument("--roi0-imgsz", nargs=2, type=int, default=None, metavar=("W", "H"), help="Optional ROI0 input size override")
|
|
parser.add_argument("--roi1-imgsz", nargs=2, type=int, default=None, metavar=("W", "H"), help="Optional ROI1 input size override")
|
|
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version")
|
|
parser.add_argument("--max-det", type=int, default=300, help="Top-k detections kept by each Detect3D head")
|
|
parser.add_argument("--dynamic", action="store_true", help="Export with dynamic batch dimension")
|
|
parser.add_argument("--simplify", action="store_true", help="Try to simplify the exported ONNX graph with onnxslim")
|
|
export_group = parser.add_mutually_exclusive_group()
|
|
export_group.add_argument(
|
|
"--hybrid-outputs",
|
|
action="store_true",
|
|
help="Export postprocessed 2D detections together with raw 2D/3D/edge head outputs.",
|
|
)
|
|
export_group.add_argument(
|
|
"--denorm-branch-outputs",
|
|
action="store_true",
|
|
help="Export one2one branch outputs after Detect3D branch denormalization, but before top-k postprocessing.",
|
|
)
|
|
export_group.add_argument(
|
|
"--postprocessed-outputs",
|
|
action="store_true",
|
|
help="Export postprocessed selected outputs instead of raw head tensors.",
|
|
)
|
|
parser.add_argument(
|
|
"--include-anchor-stride",
|
|
action="store_true",
|
|
help="Also export anchors and strides for each ROI. Only used with --postprocessed-outputs.",
|
|
)
|
|
parser.add_argument("--export-separate", action="store_true", help="Export ROI0 and ROI1 as separate single-input artifacts")
|
|
parser.add_argument("--skip-torchscript", action="store_true", help="Skip TorchScript export")
|
|
parser.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export")
|
|
parser.add_argument("--no-fuse", action="store_true", help="Disable Conv-BN fusion before export")
|
|
parser.add_argument(
|
|
"--edge-head-mode",
|
|
type=str,
|
|
choices=("keep", "drop"),
|
|
default="keep",
|
|
help="Control whether edge_head branch outputs are kept in exported artifacts.",
|
|
)
|
|
parser.add_argument(
|
|
"--fake-3d-branch-mode",
|
|
type=str,
|
|
choices=("keep", "drop"),
|
|
default="keep",
|
|
help="Control whether fake 3D branch outputs are kept in exported artifacts.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def _resolve_imgsz_wh(common_imgsz_wh: tuple[int, int], override: list[int] | tuple[int, int] | None) -> tuple[int, int]:
|
|
if override is None:
|
|
return int(common_imgsz_wh[0]), int(common_imgsz_wh[1])
|
|
return int(override[0]), int(override[1])
|
|
|
|
|
|
def _make_example_input(imgsz_wh: tuple[int, int]) -> torch.Tensor:
|
|
width, height = imgsz_wh
|
|
return torch.zeros(1, 3, height, width, dtype=torch.float32)
|
|
|
|
|
|
def _prepare_model_for_export(model: nn.Module, max_det: int, dynamic: bool, fuse: bool) -> nn.Module:
|
|
model = deepcopy(model).cpu()
|
|
for parameter in model.parameters():
|
|
parameter.requires_grad_(False)
|
|
model.eval()
|
|
model.float()
|
|
if fuse and hasattr(model, "fuse"):
|
|
model = model.fuse()
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, (Detect, RTDETRDecoder)):
|
|
if hasattr(module, "dynamic"):
|
|
module.dynamic = dynamic
|
|
if hasattr(module, "max_det"):
|
|
module.max_det = int(max_det)
|
|
if hasattr(module, "shape"):
|
|
module.shape = None
|
|
return model
|
|
|
|
|
|
def _export_mode(args: argparse.Namespace) -> str:
|
|
if args.postprocessed_outputs:
|
|
return "postprocessed_outputs"
|
|
if args.hybrid_outputs:
|
|
return "hybrid_outputs"
|
|
if args.denorm_branch_outputs:
|
|
return "denorm_branch_outputs"
|
|
return "raw_head_outputs"
|
|
|
|
|
|
def _forward_model_to_head_inputs(model: nn.Module, x: torch.Tensor) -> tuple[nn.Module, Any]:
|
|
y = []
|
|
for module in model.model[:-1]:
|
|
if module.f != -1:
|
|
x = y[module.f] if isinstance(module.f, int) else [x if j == -1 else y[j] for j in module.f]
|
|
x = module(x)
|
|
y.append(x if module.i in model.save else None)
|
|
|
|
head = model.model[-1]
|
|
if head.f != -1:
|
|
x = y[head.f] if isinstance(head.f, int) else [x if j == -1 else y[j] for j in head.f]
|
|
return head, x
|
|
|
|
|
|
def _flatten_detect3d_outputs(
|
|
outputs: Any,
|
|
roi_name: str,
|
|
include_anchor_stride: bool = False,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
if not isinstance(outputs, (tuple, list)) or len(outputs) < 2:
|
|
raise RuntimeError(f"{roi_name} forward output must be `(detections, raw_preds)` for Detect3D models.")
|
|
|
|
detections, raw_preds = outputs[0], outputs[1]
|
|
if not isinstance(raw_preds, dict):
|
|
raise RuntimeError(f"{roi_name} raw prediction payload is missing.")
|
|
|
|
one2one = raw_preds.get("one2one", raw_preds)
|
|
if not isinstance(one2one, dict):
|
|
raise RuntimeError(f"{roi_name} one2one prediction payload is missing.")
|
|
|
|
required_keys = ["preds_3d_selected"]
|
|
if keep_edge_head:
|
|
required_keys.append("preds_edge_selected")
|
|
if include_anchor_stride:
|
|
required_keys += ["anchors_selected", "strides_selected"]
|
|
missing = [key for key in required_keys if one2one.get(key) is None]
|
|
if missing:
|
|
available = sorted(one2one.keys())
|
|
raise RuntimeError(
|
|
f"{roi_name} is missing Detect3D export tensors {missing}. Available keys: {available}. "
|
|
"This script expects yolo26 end2end Detect3D checkpoints."
|
|
)
|
|
|
|
result = (
|
|
detections,
|
|
one2one["preds_3d_selected"],
|
|
one2one.get("preds_diff_selected"),
|
|
)
|
|
if keep_fake_3d_branch:
|
|
result = result + (one2one.get("preds_3d_fake_selected"),)
|
|
if keep_edge_head:
|
|
result = result + (one2one["preds_edge_selected"],)
|
|
if include_anchor_stride:
|
|
result = result + (one2one["anchors_selected"], one2one["strides_selected"])
|
|
return result
|
|
|
|
|
|
def _denorm_detect3d_branch_outputs(
|
|
model: nn.Module,
|
|
x: torch.Tensor,
|
|
roi_name: str,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
head, head_inputs = _forward_model_to_head_inputs(model, x)
|
|
if not hasattr(head, "forward_head"):
|
|
raise RuntimeError(f"{roi_name} final layer does not expose forward_head(), cannot export raw branch outputs.")
|
|
|
|
branch_inputs = [tensor.detach() for tensor in head_inputs] if getattr(head, "end2end", False) else head_inputs
|
|
branch = head.one2one if getattr(head, "end2end", False) else head.one2many
|
|
raw_preds = head.forward_head(branch_inputs, **branch)
|
|
if not isinstance(raw_preds, dict):
|
|
raise RuntimeError(f"{roi_name} raw branch payload is missing.")
|
|
|
|
required_keys = ["boxes", "scores", "preds_3d"]
|
|
if keep_edge_head:
|
|
required_keys.append("preds_edge")
|
|
missing = [key for key in required_keys if raw_preds.get(key) is None]
|
|
if missing:
|
|
available = sorted(raw_preds.keys())
|
|
raise RuntimeError(f"{roi_name} raw branch tensors {missing} are missing. Available keys: {available}.")
|
|
|
|
result = (
|
|
raw_preds["boxes"],
|
|
raw_preds["scores"],
|
|
raw_preds["preds_3d"],
|
|
raw_preds.get("preds_diff"),
|
|
)
|
|
if keep_fake_3d_branch:
|
|
result = result + (raw_preds.get("preds_3d_fake"),)
|
|
if keep_edge_head:
|
|
result = result + (raw_preds["preds_edge"],)
|
|
return result
|
|
|
|
|
|
def _collect_raw_detect3d_head_outputs(
|
|
head: nn.Module,
|
|
head_inputs: list[torch.Tensor],
|
|
roi_name: str,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
if not getattr(head, "end2end", False):
|
|
raise RuntimeError(f"{roi_name} raw-head export currently expects an end2end Detect3D head.")
|
|
|
|
branch = head.one2one
|
|
box_head = branch.get("box_head")
|
|
cls_head = branch.get("cls_head")
|
|
diff_head = branch.get("diff_head")
|
|
head_3d = branch.get("head_3d")
|
|
fake_head_3d = branch.get("fake_head_3d")
|
|
edge_head = branch.get("edge_head")
|
|
required_modules = (
|
|
box_head,
|
|
cls_head,
|
|
diff_head,
|
|
head_3d,
|
|
edge_head,
|
|
) if keep_edge_head else (
|
|
box_head,
|
|
cls_head,
|
|
diff_head,
|
|
head_3d,
|
|
)
|
|
if keep_fake_3d_branch:
|
|
required_modules = (*required_modules, fake_head_3d)
|
|
if any(module is None for module in required_modules):
|
|
raise RuntimeError(f"{roi_name} raw-head export could not find complete one2one head modules.")
|
|
|
|
bs = head_inputs[0].shape[0]
|
|
raw_boxes = torch.cat([box_head[i](head_inputs[i]).view(bs, 4 * head.reg_max, -1) for i in range(head.nl)], dim=-1)
|
|
raw_scores = torch.cat([cls_head[i](head_inputs[i]).view(bs, head.nc, -1) for i in range(head.nl)], dim=-1)
|
|
raw_diff = torch.cat([diff_head[i](head_inputs[i]).view(bs, 1, -1) for i in range(head.nl)], dim=-1)
|
|
raw_3d = torch.cat([head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1)
|
|
result = (raw_boxes, raw_scores, raw_3d, raw_diff)
|
|
if keep_fake_3d_branch:
|
|
raw_3d_fake = torch.cat([fake_head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1)
|
|
result = result + (raw_3d_fake,)
|
|
if keep_edge_head:
|
|
raw_edge = torch.cat([edge_head[i](head_inputs[i]).view(bs, head.no_edge, -1) for i in range(head.nl)], dim=-1)
|
|
result = result + (raw_edge,)
|
|
return result
|
|
|
|
|
|
def _raw_detect3d_head_outputs(
|
|
model: nn.Module,
|
|
x: torch.Tensor,
|
|
roi_name: str,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
head, head_inputs = _forward_model_to_head_inputs(model, x)
|
|
return _collect_raw_detect3d_head_outputs(
|
|
head,
|
|
head_inputs,
|
|
roi_name,
|
|
keep_edge_head=keep_edge_head,
|
|
keep_fake_3d_branch=keep_fake_3d_branch,
|
|
)
|
|
|
|
|
|
def _hybrid_detect2d_raw3d_outputs(
|
|
model: nn.Module,
|
|
x: torch.Tensor,
|
|
roi_name: str,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
head, head_inputs = _forward_model_to_head_inputs(model, x)
|
|
raw_outputs = _collect_raw_detect3d_head_outputs(
|
|
head,
|
|
head_inputs,
|
|
roi_name,
|
|
keep_edge_head=keep_edge_head,
|
|
keep_fake_3d_branch=keep_fake_3d_branch,
|
|
)
|
|
raw_boxes, raw_scores, *_ = raw_outputs
|
|
preds = {"boxes": raw_boxes, "scores": raw_scores, "feats": head_inputs}
|
|
detections = head.postprocess(head._inference(preds).permute(0, 2, 1))
|
|
return (detections, *raw_outputs)
|
|
|
|
|
|
class ROISelectedOutputsWrapper(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
roi_name: str,
|
|
include_anchor_stride: bool = False,
|
|
hybrid_outputs: bool = False,
|
|
denorm_branch_outputs: bool = False,
|
|
postprocessed_outputs: bool = False,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.model = model
|
|
self.roi_name = roi_name
|
|
self.include_anchor_stride = include_anchor_stride
|
|
self.hybrid_outputs = hybrid_outputs
|
|
self.denorm_branch_outputs = denorm_branch_outputs
|
|
self.postprocessed_outputs = postprocessed_outputs
|
|
self.keep_edge_head = keep_edge_head
|
|
self.keep_fake_3d_branch = keep_fake_3d_branch
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
if self.postprocessed_outputs:
|
|
return _flatten_detect3d_outputs(
|
|
self.model(x),
|
|
self.roi_name,
|
|
include_anchor_stride=self.include_anchor_stride,
|
|
keep_edge_head=self.keep_edge_head,
|
|
keep_fake_3d_branch=self.keep_fake_3d_branch,
|
|
)
|
|
if self.hybrid_outputs:
|
|
return _hybrid_detect2d_raw3d_outputs(
|
|
self.model,
|
|
x,
|
|
self.roi_name,
|
|
keep_edge_head=self.keep_edge_head,
|
|
keep_fake_3d_branch=self.keep_fake_3d_branch,
|
|
)
|
|
if self.denorm_branch_outputs:
|
|
return _denorm_detect3d_branch_outputs(
|
|
self.model,
|
|
x,
|
|
self.roi_name,
|
|
keep_edge_head=self.keep_edge_head,
|
|
keep_fake_3d_branch=self.keep_fake_3d_branch,
|
|
)
|
|
return _raw_detect3d_head_outputs(
|
|
self.model,
|
|
x,
|
|
self.roi_name,
|
|
keep_edge_head=self.keep_edge_head,
|
|
keep_fake_3d_branch=self.keep_fake_3d_branch,
|
|
)
|
|
|
|
|
|
class TwoROIMergedExportModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
roi0_model: nn.Module,
|
|
roi1_model: nn.Module,
|
|
include_anchor_stride: bool = False,
|
|
hybrid_outputs: bool = False,
|
|
denorm_branch_outputs: bool = False,
|
|
postprocessed_outputs: bool = False,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.roi0_model = ROISelectedOutputsWrapper(
|
|
roi0_model,
|
|
"roi0",
|
|
include_anchor_stride=include_anchor_stride,
|
|
hybrid_outputs=hybrid_outputs,
|
|
denorm_branch_outputs=denorm_branch_outputs,
|
|
postprocessed_outputs=postprocessed_outputs,
|
|
keep_edge_head=keep_edge_head,
|
|
keep_fake_3d_branch=keep_fake_3d_branch,
|
|
)
|
|
self.roi1_model = ROISelectedOutputsWrapper(
|
|
roi1_model,
|
|
"roi1",
|
|
include_anchor_stride=include_anchor_stride,
|
|
hybrid_outputs=hybrid_outputs,
|
|
denorm_branch_outputs=denorm_branch_outputs,
|
|
postprocessed_outputs=postprocessed_outputs,
|
|
keep_edge_head=keep_edge_head,
|
|
keep_fake_3d_branch=keep_fake_3d_branch,
|
|
)
|
|
|
|
self.stride = getattr(roi0_model, "stride", None)
|
|
self.names = getattr(roi0_model, "names", None)
|
|
|
|
def forward(self, x_roi0: torch.Tensor, x_roi1: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
return (*self.roi0_model(x_roi0), *self.roi1_model(x_roi1))
|
|
|
|
|
|
def _output_names(
|
|
prefix: str,
|
|
include_anchor_stride: bool = False,
|
|
hybrid_outputs: bool = False,
|
|
denorm_branch_outputs: bool = False,
|
|
postprocessed_outputs: bool = False,
|
|
keep_edge_head: bool = True,
|
|
keep_fake_3d_branch: bool = True,
|
|
) -> list[str]:
|
|
if postprocessed_outputs:
|
|
names = [
|
|
f"{prefix}_detections",
|
|
f"{prefix}_preds_3d",
|
|
f"{prefix}_preds_diff",
|
|
]
|
|
if keep_fake_3d_branch:
|
|
names.append(f"{prefix}_preds_3d_fake")
|
|
if keep_edge_head:
|
|
names.append(f"{prefix}_preds_edge")
|
|
if include_anchor_stride:
|
|
names += [f"{prefix}_anchors", f"{prefix}_strides"]
|
|
return names
|
|
|
|
if hybrid_outputs:
|
|
names = [
|
|
f"{prefix}_detections",
|
|
f"{prefix}_boxes_head_raw",
|
|
f"{prefix}_scores_head_raw",
|
|
f"{prefix}_preds_3d_head_raw",
|
|
f"{prefix}_preds_diff_head_raw",
|
|
]
|
|
if keep_fake_3d_branch:
|
|
names.append(f"{prefix}_preds_3d_fake_head_raw")
|
|
if keep_edge_head:
|
|
names.append(f"{prefix}_preds_edge_head_raw")
|
|
return names
|
|
|
|
if denorm_branch_outputs:
|
|
names = [
|
|
f"{prefix}_boxes_raw",
|
|
f"{prefix}_scores_raw",
|
|
f"{prefix}_preds_3d_raw",
|
|
f"{prefix}_preds_diff_raw",
|
|
]
|
|
if keep_fake_3d_branch:
|
|
names.append(f"{prefix}_preds_3d_fake_raw")
|
|
if keep_edge_head:
|
|
names.append(f"{prefix}_preds_edge_raw")
|
|
return names
|
|
|
|
names = [
|
|
f"{prefix}_boxes_head_raw",
|
|
f"{prefix}_scores_head_raw",
|
|
f"{prefix}_preds_3d_head_raw",
|
|
f"{prefix}_preds_diff_head_raw",
|
|
]
|
|
if keep_fake_3d_branch:
|
|
names.append(f"{prefix}_preds_3d_fake_head_raw")
|
|
if keep_edge_head:
|
|
names.append(f"{prefix}_preds_edge_head_raw")
|
|
return names
|
|
|
|
|
|
def _describe_outputs(outputs: tuple[torch.Tensor, ...]) -> list[list[int]]:
|
|
return [list(output.shape) for output in outputs]
|
|
|
|
|
|
def _dynamic_axes(input_names: list[str], output_names: list[str], enabled: bool) -> dict[str, dict[int, str]] | None:
|
|
if not enabled:
|
|
return None
|
|
dynamic_axes = {name: {0: "batch"} for name in input_names}
|
|
dynamic_axes.update({name: {0: "batch"} for name in output_names})
|
|
return dynamic_axes
|
|
|
|
|
|
def _write_manifest(path: Path, payload: dict[str, Any]) -> None:
|
|
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
|
|
|
|
|
def _base_manifest(
|
|
artifact_name: str,
|
|
args: argparse.Namespace,
|
|
input_names: list[str],
|
|
input_sizes_wh: dict[str, list[int]],
|
|
output_names: list[str],
|
|
output_shapes: list[list[int]],
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"artifact_name": artifact_name,
|
|
"roi0_model_path": str(Path(args.roi0_model_path).resolve()),
|
|
"roi1_model_path": str(Path(args.roi1_model_path).resolve()),
|
|
"input_names": input_names,
|
|
"input_sizes_wh": input_sizes_wh,
|
|
"output_names": output_names,
|
|
"output_shapes": output_shapes,
|
|
"dynamic_batch": bool(args.dynamic),
|
|
"max_det": int(args.max_det),
|
|
"opset": int(args.opset),
|
|
"fuse": not bool(args.no_fuse),
|
|
"include_anchor_stride": bool(args.include_anchor_stride),
|
|
"hybrid_outputs": bool(args.hybrid_outputs),
|
|
"denorm_branch_outputs": bool(args.denorm_branch_outputs),
|
|
"postprocessed_outputs": bool(args.postprocessed_outputs),
|
|
"edge_head_mode": args.edge_head_mode,
|
|
"keep_edge_head": args.edge_head_mode == "keep",
|
|
"fake_3d_branch_mode": args.fake_3d_branch_mode,
|
|
"keep_fake_3d_branch": args.fake_3d_branch_mode == "keep",
|
|
"export_mode": _export_mode(args),
|
|
}
|
|
|
|
|
|
def _export_torchscript(
|
|
model: nn.Module,
|
|
example_inputs: torch.Tensor | tuple[torch.Tensor, ...],
|
|
save_path: Path,
|
|
manifest: dict[str, Any],
|
|
) -> None:
|
|
LOGGER.info(f"Exporting TorchScript to {save_path}")
|
|
traced = torch.jit.trace(model, example_inputs, strict=False)
|
|
extra_files = {"config.txt": json.dumps(manifest, sort_keys=True)}
|
|
traced.save(str(save_path), _extra_files=extra_files)
|
|
LOGGER.info(f"TorchScript export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
|
|
|
|
|
|
def _save_onnx_metadata(onnx_path: Path, manifest: dict[str, Any], simplify: bool) -> None:
|
|
import onnx
|
|
|
|
model_onnx = onnx.load(str(onnx_path))
|
|
for key, value in manifest.items():
|
|
meta = model_onnx.metadata_props.add()
|
|
meta.key, meta.value = key, json.dumps(value, ensure_ascii=True) if isinstance(value, (dict, list)) else str(value)
|
|
|
|
if simplify:
|
|
import onnxslim
|
|
|
|
LOGGER.info("Simplifying ONNX graph with onnxslim")
|
|
model_onnx = onnxslim.slim(model_onnx)
|
|
|
|
onnx.save(model_onnx, str(onnx_path))
|
|
|
|
|
|
def _export_onnx(
|
|
model: nn.Module,
|
|
example_inputs: torch.Tensor | tuple[torch.Tensor, ...],
|
|
save_path: Path,
|
|
input_names: list[str],
|
|
output_names: list[str],
|
|
manifest: dict[str, Any],
|
|
opset: int,
|
|
dynamic: bool,
|
|
simplify: bool,
|
|
) -> None:
|
|
LOGGER.info(f"Exporting ONNX to {save_path} with opset={opset}")
|
|
patch_args = SimpleNamespace(dynamic=dynamic, half=False, format="onnx")
|
|
with onnx_export_patch():
|
|
with arange_patch(patch_args):
|
|
torch.onnx.export(
|
|
model,
|
|
example_inputs,
|
|
str(save_path),
|
|
verbose=False,
|
|
opset_version=opset,
|
|
do_constant_folding=True,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=_dynamic_axes(input_names, output_names, dynamic),
|
|
)
|
|
|
|
try:
|
|
_save_onnx_metadata(save_path, manifest, simplify=simplify)
|
|
except ImportError as error:
|
|
LOGGER.warning(f"Skipping ONNX metadata/simplify step because a package is missing: {error}")
|
|
except Exception as error:
|
|
LOGGER.warning(f"ONNX post-processing failed: {error}")
|
|
|
|
LOGGER.info(f"ONNX export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
|
|
|
|
|
|
def _load_and_prepare_wrapper(weights_path: str, roi_name: str, args: argparse.Namespace) -> ROISelectedOutputsWrapper:
|
|
LOGGER.info(f"Loading {roi_name} checkpoint from {weights_path}")
|
|
model, _ = load_checkpoint(weights_path, device="cpu", fuse=False)
|
|
prepared_model = _prepare_model_for_export(model, max_det=args.max_det, dynamic=args.dynamic, fuse=not args.no_fuse)
|
|
return ROISelectedOutputsWrapper(
|
|
prepared_model,
|
|
roi_name,
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
keep_edge_head=args.edge_head_mode == "keep",
|
|
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
|
|
)
|
|
|
|
|
|
def _export_single_roi_artifacts(
|
|
wrapper: ROISelectedOutputsWrapper,
|
|
example_input: torch.Tensor,
|
|
save_dir: Path,
|
|
base_name: str,
|
|
args: argparse.Namespace,
|
|
input_size_wh: tuple[int, int],
|
|
) -> None:
|
|
outputs = wrapper(example_input)
|
|
output_names = _output_names(
|
|
base_name,
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
keep_edge_head=args.edge_head_mode == "keep",
|
|
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
|
|
)
|
|
manifest = _base_manifest(
|
|
artifact_name=base_name,
|
|
args=args,
|
|
input_names=[f"{base_name}_input"],
|
|
input_sizes_wh={f"{base_name}_input": list(input_size_wh)},
|
|
output_names=output_names,
|
|
output_shapes=_describe_outputs(outputs),
|
|
)
|
|
_write_manifest(save_dir / f"{base_name}.export.json", manifest)
|
|
|
|
if not args.skip_torchscript:
|
|
_export_torchscript(wrapper, example_input, save_dir / f"{base_name}.torchscript", manifest)
|
|
if not args.skip_onnx:
|
|
_export_onnx(
|
|
wrapper,
|
|
example_input,
|
|
save_dir / f"{base_name}.onnx",
|
|
input_names=[f"{base_name}_input"],
|
|
output_names=output_names,
|
|
manifest=manifest,
|
|
opset=args.opset,
|
|
dynamic=args.dynamic,
|
|
simplify=args.simplify,
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
if args.include_anchor_stride and not args.postprocessed_outputs:
|
|
LOGGER.warning("--include-anchor-stride is ignored unless --postprocessed-outputs is also set.")
|
|
LOGGER.info(f"Export mode: {_export_mode(args)}")
|
|
save_dir = Path(args.save_dir)
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
common_imgsz_wh = (int(args.imgsz[0]), int(args.imgsz[1]))
|
|
roi0_imgsz_wh = _resolve_imgsz_wh(common_imgsz_wh, args.roi0_imgsz)
|
|
roi1_imgsz_wh = _resolve_imgsz_wh(common_imgsz_wh, args.roi1_imgsz)
|
|
|
|
roi0_wrapper = _load_and_prepare_wrapper(args.roi0_model_path, "roi0", args)
|
|
roi1_wrapper = _load_and_prepare_wrapper(args.roi1_model_path, "roi1", args)
|
|
|
|
roi0_example_input = _make_example_input(roi0_imgsz_wh)
|
|
roi1_example_input = _make_example_input(roi1_imgsz_wh)
|
|
|
|
if args.export_separate:
|
|
LOGGER.info("Exporting ROI0 and ROI1 as separate artifacts")
|
|
with torch.inference_mode():
|
|
_export_single_roi_artifacts(roi0_wrapper, roi0_example_input, save_dir, "roi0_model", args, roi0_imgsz_wh)
|
|
_export_single_roi_artifacts(roi1_wrapper, roi1_example_input, save_dir, "roi1_model", args, roi1_imgsz_wh)
|
|
return
|
|
|
|
merged_model = TwoROIMergedExportModel(
|
|
roi0_wrapper.model,
|
|
roi1_wrapper.model,
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
keep_edge_head=args.edge_head_mode == "keep",
|
|
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
|
|
).eval()
|
|
example_inputs = (roi0_example_input, roi1_example_input)
|
|
|
|
with torch.inference_mode():
|
|
outputs = merged_model(*example_inputs)
|
|
output_names = _output_names(
|
|
"roi0",
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
keep_edge_head=args.edge_head_mode == "keep",
|
|
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
|
|
) + _output_names(
|
|
"roi1",
|
|
include_anchor_stride=args.include_anchor_stride,
|
|
hybrid_outputs=args.hybrid_outputs,
|
|
denorm_branch_outputs=args.denorm_branch_outputs,
|
|
postprocessed_outputs=args.postprocessed_outputs,
|
|
keep_edge_head=args.edge_head_mode == "keep",
|
|
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
|
|
)
|
|
LOGGER.info(f"Merged model output shapes: {_describe_outputs(outputs)}")
|
|
|
|
manifest = _base_manifest(
|
|
artifact_name="merged_model",
|
|
args=args,
|
|
input_names=["roi0_input", "roi1_input"],
|
|
input_sizes_wh={
|
|
"roi0_input": list(roi0_imgsz_wh),
|
|
"roi1_input": list(roi1_imgsz_wh),
|
|
},
|
|
output_names=output_names,
|
|
output_shapes=_describe_outputs(outputs),
|
|
)
|
|
_write_manifest(save_dir / "merged_model.export.json", manifest)
|
|
|
|
if not args.skip_torchscript:
|
|
_export_torchscript(merged_model, example_inputs, save_dir / "merged_model.torchscript", manifest)
|
|
if not args.skip_onnx:
|
|
_export_onnx(
|
|
merged_model,
|
|
example_inputs,
|
|
save_dir / "merged_model.onnx",
|
|
input_names=["roi0_input", "roi1_input"],
|
|
output_names=output_names,
|
|
manifest=manifest,
|
|
opset=args.opset,
|
|
dynamic=args.dynamic,
|
|
simplify=args.simplify,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|