feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code, Docker Compose setup, and vendored dataset scaffolds for clone-and-run. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
[package]
|
||||
name = "yolov8-rs"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
image = { version = "0.25.2"}
|
||||
imageproc = { version = "0.25.0"}
|
||||
ndarray = { version = "0.16" }
|
||||
ort = { version = "2.0.0-rc.5", features = ["cuda", "tensorrt", "download-binaries", "copy-dylibs", "half"]}
|
||||
rusttype = { version = "0.9.3" }
|
||||
anyhow = { version = "1.0.75" }
|
||||
regex = { version = "1.5.4" }
|
||||
rand = { version = "0.8.5" }
|
||||
chrono = { version = "0.4.30" }
|
||||
half = { version = "2.3.1" }
|
||||
dirs = { version = "5.0.1" }
|
||||
ureq = { version = "2.9.1" }
|
||||
ab_glyph = "0.2.29"
|
||||
@@ -0,0 +1,230 @@
|
||||
# YOLOv8-ONNXRuntime-Rust for All Key YOLO Tasks
|
||||
|
||||
This repository provides a Rust demonstration for performing Ultralytics YOLOv8 tasks like [Classification](https://docs.ultralytics.com/tasks/classify/), [Segmentation](https://docs.ultralytics.com/tasks/segment/), [Detection](https://docs.ultralytics.com/tasks/detect/), [Pose Estimation](https://docs.ultralytics.com/tasks/pose/), and [Oriented Bounding Box (OBB)](https://docs.ultralytics.com/tasks/obb/) detection using the [ONNXRuntime](https://onnxruntime.ai/).
|
||||
|
||||
## ✨ Recently Updated
|
||||
|
||||
- Added YOLOv8-OBB demo.
|
||||
- Updated ONNXRuntime dependency to 1.19.x.
|
||||
|
||||
Newly updated YOLOv8 example code is located in [this repository](https://github.com/jamjamjon/usls/tree/main/examples/yolo).
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
- Supports `Classification`, `Segmentation`, `Detection`, `Pose(Keypoints)-Detection`, and `OBB` tasks.
|
||||
- Supports `FP16` & `FP32` [ONNX](https://onnx.ai/) models.
|
||||
- Supports `CPU`, `CUDA`, and `TensorRT` execution providers to accelerate computation.
|
||||
- Supports dynamic input shapes (`batch`, `width`, `height`).
|
||||
|
||||
## 🛠️ Installation
|
||||
|
||||
### 1. Install Rust
|
||||
|
||||
Please follow the official Rust installation guide: [https://www.rust-lang.org/tools/install](https://rust-lang.org/tools/install/).
|
||||
|
||||
### 2. ONNXRuntime Linking
|
||||
|
||||
- #### For detailed setup instructions, refer to the [ORT documentation](https://ort.pyke.io/setup/linking).
|
||||
|
||||
- #### For Linux or macOS Users:
|
||||
- Download the ONNX Runtime package from the [Releases page](https://github.com/microsoft/onnxruntime/releases).
|
||||
- Set up the library path by exporting the `ORT_DYLIB_PATH` environment variable:
|
||||
```bash
|
||||
export ORT_DYLIB_PATH=/path/to/onnxruntime/lib/libonnxruntime.so.1.19.0 # Adjust version/path as needed
|
||||
```
|
||||
|
||||
### 3. [Optional] Install CUDA & CuDNN & TensorRT
|
||||
|
||||
- The CUDA execution provider requires [CUDA](https://developer.nvidia.com/cuda-toolkit) v11.6+.
|
||||
- The TensorRT execution provider requires CUDA v11.4+ and [TensorRT](https://developer.nvidia.com/tensorrt) v8.4+. You may also need [cuDNN](https://developer.nvidia.com/cudnn).
|
||||
|
||||
## ▶️ Get Started
|
||||
|
||||
### 1. Export the Ultralytics YOLOv8 ONNX Models
|
||||
|
||||
First, install the Ultralytics package:
|
||||
|
||||
```bash
|
||||
pip install -U ultralytics
|
||||
```
|
||||
|
||||
Then, export the desired [Ultralytics YOLOv8](https://docs.ultralytics.com/models/yolov8/) models to the ONNX format. See the [Export documentation](https://docs.ultralytics.com/modes/export/) for more details.
|
||||
|
||||
```bash
|
||||
# Export ONNX model with dynamic shapes (recommended for flexibility)
|
||||
yolo export model=yolov8m.pt format=onnx simplify dynamic
|
||||
yolo export model=yolov8m-cls.pt format=onnx simplify dynamic
|
||||
yolo export model=yolov8m-pose.pt format=onnx simplify dynamic
|
||||
yolo export model=yolov8m-seg.pt format=onnx simplify dynamic
|
||||
# yolo export model=yolov8m-obb.pt format=onnx simplify dynamic # Add OBB export if needed
|
||||
|
||||
# Export ONNX model with constant shapes (if dynamic shapes are not required)
|
||||
# yolo export model=yolov8m.pt format=onnx simplify
|
||||
# yolo export model=yolov8m-cls.pt format=onnx simplify
|
||||
# yolo export model=yolov8m-pose.pt format=onnx simplify
|
||||
# yolo export model=yolov8m-seg.pt format=onnx simplify
|
||||
# yolo export model=yolov8m-obb.pt format=onnx simplify
|
||||
```
|
||||
|
||||
### 2. Run Inference
|
||||
|
||||
This command will perform inference using the specified ONNX model on the source image using the CPU.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --model MODEL_PATH.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
#### Using GPU Acceleration
|
||||
|
||||
Set `--cuda` to use the CUDA execution provider for faster inference on NVIDIA GPUs.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --cuda --model MODEL_PATH.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
Set `--trt` to use the TensorRT execution provider. You can also set `--fp16` simultaneously to leverage the TensorRT FP16 engine for potentially even greater speed, especially on compatible hardware.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --trt --fp16 --model MODEL_PATH.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
#### Specifying Device and Batch Size
|
||||
|
||||
Set `--device_id` to select a specific GPU device. If the specified device ID is invalid (e.g., setting `device_id 1` when only one GPU exists), `ort` will automatically fall back to the `CPU` execution provider without causing a panic.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --cuda --device_id 0 --model MODEL_PATH.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
Set `--batch` to perform inference with a specific batch size.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --cuda --batch 2 --model MODEL_PATH.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
If you're using `--trt` with a model exported with dynamic batch dimensions, you can explicitly specify the minimum, optimal, and maximum batch sizes for TensorRT optimization using `--batch-min`, `--batch`, and `--batch-max`. Refer to the [TensorRT Execution Provider documentation](https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#explicit-shape-range-for-dynamic-shape-input) for details.
|
||||
|
||||
#### Dynamic Image Size
|
||||
|
||||
Set `--height` and `--width` to perform inference with dynamic image sizes. **Note:** The ONNX model must have been exported with dynamic input shapes (`dynamic=True`).
|
||||
|
||||
```bash
|
||||
cargo run --release -- --cuda --width 480 --height 640 --model MODEL_PATH_dynamic.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
#### Profiling Performance
|
||||
|
||||
Set `--profile` to measure the time consumed in each stage of the inference pipeline (preprocessing, H2D transfer, inference, D2H transfer, postprocessing). **Note:** Models often require a few "warm-up" runs (1-3 iterations) before reaching optimal performance. Ensure you run the command enough times to get a stable performance evaluation.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --trt --fp16 --profile --model MODEL_PATH.onnx --source SOURCE_IMAGE.jpg
|
||||
```
|
||||
|
||||
Example Profile Output (yolov8m.onnx, batch=1, 3 runs, trt, fp16, RTX 3060Ti):
|
||||
|
||||
```text
|
||||
==> 0 # Warm-up run
|
||||
[Model Preprocess]: 12.75788ms
|
||||
[ORT H2D]: 237.118µs
|
||||
[ORT Inference]: 507.895469ms
|
||||
[ORT D2H]: 191.655µs
|
||||
[Model Inference]: 508.34589ms
|
||||
[Model Postprocess]: 1.061122ms
|
||||
==> 1 # Stable run
|
||||
[Model Preprocess]: 13.658655ms
|
||||
[ORT H2D]: 209.975µs
|
||||
[ORT Inference]: 5.12372ms
|
||||
[ORT D2H]: 182.389µs
|
||||
[Model Inference]: 5.530022ms
|
||||
[Model Postprocess]: 1.04851ms
|
||||
==> 2 # Stable run
|
||||
[Model Preprocess]: 12.475332ms
|
||||
[ORT H2D]: 246.127µs
|
||||
[ORT Inference]: 5.048432ms
|
||||
[ORT D2H]: 187.117µs
|
||||
[Model Inference]: 5.493119ms
|
||||
[Model Postprocess]: 1.040906ms
|
||||
```
|
||||
|
||||
#### Other Options
|
||||
|
||||
- `--conf`: Confidence threshold for detections \[default: 0.3].
|
||||
- `--iou`: IoU (Intersection over Union) threshold for Non-Maximum Suppression (NMS) \[default: 0.45].
|
||||
- `--kconf`: Confidence threshold for keypoints (in Pose Estimation) \[default: 0.55].
|
||||
- `--plot`: Plot the inference results with random RGB colors and save the output image to the `runs` directory.
|
||||
|
||||
You can view all available command-line arguments by running:
|
||||
|
||||
```bash
|
||||
# Clone the repository if you haven't already
|
||||
# git clone https://github.com/ultralytics/ultralytics
|
||||
# cd ultralytics/examples/YOLOv8-ONNXRuntime-Rust
|
||||
|
||||
cargo run --release -- --help
|
||||
```
|
||||
|
||||
## 🖼️ Examples
|
||||
|
||||

|
||||
|
||||
### Classification
|
||||
|
||||
Running a dynamic shape ONNX classification model on the `CPU` with a specific image size (`--height 224 --width 224`). The plotted result image will be saved in the `runs` directory.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --model ../assets/weights/yolov8m-cls-dyn.onnx --source ../assets/images/dog.jpg --height 224 --width 224 --plot --profile
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```text
|
||||
Summary:
|
||||
> Task: Classify (Ultralytics 8.0.217) # Version might differ
|
||||
> EP: Cpu
|
||||
> Dtype: Float32
|
||||
> Batch: 1 (Dynamic), Height: 224 (Dynamic), Width: 224 (Dynamic)
|
||||
> nc: 1000 nk: 0, nm: 0, conf: 0.3, kconf: 0.55, iou: 0.45
|
||||
|
||||
[Model Preprocess]: 16.363477ms
|
||||
[ORT H2D]: 50.722µs
|
||||
[ORT Inference]: 16.295808ms
|
||||
[ORT D2H]: 8.37µs
|
||||
[Model Inference]: 16.367046ms
|
||||
[Model Postprocess]: 3.527µs
|
||||
[
|
||||
YOLOResult {
|
||||
Probs(top5): Some([(208, 0.6950566), (209, 0.13823675), (178, 0.04849795), (215, 0.019029364), (212, 0.016506357)]), # Class IDs and confidences
|
||||
Bboxes: None,
|
||||
Keypoints: None,
|
||||
Masks: None,
|
||||
},
|
||||
]
|
||||
```
|
||||
|
||||
### Object Detection
|
||||
|
||||
Using the `CUDA` execution provider and a dynamic image size (`--height 640 --width 480`).
|
||||
|
||||
```bash
|
||||
cargo run --release -- --cuda --model ../assets/weights/yolov8m-dynamic.onnx --source ../assets/images/bus.jpg --plot --height 640 --width 480
|
||||
```
|
||||
|
||||
### Pose Detection
|
||||
|
||||
Using the `TensorRT` execution provider.
|
||||
|
||||
```bash
|
||||
cargo run --release -- --trt --model ../assets/weights/yolov8m-pose.onnx --source ../assets/images/bus.jpg --plot
|
||||
```
|
||||
|
||||
### Instance Segmentation
|
||||
|
||||
Using the `TensorRT` execution provider with an FP16 model (`--fp16`).
|
||||
|
||||
```bash
|
||||
cargo run --release -- --trt --fp16 --model ../assets/weights/yolov8m-seg.onnx --source ../assets/images/0172.jpg --plot
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! If you find any issues or have suggestions for improvement, please feel free to open an issue or submit a pull request to the main [Ultralytics repository](https://github.com/ultralytics/ultralytics).
|
||||
@@ -0,0 +1,89 @@
|
||||
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use crate::YOLOTask;
|
||||
|
||||
#[derive(Parser, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// ONNX model path
|
||||
#[arg(long, required = true)]
|
||||
pub model: String,
|
||||
|
||||
/// input path
|
||||
#[arg(long, required = true)]
|
||||
pub source: String,
|
||||
|
||||
/// device id
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub device_id: i32,
|
||||
|
||||
/// using TensorRT EP
|
||||
#[arg(long)]
|
||||
pub trt: bool,
|
||||
|
||||
/// using CUDA EP
|
||||
#[arg(long)]
|
||||
pub cuda: bool,
|
||||
|
||||
/// input batch size
|
||||
#[arg(long, default_value_t = 1)]
|
||||
pub batch: u32,
|
||||
|
||||
/// trt input min_batch size
|
||||
#[arg(long, default_value_t = 1)]
|
||||
pub batch_min: u32,
|
||||
|
||||
/// trt input max_batch size
|
||||
#[arg(long, default_value_t = 32)]
|
||||
pub batch_max: u32,
|
||||
|
||||
/// using TensorRT --fp16
|
||||
#[arg(long)]
|
||||
pub fp16: bool,
|
||||
|
||||
/// specify YOLO task
|
||||
#[arg(long, value_enum)]
|
||||
pub task: Option<YOLOTask>,
|
||||
|
||||
/// num_classes
|
||||
#[arg(long)]
|
||||
pub nc: Option<u32>,
|
||||
|
||||
/// num_keypoints
|
||||
#[arg(long)]
|
||||
pub nk: Option<u32>,
|
||||
|
||||
/// num_masks
|
||||
#[arg(long)]
|
||||
pub nm: Option<u32>,
|
||||
|
||||
/// input image width
|
||||
#[arg(long)]
|
||||
pub width: Option<u32>,
|
||||
|
||||
/// input image height
|
||||
#[arg(long)]
|
||||
pub height: Option<u32>,
|
||||
|
||||
/// confidence threshold
|
||||
#[arg(long, required = false, default_value_t = 0.3)]
|
||||
pub conf: f32,
|
||||
|
||||
/// iou threshold in NMS
|
||||
#[arg(long, required = false, default_value_t = 0.45)]
|
||||
pub iou: f32,
|
||||
|
||||
/// confidence threshold of keypoint
|
||||
#[arg(long, required = false, default_value_t = 0.55)]
|
||||
pub kconf: f32,
|
||||
|
||||
/// plot inference result and save
|
||||
#[arg(long)]
|
||||
pub plot: bool,
|
||||
|
||||
/// check time consumed in each stage
|
||||
#[arg(long)]
|
||||
pub profile: bool,
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
use std::io::{Read, Write};
|
||||
|
||||
pub mod cli;
|
||||
pub mod model;
|
||||
pub mod ort_backend;
|
||||
pub mod yolo_result;
|
||||
pub use crate::cli::Args;
|
||||
pub use crate::model::YOLOv8;
|
||||
pub use crate::ort_backend::{Batch, OrtBackend, OrtConfig, OrtEP, YOLOTask};
|
||||
pub use crate::yolo_result::{Bbox, Embedding, Point2, YOLOResult};
|
||||
|
||||
pub fn non_max_suppression(
|
||||
xs: &mut Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)>,
|
||||
iou_threshold: f32,
|
||||
) {
|
||||
xs.sort_by(|b1, b2| b2.0.confidence().partial_cmp(&b1.0.confidence()).unwrap());
|
||||
|
||||
let mut current_index = 0;
|
||||
for index in 0..xs.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = xs[prev_index].0.iou(&xs[index].0);
|
||||
if iou > iou_threshold {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
xs.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
xs.truncate(current_index);
|
||||
}
|
||||
|
||||
pub fn gen_time_string(delimiter: &str) -> String {
|
||||
let offset = chrono::FixedOffset::east_opt(8 * 60 * 60).unwrap(); // Beijing
|
||||
let t_now = chrono::Utc::now().with_timezone(&offset);
|
||||
let fmt = format!(
|
||||
"%Y{}%m{}%d{}%H{}%M{}%S{}%f",
|
||||
delimiter, delimiter, delimiter, delimiter, delimiter, delimiter
|
||||
);
|
||||
t_now.format(&fmt).to_string()
|
||||
}
|
||||
|
||||
pub const SKELETON: [(usize, usize); 16] = [
|
||||
(0, 1),
|
||||
(0, 2),
|
||||
(1, 3),
|
||||
(2, 4),
|
||||
(5, 6),
|
||||
(5, 11),
|
||||
(6, 12),
|
||||
(11, 12),
|
||||
(5, 7),
|
||||
(6, 8),
|
||||
(7, 9),
|
||||
(8, 10),
|
||||
(11, 13),
|
||||
(12, 14),
|
||||
(13, 15),
|
||||
(14, 16),
|
||||
];
|
||||
|
||||
pub fn check_font(font: &str) -> rusttype::Font<'static> {
|
||||
// check then load font
|
||||
|
||||
// ultralytics font path
|
||||
let font_path_config = match dirs::config_dir() {
|
||||
Some(mut d) => {
|
||||
d.push("Ultralytics");
|
||||
d.push(font);
|
||||
d
|
||||
}
|
||||
None => panic!("Unsupported operating system. Now support Linux, MacOS, Windows."),
|
||||
};
|
||||
|
||||
// current font path
|
||||
let font_path_current = std::path::PathBuf::from(font);
|
||||
|
||||
// check font
|
||||
let font_path = if font_path_config.exists() {
|
||||
font_path_config
|
||||
} else if font_path_current.exists() {
|
||||
font_path_current
|
||||
} else {
|
||||
println!("Downloading font...");
|
||||
let source_url = "https://ultralytics.com/assets/Arial.ttf";
|
||||
let resp = ureq::get(source_url)
|
||||
.timeout(std::time::Duration::from_secs(500))
|
||||
.call()
|
||||
.unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
|
||||
|
||||
// read to buffer with size limit (10MB max for font file)
|
||||
const MAX_FONT_SIZE: u64 = 10 * 1024 * 1024;
|
||||
let mut buffer = vec![];
|
||||
resp.into_reader()
|
||||
.take(MAX_FONT_SIZE)
|
||||
.read_to_end(&mut buffer)
|
||||
.unwrap();
|
||||
|
||||
// save
|
||||
let _path = std::fs::File::create(font).unwrap();
|
||||
let mut writer = std::io::BufWriter::new(_path);
|
||||
writer.write_all(&buffer).unwrap();
|
||||
println!("Font saved at: {:?}", font_path_current.display());
|
||||
font_path_current
|
||||
};
|
||||
|
||||
// load font
|
||||
let buffer = std::fs::read(font_path).unwrap();
|
||||
rusttype::Font::try_from_vec(buffer).unwrap()
|
||||
}
|
||||
|
||||
use ab_glyph::FontArc;
|
||||
pub fn load_font() -> FontArc {
|
||||
use std::path::Path;
|
||||
let font_path = Path::new("./font/Arial.ttf");
|
||||
match font_path.try_exists() {
|
||||
Ok(true) => {
|
||||
let buffer = std::fs::read(font_path).unwrap();
|
||||
FontArc::try_from_vec(buffer).unwrap()
|
||||
}
|
||||
Ok(false) => {
|
||||
std::fs::create_dir_all("./font").unwrap();
|
||||
println!("Downloading font...");
|
||||
let source_url = "https://ultralytics.com/assets/Arial.ttf";
|
||||
let resp = ureq::get(source_url)
|
||||
.timeout(std::time::Duration::from_secs(500))
|
||||
.call()
|
||||
.unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
|
||||
|
||||
// read to buffer with size limit (10MB max for font file)
|
||||
const MAX_FONT_SIZE: u64 = 10 * 1024 * 1024;
|
||||
let mut buffer = vec![];
|
||||
resp.into_reader()
|
||||
.take(MAX_FONT_SIZE)
|
||||
.read_to_end(&mut buffer)
|
||||
.unwrap();
|
||||
// save
|
||||
let mut fd = std::fs::File::create(font_path).unwrap();
|
||||
fd.write_all(&buffer).unwrap();
|
||||
println!("Font saved at: {:?}", font_path.display());
|
||||
FontArc::try_from_vec(buffer).unwrap()
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Failed to load font {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use yolov8_rs::{Args, YOLOv8};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
|
||||
// 1. load image
|
||||
let x = image::ImageReader::open(&args.source)?
|
||||
.with_guessed_format()?
|
||||
.decode()?;
|
||||
|
||||
// 2. model support dynamic batch inference, so input should be a Vec
|
||||
let xs = vec![x];
|
||||
|
||||
// You can test `--batch 2` with this
|
||||
// let xs = vec![x.clone(), x];
|
||||
|
||||
// 3. build yolov8 model
|
||||
let mut model = YOLOv8::new(args)?;
|
||||
model.summary(); // model info
|
||||
|
||||
// 4. run
|
||||
let ys = model.run(&xs)?;
|
||||
println!("{:?}", ys);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
#![allow(clippy::type_complexity)]
|
||||
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
use ab_glyph::FontArc;
|
||||
use anyhow::Result;
|
||||
use image::{DynamicImage, GenericImageView, ImageBuffer};
|
||||
use ndarray::{s, Array, Axis, IxDyn};
|
||||
use rand::{thread_rng, Rng};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::{
|
||||
gen_time_string, load_font, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
|
||||
OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON,
|
||||
};
|
||||
|
||||
pub struct YOLOv8 {
|
||||
// YOLOv8 model for all yolo-tasks
|
||||
engine: OrtBackend,
|
||||
nc: u32,
|
||||
nk: u32,
|
||||
nm: u32,
|
||||
height: u32,
|
||||
width: u32,
|
||||
batch: u32,
|
||||
task: YOLOTask,
|
||||
conf: f32,
|
||||
kconf: f32,
|
||||
iou: f32,
|
||||
names: Vec<String>,
|
||||
color_palette: Vec<(u8, u8, u8)>,
|
||||
profile: bool,
|
||||
plot: bool,
|
||||
}
|
||||
|
||||
impl YOLOv8 {
|
||||
pub fn new(config: Args) -> Result<Self> {
|
||||
// execution provider
|
||||
let ep = if config.trt {
|
||||
OrtEP::Trt(config.device_id)
|
||||
} else if config.cuda {
|
||||
OrtEP::CUDA(config.device_id)
|
||||
} else {
|
||||
OrtEP::CPU
|
||||
};
|
||||
|
||||
// batch
|
||||
let batch = Batch {
|
||||
opt: config.batch,
|
||||
min: config.batch_min,
|
||||
max: config.batch_max,
|
||||
};
|
||||
|
||||
// build ort engine
|
||||
let ort_args = OrtConfig {
|
||||
ep,
|
||||
batch,
|
||||
f: config.model,
|
||||
task: config.task,
|
||||
trt_fp16: config.fp16,
|
||||
image_size: (config.height, config.width),
|
||||
};
|
||||
let engine = OrtBackend::build(ort_args)?;
|
||||
|
||||
// get batch, height, width, tasks, nc, nk, nm
|
||||
let (batch, height, width, task) = (
|
||||
engine.batch(),
|
||||
engine.height(),
|
||||
engine.width(),
|
||||
engine.task(),
|
||||
);
|
||||
let nc = engine.nc().or(config.nc).unwrap_or_else(|| {
|
||||
panic!("Failed to get num_classes, make it explicit with `--nc`");
|
||||
});
|
||||
let (nk, nm) = match task {
|
||||
YOLOTask::Pose => {
|
||||
let nk = engine.nk().or(config.nk).unwrap_or_else(|| {
|
||||
panic!("Failed to get num_keypoints, make it explicit with `--nk`");
|
||||
});
|
||||
(nk, 0)
|
||||
}
|
||||
YOLOTask::Segment => {
|
||||
let nm = engine.nm().or(config.nm).unwrap_or_else(|| {
|
||||
panic!("Failed to get num_masks, make it explicit with `--nm`");
|
||||
});
|
||||
(0, nm)
|
||||
}
|
||||
_ => (0, 0),
|
||||
};
|
||||
|
||||
// class names
|
||||
let names = engine.names().unwrap_or(vec!["Unknown".to_string()]);
|
||||
|
||||
// color palette
|
||||
let mut rng = thread_rng();
|
||||
let color_palette: Vec<_> = names
|
||||
.iter()
|
||||
.map(|_| {
|
||||
(
|
||||
rng.gen_range(0..=255),
|
||||
rng.gen_range(0..=255),
|
||||
rng.gen_range(0..=255),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
engine,
|
||||
names,
|
||||
conf: config.conf,
|
||||
kconf: config.kconf,
|
||||
iou: config.iou,
|
||||
color_palette,
|
||||
profile: config.profile,
|
||||
plot: config.plot,
|
||||
nc,
|
||||
nk,
|
||||
nm,
|
||||
height,
|
||||
width,
|
||||
batch,
|
||||
task,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
|
||||
let r = (w1 / w0).min(h1 / h0);
|
||||
(r, (w0 * r).round(), (h0 * r).round())
|
||||
}
|
||||
|
||||
pub fn preprocess(&mut self, xs: &Vec<DynamicImage>) -> Result<Array<f32, IxDyn>> {
|
||||
let mut ys =
|
||||
Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn();
|
||||
ys.fill(144.0 / 255.0);
|
||||
for (idx, x) in xs.iter().enumerate() {
|
||||
let img = match self.task() {
|
||||
YOLOTask::Classify => x.resize_exact(
|
||||
self.width(),
|
||||
self.height(),
|
||||
image::imageops::FilterType::Triangle,
|
||||
),
|
||||
_ => {
|
||||
let (w0, h0) = x.dimensions();
|
||||
let w0 = w0 as f32;
|
||||
let h0 = h0 as f32;
|
||||
let (_, w_new, h_new) =
|
||||
self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round
|
||||
x.resize_exact(
|
||||
w_new as u32,
|
||||
h_new as u32,
|
||||
if let YOLOTask::Segment = self.task() {
|
||||
image::imageops::FilterType::CatmullRom
|
||||
} else {
|
||||
image::imageops::FilterType::Triangle
|
||||
},
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
for (x, y, rgb) in img.pixels() {
|
||||
let x = x as usize;
|
||||
let y = y as usize;
|
||||
let [r, g, b, _] = rgb.0;
|
||||
ys[[idx, 0, y, x]] = (r as f32) / 255.0;
|
||||
ys[[idx, 1, y, x]] = (g as f32) / 255.0;
|
||||
ys[[idx, 2, y, x]] = (b as f32) / 255.0;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &Vec<DynamicImage>) -> Result<Vec<YOLOResult>> {
|
||||
// pre-process
|
||||
let t_pre = std::time::Instant::now();
|
||||
let xs_ = self.preprocess(xs)?;
|
||||
if self.profile {
|
||||
println!("[Model Preprocess]: {:?}", t_pre.elapsed());
|
||||
}
|
||||
|
||||
// run
|
||||
let t_run = std::time::Instant::now();
|
||||
let ys = self.engine.run(xs_, self.profile)?;
|
||||
if self.profile {
|
||||
println!("[Model Inference]: {:?}", t_run.elapsed());
|
||||
}
|
||||
|
||||
// post-process
|
||||
let t_post = std::time::Instant::now();
|
||||
let ys = self.postprocess(ys, xs)?;
|
||||
if self.profile {
|
||||
println!("[Model Postprocess]: {:?}", t_post.elapsed());
|
||||
}
|
||||
|
||||
// plot and save
|
||||
if self.plot {
|
||||
self.plot_and_save(&ys, xs, Some(&SKELETON));
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn postprocess(
|
||||
&self,
|
||||
xs: Vec<Array<f32, IxDyn>>,
|
||||
xs0: &[DynamicImage],
|
||||
) -> Result<Vec<YOLOResult>> {
|
||||
if let YOLOTask::Classify = self.task() {
|
||||
let mut ys = Vec::new();
|
||||
let preds = &xs[0];
|
||||
for batch in preds.axis_iter(Axis(0)) {
|
||||
ys.push(YOLOResult::new(
|
||||
Some(Embedding::new(batch.into_owned())),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
));
|
||||
}
|
||||
Ok(ys)
|
||||
} else {
|
||||
const CXYWH_OFFSET: usize = 4; // cxcywh
|
||||
const KPT_STEP: usize = 3; // xyconf
|
||||
let preds = &xs[0];
|
||||
let protos = {
|
||||
if xs.len() > 1 {
|
||||
Some(&xs[1])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
let mut ys = Vec::new();
|
||||
for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() {
|
||||
// [bs, 4 + nc + nm, anchors]
|
||||
// input image
|
||||
let width_original = xs0[idx].width() as f32;
|
||||
let height_original = xs0[idx].height() as f32;
|
||||
let ratio = (self.width() as f32 / width_original)
|
||||
.min(self.height() as f32 / height_original);
|
||||
|
||||
// save each result
|
||||
let mut data: Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)> = Vec::new();
|
||||
for pred in anchor.axis_iter(Axis(1)) {
|
||||
// split preds for different tasks
|
||||
let bbox = pred.slice(s![0..CXYWH_OFFSET]);
|
||||
let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]);
|
||||
let kpts = {
|
||||
if let YOLOTask::Pose = self.task() {
|
||||
Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
let coefs = {
|
||||
if let YOLOTask::Segment = self.task() {
|
||||
Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// confidence and id
|
||||
let (id, &confidence) = clss
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.reduce(|max, x| if x.1 > max.1 { x } else { max })
|
||||
.unwrap(); // definitely will not panic!
|
||||
|
||||
// confidence filter
|
||||
if confidence < self.conf {
|
||||
continue;
|
||||
}
|
||||
|
||||
// bbox re-scale
|
||||
let cx = bbox[0] / ratio;
|
||||
let cy = bbox[1] / ratio;
|
||||
let w = bbox[2] / ratio;
|
||||
let h = bbox[3] / ratio;
|
||||
let x = cx - w / 2.;
|
||||
let y = cy - h / 2.;
|
||||
let y_bbox = Bbox::new(
|
||||
x.max(0.0f32).min(width_original),
|
||||
y.max(0.0f32).min(height_original),
|
||||
w,
|
||||
h,
|
||||
id,
|
||||
confidence,
|
||||
);
|
||||
|
||||
// kpts
|
||||
let y_kpts = {
|
||||
if let Some(kpts) = kpts {
|
||||
let mut kpts_ = Vec::new();
|
||||
// rescale
|
||||
for i in 0..self.nk() as usize {
|
||||
let kx = kpts[KPT_STEP * i] / ratio;
|
||||
let ky = kpts[KPT_STEP * i + 1] / ratio;
|
||||
let kconf = kpts[KPT_STEP * i + 2];
|
||||
if kconf < self.kconf {
|
||||
kpts_.push(Point2::default());
|
||||
} else {
|
||||
kpts_.push(Point2::new_with_conf(
|
||||
kx.max(0.0f32).min(width_original),
|
||||
ky.max(0.0f32).min(height_original),
|
||||
kconf,
|
||||
));
|
||||
}
|
||||
}
|
||||
Some(kpts_)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// data merged
|
||||
data.push((y_bbox, y_kpts, coefs));
|
||||
}
|
||||
|
||||
// nms
|
||||
non_max_suppression(&mut data, self.iou);
|
||||
|
||||
// decode
|
||||
let mut y_bboxes: Vec<Bbox> = Vec::new();
|
||||
let mut y_kpts: Vec<Vec<Point2>> = Vec::new();
|
||||
let mut y_masks: Vec<Vec<u8>> = Vec::new();
|
||||
for elem in data.into_iter() {
|
||||
if let Some(kpts) = elem.1 {
|
||||
y_kpts.push(kpts)
|
||||
}
|
||||
|
||||
// decode masks
|
||||
if let Some(coefs) = elem.2 {
|
||||
let proto = protos.unwrap().slice(s![idx, .., .., ..]);
|
||||
let (nm, nh, nw) = proto.dim();
|
||||
|
||||
// coefs * proto -> mask
|
||||
let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
|
||||
|
||||
let proto = proto.to_owned();
|
||||
let proto = proto.to_shape((nm, nh * nw))?; // (nm, nh*nw)
|
||||
let mask = coefs.dot(&proto); // (nh, nw, n)
|
||||
let mask = mask.to_shape((nh, nw, 1))?;
|
||||
|
||||
// build image from ndarray
|
||||
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
|
||||
match ImageBuffer::from_raw(
|
||||
nw as u32,
|
||||
nh as u32,
|
||||
mask.to_owned().into_raw_vec_and_offset().0,
|
||||
) {
|
||||
Some(image) => image,
|
||||
None => panic!("can not create image from ndarray"),
|
||||
};
|
||||
let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn
|
||||
|
||||
// rescale masks
|
||||
let (_, w_mask, h_mask) =
|
||||
self.scale_wh(width_original, height_original, nw as f32, nh as f32);
|
||||
let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
|
||||
let mask_original = mask_cropped.resize_exact(
|
||||
// resize_to_fill
|
||||
width_original as u32,
|
||||
height_original as u32,
|
||||
match self.task() {
|
||||
YOLOTask::Segment => image::imageops::FilterType::CatmullRom,
|
||||
_ => image::imageops::FilterType::Triangle,
|
||||
},
|
||||
);
|
||||
|
||||
// crop-mask with bbox
|
||||
let mut mask_original_cropped = mask_original.into_luma8();
|
||||
for y in 0..height_original as usize {
|
||||
for x in 0..width_original as usize {
|
||||
if x < elem.0.xmin() as usize
|
||||
|| x > elem.0.xmax() as usize
|
||||
|| y < elem.0.ymin() as usize
|
||||
|| y > elem.0.ymax() as usize
|
||||
{
|
||||
mask_original_cropped.put_pixel(
|
||||
x as u32,
|
||||
y as u32,
|
||||
image::Luma([0u8]),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
y_masks.push(mask_original_cropped.into_raw());
|
||||
}
|
||||
y_bboxes.push(elem.0);
|
||||
}
|
||||
|
||||
// save each result
|
||||
let y = YOLOResult {
|
||||
probs: None,
|
||||
bboxes: if !y_bboxes.is_empty() {
|
||||
Some(y_bboxes)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
keypoints: if !y_kpts.is_empty() {
|
||||
Some(y_kpts)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
masks: if !y_masks.is_empty() {
|
||||
Some(y_masks)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
ys.push(y);
|
||||
}
|
||||
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn plot_and_save(
|
||||
&self,
|
||||
ys: &[YOLOResult],
|
||||
xs0: &[DynamicImage],
|
||||
skeletons: Option<&[(usize, usize)]>,
|
||||
) {
|
||||
// check font then load
|
||||
let font: FontArc = load_font();
|
||||
for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() {
|
||||
let mut img = img0.to_rgb8();
|
||||
|
||||
// draw for classifier
|
||||
if let Some(probs) = y.probs() {
|
||||
for (i, k) in probs.topk(5).iter().enumerate() {
|
||||
let legend = format!("{} {:.2}%", self.names[k.0], k.1);
|
||||
let scale = 32;
|
||||
let legend_size = img.width().max(img.height()) / scale;
|
||||
let x = img.width() / 20;
|
||||
let y = img.height() / 20 + i as u32 * legend_size;
|
||||
|
||||
imageproc::drawing::draw_text_mut(
|
||||
&mut img,
|
||||
image::Rgb([0, 255, 0]),
|
||||
x as i32,
|
||||
y as i32,
|
||||
legend_size as f32,
|
||||
&font,
|
||||
&legend,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// draw bboxes & keypoints
|
||||
if let Some(bboxes) = y.bboxes() {
|
||||
for (_idx, bbox) in bboxes.iter().enumerate() {
|
||||
// rect
|
||||
imageproc::drawing::draw_hollow_rect_mut(
|
||||
&mut img,
|
||||
imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32)
|
||||
.of_size(bbox.width() as u32, bbox.height() as u32),
|
||||
image::Rgb(self.color_palette[bbox.id()].into()),
|
||||
);
|
||||
|
||||
// text
|
||||
let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence());
|
||||
let scale = 40;
|
||||
let legend_size = img.width().max(img.height()) / scale;
|
||||
imageproc::drawing::draw_text_mut(
|
||||
&mut img,
|
||||
image::Rgb(self.color_palette[bbox.id()].into()),
|
||||
bbox.xmin() as i32,
|
||||
(bbox.ymin() - legend_size as f32) as i32,
|
||||
legend_size as f32,
|
||||
&font,
|
||||
&legend,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// draw kpts
|
||||
if let Some(keypoints) = y.keypoints() {
|
||||
for kpts in keypoints.iter() {
|
||||
for kpt in kpts.iter() {
|
||||
// filter
|
||||
if kpt.confidence() < self.kconf {
|
||||
continue;
|
||||
}
|
||||
|
||||
// draw point
|
||||
imageproc::drawing::draw_filled_circle_mut(
|
||||
&mut img,
|
||||
(kpt.x() as i32, kpt.y() as i32),
|
||||
2,
|
||||
image::Rgb([0, 255, 0]),
|
||||
);
|
||||
}
|
||||
|
||||
// draw skeleton if has
|
||||
if let Some(skeletons) = skeletons {
|
||||
for &(idx1, idx2) in skeletons.iter() {
|
||||
let kpt1 = &kpts[idx1];
|
||||
let kpt2 = &kpts[idx2];
|
||||
if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf {
|
||||
continue;
|
||||
}
|
||||
imageproc::drawing::draw_line_segment_mut(
|
||||
&mut img,
|
||||
(kpt1.x(), kpt1.y()),
|
||||
(kpt2.x(), kpt2.y()),
|
||||
image::Rgb([233, 14, 57]),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// draw mask
|
||||
if let Some(masks) = y.masks() {
|
||||
for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) {
|
||||
let mask_nd: ImageBuffer<image::Luma<_>, Vec<u8>> =
|
||||
match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) {
|
||||
Some(image) => image,
|
||||
None => panic!("can not create image from ndarray"),
|
||||
};
|
||||
|
||||
for _x in 0..img.width() {
|
||||
for _y in 0..img.height() {
|
||||
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y);
|
||||
if mask_p.0[0] > 0 {
|
||||
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y);
|
||||
// img_p.0[2] = self.color_palette[bbox.id()].2 / 2;
|
||||
// img_p.0[1] = self.color_palette[bbox.id()].1 / 2;
|
||||
// img_p.0[0] = self.color_palette[bbox.id()].0 / 2;
|
||||
img_p.0[2] /= 2;
|
||||
img_p.0[1] = 255 - (255 - img_p.0[2]) / 2;
|
||||
img_p.0[0] /= 2;
|
||||
imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mkdir and save
|
||||
let mut runs = PathBuf::from("runs");
|
||||
if !runs.exists() {
|
||||
std::fs::create_dir_all(&runs).unwrap();
|
||||
}
|
||||
runs.push(gen_time_string("-"));
|
||||
let saveout = format!("{}.jpg", runs.to_str().unwrap());
|
||||
let _ = img.save(saveout);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn summary(&self) {
|
||||
println!(
|
||||
"\nSummary:\n\
|
||||
> Task: {:?}{}\n\
|
||||
> EP: {:?} {}\n\
|
||||
> Dtype: {:?}\n\
|
||||
> Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\
|
||||
> nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\
|
||||
",
|
||||
self.task(),
|
||||
match self.engine.author().zip(self.engine.version()) {
|
||||
Some((author, ver)) => format!(" ({} {})", author, ver),
|
||||
None => String::from(""),
|
||||
},
|
||||
self.engine.ep(),
|
||||
if let OrtEP::CPU = self.engine.ep() {
|
||||
""
|
||||
} else {
|
||||
"(May still fall back to CPU)"
|
||||
},
|
||||
self.engine.dtype(),
|
||||
self.batch(),
|
||||
if self.engine.is_batch_dynamic() {
|
||||
"Dynamic"
|
||||
} else {
|
||||
"Const"
|
||||
},
|
||||
self.height(),
|
||||
if self.engine.is_height_dynamic() {
|
||||
"Dynamic"
|
||||
} else {
|
||||
"Const"
|
||||
},
|
||||
self.width(),
|
||||
if self.engine.is_width_dynamic() {
|
||||
"Dynamic"
|
||||
} else {
|
||||
"Const"
|
||||
},
|
||||
self.nc(),
|
||||
self.nk(),
|
||||
self.nm(),
|
||||
self.conf,
|
||||
self.kconf,
|
||||
self.iou,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn engine(&self) -> &OrtBackend {
|
||||
&self.engine
|
||||
}
|
||||
|
||||
pub fn conf(&self) -> f32 {
|
||||
self.conf
|
||||
}
|
||||
|
||||
pub fn set_conf(&mut self, val: f32) {
|
||||
self.conf = val;
|
||||
}
|
||||
|
||||
pub fn conf_mut(&mut self) -> &mut f32 {
|
||||
&mut self.conf
|
||||
}
|
||||
|
||||
pub fn kconf(&self) -> f32 {
|
||||
self.kconf
|
||||
}
|
||||
|
||||
pub fn iou(&self) -> f32 {
|
||||
self.iou
|
||||
}
|
||||
|
||||
pub fn task(&self) -> &YOLOTask {
|
||||
&self.task
|
||||
}
|
||||
|
||||
pub fn batch(&self) -> u32 {
|
||||
self.batch
|
||||
}
|
||||
|
||||
pub fn width(&self) -> u32 {
|
||||
self.width
|
||||
}
|
||||
|
||||
pub fn height(&self) -> u32 {
|
||||
self.height
|
||||
}
|
||||
|
||||
pub fn nc(&self) -> u32 {
|
||||
self.nc
|
||||
}
|
||||
|
||||
pub fn nk(&self) -> u32 {
|
||||
self.nk
|
||||
}
|
||||
|
||||
pub fn nm(&self) -> u32 {
|
||||
self.nm
|
||||
}
|
||||
|
||||
pub fn names(&self) -> &Vec<String> {
|
||||
&self.names
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,609 @@
|
||||
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::ValueEnum;
|
||||
use half::f16;
|
||||
use ndarray::{Array, CowArray, IxDyn};
|
||||
use ort::execution_providers::{
|
||||
CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
|
||||
TensorRTExecutionProvider,
|
||||
};
|
||||
use ort::session::builder::SessionBuilder;
|
||||
use ort::session::Session;
|
||||
use ort::tensor::TensorElementType;
|
||||
use ort::value::ValueType;
|
||||
use regex::Regex;
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||
pub enum YOLOTask {
|
||||
// YOLO tasks
|
||||
Classify,
|
||||
Detect,
|
||||
Pose,
|
||||
Segment,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum OrtEP {
|
||||
// ONNXRuntime execution provider
|
||||
CPU,
|
||||
CUDA(i32),
|
||||
Trt(i32),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Batch {
|
||||
pub opt: u32,
|
||||
pub min: u32,
|
||||
pub max: u32,
|
||||
}
|
||||
|
||||
impl Default for Batch {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
opt: 1,
|
||||
min: 1,
|
||||
max: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct OrtInputs {
|
||||
// ONNX model inputs attrs
|
||||
pub shapes: Vec<Vec<i64>>,
|
||||
//pub dtypes: Vec<TensorElementDataType>,
|
||||
pub dtypes: Vec<TensorElementType>,
|
||||
pub names: Vec<String>,
|
||||
pub sizes: Vec<Vec<u32>>,
|
||||
}
|
||||
|
||||
impl OrtInputs {
|
||||
pub fn new(session: &Session) -> Self {
|
||||
let mut shapes = Vec::new();
|
||||
let mut dtypes = Vec::new();
|
||||
let mut names = Vec::new();
|
||||
for i in session.inputs.iter() {
|
||||
/* let shape: Vec<i32> = i
|
||||
.dimensions()
|
||||
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
|
||||
.collect();
|
||||
shapes.push(shape); */
|
||||
if let ValueType::Tensor {
|
||||
ty,
|
||||
dimension_symbols: _,
|
||||
shape,
|
||||
} = &i.input_type
|
||||
{
|
||||
dtypes.push(ty.clone());
|
||||
let shape = shape.to_vec().clone();
|
||||
shapes.push(shape);
|
||||
} else {
|
||||
panic!("不支持的数据格式, {} - {}", file!(), line!());
|
||||
}
|
||||
//dtypes.push(i.input_type);
|
||||
names.push(i.name.clone());
|
||||
}
|
||||
Self {
|
||||
shapes,
|
||||
dtypes,
|
||||
names,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OrtConfig {
|
||||
// ORT config
|
||||
pub f: String,
|
||||
pub task: Option<YOLOTask>,
|
||||
pub ep: OrtEP,
|
||||
pub trt_fp16: bool,
|
||||
pub batch: Batch,
|
||||
pub image_size: (Option<u32>, Option<u32>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OrtBackend {
|
||||
// ORT engine
|
||||
session: Session,
|
||||
task: YOLOTask,
|
||||
ep: OrtEP,
|
||||
batch: Batch,
|
||||
inputs: OrtInputs,
|
||||
}
|
||||
|
||||
impl OrtBackend {
|
||||
pub fn build(args: OrtConfig) -> Result<Self> {
|
||||
// build env & session
|
||||
// in version 2.x environment is removed
|
||||
/* let env = ort::EnvironmentBuilder
|
||||
::with_name("YOLOv8")
|
||||
.build()?
|
||||
.into_arc(); */
|
||||
let sessionbuilder = SessionBuilder::new()?;
|
||||
let session = sessionbuilder.commit_from_file(&args.f)?;
|
||||
//let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
|
||||
|
||||
// get inputs
|
||||
let mut inputs = OrtInputs::new(&session);
|
||||
|
||||
// batch size
|
||||
let mut batch = args.batch;
|
||||
let batch = if inputs.shapes[0][0] == -1 {
|
||||
batch
|
||||
} else {
|
||||
assert_eq!(
|
||||
inputs.shapes[0][0] as u32, batch.opt,
|
||||
"Expected batch size: {}, got {}. Try using `--batch {}`.",
|
||||
inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32
|
||||
);
|
||||
batch.opt = inputs.shapes[0][0] as u32;
|
||||
batch
|
||||
};
|
||||
|
||||
// input size: height and width
|
||||
let height = if inputs.shapes[0][2] == -1 {
|
||||
match args.image_size.0 {
|
||||
Some(height) => height,
|
||||
None => panic!("Failed to get model height. Make it explicit with `--height`"),
|
||||
}
|
||||
} else {
|
||||
inputs.shapes[0][2] as u32
|
||||
};
|
||||
let width = if inputs.shapes[0][3] == -1 {
|
||||
match args.image_size.1 {
|
||||
Some(width) => width,
|
||||
None => panic!("Failed to get model width. Make it explicit with `--width`"),
|
||||
}
|
||||
} else {
|
||||
inputs.shapes[0][3] as u32
|
||||
};
|
||||
inputs.sizes.push(vec![height, width]);
|
||||
|
||||
// build provider
|
||||
let (ep, provider) = match args.ep {
|
||||
OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id),
|
||||
OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
|
||||
_ => (
|
||||
OrtEP::CPU,
|
||||
ExecutionProviderDispatch::from(CPUExecutionProvider::default()),
|
||||
),
|
||||
};
|
||||
|
||||
// build session again with the new provider
|
||||
let session = SessionBuilder::new()?
|
||||
// .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
||||
.with_execution_providers([provider])?
|
||||
.commit_from_file(args.f)?;
|
||||
|
||||
// task: using given one or guessing
|
||||
let task = match args.task {
|
||||
Some(task) => task,
|
||||
None => match session.metadata() {
|
||||
Err(_) => panic!("No metadata found. Try making it explicit by `--task`"),
|
||||
Ok(metadata) => match metadata.custom("task") {
|
||||
Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"),
|
||||
Ok(value) => match value {
|
||||
None => panic!("No corresponding value of `task` found in metadata. Make it explicit by `--task`"),
|
||||
Some(task) => match task.as_str() {
|
||||
"classify" => YOLOTask::Classify,
|
||||
"detect" => YOLOTask::Detect,
|
||||
"pose" => YOLOTask::Pose,
|
||||
"segment" => YOLOTask::Segment,
|
||||
x => todo!("{:?} is not supported for now!", x),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
session,
|
||||
task,
|
||||
ep,
|
||||
batch,
|
||||
inputs,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fetch_inputs_from_session(
|
||||
session: &Session,
|
||||
) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) {
|
||||
// get inputs attrs from ONNX model
|
||||
let mut shapes = Vec::new();
|
||||
let mut dtypes = Vec::new();
|
||||
let mut names = Vec::new();
|
||||
for i in session.inputs.iter() {
|
||||
if let ValueType::Tensor {
|
||||
ty,
|
||||
dimension_symbols: _,
|
||||
shape,
|
||||
} = &i.input_type
|
||||
{
|
||||
dtypes.push(ty.clone());
|
||||
let shape = shape.to_vec().clone();
|
||||
shapes.push(shape);
|
||||
} else {
|
||||
panic!("不支持的数据格式, {} - {}", file!(), line!());
|
||||
}
|
||||
names.push(i.name.clone());
|
||||
}
|
||||
(shapes, dtypes, names)
|
||||
}
|
||||
|
||||
pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) {
|
||||
let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id);
|
||||
if let Ok(true) = cuda_provider.is_available() {
|
||||
(
|
||||
OrtEP::CUDA(device_id),
|
||||
ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider)
|
||||
)
|
||||
} else {
|
||||
println!("> CUDA is not available! Using CPU.");
|
||||
(
|
||||
OrtEP::CPU,
|
||||
ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_ep_trt(
|
||||
device_id: i32,
|
||||
fp16: bool,
|
||||
batch: &Batch,
|
||||
inputs: &OrtInputs,
|
||||
) -> (OrtEP, ExecutionProviderDispatch) {
|
||||
// set TensorRT
|
||||
let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id);
|
||||
|
||||
//trt_provider.
|
||||
if let Ok(true) = trt_provider.is_available() {
|
||||
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
|
||||
if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 {
|
||||
panic!(
|
||||
"Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
|
||||
inputs.dtypes[0]
|
||||
);
|
||||
}
|
||||
// dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
|
||||
let mut opt_string = String::new();
|
||||
let mut min_string = String::new();
|
||||
let mut max_string = String::new();
|
||||
for name in inputs.names.iter() {
|
||||
let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width);
|
||||
let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width);
|
||||
let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width);
|
||||
opt_string.push_str(s_opt.as_str());
|
||||
min_string.push_str(s_min.as_str());
|
||||
max_string.push_str(s_max.as_str());
|
||||
}
|
||||
let _ = opt_string.pop();
|
||||
let _ = min_string.pop();
|
||||
let _ = max_string.pop();
|
||||
|
||||
let trt_provider = trt_provider
|
||||
.with_profile_opt_shapes(opt_string)
|
||||
.with_profile_min_shapes(min_string)
|
||||
.with_profile_max_shapes(max_string)
|
||||
.with_fp16(fp16)
|
||||
.with_timing_cache(true);
|
||||
(
|
||||
OrtEP::Trt(device_id),
|
||||
ExecutionProviderDispatch::from(trt_provider),
|
||||
)
|
||||
} else {
|
||||
println!("> TensorRT is not available! Try using CUDA...");
|
||||
Self::set_ep_cuda(device_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fetch_from_metadata(&self, key: &str) -> Option<String> {
|
||||
// fetch value from onnx model file by key
|
||||
match self.session.metadata() {
|
||||
Err(_) => None,
|
||||
Ok(metadata) => match metadata.custom(key) {
|
||||
Err(_) => None,
|
||||
Ok(value) => value,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
|
||||
// ORT inference
|
||||
match self.dtype() {
|
||||
TensorElementType::Float16 => self.run_fp16(xs, profile),
|
||||
TensorElementType::Float32 => self.run_fp32(xs, profile),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_fp16(
|
||||
&mut self,
|
||||
xs: Array<f32, IxDyn>,
|
||||
profile: bool,
|
||||
) -> Result<Vec<Array<f32, IxDyn>>> {
|
||||
// f32->f16
|
||||
let t = std::time::Instant::now();
|
||||
let xs = xs.mapv(f16::from_f32);
|
||||
if profile {
|
||||
println!("[ORT f32->f16]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// h2d
|
||||
let t = std::time::Instant::now();
|
||||
let xs = CowArray::from(xs);
|
||||
if profile {
|
||||
println!("[ORT H2D]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// prepare input Value from the ndarray (needed because SessionInputValue implements From<Value<_>>)
|
||||
let t = std::time::Instant::now();
|
||||
let input = ort::value::Value::from_array(xs.into_owned())?;
|
||||
if profile {
|
||||
println!("[ORT Prepare Value]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// compute output shapes before calling session.run to avoid borrowing self immutably while session is mutably borrowed
|
||||
let out_shapes = self.output_shapes();
|
||||
|
||||
// run
|
||||
let t = std::time::Instant::now();
|
||||
let ys = self.session.run(ort::inputs![input])?;
|
||||
if profile {
|
||||
println!("[ORT Inference]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// d2h
|
||||
Ok(ys
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, (_k, v))| {
|
||||
// d2h
|
||||
let t = std::time::Instant::now();
|
||||
// try_extract_tensor for f16 returns (shape, slice)
|
||||
let (_shape, slice) = v.try_extract_tensor::<f16>().unwrap();
|
||||
if profile {
|
||||
println!("[ORT D2H]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// f16->f32
|
||||
let t_ = std::time::Instant::now();
|
||||
// build ndarray from the returned slice using the runtime output shape
|
||||
let out_shape = out_shapes[idx].clone();
|
||||
let dims = out_shape.iter().map(|&d| d as usize).collect::<Vec<_>>();
|
||||
let arr_f16 = Array::from_shape_vec(IxDyn(&dims), slice.to_vec()).unwrap();
|
||||
let v = arr_f16.mapv(f16::to_f32);
|
||||
if profile {
|
||||
println!("[ORT f16->f32]: {:?}", t_.elapsed());
|
||||
}
|
||||
v
|
||||
})
|
||||
.collect::<Vec<Array<_, _>>>())
|
||||
}
|
||||
|
||||
pub fn run_fp32(
|
||||
&mut self,
|
||||
xs: Array<f32, IxDyn>,
|
||||
profile: bool,
|
||||
) -> Result<Vec<Array<f32, IxDyn>>> {
|
||||
// h2d
|
||||
let t = std::time::Instant::now();
|
||||
let xs = CowArray::from(xs);
|
||||
if profile {
|
||||
println!("[ORT H2D]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// prepare input Value from the ndarray (needed because SessionInputValue implements From<Value<_>>)
|
||||
let t = std::time::Instant::now();
|
||||
let input = ort::value::Value::from_array(xs.into_owned())?;
|
||||
if profile {
|
||||
println!("[ORT Prepare Value]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// compute output shapes before calling session.run to avoid borrowing self immutably while session is mutably borrowed
|
||||
let out_shapes = self.output_shapes();
|
||||
|
||||
// run
|
||||
let t = std::time::Instant::now();
|
||||
let ys = self.session.run(ort::inputs![input])?;
|
||||
if profile {
|
||||
println!("[ORT Inference]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// d2h
|
||||
Ok(ys
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, (_k, v))| {
|
||||
let t = std::time::Instant::now();
|
||||
// try_extract_tensor for f32 returns (shape, slice)
|
||||
let (_shape, slice) = v.try_extract_tensor::<f32>().unwrap();
|
||||
if profile {
|
||||
println!("[ORT D2H]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// build ndarray from the returned slice using the runtime output shape
|
||||
let out_shape = out_shapes[idx].clone();
|
||||
let dims = out_shape.iter().map(|&d| d as usize).collect::<Vec<_>>();
|
||||
Array::from_shape_vec(IxDyn(&dims), slice.to_vec()).unwrap()
|
||||
})
|
||||
.collect::<Vec<Array<f32, IxDyn>>>())
|
||||
}
|
||||
|
||||
pub fn output_shapes(&self) -> Vec<Vec<i64>> {
|
||||
let mut shapes = Vec::new();
|
||||
for output in &self.session.outputs {
|
||||
if let ValueType::Tensor { shape, .. } = &output.output_type {
|
||||
shapes.push(shape.to_vec().clone());
|
||||
} else {
|
||||
panic!("not support data format, {} - {}", file!(), line!());
|
||||
}
|
||||
}
|
||||
shapes
|
||||
}
|
||||
|
||||
pub fn output_dtypes(&self) -> Vec<TensorElementType> {
|
||||
let mut dtypes = Vec::new();
|
||||
for output in &self.session.outputs {
|
||||
if let ValueType::Tensor {
|
||||
ty,
|
||||
shape: _,
|
||||
dimension_symbols: _,
|
||||
} = &output.output_type
|
||||
{
|
||||
dtypes.push(ty.clone());
|
||||
} else {
|
||||
panic!("not support data format, {} - {}", file!(), line!());
|
||||
}
|
||||
}
|
||||
dtypes
|
||||
}
|
||||
|
||||
pub fn input_shapes(&self) -> &Vec<Vec<i64>> {
|
||||
&self.inputs.shapes
|
||||
}
|
||||
|
||||
pub fn input_names(&self) -> &Vec<String> {
|
||||
&self.inputs.names
|
||||
}
|
||||
|
||||
pub fn input_dtypes(&self) -> &Vec<TensorElementType> {
|
||||
&self.inputs.dtypes
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> TensorElementType {
|
||||
self.input_dtypes()[0]
|
||||
}
|
||||
|
||||
pub fn height(&self) -> u32 {
|
||||
self.inputs.sizes[0][0]
|
||||
}
|
||||
|
||||
pub fn width(&self) -> u32 {
|
||||
self.inputs.sizes[0][1]
|
||||
}
|
||||
|
||||
pub fn is_height_dynamic(&self) -> bool {
|
||||
self.input_shapes()[0][2] == -1
|
||||
}
|
||||
|
||||
pub fn is_width_dynamic(&self) -> bool {
|
||||
self.input_shapes()[0][3] == -1
|
||||
}
|
||||
|
||||
pub fn batch(&self) -> u32 {
|
||||
self.batch.opt
|
||||
}
|
||||
|
||||
pub fn is_batch_dynamic(&self) -> bool {
|
||||
self.input_shapes()[0][0] == -1
|
||||
}
|
||||
|
||||
pub fn ep(&self) -> &OrtEP {
|
||||
&self.ep
|
||||
}
|
||||
|
||||
pub fn task(&self) -> YOLOTask {
|
||||
self.task.clone()
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Option<Vec<String>> {
|
||||
// class names, metadata parsing
|
||||
// String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}`
|
||||
match self.fetch_from_metadata("names") {
|
||||
Some(names) => {
|
||||
let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap();
|
||||
let mut names_ = vec![];
|
||||
for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) {
|
||||
names_.push(name.to_string());
|
||||
}
|
||||
Some(names_)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nk(&self) -> Option<u32> {
|
||||
// num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]`
|
||||
match self.fetch_from_metadata("kpt_shape") {
|
||||
None => None,
|
||||
Some(kpt_string) => {
|
||||
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
||||
let caps = re.captures(&kpt_string).unwrap();
|
||||
Some(caps.get(1).unwrap().as_str().parse::<u32>().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nc(&self) -> Option<u32> {
|
||||
// num_classes
|
||||
match self.names() {
|
||||
// by names
|
||||
Some(names) => Some(names.len() as u32),
|
||||
None => match self.task() {
|
||||
// by task calculation
|
||||
YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32),
|
||||
YOLOTask::Detect => {
|
||||
if self.output_shapes()[0][1] == -1 {
|
||||
None
|
||||
} else {
|
||||
// cxywhclss
|
||||
Some(self.output_shapes()[0][1] as u32 - 4)
|
||||
}
|
||||
}
|
||||
YOLOTask::Pose => {
|
||||
match self.nk() {
|
||||
None => None,
|
||||
Some(nk) => {
|
||||
if self.output_shapes()[0][1] == -1 {
|
||||
None
|
||||
} else {
|
||||
// cxywhclss3*kpt
|
||||
Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
YOLOTask::Segment => {
|
||||
if self.output_shapes()[0][1] == -1 {
|
||||
None
|
||||
} else {
|
||||
// cxywhclssnm
|
||||
Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nm(&self) -> Option<u32> {
|
||||
// num_masks
|
||||
match self.task() {
|
||||
YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn na(&self) -> Option<u32> {
|
||||
// num_anchors
|
||||
match self.task() {
|
||||
YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => {
|
||||
if self.output_shapes()[0][2] == -1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.output_shapes()[0][2] as u32)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn author(&self) -> Option<String> {
|
||||
self.fetch_from_metadata("author")
|
||||
}
|
||||
|
||||
pub fn version(&self) -> Option<String> {
|
||||
self.fetch_from_metadata("version")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
// Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
use ndarray::{Array, Axis, IxDyn};
|
||||
|
||||
#[derive(Clone, PartialEq, Default)]
|
||||
pub struct YOLOResult {
|
||||
// YOLO tasks results of an image
|
||||
pub probs: Option<Embedding>,
|
||||
pub bboxes: Option<Vec<Bbox>>,
|
||||
pub keypoints: Option<Vec<Vec<Point2>>>,
|
||||
pub masks: Option<Vec<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for YOLOResult {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("YOLOResult")
|
||||
.field(
|
||||
"Probs(top5)",
|
||||
&format_args!("{:?}", self.probs().map(|probs| probs.topk(5))),
|
||||
)
|
||||
.field("Bboxes", &self.bboxes)
|
||||
.field("Keypoints", &self.keypoints)
|
||||
.field(
|
||||
"Masks",
|
||||
&format_args!("{:?}", self.masks().map(|masks| masks.len())),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl YOLOResult {
|
||||
pub fn new(
|
||||
probs: Option<Embedding>,
|
||||
bboxes: Option<Vec<Bbox>>,
|
||||
keypoints: Option<Vec<Vec<Point2>>>,
|
||||
masks: Option<Vec<Vec<u8>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
probs,
|
||||
bboxes,
|
||||
keypoints,
|
||||
masks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probs(&self) -> Option<&Embedding> {
|
||||
self.probs.as_ref()
|
||||
}
|
||||
|
||||
pub fn keypoints(&self) -> Option<&Vec<Vec<Point2>>> {
|
||||
self.keypoints.as_ref()
|
||||
}
|
||||
|
||||
pub fn masks(&self) -> Option<&Vec<Vec<u8>>> {
|
||||
self.masks.as_ref()
|
||||
}
|
||||
|
||||
pub fn bboxes(&self) -> Option<&Vec<Bbox>> {
|
||||
self.bboxes.as_ref()
|
||||
}
|
||||
|
||||
pub fn bboxes_mut(&mut self) -> Option<&mut Vec<Bbox>> {
|
||||
self.bboxes.as_mut()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Default)]
|
||||
pub struct Point2 {
|
||||
// A point2d with x, y, conf
|
||||
x: f32,
|
||||
y: f32,
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
impl Point2 {
|
||||
pub fn new_with_conf(x: f32, y: f32, confidence: f32) -> Self {
|
||||
Self { x, y, confidence }
|
||||
}
|
||||
|
||||
pub fn new(x: f32, y: f32) -> Self {
|
||||
Self {
|
||||
x,
|
||||
y,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn x(&self) -> f32 {
|
||||
self.x
|
||||
}
|
||||
|
||||
pub fn y(&self) -> f32 {
|
||||
self.y
|
||||
}
|
||||
|
||||
pub fn confidence(&self) -> f32 {
|
||||
self.confidence
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub struct Embedding {
|
||||
// An float32 n-dims tensor
|
||||
data: Array<f32, IxDyn>,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(data: Array<f32, IxDyn>) -> Self {
|
||||
Self { data }
|
||||
}
|
||||
|
||||
pub fn data(&self) -> &Array<f32, IxDyn> {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn topk(&self, k: usize) -> Vec<(usize, f32)> {
|
||||
let mut probs = self
|
||||
.data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(a, b)| (a, *b))
|
||||
.collect::<Vec<_>>();
|
||||
probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
let mut topk = Vec::new();
|
||||
for &(id, confidence) in probs.iter().take(k) {
|
||||
topk.push((id, confidence));
|
||||
}
|
||||
topk
|
||||
}
|
||||
|
||||
pub fn norm(&self) -> Array<f32, IxDyn> {
|
||||
let std_ = self.data.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt);
|
||||
self.data.clone() / std_
|
||||
}
|
||||
|
||||
pub fn top1(&self) -> (usize, f32) {
|
||||
self.topk(1)[0]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub struct Bbox {
|
||||
// a bounding box around an object
|
||||
xmin: f32,
|
||||
ymin: f32,
|
||||
width: f32,
|
||||
height: f32,
|
||||
id: usize,
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
impl Bbox {
|
||||
pub fn new_from_xywh(xmin: f32, ymin: f32, width: f32, height: f32) -> Self {
|
||||
Self {
|
||||
xmin,
|
||||
ymin,
|
||||
width,
|
||||
height,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(xmin: f32, ymin: f32, width: f32, height: f32, id: usize, confidence: f32) -> Self {
|
||||
Self {
|
||||
xmin,
|
||||
ymin,
|
||||
width,
|
||||
height,
|
||||
id,
|
||||
confidence,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn width(&self) -> f32 {
|
||||
self.width
|
||||
}
|
||||
|
||||
pub fn height(&self) -> f32 {
|
||||
self.height
|
||||
}
|
||||
|
||||
pub fn xmin(&self) -> f32 {
|
||||
self.xmin
|
||||
}
|
||||
|
||||
pub fn ymin(&self) -> f32 {
|
||||
self.ymin
|
||||
}
|
||||
|
||||
pub fn xmax(&self) -> f32 {
|
||||
self.xmin + self.width
|
||||
}
|
||||
|
||||
pub fn ymax(&self) -> f32 {
|
||||
self.ymin + self.height
|
||||
}
|
||||
|
||||
pub fn tl(&self) -> Point2 {
|
||||
Point2::new(self.xmin, self.ymin)
|
||||
}
|
||||
|
||||
pub fn br(&self) -> Point2 {
|
||||
Point2::new(self.xmax(), self.ymax())
|
||||
}
|
||||
|
||||
pub fn cxcy(&self) -> Point2 {
|
||||
Point2::new(self.xmin + self.width / 2., self.ymin + self.height / 2.)
|
||||
}
|
||||
|
||||
pub fn id(&self) -> usize {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn confidence(&self) -> f32 {
|
||||
self.confidence
|
||||
}
|
||||
|
||||
pub fn area(&self) -> f32 {
|
||||
self.width * self.height
|
||||
}
|
||||
|
||||
pub fn intersection_area(&self, another: &Bbox) -> f32 {
|
||||
let l = self.xmin.max(another.xmin);
|
||||
let r = (self.xmin + self.width).min(another.xmin + another.width);
|
||||
let t = self.ymin.max(another.ymin);
|
||||
let b = (self.ymin + self.height).min(another.ymin + another.height);
|
||||
(r - l + 1.).max(0.) * (b - t + 1.).max(0.)
|
||||
}
|
||||
|
||||
pub fn union(&self, another: &Bbox) -> f32 {
|
||||
self.area() + another.area() - self.intersection_area(another)
|
||||
}
|
||||
|
||||
pub fn iou(&self, another: &Bbox) -> f32 {
|
||||
let union = self.union(another);
|
||||
if union <= 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
self.intersection_area(another) / union
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user