Files
yolov26_3d/tools/model_merging/merge_models_of_2roi_yolo26.py
2026-06-24 09:35:46 +08:00

741 lines
27 KiB
Python
Executable File

from __future__ import annotations
import argparse
import json
import sys
from copy import deepcopy
from pathlib import Path
from types import SimpleNamespace
from typing import Any
import torch
import torch.nn as nn
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from ultralytics.nn.modules import Detect, RTDETRDecoder
from ultralytics.nn.tasks import load_checkpoint
from ultralytics.utils import LOGGER
from ultralytics.utils.patches import arange_patch, onnx_export_patch
DEFAULT_ROI0_MODEL = ROOT / "runs" / "detect" / "mono3d_roi0_20260506_epoch99.pt"
DEFAULT_ROI1_MODEL = ROOT / "runs" / "detect" / "mono3d_roi1_20260506_epoch99.pt"
DEFAULT_SAVE_DIR = ROOT / "runs" / "export" / "train_mono3d_two_roi_20260506"
DEFAULT_IMGSZ_WH = (768, 352)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Merge two yolo26 single-ROI Detect3D checkpoints and export them.")
parser.add_argument("--roi0-model-path", type=str, default=str(DEFAULT_ROI0_MODEL), help="Path to ROI0 checkpoint")
parser.add_argument("--roi1-model-path", type=str, default=str(DEFAULT_ROI1_MODEL), help="Path to ROI1 checkpoint")
parser.add_argument("--save-dir", type=str, default=str(DEFAULT_SAVE_DIR), help="Directory used to store exported files")
parser.add_argument(
"--imgsz",
nargs=2,
type=int,
default=list(DEFAULT_IMGSZ_WH),
metavar=("W", "H"),
help="Common example input size as width height, e.g. --imgsz 768 352",
)
parser.add_argument("--roi0-imgsz", nargs=2, type=int, default=None, metavar=("W", "H"), help="Optional ROI0 input size override")
parser.add_argument("--roi1-imgsz", nargs=2, type=int, default=None, metavar=("W", "H"), help="Optional ROI1 input size override")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version")
parser.add_argument("--max-det", type=int, default=300, help="Top-k detections kept by each Detect3D head")
parser.add_argument("--dynamic", action="store_true", help="Export with dynamic batch dimension")
parser.add_argument("--simplify", action="store_true", help="Try to simplify the exported ONNX graph with onnxslim")
export_group = parser.add_mutually_exclusive_group()
export_group.add_argument(
"--hybrid-outputs",
action="store_true",
help="Export postprocessed 2D detections together with raw 2D/3D/edge head outputs.",
)
export_group.add_argument(
"--denorm-branch-outputs",
action="store_true",
help="Export one2one branch outputs after Detect3D branch denormalization, but before top-k postprocessing.",
)
export_group.add_argument(
"--postprocessed-outputs",
action="store_true",
help="Export postprocessed selected outputs instead of raw head tensors.",
)
parser.add_argument(
"--include-anchor-stride",
action="store_true",
help="Also export anchors and strides for each ROI. Only used with --postprocessed-outputs.",
)
parser.add_argument("--export-separate", action="store_true", help="Export ROI0 and ROI1 as separate single-input artifacts")
parser.add_argument("--skip-torchscript", action="store_true", help="Skip TorchScript export")
parser.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export")
parser.add_argument("--no-fuse", action="store_true", help="Disable Conv-BN fusion before export")
parser.add_argument(
"--edge-head-mode",
type=str,
choices=("keep", "drop"),
default="keep",
help="Control whether edge_head branch outputs are kept in exported artifacts.",
)
parser.add_argument(
"--fake-3d-branch-mode",
type=str,
choices=("keep", "drop"),
default="keep",
help="Control whether fake 3D branch outputs are kept in exported artifacts.",
)
return parser.parse_args()
def _resolve_imgsz_wh(common_imgsz_wh: tuple[int, int], override: list[int] | tuple[int, int] | None) -> tuple[int, int]:
if override is None:
return int(common_imgsz_wh[0]), int(common_imgsz_wh[1])
return int(override[0]), int(override[1])
def _make_example_input(imgsz_wh: tuple[int, int]) -> torch.Tensor:
width, height = imgsz_wh
return torch.zeros(1, 3, height, width, dtype=torch.float32)
def _prepare_model_for_export(model: nn.Module, max_det: int, dynamic: bool, fuse: bool) -> nn.Module:
model = deepcopy(model).cpu()
for parameter in model.parameters():
parameter.requires_grad_(False)
model.eval()
model.float()
if fuse and hasattr(model, "fuse"):
model = model.fuse()
for module in model.modules():
if isinstance(module, (Detect, RTDETRDecoder)):
if hasattr(module, "dynamic"):
module.dynamic = dynamic
if hasattr(module, "max_det"):
module.max_det = int(max_det)
if hasattr(module, "shape"):
module.shape = None
return model
def _export_mode(args: argparse.Namespace) -> str:
if args.postprocessed_outputs:
return "postprocessed_outputs"
if args.hybrid_outputs:
return "hybrid_outputs"
if args.denorm_branch_outputs:
return "denorm_branch_outputs"
return "raw_head_outputs"
def _forward_model_to_head_inputs(model: nn.Module, x: torch.Tensor) -> tuple[nn.Module, Any]:
y = []
for module in model.model[:-1]:
if module.f != -1:
x = y[module.f] if isinstance(module.f, int) else [x if j == -1 else y[j] for j in module.f]
x = module(x)
y.append(x if module.i in model.save else None)
head = model.model[-1]
if head.f != -1:
x = y[head.f] if isinstance(head.f, int) else [x if j == -1 else y[j] for j in head.f]
return head, x
def _flatten_detect3d_outputs(
outputs: Any,
roi_name: str,
include_anchor_stride: bool = False,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
) -> tuple[torch.Tensor, ...]:
if not isinstance(outputs, (tuple, list)) or len(outputs) < 2:
raise RuntimeError(f"{roi_name} forward output must be `(detections, raw_preds)` for Detect3D models.")
detections, raw_preds = outputs[0], outputs[1]
if not isinstance(raw_preds, dict):
raise RuntimeError(f"{roi_name} raw prediction payload is missing.")
one2one = raw_preds.get("one2one", raw_preds)
if not isinstance(one2one, dict):
raise RuntimeError(f"{roi_name} one2one prediction payload is missing.")
required_keys = ["preds_3d_selected"]
if keep_edge_head:
required_keys.append("preds_edge_selected")
if include_anchor_stride:
required_keys += ["anchors_selected", "strides_selected"]
missing = [key for key in required_keys if one2one.get(key) is None]
if missing:
available = sorted(one2one.keys())
raise RuntimeError(
f"{roi_name} is missing Detect3D export tensors {missing}. Available keys: {available}. "
"This script expects yolo26 end2end Detect3D checkpoints."
)
result = (
detections,
one2one["preds_3d_selected"],
one2one.get("preds_diff_selected"),
)
if keep_fake_3d_branch:
result = result + (one2one.get("preds_3d_fake_selected"),)
if keep_edge_head:
result = result + (one2one["preds_edge_selected"],)
if include_anchor_stride:
result = result + (one2one["anchors_selected"], one2one["strides_selected"])
return result
def _denorm_detect3d_branch_outputs(
model: nn.Module,
x: torch.Tensor,
roi_name: str,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
) -> tuple[torch.Tensor, ...]:
head, head_inputs = _forward_model_to_head_inputs(model, x)
if not hasattr(head, "forward_head"):
raise RuntimeError(f"{roi_name} final layer does not expose forward_head(), cannot export raw branch outputs.")
branch_inputs = [tensor.detach() for tensor in head_inputs] if getattr(head, "end2end", False) else head_inputs
branch = head.one2one if getattr(head, "end2end", False) else head.one2many
raw_preds = head.forward_head(branch_inputs, **branch)
if not isinstance(raw_preds, dict):
raise RuntimeError(f"{roi_name} raw branch payload is missing.")
required_keys = ["boxes", "scores", "preds_3d"]
if keep_edge_head:
required_keys.append("preds_edge")
missing = [key for key in required_keys if raw_preds.get(key) is None]
if missing:
available = sorted(raw_preds.keys())
raise RuntimeError(f"{roi_name} raw branch tensors {missing} are missing. Available keys: {available}.")
result = (
raw_preds["boxes"],
raw_preds["scores"],
raw_preds["preds_3d"],
raw_preds.get("preds_diff"),
)
if keep_fake_3d_branch:
result = result + (raw_preds.get("preds_3d_fake"),)
if keep_edge_head:
result = result + (raw_preds["preds_edge"],)
return result
def _collect_raw_detect3d_head_outputs(
head: nn.Module,
head_inputs: list[torch.Tensor],
roi_name: str,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
) -> tuple[torch.Tensor, ...]:
if not getattr(head, "end2end", False):
raise RuntimeError(f"{roi_name} raw-head export currently expects an end2end Detect3D head.")
branch = head.one2one
box_head = branch.get("box_head")
cls_head = branch.get("cls_head")
diff_head = branch.get("diff_head")
head_3d = branch.get("head_3d")
fake_head_3d = branch.get("fake_head_3d")
edge_head = branch.get("edge_head")
required_modules = (
box_head,
cls_head,
diff_head,
head_3d,
edge_head,
) if keep_edge_head else (
box_head,
cls_head,
diff_head,
head_3d,
)
if keep_fake_3d_branch:
required_modules = (*required_modules, fake_head_3d)
if any(module is None for module in required_modules):
raise RuntimeError(f"{roi_name} raw-head export could not find complete one2one head modules.")
bs = head_inputs[0].shape[0]
raw_boxes = torch.cat([box_head[i](head_inputs[i]).view(bs, 4 * head.reg_max, -1) for i in range(head.nl)], dim=-1)
raw_scores = torch.cat([cls_head[i](head_inputs[i]).view(bs, head.nc, -1) for i in range(head.nl)], dim=-1)
raw_diff = torch.cat([diff_head[i](head_inputs[i]).view(bs, 1, -1) for i in range(head.nl)], dim=-1)
raw_3d = torch.cat([head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1)
result = (raw_boxes, raw_scores, raw_3d, raw_diff)
if keep_fake_3d_branch:
raw_3d_fake = torch.cat([fake_head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1)
result = result + (raw_3d_fake,)
if keep_edge_head:
raw_edge = torch.cat([edge_head[i](head_inputs[i]).view(bs, head.no_edge, -1) for i in range(head.nl)], dim=-1)
result = result + (raw_edge,)
return result
def _raw_detect3d_head_outputs(
model: nn.Module,
x: torch.Tensor,
roi_name: str,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
) -> tuple[torch.Tensor, ...]:
head, head_inputs = _forward_model_to_head_inputs(model, x)
return _collect_raw_detect3d_head_outputs(
head,
head_inputs,
roi_name,
keep_edge_head=keep_edge_head,
keep_fake_3d_branch=keep_fake_3d_branch,
)
def _hybrid_detect2d_raw3d_outputs(
model: nn.Module,
x: torch.Tensor,
roi_name: str,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
) -> tuple[torch.Tensor, ...]:
head, head_inputs = _forward_model_to_head_inputs(model, x)
raw_outputs = _collect_raw_detect3d_head_outputs(
head,
head_inputs,
roi_name,
keep_edge_head=keep_edge_head,
keep_fake_3d_branch=keep_fake_3d_branch,
)
raw_boxes, raw_scores, *_ = raw_outputs
preds = {"boxes": raw_boxes, "scores": raw_scores, "feats": head_inputs}
detections = head.postprocess(head._inference(preds).permute(0, 2, 1))
return (detections, *raw_outputs)
class ROISelectedOutputsWrapper(nn.Module):
def __init__(
self,
model: nn.Module,
roi_name: str,
include_anchor_stride: bool = False,
hybrid_outputs: bool = False,
denorm_branch_outputs: bool = False,
postprocessed_outputs: bool = False,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
):
super().__init__()
self.model = model
self.roi_name = roi_name
self.include_anchor_stride = include_anchor_stride
self.hybrid_outputs = hybrid_outputs
self.denorm_branch_outputs = denorm_branch_outputs
self.postprocessed_outputs = postprocessed_outputs
self.keep_edge_head = keep_edge_head
self.keep_fake_3d_branch = keep_fake_3d_branch
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]:
if self.postprocessed_outputs:
return _flatten_detect3d_outputs(
self.model(x),
self.roi_name,
include_anchor_stride=self.include_anchor_stride,
keep_edge_head=self.keep_edge_head,
keep_fake_3d_branch=self.keep_fake_3d_branch,
)
if self.hybrid_outputs:
return _hybrid_detect2d_raw3d_outputs(
self.model,
x,
self.roi_name,
keep_edge_head=self.keep_edge_head,
keep_fake_3d_branch=self.keep_fake_3d_branch,
)
if self.denorm_branch_outputs:
return _denorm_detect3d_branch_outputs(
self.model,
x,
self.roi_name,
keep_edge_head=self.keep_edge_head,
keep_fake_3d_branch=self.keep_fake_3d_branch,
)
return _raw_detect3d_head_outputs(
self.model,
x,
self.roi_name,
keep_edge_head=self.keep_edge_head,
keep_fake_3d_branch=self.keep_fake_3d_branch,
)
class TwoROIMergedExportModel(nn.Module):
def __init__(
self,
roi0_model: nn.Module,
roi1_model: nn.Module,
include_anchor_stride: bool = False,
hybrid_outputs: bool = False,
denorm_branch_outputs: bool = False,
postprocessed_outputs: bool = False,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
):
super().__init__()
self.roi0_model = ROISelectedOutputsWrapper(
roi0_model,
"roi0",
include_anchor_stride=include_anchor_stride,
hybrid_outputs=hybrid_outputs,
denorm_branch_outputs=denorm_branch_outputs,
postprocessed_outputs=postprocessed_outputs,
keep_edge_head=keep_edge_head,
keep_fake_3d_branch=keep_fake_3d_branch,
)
self.roi1_model = ROISelectedOutputsWrapper(
roi1_model,
"roi1",
include_anchor_stride=include_anchor_stride,
hybrid_outputs=hybrid_outputs,
denorm_branch_outputs=denorm_branch_outputs,
postprocessed_outputs=postprocessed_outputs,
keep_edge_head=keep_edge_head,
keep_fake_3d_branch=keep_fake_3d_branch,
)
self.stride = getattr(roi0_model, "stride", None)
self.names = getattr(roi0_model, "names", None)
def forward(self, x_roi0: torch.Tensor, x_roi1: torch.Tensor) -> tuple[torch.Tensor, ...]:
return (*self.roi0_model(x_roi0), *self.roi1_model(x_roi1))
def _output_names(
prefix: str,
include_anchor_stride: bool = False,
hybrid_outputs: bool = False,
denorm_branch_outputs: bool = False,
postprocessed_outputs: bool = False,
keep_edge_head: bool = True,
keep_fake_3d_branch: bool = True,
) -> list[str]:
if postprocessed_outputs:
names = [
f"{prefix}_detections",
f"{prefix}_preds_3d",
f"{prefix}_preds_diff",
]
if keep_fake_3d_branch:
names.append(f"{prefix}_preds_3d_fake")
if keep_edge_head:
names.append(f"{prefix}_preds_edge")
if include_anchor_stride:
names += [f"{prefix}_anchors", f"{prefix}_strides"]
return names
if hybrid_outputs:
names = [
f"{prefix}_detections",
f"{prefix}_boxes_head_raw",
f"{prefix}_scores_head_raw",
f"{prefix}_preds_3d_head_raw",
f"{prefix}_preds_diff_head_raw",
]
if keep_fake_3d_branch:
names.append(f"{prefix}_preds_3d_fake_head_raw")
if keep_edge_head:
names.append(f"{prefix}_preds_edge_head_raw")
return names
if denorm_branch_outputs:
names = [
f"{prefix}_boxes_raw",
f"{prefix}_scores_raw",
f"{prefix}_preds_3d_raw",
f"{prefix}_preds_diff_raw",
]
if keep_fake_3d_branch:
names.append(f"{prefix}_preds_3d_fake_raw")
if keep_edge_head:
names.append(f"{prefix}_preds_edge_raw")
return names
names = [
f"{prefix}_boxes_head_raw",
f"{prefix}_scores_head_raw",
f"{prefix}_preds_3d_head_raw",
f"{prefix}_preds_diff_head_raw",
]
if keep_fake_3d_branch:
names.append(f"{prefix}_preds_3d_fake_head_raw")
if keep_edge_head:
names.append(f"{prefix}_preds_edge_head_raw")
return names
def _describe_outputs(outputs: tuple[torch.Tensor, ...]) -> list[list[int]]:
return [list(output.shape) for output in outputs]
def _dynamic_axes(input_names: list[str], output_names: list[str], enabled: bool) -> dict[str, dict[int, str]] | None:
if not enabled:
return None
dynamic_axes = {name: {0: "batch"} for name in input_names}
dynamic_axes.update({name: {0: "batch"} for name in output_names})
return dynamic_axes
def _write_manifest(path: Path, payload: dict[str, Any]) -> None:
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
def _base_manifest(
artifact_name: str,
args: argparse.Namespace,
input_names: list[str],
input_sizes_wh: dict[str, list[int]],
output_names: list[str],
output_shapes: list[list[int]],
) -> dict[str, Any]:
return {
"artifact_name": artifact_name,
"roi0_model_path": str(Path(args.roi0_model_path).resolve()),
"roi1_model_path": str(Path(args.roi1_model_path).resolve()),
"input_names": input_names,
"input_sizes_wh": input_sizes_wh,
"output_names": output_names,
"output_shapes": output_shapes,
"dynamic_batch": bool(args.dynamic),
"max_det": int(args.max_det),
"opset": int(args.opset),
"fuse": not bool(args.no_fuse),
"include_anchor_stride": bool(args.include_anchor_stride),
"hybrid_outputs": bool(args.hybrid_outputs),
"denorm_branch_outputs": bool(args.denorm_branch_outputs),
"postprocessed_outputs": bool(args.postprocessed_outputs),
"edge_head_mode": args.edge_head_mode,
"keep_edge_head": args.edge_head_mode == "keep",
"fake_3d_branch_mode": args.fake_3d_branch_mode,
"keep_fake_3d_branch": args.fake_3d_branch_mode == "keep",
"export_mode": _export_mode(args),
}
def _export_torchscript(
model: nn.Module,
example_inputs: torch.Tensor | tuple[torch.Tensor, ...],
save_path: Path,
manifest: dict[str, Any],
) -> None:
LOGGER.info(f"Exporting TorchScript to {save_path}")
traced = torch.jit.trace(model, example_inputs, strict=False)
extra_files = {"config.txt": json.dumps(manifest, sort_keys=True)}
traced.save(str(save_path), _extra_files=extra_files)
LOGGER.info(f"TorchScript export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
def _save_onnx_metadata(onnx_path: Path, manifest: dict[str, Any], simplify: bool) -> None:
import onnx
model_onnx = onnx.load(str(onnx_path))
for key, value in manifest.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = key, json.dumps(value, ensure_ascii=True) if isinstance(value, (dict, list)) else str(value)
if simplify:
import onnxslim
LOGGER.info("Simplifying ONNX graph with onnxslim")
model_onnx = onnxslim.slim(model_onnx)
onnx.save(model_onnx, str(onnx_path))
def _export_onnx(
model: nn.Module,
example_inputs: torch.Tensor | tuple[torch.Tensor, ...],
save_path: Path,
input_names: list[str],
output_names: list[str],
manifest: dict[str, Any],
opset: int,
dynamic: bool,
simplify: bool,
) -> None:
LOGGER.info(f"Exporting ONNX to {save_path} with opset={opset}")
patch_args = SimpleNamespace(dynamic=dynamic, half=False, format="onnx")
with onnx_export_patch():
with arange_patch(patch_args):
torch.onnx.export(
model,
example_inputs,
str(save_path),
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=_dynamic_axes(input_names, output_names, dynamic),
)
try:
_save_onnx_metadata(save_path, manifest, simplify=simplify)
except ImportError as error:
LOGGER.warning(f"Skipping ONNX metadata/simplify step because a package is missing: {error}")
except Exception as error:
LOGGER.warning(f"ONNX post-processing failed: {error}")
LOGGER.info(f"ONNX export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)")
def _load_and_prepare_wrapper(weights_path: str, roi_name: str, args: argparse.Namespace) -> ROISelectedOutputsWrapper:
LOGGER.info(f"Loading {roi_name} checkpoint from {weights_path}")
model, _ = load_checkpoint(weights_path, device="cpu", fuse=False)
prepared_model = _prepare_model_for_export(model, max_det=args.max_det, dynamic=args.dynamic, fuse=not args.no_fuse)
return ROISelectedOutputsWrapper(
prepared_model,
roi_name,
include_anchor_stride=args.include_anchor_stride,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
keep_edge_head=args.edge_head_mode == "keep",
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
)
def _export_single_roi_artifacts(
wrapper: ROISelectedOutputsWrapper,
example_input: torch.Tensor,
save_dir: Path,
base_name: str,
args: argparse.Namespace,
input_size_wh: tuple[int, int],
) -> None:
outputs = wrapper(example_input)
output_names = _output_names(
base_name,
include_anchor_stride=args.include_anchor_stride,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
keep_edge_head=args.edge_head_mode == "keep",
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
)
manifest = _base_manifest(
artifact_name=base_name,
args=args,
input_names=[f"{base_name}_input"],
input_sizes_wh={f"{base_name}_input": list(input_size_wh)},
output_names=output_names,
output_shapes=_describe_outputs(outputs),
)
_write_manifest(save_dir / f"{base_name}.export.json", manifest)
if not args.skip_torchscript:
_export_torchscript(wrapper, example_input, save_dir / f"{base_name}.torchscript", manifest)
if not args.skip_onnx:
_export_onnx(
wrapper,
example_input,
save_dir / f"{base_name}.onnx",
input_names=[f"{base_name}_input"],
output_names=output_names,
manifest=manifest,
opset=args.opset,
dynamic=args.dynamic,
simplify=args.simplify,
)
def main() -> None:
args = parse_args()
if args.include_anchor_stride and not args.postprocessed_outputs:
LOGGER.warning("--include-anchor-stride is ignored unless --postprocessed-outputs is also set.")
LOGGER.info(f"Export mode: {_export_mode(args)}")
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
common_imgsz_wh = (int(args.imgsz[0]), int(args.imgsz[1]))
roi0_imgsz_wh = _resolve_imgsz_wh(common_imgsz_wh, args.roi0_imgsz)
roi1_imgsz_wh = _resolve_imgsz_wh(common_imgsz_wh, args.roi1_imgsz)
roi0_wrapper = _load_and_prepare_wrapper(args.roi0_model_path, "roi0", args)
roi1_wrapper = _load_and_prepare_wrapper(args.roi1_model_path, "roi1", args)
roi0_example_input = _make_example_input(roi0_imgsz_wh)
roi1_example_input = _make_example_input(roi1_imgsz_wh)
if args.export_separate:
LOGGER.info("Exporting ROI0 and ROI1 as separate artifacts")
with torch.inference_mode():
_export_single_roi_artifacts(roi0_wrapper, roi0_example_input, save_dir, "roi0_model", args, roi0_imgsz_wh)
_export_single_roi_artifacts(roi1_wrapper, roi1_example_input, save_dir, "roi1_model", args, roi1_imgsz_wh)
return
merged_model = TwoROIMergedExportModel(
roi0_wrapper.model,
roi1_wrapper.model,
include_anchor_stride=args.include_anchor_stride,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
keep_edge_head=args.edge_head_mode == "keep",
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
).eval()
example_inputs = (roi0_example_input, roi1_example_input)
with torch.inference_mode():
outputs = merged_model(*example_inputs)
output_names = _output_names(
"roi0",
include_anchor_stride=args.include_anchor_stride,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
keep_edge_head=args.edge_head_mode == "keep",
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
) + _output_names(
"roi1",
include_anchor_stride=args.include_anchor_stride,
hybrid_outputs=args.hybrid_outputs,
denorm_branch_outputs=args.denorm_branch_outputs,
postprocessed_outputs=args.postprocessed_outputs,
keep_edge_head=args.edge_head_mode == "keep",
keep_fake_3d_branch=args.fake_3d_branch_mode == "keep",
)
LOGGER.info(f"Merged model output shapes: {_describe_outputs(outputs)}")
manifest = _base_manifest(
artifact_name="merged_model",
args=args,
input_names=["roi0_input", "roi1_input"],
input_sizes_wh={
"roi0_input": list(roi0_imgsz_wh),
"roi1_input": list(roi1_imgsz_wh),
},
output_names=output_names,
output_shapes=_describe_outputs(outputs),
)
_write_manifest(save_dir / "merged_model.export.json", manifest)
if not args.skip_torchscript:
_export_torchscript(merged_model, example_inputs, save_dir / "merged_model.torchscript", manifest)
if not args.skip_onnx:
_export_onnx(
merged_model,
example_inputs,
save_dir / "merged_model.onnx",
input_names=["roi0_input", "roi1_input"],
output_names=output_names,
manifest=manifest,
opset=args.opset,
dynamic=args.dynamic,
simplify=args.simplify,
)
if __name__ == "__main__":
main()