from __future__ import annotations import argparse import json import sys from copy import deepcopy from pathlib import Path from types import SimpleNamespace from typing import Any import torch import torch.nn as nn FILE = Path(__file__).resolve() ROOT = FILE.parents[2] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from ultralytics.nn.modules import Detect, RTDETRDecoder from ultralytics.nn.tasks import load_checkpoint from ultralytics.utils import LOGGER from ultralytics.utils.patches import arange_patch, onnx_export_patch DEFAULT_ROI0_MODEL = ROOT / "runs" / "detect" / "mono3d_roi0_20260506_epoch99.pt" DEFAULT_ROI1_MODEL = ROOT / "runs" / "detect" / "mono3d_roi1_20260506_epoch99.pt" DEFAULT_SAVE_DIR = ROOT / "runs" / "export" / "train_mono3d_two_roi_20260506" DEFAULT_IMGSZ_WH = (768, 352) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Merge two yolo26 single-ROI Detect3D checkpoints and export them.") parser.add_argument("--roi0-model-path", type=str, default=str(DEFAULT_ROI0_MODEL), help="Path to ROI0 checkpoint") parser.add_argument("--roi1-model-path", type=str, default=str(DEFAULT_ROI1_MODEL), help="Path to ROI1 checkpoint") parser.add_argument("--save-dir", type=str, default=str(DEFAULT_SAVE_DIR), help="Directory used to store exported files") parser.add_argument( "--imgsz", nargs=2, type=int, default=list(DEFAULT_IMGSZ_WH), metavar=("W", "H"), help="Common example input size as width height, e.g. --imgsz 768 352", ) parser.add_argument("--roi0-imgsz", nargs=2, type=int, default=None, metavar=("W", "H"), help="Optional ROI0 input size override") parser.add_argument("--roi1-imgsz", nargs=2, type=int, default=None, metavar=("W", "H"), help="Optional ROI1 input size override") parser.add_argument("--opset", type=int, default=17, help="ONNX opset version") parser.add_argument("--max-det", type=int, default=300, help="Top-k detections kept by each Detect3D head") parser.add_argument("--dynamic", action="store_true", help="Export with dynamic batch dimension") parser.add_argument("--simplify", action="store_true", help="Try to simplify the exported ONNX graph with onnxslim") export_group = parser.add_mutually_exclusive_group() export_group.add_argument( "--hybrid-outputs", action="store_true", help="Export postprocessed 2D detections together with raw 2D/3D/edge head outputs.", ) export_group.add_argument( "--denorm-branch-outputs", action="store_true", help="Export one2one branch outputs after Detect3D branch denormalization, but before top-k postprocessing.", ) export_group.add_argument( "--postprocessed-outputs", action="store_true", help="Export postprocessed selected outputs instead of raw head tensors.", ) parser.add_argument( "--include-anchor-stride", action="store_true", help="Also export anchors and strides for each ROI. Only used with --postprocessed-outputs.", ) parser.add_argument("--export-separate", action="store_true", help="Export ROI0 and ROI1 as separate single-input artifacts") parser.add_argument("--skip-torchscript", action="store_true", help="Skip TorchScript export") parser.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export") parser.add_argument("--no-fuse", action="store_true", help="Disable Conv-BN fusion before export") parser.add_argument( "--edge-head-mode", type=str, choices=("keep", "drop"), default="keep", help="Control whether edge_head branch outputs are kept in exported artifacts.", ) parser.add_argument( "--fake-3d-branch-mode", type=str, choices=("keep", "drop"), default="keep", help="Control whether fake 3D branch outputs are kept in exported artifacts.", ) return parser.parse_args() def _resolve_imgsz_wh(common_imgsz_wh: tuple[int, int], override: list[int] | tuple[int, int] | None) -> tuple[int, int]: if override is None: return int(common_imgsz_wh[0]), int(common_imgsz_wh[1]) return int(override[0]), int(override[1]) def _make_example_input(imgsz_wh: tuple[int, int]) -> torch.Tensor: width, height = imgsz_wh return torch.zeros(1, 3, height, width, dtype=torch.float32) def _prepare_model_for_export(model: nn.Module, max_det: int, dynamic: bool, fuse: bool) -> nn.Module: model = deepcopy(model).cpu() for parameter in model.parameters(): parameter.requires_grad_(False) model.eval() model.float() if fuse and hasattr(model, "fuse"): model = model.fuse() for module in model.modules(): if isinstance(module, (Detect, RTDETRDecoder)): if hasattr(module, "dynamic"): module.dynamic = dynamic if hasattr(module, "max_det"): module.max_det = int(max_det) if hasattr(module, "shape"): module.shape = None return model def _export_mode(args: argparse.Namespace) -> str: if args.postprocessed_outputs: return "postprocessed_outputs" if args.hybrid_outputs: return "hybrid_outputs" if args.denorm_branch_outputs: return "denorm_branch_outputs" return "raw_head_outputs" def _forward_model_to_head_inputs(model: nn.Module, x: torch.Tensor) -> tuple[nn.Module, Any]: y = [] for module in model.model[:-1]: if module.f != -1: x = y[module.f] if isinstance(module.f, int) else [x if j == -1 else y[j] for j in module.f] x = module(x) y.append(x if module.i in model.save else None) head = model.model[-1] if head.f != -1: x = y[head.f] if isinstance(head.f, int) else [x if j == -1 else y[j] for j in head.f] return head, x def _flatten_detect3d_outputs( outputs: Any, roi_name: str, include_anchor_stride: bool = False, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ) -> tuple[torch.Tensor, ...]: if not isinstance(outputs, (tuple, list)) or len(outputs) < 2: raise RuntimeError(f"{roi_name} forward output must be `(detections, raw_preds)` for Detect3D models.") detections, raw_preds = outputs[0], outputs[1] if not isinstance(raw_preds, dict): raise RuntimeError(f"{roi_name} raw prediction payload is missing.") one2one = raw_preds.get("one2one", raw_preds) if not isinstance(one2one, dict): raise RuntimeError(f"{roi_name} one2one prediction payload is missing.") required_keys = ["preds_3d_selected"] if keep_edge_head: required_keys.append("preds_edge_selected") if include_anchor_stride: required_keys += ["anchors_selected", "strides_selected"] missing = [key for key in required_keys if one2one.get(key) is None] if missing: available = sorted(one2one.keys()) raise RuntimeError( f"{roi_name} is missing Detect3D export tensors {missing}. Available keys: {available}. " "This script expects yolo26 end2end Detect3D checkpoints." ) result = ( detections, one2one["preds_3d_selected"], one2one.get("preds_diff_selected"), ) if keep_fake_3d_branch: result = result + (one2one.get("preds_3d_fake_selected"),) if keep_edge_head: result = result + (one2one["preds_edge_selected"],) if include_anchor_stride: result = result + (one2one["anchors_selected"], one2one["strides_selected"]) return result def _denorm_detect3d_branch_outputs( model: nn.Module, x: torch.Tensor, roi_name: str, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ) -> tuple[torch.Tensor, ...]: head, head_inputs = _forward_model_to_head_inputs(model, x) if not hasattr(head, "forward_head"): raise RuntimeError(f"{roi_name} final layer does not expose forward_head(), cannot export raw branch outputs.") branch_inputs = [tensor.detach() for tensor in head_inputs] if getattr(head, "end2end", False) else head_inputs branch = head.one2one if getattr(head, "end2end", False) else head.one2many raw_preds = head.forward_head(branch_inputs, **branch) if not isinstance(raw_preds, dict): raise RuntimeError(f"{roi_name} raw branch payload is missing.") required_keys = ["boxes", "scores", "preds_3d"] if keep_edge_head: required_keys.append("preds_edge") missing = [key for key in required_keys if raw_preds.get(key) is None] if missing: available = sorted(raw_preds.keys()) raise RuntimeError(f"{roi_name} raw branch tensors {missing} are missing. Available keys: {available}.") result = ( raw_preds["boxes"], raw_preds["scores"], raw_preds["preds_3d"], raw_preds.get("preds_diff"), ) if keep_fake_3d_branch: result = result + (raw_preds.get("preds_3d_fake"),) if keep_edge_head: result = result + (raw_preds["preds_edge"],) return result def _collect_raw_detect3d_head_outputs( head: nn.Module, head_inputs: list[torch.Tensor], roi_name: str, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ) -> tuple[torch.Tensor, ...]: if not getattr(head, "end2end", False): raise RuntimeError(f"{roi_name} raw-head export currently expects an end2end Detect3D head.") branch = head.one2one box_head = branch.get("box_head") cls_head = branch.get("cls_head") diff_head = branch.get("diff_head") head_3d = branch.get("head_3d") fake_head_3d = branch.get("fake_head_3d") edge_head = branch.get("edge_head") required_modules = ( box_head, cls_head, diff_head, head_3d, edge_head, ) if keep_edge_head else ( box_head, cls_head, diff_head, head_3d, ) if keep_fake_3d_branch: required_modules = (*required_modules, fake_head_3d) if any(module is None for module in required_modules): raise RuntimeError(f"{roi_name} raw-head export could not find complete one2one head modules.") bs = head_inputs[0].shape[0] raw_boxes = torch.cat([box_head[i](head_inputs[i]).view(bs, 4 * head.reg_max, -1) for i in range(head.nl)], dim=-1) raw_scores = torch.cat([cls_head[i](head_inputs[i]).view(bs, head.nc, -1) for i in range(head.nl)], dim=-1) raw_diff = torch.cat([diff_head[i](head_inputs[i]).view(bs, 1, -1) for i in range(head.nl)], dim=-1) raw_3d = torch.cat([head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1) result = (raw_boxes, raw_scores, raw_3d, raw_diff) if keep_fake_3d_branch: raw_3d_fake = torch.cat([fake_head_3d[i](head_inputs[i]).view(bs, head.no_3d, -1) for i in range(head.nl)], dim=-1) result = result + (raw_3d_fake,) if keep_edge_head: raw_edge = torch.cat([edge_head[i](head_inputs[i]).view(bs, head.no_edge, -1) for i in range(head.nl)], dim=-1) result = result + (raw_edge,) return result def _raw_detect3d_head_outputs( model: nn.Module, x: torch.Tensor, roi_name: str, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ) -> tuple[torch.Tensor, ...]: head, head_inputs = _forward_model_to_head_inputs(model, x) return _collect_raw_detect3d_head_outputs( head, head_inputs, roi_name, keep_edge_head=keep_edge_head, keep_fake_3d_branch=keep_fake_3d_branch, ) def _hybrid_detect2d_raw3d_outputs( model: nn.Module, x: torch.Tensor, roi_name: str, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ) -> tuple[torch.Tensor, ...]: head, head_inputs = _forward_model_to_head_inputs(model, x) raw_outputs = _collect_raw_detect3d_head_outputs( head, head_inputs, roi_name, keep_edge_head=keep_edge_head, keep_fake_3d_branch=keep_fake_3d_branch, ) raw_boxes, raw_scores, *_ = raw_outputs preds = {"boxes": raw_boxes, "scores": raw_scores, "feats": head_inputs} detections = head.postprocess(head._inference(preds).permute(0, 2, 1)) return (detections, *raw_outputs) class ROISelectedOutputsWrapper(nn.Module): def __init__( self, model: nn.Module, roi_name: str, include_anchor_stride: bool = False, hybrid_outputs: bool = False, denorm_branch_outputs: bool = False, postprocessed_outputs: bool = False, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ): super().__init__() self.model = model self.roi_name = roi_name self.include_anchor_stride = include_anchor_stride self.hybrid_outputs = hybrid_outputs self.denorm_branch_outputs = denorm_branch_outputs self.postprocessed_outputs = postprocessed_outputs self.keep_edge_head = keep_edge_head self.keep_fake_3d_branch = keep_fake_3d_branch def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: if self.postprocessed_outputs: return _flatten_detect3d_outputs( self.model(x), self.roi_name, include_anchor_stride=self.include_anchor_stride, keep_edge_head=self.keep_edge_head, keep_fake_3d_branch=self.keep_fake_3d_branch, ) if self.hybrid_outputs: return _hybrid_detect2d_raw3d_outputs( self.model, x, self.roi_name, keep_edge_head=self.keep_edge_head, keep_fake_3d_branch=self.keep_fake_3d_branch, ) if self.denorm_branch_outputs: return _denorm_detect3d_branch_outputs( self.model, x, self.roi_name, keep_edge_head=self.keep_edge_head, keep_fake_3d_branch=self.keep_fake_3d_branch, ) return _raw_detect3d_head_outputs( self.model, x, self.roi_name, keep_edge_head=self.keep_edge_head, keep_fake_3d_branch=self.keep_fake_3d_branch, ) class TwoROIMergedExportModel(nn.Module): def __init__( self, roi0_model: nn.Module, roi1_model: nn.Module, include_anchor_stride: bool = False, hybrid_outputs: bool = False, denorm_branch_outputs: bool = False, postprocessed_outputs: bool = False, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ): super().__init__() self.roi0_model = ROISelectedOutputsWrapper( roi0_model, "roi0", include_anchor_stride=include_anchor_stride, hybrid_outputs=hybrid_outputs, denorm_branch_outputs=denorm_branch_outputs, postprocessed_outputs=postprocessed_outputs, keep_edge_head=keep_edge_head, keep_fake_3d_branch=keep_fake_3d_branch, ) self.roi1_model = ROISelectedOutputsWrapper( roi1_model, "roi1", include_anchor_stride=include_anchor_stride, hybrid_outputs=hybrid_outputs, denorm_branch_outputs=denorm_branch_outputs, postprocessed_outputs=postprocessed_outputs, keep_edge_head=keep_edge_head, keep_fake_3d_branch=keep_fake_3d_branch, ) self.stride = getattr(roi0_model, "stride", None) self.names = getattr(roi0_model, "names", None) def forward(self, x_roi0: torch.Tensor, x_roi1: torch.Tensor) -> tuple[torch.Tensor, ...]: return (*self.roi0_model(x_roi0), *self.roi1_model(x_roi1)) def _output_names( prefix: str, include_anchor_stride: bool = False, hybrid_outputs: bool = False, denorm_branch_outputs: bool = False, postprocessed_outputs: bool = False, keep_edge_head: bool = True, keep_fake_3d_branch: bool = True, ) -> list[str]: if postprocessed_outputs: names = [ f"{prefix}_detections", f"{prefix}_preds_3d", f"{prefix}_preds_diff", ] if keep_fake_3d_branch: names.append(f"{prefix}_preds_3d_fake") if keep_edge_head: names.append(f"{prefix}_preds_edge") if include_anchor_stride: names += [f"{prefix}_anchors", f"{prefix}_strides"] return names if hybrid_outputs: names = [ f"{prefix}_detections", f"{prefix}_boxes_head_raw", f"{prefix}_scores_head_raw", f"{prefix}_preds_3d_head_raw", f"{prefix}_preds_diff_head_raw", ] if keep_fake_3d_branch: names.append(f"{prefix}_preds_3d_fake_head_raw") if keep_edge_head: names.append(f"{prefix}_preds_edge_head_raw") return names if denorm_branch_outputs: names = [ f"{prefix}_boxes_raw", f"{prefix}_scores_raw", f"{prefix}_preds_3d_raw", f"{prefix}_preds_diff_raw", ] if keep_fake_3d_branch: names.append(f"{prefix}_preds_3d_fake_raw") if keep_edge_head: names.append(f"{prefix}_preds_edge_raw") return names names = [ f"{prefix}_boxes_head_raw", f"{prefix}_scores_head_raw", f"{prefix}_preds_3d_head_raw", f"{prefix}_preds_diff_head_raw", ] if keep_fake_3d_branch: names.append(f"{prefix}_preds_3d_fake_head_raw") if keep_edge_head: names.append(f"{prefix}_preds_edge_head_raw") return names def _describe_outputs(outputs: tuple[torch.Tensor, ...]) -> list[list[int]]: return [list(output.shape) for output in outputs] def _dynamic_axes(input_names: list[str], output_names: list[str], enabled: bool) -> dict[str, dict[int, str]] | None: if not enabled: return None dynamic_axes = {name: {0: "batch"} for name in input_names} dynamic_axes.update({name: {0: "batch"} for name in output_names}) return dynamic_axes def _write_manifest(path: Path, payload: dict[str, Any]) -> None: path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") def _base_manifest( artifact_name: str, args: argparse.Namespace, input_names: list[str], input_sizes_wh: dict[str, list[int]], output_names: list[str], output_shapes: list[list[int]], ) -> dict[str, Any]: return { "artifact_name": artifact_name, "roi0_model_path": str(Path(args.roi0_model_path).resolve()), "roi1_model_path": str(Path(args.roi1_model_path).resolve()), "input_names": input_names, "input_sizes_wh": input_sizes_wh, "output_names": output_names, "output_shapes": output_shapes, "dynamic_batch": bool(args.dynamic), "max_det": int(args.max_det), "opset": int(args.opset), "fuse": not bool(args.no_fuse), "include_anchor_stride": bool(args.include_anchor_stride), "hybrid_outputs": bool(args.hybrid_outputs), "denorm_branch_outputs": bool(args.denorm_branch_outputs), "postprocessed_outputs": bool(args.postprocessed_outputs), "edge_head_mode": args.edge_head_mode, "keep_edge_head": args.edge_head_mode == "keep", "fake_3d_branch_mode": args.fake_3d_branch_mode, "keep_fake_3d_branch": args.fake_3d_branch_mode == "keep", "export_mode": _export_mode(args), } def _export_torchscript( model: nn.Module, example_inputs: torch.Tensor | tuple[torch.Tensor, ...], save_path: Path, manifest: dict[str, Any], ) -> None: LOGGER.info(f"Exporting TorchScript to {save_path}") traced = torch.jit.trace(model, example_inputs, strict=False) extra_files = {"config.txt": json.dumps(manifest, sort_keys=True)} traced.save(str(save_path), _extra_files=extra_files) LOGGER.info(f"TorchScript export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)") def _save_onnx_metadata(onnx_path: Path, manifest: dict[str, Any], simplify: bool) -> None: import onnx model_onnx = onnx.load(str(onnx_path)) for key, value in manifest.items(): meta = model_onnx.metadata_props.add() meta.key, meta.value = key, json.dumps(value, ensure_ascii=True) if isinstance(value, (dict, list)) else str(value) if simplify: import onnxslim LOGGER.info("Simplifying ONNX graph with onnxslim") model_onnx = onnxslim.slim(model_onnx) onnx.save(model_onnx, str(onnx_path)) def _export_onnx( model: nn.Module, example_inputs: torch.Tensor | tuple[torch.Tensor, ...], save_path: Path, input_names: list[str], output_names: list[str], manifest: dict[str, Any], opset: int, dynamic: bool, simplify: bool, ) -> None: LOGGER.info(f"Exporting ONNX to {save_path} with opset={opset}") patch_args = SimpleNamespace(dynamic=dynamic, half=False, format="onnx") with onnx_export_patch(): with arange_patch(patch_args): torch.onnx.export( model, example_inputs, str(save_path), verbose=False, opset_version=opset, do_constant_folding=True, input_names=input_names, output_names=output_names, dynamic_axes=_dynamic_axes(input_names, output_names, dynamic), ) try: _save_onnx_metadata(save_path, manifest, simplify=simplify) except ImportError as error: LOGGER.warning(f"Skipping ONNX metadata/simplify step because a package is missing: {error}") except Exception as error: LOGGER.warning(f"ONNX post-processing failed: {error}") LOGGER.info(f"ONNX export success: {save_path} ({save_path.stat().st_size / (1024 * 1024):.2f} MB)") def _load_and_prepare_wrapper(weights_path: str, roi_name: str, args: argparse.Namespace) -> ROISelectedOutputsWrapper: LOGGER.info(f"Loading {roi_name} checkpoint from {weights_path}") model, _ = load_checkpoint(weights_path, device="cpu", fuse=False) prepared_model = _prepare_model_for_export(model, max_det=args.max_det, dynamic=args.dynamic, fuse=not args.no_fuse) return ROISelectedOutputsWrapper( prepared_model, roi_name, include_anchor_stride=args.include_anchor_stride, hybrid_outputs=args.hybrid_outputs, denorm_branch_outputs=args.denorm_branch_outputs, postprocessed_outputs=args.postprocessed_outputs, keep_edge_head=args.edge_head_mode == "keep", keep_fake_3d_branch=args.fake_3d_branch_mode == "keep", ) def _export_single_roi_artifacts( wrapper: ROISelectedOutputsWrapper, example_input: torch.Tensor, save_dir: Path, base_name: str, args: argparse.Namespace, input_size_wh: tuple[int, int], ) -> None: outputs = wrapper(example_input) output_names = _output_names( base_name, include_anchor_stride=args.include_anchor_stride, hybrid_outputs=args.hybrid_outputs, denorm_branch_outputs=args.denorm_branch_outputs, postprocessed_outputs=args.postprocessed_outputs, keep_edge_head=args.edge_head_mode == "keep", keep_fake_3d_branch=args.fake_3d_branch_mode == "keep", ) manifest = _base_manifest( artifact_name=base_name, args=args, input_names=[f"{base_name}_input"], input_sizes_wh={f"{base_name}_input": list(input_size_wh)}, output_names=output_names, output_shapes=_describe_outputs(outputs), ) _write_manifest(save_dir / f"{base_name}.export.json", manifest) if not args.skip_torchscript: _export_torchscript(wrapper, example_input, save_dir / f"{base_name}.torchscript", manifest) if not args.skip_onnx: _export_onnx( wrapper, example_input, save_dir / f"{base_name}.onnx", input_names=[f"{base_name}_input"], output_names=output_names, manifest=manifest, opset=args.opset, dynamic=args.dynamic, simplify=args.simplify, ) def main() -> None: args = parse_args() if args.include_anchor_stride and not args.postprocessed_outputs: LOGGER.warning("--include-anchor-stride is ignored unless --postprocessed-outputs is also set.") LOGGER.info(f"Export mode: {_export_mode(args)}") save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) common_imgsz_wh = (int(args.imgsz[0]), int(args.imgsz[1])) roi0_imgsz_wh = _resolve_imgsz_wh(common_imgsz_wh, args.roi0_imgsz) roi1_imgsz_wh = _resolve_imgsz_wh(common_imgsz_wh, args.roi1_imgsz) roi0_wrapper = _load_and_prepare_wrapper(args.roi0_model_path, "roi0", args) roi1_wrapper = _load_and_prepare_wrapper(args.roi1_model_path, "roi1", args) roi0_example_input = _make_example_input(roi0_imgsz_wh) roi1_example_input = _make_example_input(roi1_imgsz_wh) if args.export_separate: LOGGER.info("Exporting ROI0 and ROI1 as separate artifacts") with torch.inference_mode(): _export_single_roi_artifacts(roi0_wrapper, roi0_example_input, save_dir, "roi0_model", args, roi0_imgsz_wh) _export_single_roi_artifacts(roi1_wrapper, roi1_example_input, save_dir, "roi1_model", args, roi1_imgsz_wh) return merged_model = TwoROIMergedExportModel( roi0_wrapper.model, roi1_wrapper.model, include_anchor_stride=args.include_anchor_stride, hybrid_outputs=args.hybrid_outputs, denorm_branch_outputs=args.denorm_branch_outputs, postprocessed_outputs=args.postprocessed_outputs, keep_edge_head=args.edge_head_mode == "keep", keep_fake_3d_branch=args.fake_3d_branch_mode == "keep", ).eval() example_inputs = (roi0_example_input, roi1_example_input) with torch.inference_mode(): outputs = merged_model(*example_inputs) output_names = _output_names( "roi0", include_anchor_stride=args.include_anchor_stride, hybrid_outputs=args.hybrid_outputs, denorm_branch_outputs=args.denorm_branch_outputs, postprocessed_outputs=args.postprocessed_outputs, keep_edge_head=args.edge_head_mode == "keep", keep_fake_3d_branch=args.fake_3d_branch_mode == "keep", ) + _output_names( "roi1", include_anchor_stride=args.include_anchor_stride, hybrid_outputs=args.hybrid_outputs, denorm_branch_outputs=args.denorm_branch_outputs, postprocessed_outputs=args.postprocessed_outputs, keep_edge_head=args.edge_head_mode == "keep", keep_fake_3d_branch=args.fake_3d_branch_mode == "keep", ) LOGGER.info(f"Merged model output shapes: {_describe_outputs(outputs)}") manifest = _base_manifest( artifact_name="merged_model", args=args, input_names=["roi0_input", "roi1_input"], input_sizes_wh={ "roi0_input": list(roi0_imgsz_wh), "roi1_input": list(roi1_imgsz_wh), }, output_names=output_names, output_shapes=_describe_outputs(outputs), ) _write_manifest(save_dir / "merged_model.export.json", manifest) if not args.skip_torchscript: _export_torchscript(merged_model, example_inputs, save_dir / "merged_model.torchscript", manifest) if not args.skip_onnx: _export_onnx( merged_model, example_inputs, save_dir / "merged_model.onnx", input_names=["roi0_input", "roi1_input"], output_names=output_names, manifest=manifest, opset=args.opset, dynamic=args.dynamic, simplify=args.simplify, ) if __name__ == "__main__": main()