232 lines
9.7 KiB
Python
Executable File
232 lines
9.7 KiB
Python
Executable File
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from ultralytics.nn.modules import Detect, Pose, Pose26
|
|
from ultralytics.utils import LOGGER
|
|
from ultralytics.utils.downloads import attempt_download_asset
|
|
from ultralytics.utils.files import spaces_in_path
|
|
from ultralytics.utils.tal import make_anchors
|
|
|
|
|
|
def tf_wrapper(model: torch.nn.Module) -> torch.nn.Module:
|
|
"""A wrapper for TensorFlow export compatibility (TF-specific handling is now in head modules)."""
|
|
for m in model.modules():
|
|
if not isinstance(m, Detect):
|
|
continue
|
|
import types
|
|
|
|
m._get_decode_boxes = types.MethodType(_tf_decode_boxes, m)
|
|
if isinstance(m, Pose):
|
|
m.kpts_decode = types.MethodType(partial(_tf_kpts_decode, is_pose26=type(m) is Pose26), m)
|
|
return model
|
|
|
|
|
|
def _tf_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Decode bounding boxes for TensorFlow export."""
|
|
shape = x["feats"][0].shape # BCHW
|
|
boxes = x["boxes"]
|
|
if self.format != "imx" and (self.dynamic or self.shape != shape):
|
|
self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
|
|
self.shape = shape
|
|
grid_h, grid_w = shape[2:4]
|
|
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=boxes.device).reshape(1, 4, 1)
|
|
norm = self.strides / (self.stride[0] * grid_size)
|
|
dbox = self.decode_bboxes(self.dfl(boxes) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
|
return dbox
|
|
|
|
|
|
def _tf_kpts_decode(self, kpts: torch.Tensor, is_pose26: bool = False) -> torch.Tensor:
|
|
"""Decode keypoints for TensorFlow export."""
|
|
ndim = self.kpt_shape[1]
|
|
bs = kpts.shape[0]
|
|
# Precompute normalization factor to increase numerical stability
|
|
y = kpts.view(bs, *self.kpt_shape, -1)
|
|
grid_h, grid_w = self.shape[2:4]
|
|
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
|
|
norm = self.strides / (self.stride[0] * grid_size)
|
|
a = ((y[:, :, :2] + self.anchors) if is_pose26 else (y[:, :, :2] * 2.0 + (self.anchors - 0.5))) * norm
|
|
if ndim == 3:
|
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
|
|
return a.view(bs, self.nk, -1)
|
|
|
|
|
|
def onnx2saved_model(
|
|
onnx_file: str,
|
|
output_dir: Path,
|
|
int8: bool = False,
|
|
images: np.ndarray = None,
|
|
disable_group_convolution: bool = False,
|
|
prefix="",
|
|
):
|
|
"""Convert an ONNX model to TensorFlow SavedModel format using onnx2tf.
|
|
|
|
Args:
|
|
onnx_file (str): ONNX file path.
|
|
output_dir (Path): Output directory path for the SavedModel.
|
|
int8 (bool, optional): Enable INT8 quantization. Defaults to False.
|
|
images (np.ndarray, optional): Calibration images for INT8 quantization in BHWC format.
|
|
disable_group_convolution (bool, optional): Disable group convolution optimization. Defaults to False.
|
|
prefix (str, optional): Logging prefix. Defaults to "".
|
|
|
|
Returns:
|
|
(keras.Model): Converted Keras model.
|
|
|
|
Notes:
|
|
- Requires onnx2tf package. Downloads calibration data if INT8 quantization is enabled.
|
|
- Removes temporary files and renames quantized models after conversion.
|
|
"""
|
|
# Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
|
|
onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
|
|
if not onnx2tf_file.exists():
|
|
attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
|
|
np_data = None
|
|
if int8:
|
|
tmp_file = output_dir / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
|
|
if images is not None:
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
np.save(str(tmp_file), images) # BHWC
|
|
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
|
|
|
|
# Patch onnx.helper for onnx_graphsurgeon compatibility with ONNX>=1.17
|
|
# The float32_to_bfloat16 function was removed in ONNX 1.17, but onnx_graphsurgeon still uses it
|
|
import onnx.helper
|
|
|
|
if not hasattr(onnx.helper, "float32_to_bfloat16"):
|
|
import struct
|
|
|
|
def float32_to_bfloat16(fval):
|
|
"""Convert float32 to bfloat16 (truncates lower 16 bits of mantissa)."""
|
|
ival = struct.unpack("=I", struct.pack("=f", fval))[0]
|
|
return ival >> 16
|
|
|
|
onnx.helper.float32_to_bfloat16 = float32_to_bfloat16
|
|
|
|
import onnx2tf # scoped for after ONNX export for reduced conflict during import
|
|
|
|
LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
|
|
keras_model = onnx2tf.convert(
|
|
input_onnx_file_path=onnx_file,
|
|
output_folder_path=str(output_dir),
|
|
not_use_onnxsim=True,
|
|
verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
|
|
output_integer_quantized_tflite=int8,
|
|
custom_input_op_name_np_data_path=np_data,
|
|
enable_batchmatmul_unfold=True and not int8, # fix lower no. of detected objects on GPU delegate
|
|
output_signaturedefs=True, # fix error with Attention block group convolution
|
|
disable_group_convolution=disable_group_convolution, # fix error with group convolution
|
|
)
|
|
|
|
# Remove/rename TFLite models
|
|
if int8:
|
|
tmp_file.unlink(missing_ok=True)
|
|
for file in output_dir.rglob("*_dynamic_range_quant.tflite"):
|
|
file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
|
|
for file in output_dir.rglob("*_integer_quant_with_int16_act.tflite"):
|
|
file.unlink() # delete extra fp16 activation TFLite files
|
|
return keras_model
|
|
|
|
|
|
def keras2pb(keras_model, file: Path, prefix=""):
|
|
"""Convert a Keras model to TensorFlow GraphDef (.pb) format.
|
|
|
|
Args:
|
|
keras_model (keras.Model): Keras model to convert to frozen graph format.
|
|
file (Path): Output file path (suffix will be changed to .pb).
|
|
prefix (str, optional): Logging prefix. Defaults to "".
|
|
|
|
Notes:
|
|
Creates a frozen graph by converting variables to constants for inference optimization.
|
|
"""
|
|
import tensorflow as tf
|
|
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
|
|
|
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
|
|
m = tf.function(lambda x: keras_model(x)) # full model
|
|
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
|
frozen_func = convert_variables_to_constants_v2(m)
|
|
frozen_func.graph.as_graph_def()
|
|
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(file.parent), name=file.name, as_text=False)
|
|
|
|
|
|
def tflite2edgetpu(tflite_file: str | Path, output_dir: str | Path, prefix: str = ""):
|
|
"""Convert a TensorFlow Lite model to Edge TPU format using the Edge TPU compiler.
|
|
|
|
Args:
|
|
tflite_file (str | Path): Path to the input TensorFlow Lite (.tflite) model file.
|
|
output_dir (str | Path): Output directory path for the compiled Edge TPU model.
|
|
prefix (str, optional): Logging prefix. Defaults to "".
|
|
|
|
Notes:
|
|
Requires the Edge TPU compiler to be installed. The function compiles the TFLite model
|
|
for optimal performance on Google's Edge TPU hardware accelerator.
|
|
"""
|
|
import subprocess
|
|
|
|
cmd = (
|
|
"edgetpu_compiler "
|
|
f'--out_dir "{output_dir}" '
|
|
"--show_operations "
|
|
"--search_delegate "
|
|
"--delegate_search_step 30 "
|
|
"--timeout_sec 180 "
|
|
f'"{tflite_file}"'
|
|
)
|
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
subprocess.run(cmd, shell=True)
|
|
|
|
|
|
def pb2tfjs(pb_file: str, output_dir: str, half: bool = False, int8: bool = False, prefix: str = ""):
|
|
"""Convert a TensorFlow GraphDef (.pb) model to TensorFlow.js format.
|
|
|
|
Args:
|
|
pb_file (str): Path to the input TensorFlow GraphDef (.pb) model file.
|
|
output_dir (str): Output directory path for the converted TensorFlow.js model.
|
|
half (bool, optional): Enable FP16 quantization. Defaults to False.
|
|
int8 (bool, optional): Enable INT8 quantization. Defaults to False.
|
|
prefix (str, optional): Logging prefix. Defaults to "".
|
|
|
|
Notes:
|
|
Requires tensorflowjs package. Uses tensorflowjs_converter command-line tool for conversion.
|
|
Handles spaces in file paths and warns if output directory contains spaces.
|
|
"""
|
|
import subprocess
|
|
|
|
import tensorflow as tf
|
|
import tensorflowjs as tfjs
|
|
|
|
LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
|
|
|
|
gd = tf.Graph().as_graph_def() # TF GraphDef
|
|
with open(pb_file, "rb") as file:
|
|
gd.ParseFromString(file.read())
|
|
outputs = ",".join(gd_outputs(gd))
|
|
LOGGER.info(f"\n{prefix} output node names: {outputs}")
|
|
|
|
quantization = "--quantize_float16" if half else "--quantize_uint8" if int8 else ""
|
|
with spaces_in_path(pb_file) as fpb_, spaces_in_path(output_dir) as f_: # exporter cannot handle spaces in paths
|
|
cmd = (
|
|
"tensorflowjs_converter "
|
|
f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
|
|
)
|
|
LOGGER.info(f"{prefix} running '{cmd}'")
|
|
subprocess.run(cmd, shell=True)
|
|
|
|
if " " in output_dir:
|
|
LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{output_dir}'.")
|
|
|
|
|
|
def gd_outputs(gd):
|
|
"""Return TensorFlow GraphDef model output node names."""
|
|
name_list, input_list = [], []
|
|
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
|
name_list.append(node.name)
|
|
input_list.extend(node.input)
|
|
return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
|