330 lines
14 KiB
Python
Executable File
330 lines
14 KiB
Python
Executable File
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[1]
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT))
|
|
|
|
from models.common import DetectMultiBackend
|
|
from export import attempt_load
|
|
from utils.general import LOGGER
|
|
|
|
|
|
def analyze_checkpoint(weights_path):
|
|
"""Analyze checkpoint file structure to understand size."""
|
|
import torch
|
|
|
|
ckpt = torch.load(weights_path, map_location='cpu')
|
|
|
|
LOGGER.info(f"\n{'='*80}")
|
|
LOGGER.info(f"CHECKPOINT ANALYSIS: {Path(weights_path).name}")
|
|
LOGGER.info(f"{'='*80}")
|
|
|
|
# Check what's in the checkpoint
|
|
if isinstance(ckpt, dict):
|
|
LOGGER.info(f"Checkpoint keys: {list(ckpt.keys())}")
|
|
|
|
# Analyze each component size
|
|
total_params = 0
|
|
for key in ckpt.keys():
|
|
if key in ['model', 'ema']:
|
|
model_obj = ckpt[key]
|
|
if hasattr(model_obj, 'parameters'):
|
|
params = sum(p.numel() for p in model_obj.parameters())
|
|
param_size_mb = params * 4 / (1024**2) # FP32 = 4 bytes
|
|
total_params += params
|
|
LOGGER.info(f" - {key}: {params:,} parameters ({param_size_mb:.2f}MB if FP32)")
|
|
|
|
LOGGER.info(f"Total parameters: {total_params:,}")
|
|
LOGGER.info(f"Theoretical FP32 size: {total_params * 4 / (1024**2):.2f}MB")
|
|
|
|
# Check actual file size
|
|
import os
|
|
actual_size = os.path.getsize(weights_path) / (1024**2)
|
|
compression_ratio = (total_params * 4 / (1024**2)) / actual_size if actual_size > 0 else 0
|
|
LOGGER.info(f"Actual file size: {actual_size:.2f}MB")
|
|
LOGGER.info(f"Compression ratio: {compression_ratio:.2f}x")
|
|
|
|
LOGGER.info(f"{'='*80}\n")
|
|
|
|
|
|
def load_model_lightweight(weights_path, device='cpu'):
|
|
"""Load model in the most lightweight way possible (state_dict only).
|
|
|
|
This avoids loading EMA weights and other training artifacts that can double the size.
|
|
"""
|
|
import torch
|
|
from models.yolo import Model
|
|
|
|
LOGGER.info(f"Loading model with lightweight method: {weights_path}")
|
|
ckpt = torch.load(weights_path, map_location=device)
|
|
|
|
# Prefer regular model over EMA (EMA weights can be larger)
|
|
if 'model' in ckpt:
|
|
model = ckpt['model']
|
|
elif 'ema' in ckpt:
|
|
LOGGER.warning("Only EMA weights available, using them (may increase size)")
|
|
model = ckpt['ema']
|
|
else:
|
|
raise ValueError(f"No model found in checkpoint: {weights_path}")
|
|
|
|
# Convert to float32 and eval mode
|
|
model = model.float().eval()
|
|
|
|
# Remove training-specific attributes to reduce memory
|
|
if hasattr(model, 'hyp'):
|
|
delattr(model, 'hyp')
|
|
if hasattr(model, 'gr'):
|
|
delattr(model, 'gr')
|
|
|
|
return model
|
|
|
|
|
|
class Model_Merged(nn.Module):
|
|
def __init__(self, model_roi0, model_roi1):
|
|
super(Model_Merged, self).__init__()
|
|
self.model_roi0 = model_roi0
|
|
self.model_roi1 = model_roi1
|
|
|
|
self.model_roi0.model[-1].export_raw = True # set export=True for Detect head
|
|
self.model_roi1.model[-1].export_raw = True # set export=True for Detect
|
|
|
|
# 确保两个模型都设置为eval模式
|
|
self.model_roi0.eval()
|
|
self.model_roi1.eval()
|
|
|
|
self.stride = model_roi0.stride # assuming both models have the same stride
|
|
self.names = model_roi0.names # assuming both models have the same class names
|
|
|
|
def forward(self, x_roi0, x_roi1):
|
|
|
|
with torch.no_grad():
|
|
out_roi0 = self.model_roi0(x_roi0)
|
|
out_roi1 = self.model_roi1(x_roi1)
|
|
|
|
return out_roi0 + out_roi1
|
|
|
|
def merge_models(roi0_model_path, roi1_model_path, save_path, use_lightweight=False):
|
|
"""Load and merge two models, ensuring minimal memory footprint.
|
|
|
|
Issues with large model size:
|
|
1. attempt_load may load EMA weights which are larger
|
|
2. fuse=True can sometimes increase size due to optimization overhead
|
|
3. The checkpoint contains training metadata
|
|
|
|
Args:
|
|
use_lightweight: If True, use custom lightweight loading (experimental)
|
|
"""
|
|
import os
|
|
|
|
# Report original model sizes
|
|
size_roi0 = os.path.getsize(roi0_model_path) / (1024 * 1024) # MB
|
|
size_roi1 = os.path.getsize(roi1_model_path) / (1024 * 1024) # MB
|
|
LOGGER.info(f"Original model sizes: ROI0={size_roi0:.2f}MB, ROI1={size_roi1:.2f}MB, Total={size_roi0+size_roi1:.2f}MB")
|
|
|
|
# Load the two models - try to minimize size
|
|
if use_lightweight:
|
|
LOGGER.info("Using lightweight model loading (experimental)")
|
|
model_roi0 = load_model_lightweight(roi0_model_path, device='cpu')
|
|
model_roi1 = load_model_lightweight(roi1_model_path, device='cpu')
|
|
else:
|
|
LOGGER.info(f"Loading ROI0 model: {roi0_model_path}")
|
|
model_roi0 = attempt_load(roi0_model_path, device='cpu', fuse=False) # fuse=False to avoid overhead
|
|
|
|
LOGGER.info(f"Loading ROI1 model: {roi1_model_path}")
|
|
model_roi1 = attempt_load(roi1_model_path, device='cpu', fuse=False)
|
|
|
|
# Create merged model
|
|
merged_model = Model_Merged(model_roi0, model_roi1)
|
|
|
|
return merged_model
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--roi0_model_path', type=str, default='release/yolov5s-30w/roi0_last.pt', help='Path to the ROI0 model file')
|
|
parser.add_argument('--roi1_model_path', type=str, default='release/yolov5s-30w/roi1_last.pt', help='Path to the ROI1 model file')
|
|
parser.add_argument('--save_dir', type=str, default='release/yolov5s-30w', help='Path to save the merged model file')
|
|
parser.add_argument('--imgsz', nargs=2, type=int, default=[352, 704], metavar=('H', 'W'), help='Input image size (height width), e.g. --imgsz 352 704')
|
|
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version (11, 12, 13, 17, etc.)')
|
|
parser.add_argument('--skip-export', action='store_true', help='Skip exporting to TorchScript/ONNX')
|
|
parser.add_argument('--lightweight', action='store_true', help='Use lightweight model loading (may reduce export size)')
|
|
parser.add_argument('--export-separate', action='store_true', help='Export two models separately instead of merged (saves space)')
|
|
parser.add_argument('--analyze', action='store_true', help='Analyze checkpoint structure and exit')
|
|
args = parser.parse_args()
|
|
|
|
# Analyze checkpoint if requested
|
|
if args.analyze:
|
|
analyze_checkpoint(args.roi0_model_path)
|
|
if args.roi0_model_path != args.roi1_model_path:
|
|
analyze_checkpoint(args.roi1_model_path)
|
|
sys.exit(0)
|
|
|
|
if not Path(args.save_dir).exists():
|
|
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
merged_model = merge_models(args.roi0_model_path, args.roi1_model_path, args.save_dir, args.lightweight)
|
|
merged_model.eval()
|
|
|
|
h, w = args.imgsz
|
|
im0 = torch.randn(1, 3, h, w) # example input for roi0
|
|
im1 = torch.randn(1, 3, h, w)
|
|
|
|
output = merged_model(im0, im1)
|
|
|
|
print(f"Merged model created successfully.")
|
|
|
|
# Important note about model size
|
|
LOGGER.info("=" * 80)
|
|
LOGGER.info("MODEL SIZE EXPLANATION:")
|
|
LOGGER.info("The merged model contains TWO complete and independent models.")
|
|
LOGGER.info("Expected export size: ~2x original size (≈100MB for 2x25MB models)")
|
|
LOGGER.info("This is NORMAL because the two models have different trained weights.")
|
|
if args.export_separate:
|
|
LOGGER.info("Using --export-separate to save space by exporting separately.")
|
|
LOGGER.info("=" * 80)
|
|
|
|
if args.skip_export:
|
|
LOGGER.info("Skipping export (--skip-export flag set)")
|
|
sys.exit(0)
|
|
|
|
# Option 1: Export separately (recommended for saving space)
|
|
if args.export_separate:
|
|
LOGGER.info("Exporting models SEPARATELY to save space...")
|
|
|
|
# Export ROI0 model
|
|
save_path_roi0_ts = Path(args.save_dir) / "roi0_model.torchscript"
|
|
save_path_roi0_onnx = Path(args.save_dir) / "roi0_model.onnx"
|
|
|
|
try:
|
|
LOGGER.info(f"Exporting ROI0 to TorchScript: {save_path_roi0_ts}")
|
|
ts0 = torch.jit.trace(merged_model.model_roi0, im0, strict=False)
|
|
ts0.save(str(save_path_roi0_ts))
|
|
size_ts0 = save_path_roi0_ts.stat().st_size / (1024 * 1024)
|
|
LOGGER.info(f"ROI0 TorchScript export success: {size_ts0:.2f}MB")
|
|
|
|
LOGGER.info(f"Exporting ROI0 to ONNX: {save_path_roi0_onnx}")
|
|
torch.onnx.export(
|
|
merged_model.model_roi0, im0, str(save_path_roi0_onnx),
|
|
verbose=False, opset_version=args.opset, do_constant_folding=True,
|
|
input_names=["input"], output_names=["output"]
|
|
)
|
|
size_onnx0 = save_path_roi0_onnx.stat().st_size / (1024 * 1024)
|
|
LOGGER.info(f"ROI0 ONNX export success: {size_onnx0:.2f}MB")
|
|
except Exception as e:
|
|
LOGGER.error(f"ROI0 export failed: {e}")
|
|
|
|
# Export ROI1 model
|
|
save_path_roi1_ts = Path(args.save_dir) / "roi1_model.torchscript"
|
|
save_path_roi1_onnx = Path(args.save_dir) / "roi1_model.onnx"
|
|
|
|
try:
|
|
LOGGER.info(f"Exporting ROI1 to TorchScript: {save_path_roi1_ts}")
|
|
ts1 = torch.jit.trace(merged_model.model_roi1, im1, strict=False)
|
|
ts1.save(str(save_path_roi1_ts))
|
|
size_ts1 = save_path_roi1_ts.stat().st_size / (1024 * 1024)
|
|
LOGGER.info(f"ROI1 TorchScript export success: {size_ts1:.2f}MB")
|
|
|
|
LOGGER.info(f"Exporting ROI1 to ONNX: {save_path_roi1_onnx}")
|
|
torch.onnx.export(
|
|
merged_model.model_roi1, im1, str(save_path_roi1_onnx),
|
|
verbose=False, opset_version=args.opset, do_constant_folding=True,
|
|
input_names=["input"], output_names=["output"]
|
|
)
|
|
size_onnx1 = save_path_roi1_onnx.stat().st_size / (1024 * 1024)
|
|
LOGGER.info(f"ROI1 ONNX export success: {size_onnx1:.2f}MB")
|
|
except Exception as e:
|
|
LOGGER.error(f"ROI1 export failed: {e}")
|
|
|
|
LOGGER.info("=" * 80)
|
|
LOGGER.info("SEPARATE EXPORT COMPLETE")
|
|
LOGGER.info("Use these files for inference by loading both models separately")
|
|
LOGGER.info("=" * 80)
|
|
sys.exit(0)
|
|
|
|
# Option 2: Export merged model (default, but larger)
|
|
save_path_ts = Path(args.save_dir) / "merged_model.torchscript"
|
|
try:
|
|
LOGGER.info(f"Exporting to TorchScript: {save_path_ts}")
|
|
# Use torch.jit.trace with tuple of inputs for multi-input models
|
|
ts = torch.jit.trace(merged_model, (im0, im1), strict=False)
|
|
ts.save(str(save_path_ts))
|
|
|
|
# Report file size
|
|
size_ts = save_path_ts.stat().st_size / (1024 * 1024) # MB
|
|
LOGGER.info(f"TorchScript export success: {save_path_ts} (Size: {size_ts:.2f}MB)")
|
|
except Exception as e:
|
|
LOGGER.error(f"TorchScript export failed: {e}")
|
|
|
|
# Export to ONNX
|
|
save_path_onnx = Path(args.save_dir) / "merged_model.onnx"
|
|
try:
|
|
# Check PyTorch version to determine compatible opset version
|
|
import torch
|
|
torch_version = torch.__version__
|
|
LOGGER.info(f"PyTorch version: {torch_version}")
|
|
|
|
# Use opset 11 or 13 for better compatibility, or accept the default (18)
|
|
# opset 11: widely supported by TensorRT, OpenVINO, ONNX Runtime
|
|
# opset 13: better operator coverage
|
|
# opset 17+: required for newer PyTorch versions
|
|
opset = args.opset # Use command line argument
|
|
|
|
LOGGER.info(f"Exporting to ONNX with opset_version={opset}: {save_path_onnx}")
|
|
torch.onnx.export(
|
|
merged_model,
|
|
(im0, im1), # tuple of example inputs
|
|
str(save_path_onnx),
|
|
verbose=False,
|
|
opset_version=opset,
|
|
do_constant_folding=True,
|
|
input_names=["roi0_input", "roi1_input"], # specify names for two inputs
|
|
output_names=["roi0_output_2d_8x", "roi0_output_2d_16x", "roi0_output_2d_32x",
|
|
"roi0_output_3d_8x", "roi0_output_3d_16x", "roi0_output_3d_32x",
|
|
"roi1_output_2d_8x", "roi1_output_2d_16x", "roi1_output_2d_32x",
|
|
"roi1_output_3d_8x", "roi1_output_3d_16x", "roi1_output_3d_32x"], # specify names for two outputs
|
|
dynamic_axes={
|
|
"roi0_input": {0: "batch"},
|
|
"roi1_input": {0: "batch"},
|
|
"roi0_output": {0: "batch"},
|
|
"roi1_output": {0: "batch"}
|
|
} if False else None # set to True to enable dynamic batch size
|
|
)
|
|
LOGGER.info(f"ONNX export success: {save_path_onnx}")
|
|
|
|
# Report file size before simplification
|
|
size_onnx_before = save_path_onnx.stat().st_size / (1024 * 1024) # MB
|
|
LOGGER.info(f"ONNX model size before simplification: {size_onnx_before:.2f}MB")
|
|
|
|
# Optional: simplify ONNX model
|
|
try:
|
|
import onnx
|
|
import onnxslim
|
|
|
|
# Skip simplification for multi-input models as it may increase size
|
|
LOGGER.warning("Skipping ONNX simplification for multi-input model (often increases size)")
|
|
LOGGER.info("If you still want to try simplification, use: onnxslim input.onnx output.onnx")
|
|
|
|
# Uncomment below to force simplification (not recommended for multi-input models)
|
|
# LOGGER.info("Simplifying ONNX model with onnxslim...")
|
|
# model_onnx = onnx.load(str(save_path_onnx))
|
|
# model_onnx = onnxslim.slim(model_onnx)
|
|
# onnx.save(model_onnx, str(save_path_onnx))
|
|
#
|
|
# # Report size after simplification
|
|
# size_onnx_after = save_path_onnx.stat().st_size / (1024 * 1024) # MB
|
|
# reduction = size_onnx_before - size_onnx_after
|
|
# LOGGER.info(f"ONNX simplification success: {size_onnx_after:.2f}MB (reduced by {reduction:.2f}MB)")
|
|
except ImportError:
|
|
LOGGER.warning("onnxslim not installed, skipping simplification. Install with: pip install onnxslim")
|
|
except Exception as e:
|
|
LOGGER.warning(f"ONNX simplification failed: {e}")
|
|
except Exception as e:
|
|
LOGGER.error(f"ONNX export failed: {e}")
|