Files
yolov26_3d/ultralytics/data/dataset.py
2026-06-24 09:35:46 +08:00

1866 lines
84 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import hashlib
import glob
import json
import math
import os
import random
from collections import defaultdict
from copy import deepcopy
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from time import perf_counter
from typing import Any
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDataset
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.instance import Instances
from ultralytics.utils.ops import resample_segments, segments2boxes
from ultralytics.utils.plotting_3d import decode_cut_partial_side_edge_from_gt, decode_visible_face_edge_from_gt
from ultralytics.utils.torch_utils import TORCHVISION_0_18
from .augment import (
Compose,
Format,
LetterBox,
RandomHSV,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .converter import merge_multi_segment
from .ground3d_augment import (
adjust_calib_for_roi_crop,
apply_simul_transform,
build_final_resized_calib,
compute_centered_roi_bounds,
compute_simul_calib,
compute_vanishing_point_x,
compute_vanishing_point_y,
normalize_roi_depth,
pack_labels_to_48,
parse_ground_3d_label_file,
read_calib_from_path,
remap_labels_to_roi,
unpack_labels_from_48,
)
from .utils import (
FORMATS_HELP_MSG,
HELP_URL,
IMG_FORMATS,
check_file_speeds,
get_hash,
img2label_paths,
load_dataset_cache_file,
save_dataset_cache_file,
verify_image,
verify_image_label,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for Ultralytics YOLO models
DATASET_CACHE_VERSION = "1.0.3"
GROUND3D_REQUIRED_DATA_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")
GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"}
class Ground3DImageReadError(RuntimeError):
"""Raised when a Ground3D sample image is unreadable or fails validation."""
class Ground3DCalibrationError(RuntimeError):
"""Raised when a Ground3D sample is missing required calibration."""
def _resize_ground3d_image_in_steps(
img: np.ndarray, target_size: tuple[int, int], interpolation: int = cv2.INTER_LINEAR
) -> np.ndarray:
"""Resize a Ground3D image with repeated 0.5x downsampling before the final resize."""
target_w, target_h = int(target_size[0]), int(target_size[1])
current_h, current_w = img.shape[:2]
if (current_w, current_h) == (target_w, target_h):
return img
if target_w < current_w and target_h < current_h:
while True:
next_w = math.ceil(current_w * 0.5)
next_h = math.ceil(current_h * 0.5)
if next_w < target_w or next_h < target_h:
break
img = cv2.resize(img, (next_w, next_h), interpolation=interpolation)
current_h, current_w = img.shape[:2]
if (current_w, current_h) == (target_w, target_h):
return img
return cv2.resize(img, (target_w, target_h), interpolation=interpolation)
def _require_ground3d_data_value(data: dict[str, Any], key: str) -> Any:
"""Return a required Ground3D config value or raise a clear error."""
if key not in data or data[key] is None:
raise ValueError(f"Ground3D dataset config must define '{key}' explicitly; defaults are not allowed.")
return data[key]
def _validate_ground3d_pair(value: Any, key: str) -> tuple[int, int]:
"""Validate a Ground3D size-like config field and return it as an integer pair."""
if not isinstance(value, (list, tuple)) or len(value) != 2:
raise ValueError(f"Ground3D dataset config field '{key}' must be a 2-item list/tuple, but got: {value!r}")
pair = (int(value[0]), int(value[1]))
if pair[0] <= 0 or pair[1] <= 0:
raise ValueError(f"Ground3D dataset config field '{key}' must contain positive values, but got: {value!r}")
return pair
def _validate_ground3d_crop_center_mode(value: Any) -> str:
"""Validate the Ground3D crop-center mode."""
if value not in GROUND3D_VALID_CROP_CENTER_MODES:
valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES))
raise ValueError(
f"Ground3D dataset config field 'crop_center_mode' must be one of {{{valid_modes}}}, but got: {value!r}"
)
return str(value)
def _validate_ground3d_positive_float(value: Any, key: str) -> float:
"""Validate a positive numeric Ground3D config field."""
value = float(value)
if value <= 0:
raise ValueError(f"Ground3D dataset config field '{key}' must be positive, but got: {value!r}")
return value
def _validate_ground3d_float(value: Any, key: str) -> float:
"""Validate a numeric Ground3D config field."""
try:
return float(value)
except (TypeError, ValueError) as e:
raise ValueError(f"Ground3D dataset config field '{key}' must be numeric, but got: {value!r}") from e
def _ground3d_has_finite_targets(labels_3d: np.ndarray | None) -> bool:
"""Return whether a parsed 3D label array contains any real 3D supervision."""
if labels_3d is None:
return False
labels_3d = np.asarray(labels_3d)
return labels_3d.size > 0 and np.isfinite(labels_3d).any()
class YOLODataset(BaseDataset):
"""Dataset class for loading object detection and/or segmentation labels in YOLO format.
This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
(OBB) tasks using the YOLO format.
Attributes:
use_segments (bool): Indicates if segmentation masks should be used.
use_keypoints (bool): Indicates if keypoints should be used for pose estimation.
use_obb (bool): Indicates if oriented bounding boxes should be used.
data (dict): Dataset configuration dictionary.
Methods:
cache_labels: Cache dataset labels, check images and read shapes.
get_labels: Return list of label dictionaries for YOLO training.
build_transforms: Build and append transforms to the list.
close_mosaic: Disable mosaic, copy_paste, mixup and cutmix augmentations and build transformations.
update_labels_info: Update label format for different tasks.
collate_fn: Collate data samples into batches.
Examples:
>>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
>>> dataset.get_labels()
"""
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
"""Initialize the YOLODataset.
Args:
data (dict, optional): Dataset configuration dictionary.
task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
*args (Any): Additional positional arguments for the parent class.
**kwargs (Any): Additional keyword arguments for the parent class.
"""
self.use_segments = task == "segment"
self.use_keypoints = task == "pose"
self.use_obb = task == "obb"
self.data = data
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
super().__init__(*args, channels=self.data.get("channels", 3), use_yuv444=self.data.get("use_yuv444", False), **kwargs)
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
"""Cache dataset labels, check images and read shapes.
Args:
path (Path): Path where to save the cache file.
Returns:
(dict): Dictionary containing cached labels and related information.
"""
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
)
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(
func=verify_image_label,
iterable=zip(
self.im_files,
self.label_files,
repeat(self.prefix),
repeat(self.use_keypoints),
repeat(len(self.data["names"])),
repeat(nkpt),
repeat(ndim),
repeat(self.single_cls),
),
)
pbar = TQDM(results, desc=desc, total=total)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x["labels"].append(
{
"im_file": im_file,
"shape": shape,
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"segments": segments,
"keypoints": keypoint,
"normalized": True,
"bbox_format": "xywh",
}
)
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self) -> list[dict]:
"""Return list of label dictionaries for YOLO training.
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
Returns:
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
"""
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
# Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
if not labels:
raise RuntimeError(
f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}"
)
self.im_files = [lb["im_file"] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f"Box and segment counts should be equal, but got len(segments) = {len_segments}, "
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
)
for lb in labels:
lb["segments"] = []
if len_cls == 0:
LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
def build_transforms(self, hyp: dict | None = None) -> Compose:
"""Build and append transforms to the list.
Args:
hyp (dict, optional): Hyperparameters for transforms.
Returns:
(Compose): Composed transforms.
"""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append(
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
return_obb=self.use_obb,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
)
)
return transforms
def close_mosaic(self, hyp: dict) -> None:
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
Args:
hyp (dict): Hyperparameters for transforms.
"""
hyp.mosaic = 0.0
hyp.copy_paste = 0.0
hyp.mixup = 0.0
hyp.cutmix = 0.0
self.transforms = self.build_transforms(hyp)
def update_labels_info(self, label: dict) -> dict:
"""Update label format for different tasks.
Args:
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
Returns:
(dict): Updated label dictionary with instances.
Notes:
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
Can also support classification and semantic segmentation by adding or removing dict keys there.
"""
bboxes = label.pop("bboxes")
segments = label.pop("segments", [])
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
# NOTE: do NOT resample oriented boxes
segment_resamples = 100 if self.use_obb else 1000
if len(segments) > 0:
# make sure segments interpolate correctly if original length is greater than segment_resamples
max_len = max(len(s) for s in segments)
segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
# list[np.array(segment_resamples, 2)] * num_samples
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@staticmethod
def collate_fn(batch: list[dict]) -> dict:
"""Collate data samples into batches.
Args:
batch (list[dict]): List of dictionaries containing sample data.
Returns:
(dict): Collated batch with stacked tensors.
"""
new_batch = {}
batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys):
value = values[i]
if k in {"img", "text_feats", "sem_masks"}:
value = torch.stack(value, 0)
elif k == "visuals":
value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
elif k == "camera_mode":
value = tuple(value)
if k in {
"masks",
"keypoints",
"bboxes",
"cls",
"segments",
"obb",
"difficulties",
"difficulty_levels",
"labels_3d",
"edge_faces_points_2d",
"edge_faces_depths",
"edge_faces_valid",
"edge_partial_points_2d",
"edge_partial_depths",
"edge_partial_face_type",
"edge_partial_valid",
}:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
class YOLOGroundDataset(YOLODataset):
"""Dataset class for loading ground 2D detection labels with difficulty scores.
This class extends YOLODataset to support custom annotation format with:
- String class names mapped to numeric IDs via class_map
- Difficulty scores for each bounding box
- Optional minimum box size filtering
- Optional YUV444 color space support
Attributes:
class_map (dict): Mapping from class names to class IDs.
min_wh (float): Minimum box width/height in pixels for filtering.
use_yuv444 (bool): Whether to use YUV444 color space.
Examples:
>>> data = {"class_map": {"car": 0, "pedestrian": 1}, "min_wh": 2.0}
>>> dataset = YOLOGroundDataset(img_path="path/to/images", data=data, task="detect")
"""
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
"""Initialize the YOLOGroundDataset with class mapping and difficulty support.
Args:
data (dict, optional): Dataset configuration with 'class_map' and optional 'min_wh', 'use_yuv444'.
task (str): Task type, should be 'detect' for ground dataset.
*args (Any): Additional positional arguments for the parent class.
**kwargs (Any): Additional keyword arguments for the parent class.
"""
self.class_map = data.get("class_map", {})
self.min_wh = data.get("min_wh", 2.0)
self.use_yuv444 = data.get("use_yuv444", False)
self.difficulty_weights = data.get("difficulty_weights", [1.0])
super().__init__(*args, data=data, task=task, **kwargs)
def get_img_files(self, img_path: str | list[str]) -> list[str]:
"""Read image files while preserving source paths for label lookup.
Ground 2D train/val list files may live under one ``Detection2D_*`` root while the corresponding images should
be loaded from the YAML ``path`` root. The original list paths are retained for label lookup.
"""
try:
files = []
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p)
if p.is_dir():
files += glob.glob(str(p / "**" / "*.*"), recursive=True)
elif p.is_file():
with open(p, encoding="utf-8") as t:
lines = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
files += [x.replace("./", parent) if x.startswith("./") else x for x in lines]
else:
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
entries = sorted(x.replace("/", os.sep) for x in files if x.strip())
assert entries, f"{self.prefix}No entries found in {img_path}."
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:
entries = entries[: round(len(entries) * self.fraction)]
self.ground_label_files = []
im_files = []
for f in entries:
suffix = Path(f).suffix.lower().lstrip(".")
if suffix in IMG_FORMATS:
self.ground_label_files.append(img2label_paths([f])[0])
im_files.append(self._ground_image_path_from_label_path(f))
elif suffix == "txt":
self.ground_label_files.append(f)
im_files.append(self._ground_image_path_from_label_path(f, label_path=True))
else:
raise FileNotFoundError(f"{self.prefix}Unsupported ground list entry: {f}. {FORMATS_HELP_MSG}")
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
check_file_speeds(im_files, prefix=self.prefix)
return im_files
def _ground_image_path_from_label_path(self, im_file: str, label_path: bool = False) -> str:
"""Map a list-file path from a ``Detection2D_*``/``2Ddetection_*`` root onto the configured image root."""
image_root = Path(self.data["path"])
parts = Path(im_file).parts
detection_root_index = next(
(i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))),
None,
)
if detection_root_index is not None and detection_root_index + 1 < len(parts):
rel_parts = list(parts[detection_root_index + 1 :])
else:
rel_path = Path(im_file) if not Path(im_file).is_absolute() else Path(*parts[1:])
rel_parts = list(rel_path.parts)
if label_path:
image_path = image_root.joinpath(*rel_parts).with_suffix(".jpg")
if image_path.exists():
return str(image_path)
image_rel_parts = ["images" if part == "labels" else part for part in rel_parts]
image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg")
if image_path.exists():
return str(image_path)
for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"):
candidate = image_path.with_suffix(f".{suffix}")
if candidate.exists():
return str(candidate)
return str(image_path)
return str(image_root.joinpath(*rel_parts))
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
"""Cache dataset labels with difficulty scores, check images and read shapes.
Args:
path (Path): Path where to save the cache file.
Returns:
(dict): Dictionary containing cached labels and related information.
"""
from ultralytics.data.utils import verify_image_label_ground
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(
func=verify_image_label_ground,
iterable=zip(
self.im_files,
self.label_files,
repeat(self.prefix),
repeat(self.class_map),
),
)
pbar = TQDM(results, desc=desc, total=total)
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
# Map raw difficulty values to loss weights using difficulty_weights lookup
dw = self.difficulty_weights
if len(lb) > 0:
raw_diff = lb[:, 5].astype(int).clip(0, len(dw) - 1)
loss_weights = np.array([dw[d] for d in raw_diff], dtype=np.float32).reshape(-1, 1)
difficulty_levels = raw_diff.astype(np.int64).reshape(-1, 1)
else:
loss_weights = np.zeros((0, 1), dtype=np.float32)
difficulty_levels = np.zeros((0, 1), dtype=np.int64)
x["labels"].append(
{
"im_file": im_file,
"shape": shape,
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:5], # n, 4
"difficulties": loss_weights, # n, 1 (pre-computed loss weights)
"difficulty_levels": difficulty_levels, # n, 1 raw class target in [0, 3]
"segments": segments,
"keypoints": None,
"normalized": True,
"bbox_format": "xywh",
}
)
if msg:
msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self) -> list[dict]:
"""Return ground labels derived from the original train/val list paths."""
self.label_files = getattr(self, "ground_label_files", img2label_paths(self.im_files))
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = load_dataset_cache_file(cache_path), True
assert cache["version"] == DATASET_CACHE_VERSION
assert cache["hash"] == get_hash(self.label_files + self.im_files)
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
cache, exists = self.cache_labels(cache_path), False
nf, nm, ne, nc, n = cache.pop("results")
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n)
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"]))
[cache.pop(k) for k in ("hash", "version", "msgs")]
labels = cache["labels"]
if not labels:
raise RuntimeError(
f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}"
)
self.im_files = [lb["im_file"] for lb in labels]
if sum(len(lb["cls"]) for lb in labels) == 0:
LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
return labels
def update_labels_info(self, label: dict) -> dict:
"""Update label format with difficulty scores.
Args:
label (dict): Label dictionary containing bboxes, difficulties, etc.
Returns:
(dict): Updated label dictionary with instances including difficulties.
"""
bboxes = label.pop("bboxes")
difficulties = label.pop("difficulties", None)
difficulty_levels = label.pop("difficulty_levels", None)
segments = label.pop("segments", [])
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
# Convert empty segments list to numpy array (required by Instances.denormalize)
if len(segments) == 0:
segments = np.zeros((0, 1000, 2), dtype=np.float32)
# Create Instances with difficulties
label["instances"] = Instances(
bboxes,
segments,
keypoints,
difficulties=difficulties,
difficulty_levels=difficulty_levels,
bbox_format=bbox_format,
normalized=normalized,
)
return label
def __getitem__(self, index: int) -> dict:
"""Return transformed label with post-augmentation min_wh filtering.
Filters out boxes whose width and height are both smaller than min_wh pixels after all augmentations
(resize, crop, mosaic, etc.) have been applied, ensuring the filter operates on final pixel sizes.
"""
labels = self.transforms(self.get_image_and_label(index))
if self.min_wh > 0 and "bboxes" in labels and len(labels["bboxes"]) > 0:
bboxes = labels["bboxes"] # normalized xywh after Format
_, h, w = labels["img"].shape # CHW
w_px = bboxes[:, 2] * w
h_px = bboxes[:, 3] * h
valid = (w_px >= self.min_wh) | (h_px >= self.min_wh)
if not valid.all():
labels["bboxes"] = bboxes[valid]
labels["cls"] = labels["cls"][valid]
if "batch_idx" in labels:
labels["batch_idx"] = labels["batch_idx"][valid]
if "difficulties" in labels:
labels["difficulties"] = labels["difficulties"][valid]
if "difficulty_levels" in labels:
labels["difficulty_levels"] = labels["difficulty_levels"][valid]
return labels
class YOLOGround3DDataset(YOLOGroundDataset):
"""Dataset for ground 3D detection backed by GT list files and on-the-fly label loading.
Extends YOLOGroundDataset to support joint 2D+3D detection with:
- On-the-fly parsing of 19-col (complete_3d) and 51-col (face_3d) label files
- A simple GT-manifest list used for dataset length and sample lookup
- Only color-space augmentations (3D labels are in camera space)
- Virtual camera / fisheye augmentation support
Attributes:
face_3d_classes (set): Class IDs with 4-face annotations.
complete_3d_classes (set): Class IDs with whole-box 3D only.
norm_scales_3d (dict): Normalization scales for 3D loss computation.
"""
def __init__(
self,
img_path: str | list[str],
imgsz: int = 640,
cache: bool | str = False,
augment: bool = True,
hyp: dict[str, Any] = DEFAULT_CFG,
prefix: str = "",
rect: bool = False,
batch_size: int = 16,
stride: int = 32,
pad: float = 0.5,
single_cls: bool = False,
classes: list[int] | None = None,
fraction: float = 1.0,
data: dict | None = None,
task: str = "detect",
):
init_start = perf_counter()
data = data or {}
self.use_segments = task == "segment"
self.use_keypoints = task == "pose"
self.use_obb = task == "obb"
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
missing_keys = [key for key in GROUND3D_REQUIRED_DATA_KEYS if key not in data or data[key] is None]
if missing_keys:
missing = ", ".join(missing_keys)
raise ValueError(f"Ground3D dataset config must define required field(s) explicitly: {missing}.")
path_value = str(_require_ground3d_data_value(data, "path")).strip() # 检查 3D dataset 必需配置是否缺失。依赖文件顶部 GROUND3D_REQUIRED_DATA_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")。
if not path_value:
raise ValueError("Ground3D dataset config field 'path' must be a non-empty string.")
normalized_ground3d_cfg = {
"path": str(Path(path_value).expanduser()), # 读取 data["path"],也就是数据根目录。
"ori_img_size": list(_validate_ground3d_pair(_require_ground3d_data_value(data, "ori_img_size"), "ori_img_size")), #
"roi": list(_validate_ground3d_pair(_require_ground3d_data_value(data, "roi"), "roi")),
"virtual_fx": _validate_ground3d_positive_float(_require_ground3d_data_value(data, "virtual_fx"), "virtual_fx"),
"virtual_camera_prob": _validate_ground3d_float(
_require_ground3d_data_value(data, "virtual_camera_prob"), "virtual_camera_prob"
),
"crop_center_mode": _validate_ground3d_crop_center_mode(_require_ground3d_data_value(data, "crop_center_mode")),
}
self.data = {**data, **normalized_ground3d_cfg}
self.class_map = data.get("class_map", {}) # class_map字符串类别名到 class id 的映射。3D label 里类别可能是名字,需要映射成数字类别
self.min_wh = data.get("min_wh", 2.0) # min_wh过滤过小 2D 框的阈值,默认 2 像素
self.use_yuv444 = data.get("use_yuv444", False)
self.difficulty_weights = data.get("difficulty_weights", [1.0])
self.face_3d_classes = set(data.get("face_3d_classes", []))
self.complete_3d_classes = set(data.get("complete_3d_classes", []))
self.norm_scales_3d = data.get("norm_scales_3d", {}) # 3D loss 归一化尺度,例如 depth、尺寸、offset 等归一化参数
self.image_root = Path(self.data["path"]) # 数据根目录 Path 对象,后续从 label 相对路径推导 image/calib 路径。
self.ori_img_size = self.data["ori_img_size"]
self.roi = self.data["roi"]
self.virtual_fx = self.data["virtual_fx"]
self.virtual_camera_prob = self.data["virtual_camera_prob"]
self.crop_center_mode = self.data["crop_center_mode"]
self.virtual_camera_val_zoom = data.get("virtual_camera_val_zoom", False) # 验证阶段是否也启用 virtual camera zoom 相关逻辑。默认关闭
self._ground3d_shape = (int(self.ori_img_size[1]), int(self.ori_img_size[0]))
self._include_class = None
self.face_visibility_score_thresh = float(
getattr(hyp, "face_visibility_score_thresh", DEFAULT_CFG.face_visibility_score_thresh) # 可见面分数阈值,从 hyp 里取,取不到就用 DEFAULT_CFG。后续构建 edge GT / face GT 时判断某个 face 是否有效。
)
self.precompute_edge_gt = float(getattr(hyp, "edge_aux_loss_gain", 1.0)) > 0
self.profile_timing = bool(getattr(hyp, "batch_timing", False))
self.skip_bad_images = bool(data.get("skip_bad_images", True))
self.strict_image_shape = bool(data.get("strict_image_shape", "ori_img_size" in data))
self.min_image_side = max(int(data.get("min_image_side", 10)), 1)
self.img_path = img_path
self.imgsz = imgsz
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.fraction = fraction
self.channels = self.data.get("channels", 3)
self.cv2_flag = cv2.IMREAD_GRAYSCALE if self.channels == 1 else cv2.IMREAD_COLOR # 根据通道数决定 OpenCV 按灰度还是彩色读取。Ground 3D 通常是 3 通道,所以一般是 cv2.IMREAD_COLOR
self.label_entries = self._load_ground3d_label_entries(self.img_path)
self.labels = self.label_entries
self.update_labels(include_class=classes)
self.ni = len(self.labels)
self._bad_image_mask = np.zeros(self.ni, dtype=bool)
self._warned_bad_image_indices: set[int] = set()
self.rect = rect
self.batch_size = batch_size
self.stride = stride
self.pad = pad
if self.rect:
assert self.batch_size is not None
self.set_rectangle()
self.buffer = []
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 # 训练增强时最多缓存 min(样本数, batch*8, 1000) 张;验证或不增强时为 0
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
self.npy_files = []
self.cache = None
if cache:
LOGGER.warning(f"{self.prefix}Ground3D dataset ignores cache={cache} and loads samples on-the-fly")
self.transforms = self.build_transforms(hyp=hyp)
if LOCAL_RANK in {-1, 0}:
LOGGER.info(
f"{self.prefix}Ground3D dataset ready with {self.ni:,} samples in {perf_counter() - init_start:.1f}s"
)
def _load_ground3d_label_entries(self, img_path: str | list[str]) -> list[tuple[str, str]]:
"""Load Ground3D GT list files and return `(list_root, rel_label)` entries. 说明返回二元组list_root 是 list 文件所在目录rel_label 是 list 里写的 label 相对路径"""
items = img_path if isinstance(img_path, list) else [img_path]
label_entries = []
log_progress = LOCAL_RANK in {-1, 0}
scan_start = perf_counter()
total_entries = 0
if log_progress:
LOGGER.info(f"{self.prefix}Initializing Ground3D dataset from {len(items)} GT list file(s)...")
try:
for item in items:
item = Path(item)
if not item.is_file():
raise FileNotFoundError(f"{self.prefix}Ground3D dataset expects GT list files, but got: {item}")
if log_progress:
LOGGER.info(f"{self.prefix}Reading GT list: {item}")
with open(item, encoding="utf-8") as f:
file_has_entries = False
for line in f:
entry = line.strip()
if not entry or entry.lstrip().startswith("#"):
continue
if Path(entry).suffix.lower() != ".txt":
raise ValueError(f"{self.prefix}Ground3D GT list entries must be label .txt paths, but got: {entry}")
file_has_entries = True
label_entries.append((str(item.parent.resolve()), entry))
total_entries += 1
if log_progress and total_entries % 100_000 == 0:
LOGGER.info(
f"{self.prefix}Parsed {total_entries:,} GT entries in {perf_counter() - scan_start:.1f}s"
)
if not file_has_entries:
raise FileNotFoundError(f"{self.prefix}No GT entries found in {item}.")
assert label_entries, f"{self.prefix}No GT entries found in {img_path}."
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading Ground3D GT list from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:
n = round(len(label_entries) * self.fraction)
label_entries = label_entries[:n]
if log_progress:
LOGGER.info(
f"{self.prefix}Applied dataset fraction={self.fraction:g}: keeping {len(label_entries):,}/{total_entries:,} entries"
)
if log_progress:
LOGGER.info(
f"{self.prefix}Loaded {len(label_entries):,} Ground3D GT entries from {len(items)} list file(s) "
f"in {perf_counter() - scan_start:.1f}s"
)
LOGGER.info(f"{self.prefix}Checking sampled image access speed...")
sample_entries = random.sample(label_entries, min(5, len(label_entries)))
check_file_speeds([self._entry_to_image_file(entry) for entry in sample_entries], prefix=self.prefix)
return label_entries # label 路径里的 labels 会被换成 images后缀优先 .png找不到再尝试 .jpg。
@staticmethod
def _label_rel_to_image_rel(rel_label: Path) -> Path:
"""Convert a relative Ground3D label path to its matching image path."""
parts = list(rel_label.parts)
if "labels" in parts:
parts[len(parts) - 1 - parts[::-1].index("labels")] = "images"
rel_label = Path(*parts)
return rel_label.with_suffix(".png")
@staticmethod
def _entry_to_rel_label(label_entry: tuple[str, str]) -> Path:
"""Return the relative label path stored in a Ground3D GT entry."""
return Path(label_entry[1])
def _entry_to_label_file(self, label_entry: tuple[str, str]) -> str:
"""Resolve a Ground3D GT entry to its absolute label file path."""
list_root, _ = label_entry
return str((Path(list_root) / self._entry_to_rel_label(label_entry)).resolve())
@staticmethod
def _label_rel_to_calib_rel(rel_label: Path) -> Path:
"""Convert a relative Ground3D label path to its matching calibration path."""
parts = list(rel_label.parts)
if "labels" in parts:
parts[len(parts) - 1 - parts[::-1].index("labels")] = "calib"
rel_label = Path(*parts)
return rel_label.with_suffix(".json")
def _entry_to_calib_file(self, label_entry: tuple[str, str]) -> str:
"""Resolve a Ground3D GT entry to its absolute calibration file path under the label root."""
list_root, _ = label_entry
return str((Path(list_root) / self._label_rel_to_calib_rel(self._entry_to_rel_label(label_entry))).resolve())
def _entry_to_image_file(self, label_entry: tuple[str, str]) -> str:
"""Resolve a Ground3D GT entry to its absolute image file path."""
rel_image = self._label_rel_to_image_rel(self._entry_to_rel_label(label_entry))
image_file = (self.image_root / rel_image).resolve()
if image_file.exists():
return str(image_file)
jpg_image_file = image_file.with_suffix(".jpg")
if jpg_image_file.exists():
return str(jpg_image_file)
return str(image_file)
def update_labels(self, include_class: list[int] | None) -> None:
"""Record class filtering rules and apply them lazily on parsed Ground3D labels."""
self._include_class = np.array(include_class, dtype=np.int64).reshape(1, -1) if include_class is not None else None
def set_rectangle(self) -> None:
"""Configure rectangular batching from the fixed Ground3D source shape."""
if self.ni == 0:
self.batch = np.zeros(0, dtype=int)
self.batch_shapes = np.zeros((0, 2), dtype=int)
return
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)
nb = bi[-1] + 1
ori_h, ori_w = self._ground3d_shape
ar = ori_h / ori_w if ori_w else 1.0
if ar < 1:
base_shape = [ar, 1.0]
elif ar > 1:
base_shape = [1.0, 1.0 / ar]
else:
base_shape = [1.0, 1.0]
imgsz = np.array(self.imgsz if isinstance(self.imgsz, (list, tuple)) else [self.imgsz, self.imgsz], dtype=np.float32)
shapes = np.tile(np.array(base_shape, dtype=np.float32), (nb, 1))
self.batch_shapes = np.ceil(shapes * imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi
def _filter_ground3d_labels(self, lb_2d: dict[str, Any], lb_3d: np.ndarray | None) -> tuple[dict[str, Any], np.ndarray | None]:
"""Apply class selection lazily to on-the-fly Ground3D labels."""
if self._include_class is not None and len(lb_2d["cls"]):
keep = (lb_2d["cls"] == self._include_class).any(1)
lb_2d = {
**lb_2d,
"cls": lb_2d["cls"][keep],
"bboxes": lb_2d["bboxes"][keep],
"difficulties": lb_2d["difficulties"][keep],
"difficulty_levels": lb_2d["difficulty_levels"][keep],
"segments": [lb_2d["segments"][si] for si, idx in enumerate(keep.tolist()) if idx] if lb_2d["segments"] else [],
"keypoints": lb_2d["keypoints"][keep] if lb_2d["keypoints"] is not None else None,
}
if lb_3d is not None:
lb_3d = lb_3d[keep]
if self.single_cls and len(lb_2d["cls"]):
lb_2d = {**lb_2d, "cls": lb_2d["cls"].copy()}
lb_2d["cls"][:, 0] = 0
return lb_2d, lb_3d
def _mark_bad_image(self, index: int, im_file: str, reason: str) -> None:
"""Remember bad images so later epochs skip them immediately."""
self._bad_image_mask[index] = True
if index not in self._warned_bad_image_indices:
LOGGER.warning(f"{self.prefix}Skipping bad Ground3D image [{index}] {im_file}: {reason}")
self._warned_bad_image_indices.add(index)
def _validate_ground3d_image(self, index: int, im_file: str, img: np.ndarray | None) -> np.ndarray:
"""Validate a decoded Ground3D image before calibration and label work."""
if self._bad_image_mask[index]:
raise Ground3DImageReadError(f"{im_file}: previously marked unreadable or invalid")
if img is None:
self._mark_bad_image(index, im_file, "cv2.imread() returned None")
raise Ground3DImageReadError(f"{im_file}: cv2.imread() returned None")
if not isinstance(img, np.ndarray):
self._mark_bad_image(index, im_file, f"decoder returned {type(img).__name__}, expected ndarray")
raise Ground3DImageReadError(f"{im_file}: decoder returned non-array image data")
shape = tuple(int(v) for v in img.shape[:2])
if self.channels == 1:
valid_channels = img.ndim == 2
expected_shape = "HxW"
else:
valid_channels = img.ndim == 3 and img.shape[2] == self.channels
expected_shape = f"HxWx{self.channels}"
if not valid_channels:
self._mark_bad_image(index, im_file, f"unexpected decoded shape {img.shape}, expected {expected_shape}")
raise Ground3DImageReadError(f"{im_file}: unexpected decoded shape {img.shape}")
if min(shape) < self.min_image_side:
self._mark_bad_image(index, im_file, f"decoded image is too small: {shape}")
raise Ground3DImageReadError(f"{im_file}: decoded image is too small: {shape}")
if self.strict_image_shape and shape != self._ground3d_shape:
self._mark_bad_image(
index,
im_file,
f"decoded shape {shape} does not match expected {self._ground3d_shape}",
)
raise Ground3DImageReadError(f"{im_file}: decoded shape {shape} does not match expected {self._ground3d_shape}")
return img
def _read_ground3d_image(self, index: int, im_file: str) -> np.ndarray:
"""Read and validate a Ground3D image, failing with a skip-safe exception."""
if self._bad_image_mask[index]:
raise Ground3DImageReadError(f"{im_file}: previously marked unreadable or invalid")
img = cv2.imread(im_file, self.cv2_flag)
return self._validate_ground3d_image(index, im_file, img)
def _load_transformed_ground3d_sample(self, index: int) -> dict[str, Any]:
"""Load, transform, and finalize one Ground3D sample."""
labels = self.transforms(self.get_image_and_label(index))
# Convert labels_3d: numpy/None -> tensor
if labels.get("labels_3d") is not None:
if isinstance(labels["labels_3d"], np.ndarray):
labels["labels_3d"] = torch.from_numpy(labels["labels_3d"].copy())
else:
nl = len(labels["bboxes"]) if "bboxes" in labels else 0
# Keep absent 3D supervision as NaN so downstream 3D loss/metrics skip pure 2D objects.
labels["labels_3d"] = torch.full((nl, 42), float("nan"), dtype=torch.float32)
if self.min_wh > 0 and "bboxes" in labels and len(labels["bboxes"]) > 0:
bboxes = labels["bboxes"]
_, h, w = labels["img"].shape
w_px = bboxes[:, 2] * w
h_px = bboxes[:, 3] * h
valid = (w_px >= self.min_wh) | (h_px >= self.min_wh)
if not valid.all():
labels["bboxes"] = bboxes[valid]
labels["cls"] = labels["cls"][valid]
if "batch_idx" in labels:
labels["batch_idx"] = labels["batch_idx"][valid]
if "difficulties" in labels:
labels["difficulties"] = labels["difficulties"][valid]
if "difficulty_levels" in labels:
labels["difficulty_levels"] = labels["difficulty_levels"][valid]
labels["labels_3d"] = labels["labels_3d"][valid]
edge_start = perf_counter() if self.profile_timing else None
self._build_edge_gt(labels)
if edge_start is not None and isinstance(labels.get("_profile_timing"), dict):
labels["_profile_timing"]["precompute_edge_gt"] = perf_counter() - edge_start
return labels
def get_image_and_label(self, index):
"""Load a mono3D sample through an explicit ROI or virtual-camera preprocessing pipeline."""
label_entry = self.label_entries[index]
label_file = self._entry_to_label_file(label_entry)
im_file = self._entry_to_image_file(label_entry)
calib_file = self._entry_to_calib_file(label_entry)
label = {"im_file": im_file}
timings = {} if self.profile_timing else None
t0 = perf_counter()
img = self._read_ground3d_image(index, im_file)
if timings is not None:
timings["read_image"] = perf_counter() - t0
ori_h, ori_w = img.shape[:2]
t0 = perf_counter()
lb_2d, lb_3d = parse_ground_3d_label_file(
label_file,
self.class_map,
self.difficulty_weights,
self.face_3d_classes,
self.complete_3d_classes,
self.min_wh,
)
lb_2d, lb_3d = self._filter_ground3d_labels(lb_2d, lb_3d)
if timings is not None:
timings["parse_label"] = perf_counter() - t0
has_3d_targets = _ground3d_has_finite_targets(lb_3d)
t0 = perf_counter()
raw_calib = read_calib_from_path(im_file, image_root=self.image_root, extra_calib_candidates=[calib_file])
if timings is not None:
timings["read_calib"] = perf_counter() - t0
if raw_calib is None and has_3d_targets:
expected_calib = Path(calib_file)
expected_calib = (
expected_calib if expected_calib.name == "camera4.json" else expected_calib.parent / "L2_calib" / "camera4.json"
)
LOGGER.error(f"{self.prefix}Missing Ground3D calibration [{index}] {im_file} (expected {expected_calib})")
raise Ground3DCalibrationError(
f"{im_file}: calibration file not found (expected clip-level camera file at {expected_calib})"
)
target_w, target_h = self.imgsz if isinstance(self.imgsz, (list, tuple)) else (self.imgsz, self.imgsz)
roi_w, roi_h = self.roi
use_virtual_camera = (
raw_calib is not None and 0 < self.virtual_camera_prob <= 1 and random.random() < self.virtual_camera_prob
)
use_virtual_camera_zoom = self.augment or self.virtual_camera_val_zoom
roi_active = (raw_calib is not None or not has_3d_targets) and (roi_h < ori_h or roi_w < ori_w)
sample_ori_shape = (ori_h, ori_w)
vp_x = compute_vanishing_point_x(raw_calib, ori_w)
vp_y = compute_vanishing_point_y(raw_calib, ori_h)
crop_center_x = vp_x if self.crop_center_mode == "vxvy" else ori_w / 2
label["camera_mode"] = "virtual" if use_virtual_camera else "roi"
if use_virtual_camera:
t0 = perf_counter()
simul_calib = compute_simul_calib(
raw_calib,
(ori_w, ori_h),
(target_w, target_h),
crop_center_x,
vp_y,
target_fx=self.virtual_fx,
augment=use_virtual_camera_zoom,
)
if timings is not None:
timings["compute_simul_calib"] = perf_counter() - t0
labels_48 = pack_labels_to_48(lb_2d, lb_3d)
t0 = perf_counter()
img, labels_48 = apply_simul_transform(
img, labels_48, simul_calib, raw_calib, (target_w, target_h), augment=self.augment
)
if timings is not None:
timings["apply_simul_transform"] = perf_counter() - t0
lb_2d, lb_3d = unpack_labels_from_48(labels_48)
final_calib = {
"fx": simul_calib["fx"],
"fy": simul_calib["fy"],
"cx": simul_calib["cx"],
"cy": simul_calib["cy"],
"depth_scale": simul_calib["depth_scale"],
}
sample_ori_shape = (target_h, target_w)
else:
crop_bounds = (0, 0, ori_w, ori_h)
roi_crop_time = 0.0
if roi_active:
t0 = perf_counter()
crop_bounds = compute_centered_roi_bounds(ori_w, ori_h, roi_w, roi_h, crop_center_x, vp_y)
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bounds
img = img[crop_y1:crop_y2, crop_x1:crop_x2]
sample_ori_shape = (crop_y2 - crop_y1, crop_x2 - crop_x1)
roi_crop_time = perf_counter() - t0
t0 = perf_counter()
calib_for_resize = adjust_calib_for_roi_crop(raw_calib, ori_w, ori_h, crop_bounds)
img = _resize_ground3d_image_in_steps(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
final_calib = build_final_resized_calib(
calib_for_resize["focal_u"],
calib_for_resize["focal_v"],
calib_for_resize["cu"],
calib_for_resize["cv"],
calib_for_resize["src_w"],
calib_for_resize["src_h"],
target_w,
target_h,
self.virtual_fx,
distort_coeffs=calib_for_resize["distort_coeffs"],
)
if timings is not None:
timings["roi_resize"] = perf_counter() - t0
if roi_crop_time > 0:
timings["roi_crop"] = roi_crop_time
if roi_active:
t0 = perf_counter()
lb_2d, lb_3d = remap_labels_to_roi(lb_2d, lb_3d, ori_w, ori_h, crop_bounds)
if timings is not None:
timings["remap_labels_to_roi"] = perf_counter() - t0
t0 = perf_counter()
lb_3d = normalize_roi_depth(lb_3d, final_calib["fx"], self.virtual_fx)
if timings is not None:
timings["normalize_roi_depth"] = perf_counter() - t0
label["ori_shape"] = sample_ori_shape
label["resized_shape"] = (target_h, target_w)
label["ratio_pad"] = (target_h / sample_ori_shape[0], target_w / sample_ori_shape[1])
label["img"] = img
label["calib"] = final_calib
if timings is not None:
timings["total_sample_build"] = sum(timings.values())
label["_profile_timing"] = timings
if self.rect:
label["rect_shape"] = self.batch_shapes[self.batch[index]]
label.update(lb_2d)
label["labels_3d"] = lb_3d
return self.update_labels_info(label)
def build_transforms(self, hyp=None):
"""Build transforms: only color-space augmentations for 3D training.
No geometric augmentations (Mosaic, RandomPerspective, MixUp, CutMix, RandomFlip)
since 3D labels are in camera space, not image space.
No LetterBox since image is already resized to target in get_image_and_label.
"""
transforms = []
if self.augment and hyp:
transforms.append(RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v))
transforms.append(
Format(
bbox_format="xywh",
normalize=True,
batch_idx=True,
bgr=hyp.bgr if self.augment and hyp else 0.0,
)
)
return Compose(transforms)
def update_labels_info(self, label):
"""Update label format, preserving labels_3d, calib, and camera mode in the dict."""
labels_3d = label.pop("labels_3d", None)
calib = label.pop("calib", None)
camera_mode = label.pop("camera_mode", None)
label = super().update_labels_info(label)
label["labels_3d"] = labels_3d
label["calib"] = calib
label["camera_mode"] = camera_mode
return label
@staticmethod
def _empty_edge_gt(nl: int) -> dict[str, torch.Tensor]:
"""Return empty precomputed edge GT tensors aligned with the object dimension."""
return {
"edge_faces_points_2d": torch.zeros((nl, 4, 5, 2), dtype=torch.float32),
"edge_faces_depths": torch.zeros((nl, 4, 5), dtype=torch.float32),
"edge_faces_valid": torch.zeros((nl, 4), dtype=torch.bool),
"edge_partial_points_2d": torch.zeros((nl, 5, 2), dtype=torch.float32),
"edge_partial_depths": torch.zeros((nl, 5), dtype=torch.float32),
"edge_partial_face_type": torch.full((nl,), -1, dtype=torch.long),
"edge_partial_valid": torch.zeros((nl,), dtype=torch.bool),
}
def _build_edge_gt(self, labels: dict[str, Any]) -> None:
"""Precompute GT edge supervision on CPU inside dataloader workers."""
nl = int(len(labels.get("labels_3d", [])))
labels.update(self._empty_edge_gt(nl))
if not self.precompute_edge_gt or nl == 0:
return
calib = labels.get("calib")
if not isinstance(calib, dict):
return
labels_3d = labels["labels_3d"]
cls = labels["cls"]
bboxes = labels["bboxes"]
_, img_h, img_w = labels["img"].shape
labels_3d_np = labels_3d.cpu().numpy() if isinstance(labels_3d, torch.Tensor) else np.asarray(labels_3d)
cls_np = cls.view(-1).cpu().numpy().astype(int) if isinstance(cls, torch.Tensor) else np.asarray(cls).reshape(-1).astype(int)
bboxes_np = bboxes.cpu().numpy() if isinstance(bboxes, torch.Tensor) else np.asarray(bboxes)
xyxy = np.zeros((nl, 4), dtype=np.float32)
xyxy[:, 0] = (bboxes_np[:, 0] - bboxes_np[:, 2] / 2) * img_w
xyxy[:, 1] = (bboxes_np[:, 1] - bboxes_np[:, 3] / 2) * img_h
xyxy[:, 2] = (bboxes_np[:, 0] + bboxes_np[:, 2] / 2) * img_w
xyxy[:, 3] = (bboxes_np[:, 1] + bboxes_np[:, 3] / 2) * img_h
for obj_idx in range(nl):
target_42 = labels_3d_np[obj_idx]
cls_id = int(cls_np[obj_idx])
bbox_xyxy = xyxy[obj_idx]
partial_edge = decode_cut_partial_side_edge_from_gt(
target_42,
cls_id,
calib,
int(img_w),
int(img_h),
self.face_3d_classes,
self.complete_3d_classes,
bbox_xyxy=bbox_xyxy,
)
if partial_edge is not None:
labels["edge_partial_points_2d"][obj_idx] = torch.from_numpy(partial_edge["points_2d"])
labels["edge_partial_depths"][obj_idx] = torch.from_numpy(partial_edge["depths"])
labels["edge_partial_face_type"][obj_idx] = int(partial_edge["face_type"])
labels["edge_partial_valid"][obj_idx] = True
for face_type in range(4):
decoded = partial_edge if partial_edge is not None and int(partial_edge["face_type"]) == face_type else None
if decoded is None:
decoded = decode_visible_face_edge_from_gt(
target_42,
cls_id,
calib,
int(img_w),
int(img_h),
self.face_3d_classes,
self.complete_3d_classes,
face_type=face_type,
score_thr=self.face_visibility_score_thresh,
bbox_xyxy=bbox_xyxy,
)
if decoded is None:
continue
labels["edge_faces_points_2d"][obj_idx, face_type] = torch.from_numpy(decoded["points_2d"])
labels["edge_faces_depths"][obj_idx, face_type] = torch.from_numpy(decoded["depths"])
labels["edge_faces_valid"][obj_idx, face_type] = True
def __getitem__(self, index):
"""Return transformed label with post-augmentation min_wh filtering for both 2D and 3D."""
if not self.skip_bad_images:
return self._load_transformed_ground3d_sample(int(index))
last_error = None
start_index = int(index)
for offset in range(self.ni):
sample_index = (start_index + offset) % self.ni
try:
return self._load_transformed_ground3d_sample(sample_index)
except Ground3DImageReadError as exc:
last_error = exc
continue
raise RuntimeError(
f"{self.prefix}Failed to load a valid Ground3D image after checking all {self.ni} samples."
) from last_error
class YOLOMultiModalDataset(YOLODataset):
"""Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
This class extends YOLODataset to add text information for multi-modal model training, enabling models to process
both image and text data.
Methods:
update_labels_info: Add text information for multi-modal model training.
build_transforms: Enhance data transformations with text augmentation.
Examples:
>>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
>>> batch = next(iter(dataset))
>>> print(batch.keys()) # Should include 'texts'
"""
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
"""Initialize a YOLOMultiModalDataset.
Args:
data (dict, optional): Dataset configuration dictionary.
task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
*args (Any): Additional positional arguments for the parent class.
**kwargs (Any): Additional keyword arguments for the parent class.
"""
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label: dict) -> dict:
"""Add text information for multi-modal model training.
Args:
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
Returns:
(dict): Updated label dictionary with instances and texts.
"""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
# NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
def build_transforms(self, hyp: dict | None = None) -> Compose:
"""Enhance data transformations with optional text augmentation for multi-modal training.
Args:
hyp (dict, optional): Hyperparameters for transforms.
Returns:
(Compose): Composed transforms including text augmentation if applicable.
"""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
# NOTE: this implementation is different from official yoloe,
# the strategy of selecting negative is restricted in one dataset,
# while official pre-saved neg embeddings from all datasets at once.
transform = RandomLoadText(
max_samples=min(self.data["nc"], 80),
padding=True,
padding_value=self._get_neg_texts(self.category_freq),
)
transforms.insert(-1, transform)
return transforms
@property
def category_names(self):
"""Return category names for the dataset.
Returns:
(set[str]): Set of class names.
"""
names = self.data["names"].values()
return {n.strip() for name in names for n in name.split("/")} # category names
@property
def category_freq(self):
"""Return frequency of each category in the dataset."""
texts = [v.split("/") for v in self.data["names"].values()]
category_freq = defaultdict(int)
for label in self.labels:
for c in label["cls"].squeeze(-1): # to check
text = texts[int(c)]
for t in text:
t = t.strip()
category_freq[t] += 1
return category_freq
@staticmethod
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
"""Get negative text samples based on frequency threshold."""
threshold = min(max(category_freq.values()), 100)
return [k for k, v in category_freq.items() if v >= threshold]
class GroundingDataset(YOLODataset):
"""Dataset class for object detection tasks using annotations from a JSON file in grounding format.
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard
YOLO format text files.
Attributes:
json_file (str): Path to the JSON file containing annotations.
Methods:
get_img_files: Return empty list as image files are read in get_labels.
get_labels: Load annotations from a JSON file and prepare them for training.
build_transforms: Configure augmentations for training with optional text loading.
Examples:
>>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
>>> len(dataset) # Number of valid images with annotations
"""
def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
"""Initialize a GroundingDataset for object detection.
Args:
json_file (str): Path to the JSON file containing annotations.
task (str): Must be 'detect' or 'segment' for GroundingDataset.
max_samples (int): Maximum number of samples to load for text augmentation.
*args (Any): Additional positional arguments for the parent class.
**kwargs (Any): Additional keyword arguments for the parent class.
"""
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
self.json_file = json_file
self.max_samples = max_samples
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
def get_img_files(self, img_path: str) -> list:
"""The image files would be read in `get_labels` function, return empty list here.
Args:
img_path (str): Path to the directory containing images.
Returns:
(list): Empty list as image files are read in get_labels.
"""
return []
def verify_labels(self, labels: list[dict[str, Any]]) -> None:
"""Verify the number of instances in the dataset matches expected counts.
This method checks if the total number of bounding box instances in the provided labels matches the expected
count for known datasets. It performs validation against a predefined set of datasets with known instance
counts.
Args:
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset
annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding
box coordinates.
Raises:
AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset.
Notes:
For unrecognized datasets (those not in the predefined expected_counts),
a warning is logged and verification is skipped.
"""
expected_counts = {
"final_mixed_train_no_coco_segm": 3662412,
"final_mixed_train_no_coco": 3681235,
"final_flickr_separateGT_train_segm": 638214,
"final_flickr_separateGT_train": 640704,
}
instance_count = sum(label["bboxes"].shape[0] for label in labels)
for data_name, count in expected_counts.items():
if data_name in self.json_file:
assert instance_count == count, f"'{self.json_file}' has {instance_count} instances, expected {count}."
return
LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
"""Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
Args:
path (Path): Path where to save the cache file.
Returns:
(dict[str, Any]): Dictionary containing cached labels and related information.
"""
x = {"labels": []}
LOGGER.info("Loading annotation file...")
with open(self.json_file) as f:
annotations = json.load(f)
images = {f"{x['id']:d}": x for x in annotations["images"]}
img_to_anns = defaultdict(list)
for ann in annotations["annotations"]:
img_to_anns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
if not im_file.exists():
continue
self.im_files.append(str(im_file))
bboxes = []
segments = []
cat2id = {}
texts = []
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
box[:2] += box[2:] / 2
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
if box[2] <= 0 or box[3] <= 0:
continue
caption = img["caption"]
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
if not cat_name:
continue
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
box = [cls, *box.tolist()]
if box not in bboxes:
bboxes.append(box)
if ann.get("segmentation") is not None:
if len(ann["segmentation"]) == 0:
segments.append(box)
continue
elif len(ann["segmentation"]) > 1:
s = merge_multi_segment(ann["segmentation"])
s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
else:
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
s = (
(np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
.reshape(-1)
.tolist()
)
s = [cls, *s]
segments.append(s)
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
if segments:
classes = np.array([x[0] for x in segments], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
lb = np.array(lb, dtype=np.float32)
x["labels"].append(
{
"im_file": im_file,
"shape": (h, w),
"cls": lb[:, 0:1], # n, 1
"bboxes": lb[:, 1:], # n, 4
"segments": segments,
"normalized": True,
"bbox_format": "xywh",
"texts": texts,
}
)
x["hash"] = get_hash(self.json_file)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self) -> list[dict]:
"""Load labels from cache or generate them from JSON file.
Returns:
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
"""
cache_path = Path(self.json_file).with_suffix(".cache")
try:
cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash(self.json_file) # identical hash
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
cache, _ = self.cache_labels(cache_path), False # run cache ops
[cache.pop(k) for k in ("hash", "version")] # remove items
labels = cache["labels"]
self.verify_labels(labels)
self.im_files = [str(label["im_file"]) for label in labels]
if LOCAL_RANK in {-1, 0}:
LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
return labels
def build_transforms(self, hyp: dict | None = None) -> Compose:
"""Configure augmentations for training with optional text loading.
Args:
hyp (dict, optional): Hyperparameters for transforms.
Returns:
(Compose): Composed transforms including text augmentation if applicable.
"""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
# NOTE: this implementation is different from official yoloe,
# the strategy of selecting negative is restricted in one dataset,
# while official pre-saved neg embeddings from all datasets at once.
transform = RandomLoadText(
max_samples=min(self.max_samples, 80),
padding=True,
padding_value=self._get_neg_texts(self.category_freq),
)
transforms.insert(-1, transform)
return transforms
@property
def category_names(self):
"""Return unique category names from the dataset."""
return {t.strip() for label in self.labels for text in label["texts"] for t in text}
@property
def category_freq(self):
"""Return frequency of each category in the dataset."""
category_freq = defaultdict(int)
for label in self.labels:
for text in label["texts"]:
for t in text:
t = t.strip()
category_freq[t] += 1
return category_freq
@staticmethod
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
"""Get negative text samples based on frequency threshold."""
threshold = min(max(category_freq.values()), 100)
return [k for k, v in category_freq.items() if v >= threshold]
class YOLOConcatDataset(ConcatDataset):
"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation
function.
Methods:
collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.
Examples:
>>> dataset1 = YOLODataset(...)
>>> dataset2 = YOLODataset(...)
>>> combined_dataset = YOLOConcatDataset([dataset1, dataset2])
"""
@staticmethod
def collate_fn(batch: list[dict]) -> dict:
"""Collate data samples into batches.
Args:
batch (list[dict]): List of dictionaries containing sample data.
Returns:
(dict): Collated batch with stacked tensors.
"""
return YOLODataset.collate_fn(batch)
def close_mosaic(self, hyp: dict) -> None:
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
Args:
hyp (dict): Hyperparameters for transforms.
"""
for dataset in self.datasets:
if not hasattr(dataset, "close_mosaic"):
continue
dataset.close_mosaic(hyp)
# TODO: support semantic segmentation
class SemanticDataset(BaseDataset):
"""Semantic Segmentation Dataset."""
def __init__(self):
"""Initialize a SemanticDataset object."""
super().__init__()
class ClassificationDataset:
"""Dataset class for image classification tasks wrapping torchvision ImageFolder functionality.
This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
to speed up training.
Attributes:
cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): A list of lists, each containing the path to an image, its class index, path to its .npy cache
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
torch_transforms (callable): PyTorch transforms to be applied to the images.
root (str): Root directory of the dataset.
prefix (str): Prefix for logging and cache filenames.
Methods:
__getitem__: Return transformed image and class index for the given sample index.
__len__: Return the total number of samples in the dataset.
verify_images: Verify all images in dataset.
"""
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
"""Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
Args:
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
parameters, and cache settings.
augment (bool, optional): Whether to apply augmentations to the dataset.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification.
"""
import torchvision # scope for faster 'import ultralytics'
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
else:
self.base = torchvision.datasets.ImageFolder(root=root)
self.samples = self.base.samples
self.root = self.base.root
# Initialize attributes
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
if self.cache_ram:
LOGGER.warning(
"Classification `cache_ram` training has known memory leak in "
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
)
self.cache_ram = False
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
self.samples = self.verify_images() # filter out bad images
self.samples = [[*list(x), Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
self.torch_transforms = (
classify_augmentations(
size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v,
)
if augment
else classify_transforms(size=args.imgsz)
)
def __getitem__(self, i: int) -> dict:
"""Return transformed image and class index for the given sample index.
Args:
i (int): Index of the sample to retrieve.
Returns:
(dict): Dictionary containing the image and its class index.
"""
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram:
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
im = self.samples[i][3] = cv2.imread(f)
elif self.cache_disk:
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
# Convert NumPy array to PIL image
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
sample = self.torch_transforms(im)
return {"img": sample, "cls": j}
def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
return len(self.samples)
def verify_images(self) -> list[tuple]:
"""Verify all images in dataset.
Returns:
(list[tuple]): List of valid samples after verification.
"""
desc = f"{self.prefix}Scanning {self.root}..."
path = Path(self.root).with_suffix(".cache") # *.cache file path
try:
check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in {-1, 0}:
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
return samples
except (FileNotFoundError, AssertionError, AttributeError):
# Run scan if *.cache retrieval failed
nf, nc, msgs, samples, x = 0, 0, [], [], {}
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = TQDM(results, desc=desc, total=len(self.samples))
for sample, nf_f, nc_f, msg in pbar:
if nf_f:
samples.append(sample)
if msg:
msgs.append(msg)
nf += nf_f
nc += nc_f
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples