Files
yolov26_3d/tools/model_merging/merge_models_of_2roi.py
2026-06-24 09:35:46 +08:00

330 lines
14 KiB
Python
Executable File

import argparse
import json
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from models.common import DetectMultiBackend
from export import attempt_load
from utils.general import LOGGER
def analyze_checkpoint(weights_path):
"""Analyze checkpoint file structure to understand size."""
import torch
ckpt = torch.load(weights_path, map_location='cpu')
LOGGER.info(f"\n{'='*80}")
LOGGER.info(f"CHECKPOINT ANALYSIS: {Path(weights_path).name}")
LOGGER.info(f"{'='*80}")
# Check what's in the checkpoint
if isinstance(ckpt, dict):
LOGGER.info(f"Checkpoint keys: {list(ckpt.keys())}")
# Analyze each component size
total_params = 0
for key in ckpt.keys():
if key in ['model', 'ema']:
model_obj = ckpt[key]
if hasattr(model_obj, 'parameters'):
params = sum(p.numel() for p in model_obj.parameters())
param_size_mb = params * 4 / (1024**2) # FP32 = 4 bytes
total_params += params
LOGGER.info(f" - {key}: {params:,} parameters ({param_size_mb:.2f}MB if FP32)")
LOGGER.info(f"Total parameters: {total_params:,}")
LOGGER.info(f"Theoretical FP32 size: {total_params * 4 / (1024**2):.2f}MB")
# Check actual file size
import os
actual_size = os.path.getsize(weights_path) / (1024**2)
compression_ratio = (total_params * 4 / (1024**2)) / actual_size if actual_size > 0 else 0
LOGGER.info(f"Actual file size: {actual_size:.2f}MB")
LOGGER.info(f"Compression ratio: {compression_ratio:.2f}x")
LOGGER.info(f"{'='*80}\n")
def load_model_lightweight(weights_path, device='cpu'):
"""Load model in the most lightweight way possible (state_dict only).
This avoids loading EMA weights and other training artifacts that can double the size.
"""
import torch
from models.yolo import Model
LOGGER.info(f"Loading model with lightweight method: {weights_path}")
ckpt = torch.load(weights_path, map_location=device)
# Prefer regular model over EMA (EMA weights can be larger)
if 'model' in ckpt:
model = ckpt['model']
elif 'ema' in ckpt:
LOGGER.warning("Only EMA weights available, using them (may increase size)")
model = ckpt['ema']
else:
raise ValueError(f"No model found in checkpoint: {weights_path}")
# Convert to float32 and eval mode
model = model.float().eval()
# Remove training-specific attributes to reduce memory
if hasattr(model, 'hyp'):
delattr(model, 'hyp')
if hasattr(model, 'gr'):
delattr(model, 'gr')
return model
class Model_Merged(nn.Module):
def __init__(self, model_roi0, model_roi1):
super(Model_Merged, self).__init__()
self.model_roi0 = model_roi0
self.model_roi1 = model_roi1
self.model_roi0.model[-1].export_raw = True # set export=True for Detect head
self.model_roi1.model[-1].export_raw = True # set export=True for Detect
# 确保两个模型都设置为eval模式
self.model_roi0.eval()
self.model_roi1.eval()
self.stride = model_roi0.stride # assuming both models have the same stride
self.names = model_roi0.names # assuming both models have the same class names
def forward(self, x_roi0, x_roi1):
with torch.no_grad():
out_roi0 = self.model_roi0(x_roi0)
out_roi1 = self.model_roi1(x_roi1)
return out_roi0 + out_roi1
def merge_models(roi0_model_path, roi1_model_path, save_path, use_lightweight=False):
"""Load and merge two models, ensuring minimal memory footprint.
Issues with large model size:
1. attempt_load may load EMA weights which are larger
2. fuse=True can sometimes increase size due to optimization overhead
3. The checkpoint contains training metadata
Args:
use_lightweight: If True, use custom lightweight loading (experimental)
"""
import os
# Report original model sizes
size_roi0 = os.path.getsize(roi0_model_path) / (1024 * 1024) # MB
size_roi1 = os.path.getsize(roi1_model_path) / (1024 * 1024) # MB
LOGGER.info(f"Original model sizes: ROI0={size_roi0:.2f}MB, ROI1={size_roi1:.2f}MB, Total={size_roi0+size_roi1:.2f}MB")
# Load the two models - try to minimize size
if use_lightweight:
LOGGER.info("Using lightweight model loading (experimental)")
model_roi0 = load_model_lightweight(roi0_model_path, device='cpu')
model_roi1 = load_model_lightweight(roi1_model_path, device='cpu')
else:
LOGGER.info(f"Loading ROI0 model: {roi0_model_path}")
model_roi0 = attempt_load(roi0_model_path, device='cpu', fuse=False) # fuse=False to avoid overhead
LOGGER.info(f"Loading ROI1 model: {roi1_model_path}")
model_roi1 = attempt_load(roi1_model_path, device='cpu', fuse=False)
# Create merged model
merged_model = Model_Merged(model_roi0, model_roi1)
return merged_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--roi0_model_path', type=str, default='release/yolov5s-30w/roi0_last.pt', help='Path to the ROI0 model file')
parser.add_argument('--roi1_model_path', type=str, default='release/yolov5s-30w/roi1_last.pt', help='Path to the ROI1 model file')
parser.add_argument('--save_dir', type=str, default='release/yolov5s-30w', help='Path to save the merged model file')
parser.add_argument('--imgsz', nargs=2, type=int, default=[352, 704], metavar=('H', 'W'), help='Input image size (height width), e.g. --imgsz 352 704')
parser.add_argument('--opset', type=int, default=11, help='ONNX opset version (11, 12, 13, 17, etc.)')
parser.add_argument('--skip-export', action='store_true', help='Skip exporting to TorchScript/ONNX')
parser.add_argument('--lightweight', action='store_true', help='Use lightweight model loading (may reduce export size)')
parser.add_argument('--export-separate', action='store_true', help='Export two models separately instead of merged (saves space)')
parser.add_argument('--analyze', action='store_true', help='Analyze checkpoint structure and exit')
args = parser.parse_args()
# Analyze checkpoint if requested
if args.analyze:
analyze_checkpoint(args.roi0_model_path)
if args.roi0_model_path != args.roi1_model_path:
analyze_checkpoint(args.roi1_model_path)
sys.exit(0)
if not Path(args.save_dir).exists():
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
merged_model = merge_models(args.roi0_model_path, args.roi1_model_path, args.save_dir, args.lightweight)
merged_model.eval()
h, w = args.imgsz
im0 = torch.randn(1, 3, h, w) # example input for roi0
im1 = torch.randn(1, 3, h, w)
output = merged_model(im0, im1)
print(f"Merged model created successfully.")
# Important note about model size
LOGGER.info("=" * 80)
LOGGER.info("MODEL SIZE EXPLANATION:")
LOGGER.info("The merged model contains TWO complete and independent models.")
LOGGER.info("Expected export size: ~2x original size (≈100MB for 2x25MB models)")
LOGGER.info("This is NORMAL because the two models have different trained weights.")
if args.export_separate:
LOGGER.info("Using --export-separate to save space by exporting separately.")
LOGGER.info("=" * 80)
if args.skip_export:
LOGGER.info("Skipping export (--skip-export flag set)")
sys.exit(0)
# Option 1: Export separately (recommended for saving space)
if args.export_separate:
LOGGER.info("Exporting models SEPARATELY to save space...")
# Export ROI0 model
save_path_roi0_ts = Path(args.save_dir) / "roi0_model.torchscript"
save_path_roi0_onnx = Path(args.save_dir) / "roi0_model.onnx"
try:
LOGGER.info(f"Exporting ROI0 to TorchScript: {save_path_roi0_ts}")
ts0 = torch.jit.trace(merged_model.model_roi0, im0, strict=False)
ts0.save(str(save_path_roi0_ts))
size_ts0 = save_path_roi0_ts.stat().st_size / (1024 * 1024)
LOGGER.info(f"ROI0 TorchScript export success: {size_ts0:.2f}MB")
LOGGER.info(f"Exporting ROI0 to ONNX: {save_path_roi0_onnx}")
torch.onnx.export(
merged_model.model_roi0, im0, str(save_path_roi0_onnx),
verbose=False, opset_version=args.opset, do_constant_folding=True,
input_names=["input"], output_names=["output"]
)
size_onnx0 = save_path_roi0_onnx.stat().st_size / (1024 * 1024)
LOGGER.info(f"ROI0 ONNX export success: {size_onnx0:.2f}MB")
except Exception as e:
LOGGER.error(f"ROI0 export failed: {e}")
# Export ROI1 model
save_path_roi1_ts = Path(args.save_dir) / "roi1_model.torchscript"
save_path_roi1_onnx = Path(args.save_dir) / "roi1_model.onnx"
try:
LOGGER.info(f"Exporting ROI1 to TorchScript: {save_path_roi1_ts}")
ts1 = torch.jit.trace(merged_model.model_roi1, im1, strict=False)
ts1.save(str(save_path_roi1_ts))
size_ts1 = save_path_roi1_ts.stat().st_size / (1024 * 1024)
LOGGER.info(f"ROI1 TorchScript export success: {size_ts1:.2f}MB")
LOGGER.info(f"Exporting ROI1 to ONNX: {save_path_roi1_onnx}")
torch.onnx.export(
merged_model.model_roi1, im1, str(save_path_roi1_onnx),
verbose=False, opset_version=args.opset, do_constant_folding=True,
input_names=["input"], output_names=["output"]
)
size_onnx1 = save_path_roi1_onnx.stat().st_size / (1024 * 1024)
LOGGER.info(f"ROI1 ONNX export success: {size_onnx1:.2f}MB")
except Exception as e:
LOGGER.error(f"ROI1 export failed: {e}")
LOGGER.info("=" * 80)
LOGGER.info("SEPARATE EXPORT COMPLETE")
LOGGER.info("Use these files for inference by loading both models separately")
LOGGER.info("=" * 80)
sys.exit(0)
# Option 2: Export merged model (default, but larger)
save_path_ts = Path(args.save_dir) / "merged_model.torchscript"
try:
LOGGER.info(f"Exporting to TorchScript: {save_path_ts}")
# Use torch.jit.trace with tuple of inputs for multi-input models
ts = torch.jit.trace(merged_model, (im0, im1), strict=False)
ts.save(str(save_path_ts))
# Report file size
size_ts = save_path_ts.stat().st_size / (1024 * 1024) # MB
LOGGER.info(f"TorchScript export success: {save_path_ts} (Size: {size_ts:.2f}MB)")
except Exception as e:
LOGGER.error(f"TorchScript export failed: {e}")
# Export to ONNX
save_path_onnx = Path(args.save_dir) / "merged_model.onnx"
try:
# Check PyTorch version to determine compatible opset version
import torch
torch_version = torch.__version__
LOGGER.info(f"PyTorch version: {torch_version}")
# Use opset 11 or 13 for better compatibility, or accept the default (18)
# opset 11: widely supported by TensorRT, OpenVINO, ONNX Runtime
# opset 13: better operator coverage
# opset 17+: required for newer PyTorch versions
opset = args.opset # Use command line argument
LOGGER.info(f"Exporting to ONNX with opset_version={opset}: {save_path_onnx}")
torch.onnx.export(
merged_model,
(im0, im1), # tuple of example inputs
str(save_path_onnx),
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=["roi0_input", "roi1_input"], # specify names for two inputs
output_names=["roi0_output_2d_8x", "roi0_output_2d_16x", "roi0_output_2d_32x",
"roi0_output_3d_8x", "roi0_output_3d_16x", "roi0_output_3d_32x",
"roi1_output_2d_8x", "roi1_output_2d_16x", "roi1_output_2d_32x",
"roi1_output_3d_8x", "roi1_output_3d_16x", "roi1_output_3d_32x"], # specify names for two outputs
dynamic_axes={
"roi0_input": {0: "batch"},
"roi1_input": {0: "batch"},
"roi0_output": {0: "batch"},
"roi1_output": {0: "batch"}
} if False else None # set to True to enable dynamic batch size
)
LOGGER.info(f"ONNX export success: {save_path_onnx}")
# Report file size before simplification
size_onnx_before = save_path_onnx.stat().st_size / (1024 * 1024) # MB
LOGGER.info(f"ONNX model size before simplification: {size_onnx_before:.2f}MB")
# Optional: simplify ONNX model
try:
import onnx
import onnxslim
# Skip simplification for multi-input models as it may increase size
LOGGER.warning("Skipping ONNX simplification for multi-input model (often increases size)")
LOGGER.info("If you still want to try simplification, use: onnxslim input.onnx output.onnx")
# Uncomment below to force simplification (not recommended for multi-input models)
# LOGGER.info("Simplifying ONNX model with onnxslim...")
# model_onnx = onnx.load(str(save_path_onnx))
# model_onnx = onnxslim.slim(model_onnx)
# onnx.save(model_onnx, str(save_path_onnx))
#
# # Report size after simplification
# size_onnx_after = save_path_onnx.stat().st_size / (1024 * 1024) # MB
# reduction = size_onnx_before - size_onnx_after
# LOGGER.info(f"ONNX simplification success: {size_onnx_after:.2f}MB (reduced by {reduction:.2f}MB)")
except ImportError:
LOGGER.warning("onnxslim not installed, skipping simplification. Install with: pip install onnxslim")
except Exception as e:
LOGGER.warning(f"ONNX simplification failed: {e}")
except Exception as e:
LOGGER.error(f"ONNX export failed: {e}")