# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license from __future__ import annotations import hashlib import glob import json import math import os import random from collections import defaultdict from copy import deepcopy from itertools import repeat from multiprocessing.pool import ThreadPool from pathlib import Path from time import perf_counter from typing import Any import cv2 import numpy as np import torch from PIL import Image from torch.utils.data import ConcatDataset from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr from ultralytics.utils.instance import Instances from ultralytics.utils.ops import resample_segments, segments2boxes from ultralytics.utils.plotting_3d import decode_cut_partial_side_edge_from_gt, decode_visible_face_edge_from_gt from ultralytics.utils.torch_utils import TORCHVISION_0_18 from .augment import ( Compose, Format, LetterBox, RandomHSV, RandomLoadText, classify_augmentations, classify_transforms, v8_transforms, ) from .base import BaseDataset from .converter import merge_multi_segment from .ground3d_augment import ( adjust_calib_for_roi_crop, apply_simul_transform, build_final_resized_calib, compute_centered_roi_bounds, compute_simul_calib, compute_vanishing_point_x, compute_vanishing_point_y, normalize_roi_depth, pack_labels_to_48, parse_ground_3d_label_file, read_calib_from_path, remap_labels_to_roi, unpack_labels_from_48, ) from .utils import ( FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds, get_hash, img2label_paths, load_dataset_cache_file, save_dataset_cache_file, verify_image, verify_image_label, ) # Ultralytics dataset *.cache version, >= 1.0.0 for Ultralytics YOLO models DATASET_CACHE_VERSION = "1.0.3" GROUND3D_REQUIRED_DATA_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode") GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"} class Ground3DImageReadError(RuntimeError): """Raised when a Ground3D sample image is unreadable or fails validation.""" class Ground3DCalibrationError(RuntimeError): """Raised when a Ground3D sample is missing required calibration.""" def _resize_ground3d_image_in_steps( img: np.ndarray, target_size: tuple[int, int], interpolation: int = cv2.INTER_LINEAR ) -> np.ndarray: """Resize a Ground3D image with repeated 0.5x downsampling before the final resize.""" target_w, target_h = int(target_size[0]), int(target_size[1]) current_h, current_w = img.shape[:2] if (current_w, current_h) == (target_w, target_h): return img if target_w < current_w and target_h < current_h: while True: next_w = math.ceil(current_w * 0.5) next_h = math.ceil(current_h * 0.5) if next_w < target_w or next_h < target_h: break img = cv2.resize(img, (next_w, next_h), interpolation=interpolation) current_h, current_w = img.shape[:2] if (current_w, current_h) == (target_w, target_h): return img return cv2.resize(img, (target_w, target_h), interpolation=interpolation) def _require_ground3d_data_value(data: dict[str, Any], key: str) -> Any: """Return a required Ground3D config value or raise a clear error.""" if key not in data or data[key] is None: raise ValueError(f"Ground3D dataset config must define '{key}' explicitly; defaults are not allowed.") return data[key] def _validate_ground3d_pair(value: Any, key: str) -> tuple[int, int]: """Validate a Ground3D size-like config field and return it as an integer pair.""" if not isinstance(value, (list, tuple)) or len(value) != 2: raise ValueError(f"Ground3D dataset config field '{key}' must be a 2-item list/tuple, but got: {value!r}") pair = (int(value[0]), int(value[1])) if pair[0] <= 0 or pair[1] <= 0: raise ValueError(f"Ground3D dataset config field '{key}' must contain positive values, but got: {value!r}") return pair def _validate_ground3d_crop_center_mode(value: Any) -> str: """Validate the Ground3D crop-center mode.""" if value not in GROUND3D_VALID_CROP_CENTER_MODES: valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES)) raise ValueError( f"Ground3D dataset config field 'crop_center_mode' must be one of {{{valid_modes}}}, but got: {value!r}" ) return str(value) def _validate_ground3d_positive_float(value: Any, key: str) -> float: """Validate a positive numeric Ground3D config field.""" value = float(value) if value <= 0: raise ValueError(f"Ground3D dataset config field '{key}' must be positive, but got: {value!r}") return value def _validate_ground3d_float(value: Any, key: str) -> float: """Validate a numeric Ground3D config field.""" try: return float(value) except (TypeError, ValueError) as e: raise ValueError(f"Ground3D dataset config field '{key}' must be numeric, but got: {value!r}") from e def _ground3d_has_finite_targets(labels_3d: np.ndarray | None) -> bool: """Return whether a parsed 3D label array contains any real 3D supervision.""" if labels_3d is None: return False labels_3d = np.asarray(labels_3d) return labels_3d.size > 0 and np.isfinite(labels_3d).any() class YOLODataset(BaseDataset): """Dataset class for loading object detection and/or segmentation labels in YOLO format. This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box (OBB) tasks using the YOLO format. Attributes: use_segments (bool): Indicates if segmentation masks should be used. use_keypoints (bool): Indicates if keypoints should be used for pose estimation. use_obb (bool): Indicates if oriented bounding boxes should be used. data (dict): Dataset configuration dictionary. Methods: cache_labels: Cache dataset labels, check images and read shapes. get_labels: Return list of label dictionaries for YOLO training. build_transforms: Build and append transforms to the list. close_mosaic: Disable mosaic, copy_paste, mixup and cutmix augmentations and build transformations. update_labels_info: Update label format for different tasks. collate_fn: Collate data samples into batches. Examples: >>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect") >>> dataset.get_labels() """ def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs): """Initialize the YOLODataset. Args: data (dict, optional): Dataset configuration dictionary. task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'. *args (Any): Additional positional arguments for the parent class. **kwargs (Any): Additional keyword arguments for the parent class. """ self.use_segments = task == "segment" self.use_keypoints = task == "pose" self.use_obb = task == "obb" self.data = data assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." super().__init__(*args, channels=self.data.get("channels", 3), use_yuv444=self.data.get("use_yuv444", False), **kwargs) def cache_labels(self, path: Path = Path("./labels.cache")) -> dict: """Cache dataset labels, check images and read shapes. Args: path (Path): Path where to save the cache file. Returns: (dict): Dictionary containing cached labels and related information. """ x = {"labels": []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f"{self.prefix}Scanning {path.parent / path.stem}..." total = len(self.im_files) nkpt, ndim = self.data.get("kpt_shape", (0, 0)) if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}): raise ValueError( "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" ) with ThreadPool(NUM_THREADS) as pool: results = pool.imap( func=verify_image_label, iterable=zip( self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints), repeat(len(self.data["names"])), repeat(nkpt), repeat(ndim), repeat(self.single_cls), ), ) pbar = TQDM(results, desc=desc, total=total) for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: nm += nm_f nf += nf_f ne += ne_f nc += nc_f if im_file: x["labels"].append( { "im_file": im_file, "shape": shape, "cls": lb[:, 0:1], # n, 1 "bboxes": lb[:, 1:], # n, 4 "segments": segments, "keypoints": keypoint, "normalized": True, "bbox_format": "xywh", } ) if msg: msgs.append(msg) pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.close() if msgs: LOGGER.info("\n".join(msgs)) if nf == 0: LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}") x["hash"] = get_hash(self.label_files + self.im_files) x["results"] = nf, nm, ne, nc, len(self.im_files) x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return x def get_labels(self) -> list[dict]: """Return list of label dictionaries for YOLO training. This method loads labels from disk or cache, verifies their integrity, and prepares them for training. Returns: (list[dict]): List of label dictionaries, each containing information about an image and its annotations. """ self.label_files = img2label_paths(self.im_files) cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") try: cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file assert cache["version"] == DATASET_CACHE_VERSION # matches current version assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError): cache, exists = self.cache_labels(cache_path), False # run cache ops # Display cache nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total if exists and LOCAL_RANK in {-1, 0}: d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results if cache["msgs"]: LOGGER.info("\n".join(cache["msgs"])) # display warnings # Read cache [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items labels = cache["labels"] if not labels: raise RuntimeError( f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}" ) self.im_files = [lb["im_file"] for lb in labels] # update im_files # Check if the dataset is all boxes or all segments lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) if len_segments and len_boxes != len_segments: LOGGER.warning( f"Box and segment counts should be equal, but got len(segments) = {len_segments}, " f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." ) for lb in labels: lb["segments"] = [] if len_cls == 0: LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}") return labels def build_transforms(self, hyp: dict | None = None) -> Compose: """Build and append transforms to the list. Args: hyp (dict, optional): Hyperparameters for transforms. Returns: (Compose): Composed transforms. """ if self.augment: hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0 transforms = v8_transforms(self, self.imgsz, hyp) else: transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) transforms.append( Format( bbox_format="xywh", normalize=True, return_mask=self.use_segments, return_keypoint=self.use_keypoints, return_obb=self.use_obb, batch_idx=True, mask_ratio=hyp.mask_ratio, mask_overlap=hyp.overlap_mask, bgr=hyp.bgr if self.augment else 0.0, # only affect training. ) ) return transforms def close_mosaic(self, hyp: dict) -> None: """Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0. Args: hyp (dict): Hyperparameters for transforms. """ hyp.mosaic = 0.0 hyp.copy_paste = 0.0 hyp.mixup = 0.0 hyp.cutmix = 0.0 self.transforms = self.build_transforms(hyp) def update_labels_info(self, label: dict) -> dict: """Update label format for different tasks. Args: label (dict): Label dictionary containing bboxes, segments, keypoints, etc. Returns: (dict): Updated label dictionary with instances. Notes: cls is not with bboxes now, classification and semantic segmentation need an independent cls label Can also support classification and semantic segmentation by adding or removing dict keys there. """ bboxes = label.pop("bboxes") segments = label.pop("segments", []) keypoints = label.pop("keypoints", None) bbox_format = label.pop("bbox_format") normalized = label.pop("normalized") # NOTE: do NOT resample oriented boxes segment_resamples = 100 if self.use_obb else 1000 if len(segments) > 0: # make sure segments interpolate correctly if original length is greater than segment_resamples max_len = max(len(s) for s in segments) segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples # list[np.array(segment_resamples, 2)] * num_samples segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) else: segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) return label @staticmethod def collate_fn(batch: list[dict]) -> dict: """Collate data samples into batches. Args: batch (list[dict]): List of dictionaries containing sample data. Returns: (dict): Collated batch with stacked tensors. """ new_batch = {} batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order keys = batch[0].keys() values = list(zip(*[list(b.values()) for b in batch])) for i, k in enumerate(keys): value = values[i] if k in {"img", "text_feats", "sem_masks"}: value = torch.stack(value, 0) elif k == "visuals": value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True) elif k == "camera_mode": value = tuple(value) if k in { "masks", "keypoints", "bboxes", "cls", "segments", "obb", "difficulties", "difficulty_levels", "labels_3d", "edge_faces_points_2d", "edge_faces_depths", "edge_faces_valid", "edge_partial_points_2d", "edge_partial_depths", "edge_partial_face_type", "edge_partial_valid", }: value = torch.cat(value, 0) new_batch[k] = value new_batch["batch_idx"] = list(new_batch["batch_idx"]) for i in range(len(new_batch["batch_idx"])): new_batch["batch_idx"][i] += i # add target image index for build_targets() new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) return new_batch class YOLOGroundDataset(YOLODataset): """Dataset class for loading ground 2D detection labels with difficulty scores. This class extends YOLODataset to support custom annotation format with: - String class names mapped to numeric IDs via class_map - Difficulty scores for each bounding box - Optional minimum box size filtering - Optional YUV444 color space support Attributes: class_map (dict): Mapping from class names to class IDs. min_wh (float): Minimum box width/height in pixels for filtering. use_yuv444 (bool): Whether to use YUV444 color space. Examples: >>> data = {"class_map": {"car": 0, "pedestrian": 1}, "min_wh": 2.0} >>> dataset = YOLOGroundDataset(img_path="path/to/images", data=data, task="detect") """ def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs): """Initialize the YOLOGroundDataset with class mapping and difficulty support. Args: data (dict, optional): Dataset configuration with 'class_map' and optional 'min_wh', 'use_yuv444'. task (str): Task type, should be 'detect' for ground dataset. *args (Any): Additional positional arguments for the parent class. **kwargs (Any): Additional keyword arguments for the parent class. """ self.class_map = data.get("class_map", {}) self.min_wh = data.get("min_wh", 2.0) self.use_yuv444 = data.get("use_yuv444", False) self.difficulty_weights = data.get("difficulty_weights", [1.0]) super().__init__(*args, data=data, task=task, **kwargs) def get_img_files(self, img_path: str | list[str]) -> list[str]: """Read image files while preserving source paths for label lookup. Ground 2D train/val list files may live under one ``Detection2D_*`` root while the corresponding images should be loaded from the YAML ``path`` root. The original list paths are retained for label lookup. """ try: files = [] for p in img_path if isinstance(img_path, list) else [img_path]: p = Path(p) if p.is_dir(): files += glob.glob(str(p / "**" / "*.*"), recursive=True) elif p.is_file(): with open(p, encoding="utf-8") as t: lines = t.read().strip().splitlines() parent = str(p.parent) + os.sep files += [x.replace("./", parent) if x.startswith("./") else x for x in lines] else: raise FileNotFoundError(f"{self.prefix}{p} does not exist") entries = sorted(x.replace("/", os.sep) for x in files if x.strip()) assert entries, f"{self.prefix}No entries found in {img_path}." except Exception as e: raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e if self.fraction < 1: entries = entries[: round(len(entries) * self.fraction)] self.ground_label_files = [] im_files = [] for f in entries: suffix = Path(f).suffix.lower().lstrip(".") if suffix in IMG_FORMATS: self.ground_label_files.append(img2label_paths([f])[0]) im_files.append(self._ground_image_path_from_label_path(f)) elif suffix == "txt": self.ground_label_files.append(f) im_files.append(self._ground_image_path_from_label_path(f, label_path=True)) else: raise FileNotFoundError(f"{self.prefix}Unsupported ground list entry: {f}. {FORMATS_HELP_MSG}") assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}" check_file_speeds(im_files, prefix=self.prefix) return im_files def _ground_image_path_from_label_path(self, im_file: str, label_path: bool = False) -> str: """Map a list-file path from a ``Detection2D_*``/``2Ddetection_*`` root onto the configured image root.""" image_root = Path(self.data["path"]) parts = Path(im_file).parts detection_root_index = next( (i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))), None, ) if detection_root_index is not None and detection_root_index + 1 < len(parts): rel_parts = list(parts[detection_root_index + 1 :]) else: rel_path = Path(im_file) if not Path(im_file).is_absolute() else Path(*parts[1:]) rel_parts = list(rel_path.parts) if label_path: image_path = image_root.joinpath(*rel_parts).with_suffix(".jpg") if image_path.exists(): return str(image_path) image_rel_parts = ["images" if part == "labels" else part for part in rel_parts] image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg") if image_path.exists(): return str(image_path) for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"): candidate = image_path.with_suffix(f".{suffix}") if candidate.exists(): return str(candidate) return str(image_path) return str(image_root.joinpath(*rel_parts)) def cache_labels(self, path: Path = Path("./labels.cache")) -> dict: """Cache dataset labels with difficulty scores, check images and read shapes. Args: path (Path): Path where to save the cache file. Returns: (dict): Dictionary containing cached labels and related information. """ from ultralytics.data.utils import verify_image_label_ground x = {"labels": []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f"{self.prefix}Scanning {path.parent / path.stem}..." total = len(self.im_files) with ThreadPool(NUM_THREADS) as pool: results = pool.imap( func=verify_image_label_ground, iterable=zip( self.im_files, self.label_files, repeat(self.prefix), repeat(self.class_map), ), ) pbar = TQDM(results, desc=desc, total=total) for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: nm += nm_f nf += nf_f ne += ne_f nc += nc_f if im_file: # Map raw difficulty values to loss weights using difficulty_weights lookup dw = self.difficulty_weights if len(lb) > 0: raw_diff = lb[:, 5].astype(int).clip(0, len(dw) - 1) loss_weights = np.array([dw[d] for d in raw_diff], dtype=np.float32).reshape(-1, 1) difficulty_levels = raw_diff.astype(np.int64).reshape(-1, 1) else: loss_weights = np.zeros((0, 1), dtype=np.float32) difficulty_levels = np.zeros((0, 1), dtype=np.int64) x["labels"].append( { "im_file": im_file, "shape": shape, "cls": lb[:, 0:1], # n, 1 "bboxes": lb[:, 1:5], # n, 4 "difficulties": loss_weights, # n, 1 (pre-computed loss weights) "difficulty_levels": difficulty_levels, # n, 1 raw class target in [0, 3] "segments": segments, "keypoints": None, "normalized": True, "bbox_format": "xywh", } ) if msg: msgs.append(msg) pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.close() if msgs: LOGGER.info("\n".join(msgs)) if nf == 0: LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}") x["hash"] = get_hash(self.label_files + self.im_files) x["results"] = nf, nm, ne, nc, len(self.im_files) x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return x def get_labels(self) -> list[dict]: """Return ground labels derived from the original train/val list paths.""" self.label_files = getattr(self, "ground_label_files", img2label_paths(self.im_files)) cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") try: cache, exists = load_dataset_cache_file(cache_path), True assert cache["version"] == DATASET_CACHE_VERSION assert cache["hash"] == get_hash(self.label_files + self.im_files) except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError): cache, exists = self.cache_labels(cache_path), False nf, nm, ne, nc, n = cache.pop("results") if exists and LOCAL_RANK in {-1, 0}: d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" TQDM(None, desc=self.prefix + d, total=n, initial=n) if cache["msgs"]: LOGGER.info("\n".join(cache["msgs"])) [cache.pop(k) for k in ("hash", "version", "msgs")] labels = cache["labels"] if not labels: raise RuntimeError( f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}" ) self.im_files = [lb["im_file"] for lb in labels] if sum(len(lb["cls"]) for lb in labels) == 0: LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}") return labels def update_labels_info(self, label: dict) -> dict: """Update label format with difficulty scores. Args: label (dict): Label dictionary containing bboxes, difficulties, etc. Returns: (dict): Updated label dictionary with instances including difficulties. """ bboxes = label.pop("bboxes") difficulties = label.pop("difficulties", None) difficulty_levels = label.pop("difficulty_levels", None) segments = label.pop("segments", []) keypoints = label.pop("keypoints", None) bbox_format = label.pop("bbox_format") normalized = label.pop("normalized") # Convert empty segments list to numpy array (required by Instances.denormalize) if len(segments) == 0: segments = np.zeros((0, 1000, 2), dtype=np.float32) # Create Instances with difficulties label["instances"] = Instances( bboxes, segments, keypoints, difficulties=difficulties, difficulty_levels=difficulty_levels, bbox_format=bbox_format, normalized=normalized, ) return label def __getitem__(self, index: int) -> dict: """Return transformed label with post-augmentation min_wh filtering. Filters out boxes whose width and height are both smaller than min_wh pixels after all augmentations (resize, crop, mosaic, etc.) have been applied, ensuring the filter operates on final pixel sizes. """ labels = self.transforms(self.get_image_and_label(index)) if self.min_wh > 0 and "bboxes" in labels and len(labels["bboxes"]) > 0: bboxes = labels["bboxes"] # normalized xywh after Format _, h, w = labels["img"].shape # CHW w_px = bboxes[:, 2] * w h_px = bboxes[:, 3] * h valid = (w_px >= self.min_wh) | (h_px >= self.min_wh) if not valid.all(): labels["bboxes"] = bboxes[valid] labels["cls"] = labels["cls"][valid] if "batch_idx" in labels: labels["batch_idx"] = labels["batch_idx"][valid] if "difficulties" in labels: labels["difficulties"] = labels["difficulties"][valid] if "difficulty_levels" in labels: labels["difficulty_levels"] = labels["difficulty_levels"][valid] return labels class YOLOGround3DDataset(YOLOGroundDataset): """Dataset for ground 3D detection backed by GT list files and on-the-fly label loading. Extends YOLOGroundDataset to support joint 2D+3D detection with: - On-the-fly parsing of 19-col (complete_3d) and 51-col (face_3d) label files - A simple GT-manifest list used for dataset length and sample lookup - Only color-space augmentations (3D labels are in camera space) - Virtual camera / fisheye augmentation support Attributes: face_3d_classes (set): Class IDs with 4-face annotations. complete_3d_classes (set): Class IDs with whole-box 3D only. norm_scales_3d (dict): Normalization scales for 3D loss computation. """ def __init__( self, img_path: str | list[str], imgsz: int = 640, cache: bool | str = False, augment: bool = True, hyp: dict[str, Any] = DEFAULT_CFG, prefix: str = "", rect: bool = False, batch_size: int = 16, stride: int = 32, pad: float = 0.5, single_cls: bool = False, classes: list[int] | None = None, fraction: float = 1.0, data: dict | None = None, task: str = "detect", ): init_start = perf_counter() data = data or {} self.use_segments = task == "segment" self.use_keypoints = task == "pose" self.use_obb = task == "obb" assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." missing_keys = [key for key in GROUND3D_REQUIRED_DATA_KEYS if key not in data or data[key] is None] if missing_keys: missing = ", ".join(missing_keys) raise ValueError(f"Ground3D dataset config must define required field(s) explicitly: {missing}.") path_value = str(_require_ground3d_data_value(data, "path")).strip() # 检查 3D dataset 必需配置是否缺失。依赖文件顶部 GROUND3D_REQUIRED_DATA_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")。 if not path_value: raise ValueError("Ground3D dataset config field 'path' must be a non-empty string.") normalized_ground3d_cfg = { "path": str(Path(path_value).expanduser()), # 读取 data["path"],也就是数据根目录。 "ori_img_size": list(_validate_ground3d_pair(_require_ground3d_data_value(data, "ori_img_size"), "ori_img_size")), # "roi": list(_validate_ground3d_pair(_require_ground3d_data_value(data, "roi"), "roi")), "virtual_fx": _validate_ground3d_positive_float(_require_ground3d_data_value(data, "virtual_fx"), "virtual_fx"), "virtual_camera_prob": _validate_ground3d_float( _require_ground3d_data_value(data, "virtual_camera_prob"), "virtual_camera_prob" ), "crop_center_mode": _validate_ground3d_crop_center_mode(_require_ground3d_data_value(data, "crop_center_mode")), } self.data = {**data, **normalized_ground3d_cfg} self.class_map = data.get("class_map", {}) # class_map:字符串类别名到 class id 的映射。3D label 里类别可能是名字,需要映射成数字类别 self.min_wh = data.get("min_wh", 2.0) # min_wh:过滤过小 2D 框的阈值,默认 2 像素 self.use_yuv444 = data.get("use_yuv444", False) self.difficulty_weights = data.get("difficulty_weights", [1.0]) self.face_3d_classes = set(data.get("face_3d_classes", [])) self.complete_3d_classes = set(data.get("complete_3d_classes", [])) self.norm_scales_3d = data.get("norm_scales_3d", {}) # 3D loss 归一化尺度,例如 depth、尺寸、offset 等归一化参数 self.image_root = Path(self.data["path"]) # 数据根目录 Path 对象,后续从 label 相对路径推导 image/calib 路径。 self.ori_img_size = self.data["ori_img_size"] self.roi = self.data["roi"] self.virtual_fx = self.data["virtual_fx"] self.virtual_camera_prob = self.data["virtual_camera_prob"] self.crop_center_mode = self.data["crop_center_mode"] self.virtual_camera_val_zoom = data.get("virtual_camera_val_zoom", False) # 验证阶段是否也启用 virtual camera zoom 相关逻辑。默认关闭 self._ground3d_shape = (int(self.ori_img_size[1]), int(self.ori_img_size[0])) self._include_class = None self.face_visibility_score_thresh = float( getattr(hyp, "face_visibility_score_thresh", DEFAULT_CFG.face_visibility_score_thresh) # 可见面分数阈值,从 hyp 里取,取不到就用 DEFAULT_CFG。后续构建 edge GT / face GT 时判断某个 face 是否有效。 ) self.precompute_edge_gt = float(getattr(hyp, "edge_aux_loss_gain", 1.0)) > 0 self.profile_timing = bool(getattr(hyp, "batch_timing", False)) self.skip_bad_images = bool(data.get("skip_bad_images", True)) self.strict_image_shape = bool(data.get("strict_image_shape", "ori_img_size" in data)) self.min_image_side = max(int(data.get("min_image_side", 10)), 1) self.img_path = img_path self.imgsz = imgsz self.augment = augment self.single_cls = single_cls self.prefix = prefix self.fraction = fraction self.channels = self.data.get("channels", 3) self.cv2_flag = cv2.IMREAD_GRAYSCALE if self.channels == 1 else cv2.IMREAD_COLOR # 根据通道数决定 OpenCV 按灰度还是彩色读取。Ground 3D 通常是 3 通道,所以一般是 cv2.IMREAD_COLOR self.label_entries = self._load_ground3d_label_entries(self.img_path) self.labels = self.label_entries self.update_labels(include_class=classes) self.ni = len(self.labels) self._bad_image_mask = np.zeros(self.ni, dtype=bool) self._warned_bad_image_indices: set[int] = set() self.rect = rect self.batch_size = batch_size self.stride = stride self.pad = pad if self.rect: assert self.batch_size is not None self.set_rectangle() self.buffer = [] self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 # 训练增强时最多缓存 min(样本数, batch*8, 1000) 张;验证或不增强时为 0 self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni self.npy_files = [] self.cache = None if cache: LOGGER.warning(f"{self.prefix}Ground3D dataset ignores cache={cache} and loads samples on-the-fly") self.transforms = self.build_transforms(hyp=hyp) if LOCAL_RANK in {-1, 0}: LOGGER.info( f"{self.prefix}Ground3D dataset ready with {self.ni:,} samples in {perf_counter() - init_start:.1f}s" ) def _load_ground3d_label_entries(self, img_path: str | list[str]) -> list[tuple[str, str]]: """Load Ground3D GT list files and return `(list_root, rel_label)` entries. 说明返回二元组:list_root 是 list 文件所在目录,rel_label 是 list 里写的 label 相对路径""" items = img_path if isinstance(img_path, list) else [img_path] label_entries = [] log_progress = LOCAL_RANK in {-1, 0} scan_start = perf_counter() total_entries = 0 if log_progress: LOGGER.info(f"{self.prefix}Initializing Ground3D dataset from {len(items)} GT list file(s)...") try: for item in items: item = Path(item) if not item.is_file(): raise FileNotFoundError(f"{self.prefix}Ground3D dataset expects GT list files, but got: {item}") if log_progress: LOGGER.info(f"{self.prefix}Reading GT list: {item}") with open(item, encoding="utf-8") as f: file_has_entries = False for line in f: entry = line.strip() if not entry or entry.lstrip().startswith("#"): continue if Path(entry).suffix.lower() != ".txt": raise ValueError(f"{self.prefix}Ground3D GT list entries must be label .txt paths, but got: {entry}") file_has_entries = True label_entries.append((str(item.parent.resolve()), entry)) total_entries += 1 if log_progress and total_entries % 100_000 == 0: LOGGER.info( f"{self.prefix}Parsed {total_entries:,} GT entries in {perf_counter() - scan_start:.1f}s" ) if not file_has_entries: raise FileNotFoundError(f"{self.prefix}No GT entries found in {item}.") assert label_entries, f"{self.prefix}No GT entries found in {img_path}." except Exception as e: raise FileNotFoundError(f"{self.prefix}Error loading Ground3D GT list from {img_path}\n{HELP_URL}") from e if self.fraction < 1: n = round(len(label_entries) * self.fraction) label_entries = label_entries[:n] if log_progress: LOGGER.info( f"{self.prefix}Applied dataset fraction={self.fraction:g}: keeping {len(label_entries):,}/{total_entries:,} entries" ) if log_progress: LOGGER.info( f"{self.prefix}Loaded {len(label_entries):,} Ground3D GT entries from {len(items)} list file(s) " f"in {perf_counter() - scan_start:.1f}s" ) LOGGER.info(f"{self.prefix}Checking sampled image access speed...") sample_entries = random.sample(label_entries, min(5, len(label_entries))) check_file_speeds([self._entry_to_image_file(entry) for entry in sample_entries], prefix=self.prefix) return label_entries # label 路径里的 labels 会被换成 images,后缀优先 .png,找不到再尝试 .jpg。 @staticmethod def _label_rel_to_image_rel(rel_label: Path) -> Path: """Convert a relative Ground3D label path to its matching image path.""" parts = list(rel_label.parts) if "labels" in parts: parts[len(parts) - 1 - parts[::-1].index("labels")] = "images" rel_label = Path(*parts) return rel_label.with_suffix(".png") @staticmethod def _entry_to_rel_label(label_entry: tuple[str, str]) -> Path: """Return the relative label path stored in a Ground3D GT entry.""" return Path(label_entry[1]) def _entry_to_label_file(self, label_entry: tuple[str, str]) -> str: """Resolve a Ground3D GT entry to its absolute label file path.""" list_root, _ = label_entry return str((Path(list_root) / self._entry_to_rel_label(label_entry)).resolve()) @staticmethod def _label_rel_to_calib_rel(rel_label: Path) -> Path: """Convert a relative Ground3D label path to its matching calibration path.""" parts = list(rel_label.parts) if "labels" in parts: parts[len(parts) - 1 - parts[::-1].index("labels")] = "calib" rel_label = Path(*parts) return rel_label.with_suffix(".json") def _entry_to_calib_file(self, label_entry: tuple[str, str]) -> str: """Resolve a Ground3D GT entry to its absolute calibration file path under the label root.""" list_root, _ = label_entry return str((Path(list_root) / self._label_rel_to_calib_rel(self._entry_to_rel_label(label_entry))).resolve()) def _entry_to_image_file(self, label_entry: tuple[str, str]) -> str: """Resolve a Ground3D GT entry to its absolute image file path.""" rel_image = self._label_rel_to_image_rel(self._entry_to_rel_label(label_entry)) image_file = (self.image_root / rel_image).resolve() if image_file.exists(): return str(image_file) jpg_image_file = image_file.with_suffix(".jpg") if jpg_image_file.exists(): return str(jpg_image_file) return str(image_file) def update_labels(self, include_class: list[int] | None) -> None: """Record class filtering rules and apply them lazily on parsed Ground3D labels.""" self._include_class = np.array(include_class, dtype=np.int64).reshape(1, -1) if include_class is not None else None def set_rectangle(self) -> None: """Configure rectangular batching from the fixed Ground3D source shape.""" if self.ni == 0: self.batch = np.zeros(0, dtype=int) self.batch_shapes = np.zeros((0, 2), dtype=int) return bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) nb = bi[-1] + 1 ori_h, ori_w = self._ground3d_shape ar = ori_h / ori_w if ori_w else 1.0 if ar < 1: base_shape = [ar, 1.0] elif ar > 1: base_shape = [1.0, 1.0 / ar] else: base_shape = [1.0, 1.0] imgsz = np.array(self.imgsz if isinstance(self.imgsz, (list, tuple)) else [self.imgsz, self.imgsz], dtype=np.float32) shapes = np.tile(np.array(base_shape, dtype=np.float32), (nb, 1)) self.batch_shapes = np.ceil(shapes * imgsz / self.stride + self.pad).astype(int) * self.stride self.batch = bi def _filter_ground3d_labels(self, lb_2d: dict[str, Any], lb_3d: np.ndarray | None) -> tuple[dict[str, Any], np.ndarray | None]: """Apply class selection lazily to on-the-fly Ground3D labels.""" if self._include_class is not None and len(lb_2d["cls"]): keep = (lb_2d["cls"] == self._include_class).any(1) lb_2d = { **lb_2d, "cls": lb_2d["cls"][keep], "bboxes": lb_2d["bboxes"][keep], "difficulties": lb_2d["difficulties"][keep], "difficulty_levels": lb_2d["difficulty_levels"][keep], "segments": [lb_2d["segments"][si] for si, idx in enumerate(keep.tolist()) if idx] if lb_2d["segments"] else [], "keypoints": lb_2d["keypoints"][keep] if lb_2d["keypoints"] is not None else None, } if lb_3d is not None: lb_3d = lb_3d[keep] if self.single_cls and len(lb_2d["cls"]): lb_2d = {**lb_2d, "cls": lb_2d["cls"].copy()} lb_2d["cls"][:, 0] = 0 return lb_2d, lb_3d def _mark_bad_image(self, index: int, im_file: str, reason: str) -> None: """Remember bad images so later epochs skip them immediately.""" self._bad_image_mask[index] = True if index not in self._warned_bad_image_indices: LOGGER.warning(f"{self.prefix}Skipping bad Ground3D image [{index}] {im_file}: {reason}") self._warned_bad_image_indices.add(index) def _validate_ground3d_image(self, index: int, im_file: str, img: np.ndarray | None) -> np.ndarray: """Validate a decoded Ground3D image before calibration and label work.""" if self._bad_image_mask[index]: raise Ground3DImageReadError(f"{im_file}: previously marked unreadable or invalid") if img is None: self._mark_bad_image(index, im_file, "cv2.imread() returned None") raise Ground3DImageReadError(f"{im_file}: cv2.imread() returned None") if not isinstance(img, np.ndarray): self._mark_bad_image(index, im_file, f"decoder returned {type(img).__name__}, expected ndarray") raise Ground3DImageReadError(f"{im_file}: decoder returned non-array image data") shape = tuple(int(v) for v in img.shape[:2]) if self.channels == 1: valid_channels = img.ndim == 2 expected_shape = "HxW" else: valid_channels = img.ndim == 3 and img.shape[2] == self.channels expected_shape = f"HxWx{self.channels}" if not valid_channels: self._mark_bad_image(index, im_file, f"unexpected decoded shape {img.shape}, expected {expected_shape}") raise Ground3DImageReadError(f"{im_file}: unexpected decoded shape {img.shape}") if min(shape) < self.min_image_side: self._mark_bad_image(index, im_file, f"decoded image is too small: {shape}") raise Ground3DImageReadError(f"{im_file}: decoded image is too small: {shape}") if self.strict_image_shape and shape != self._ground3d_shape: self._mark_bad_image( index, im_file, f"decoded shape {shape} does not match expected {self._ground3d_shape}", ) raise Ground3DImageReadError(f"{im_file}: decoded shape {shape} does not match expected {self._ground3d_shape}") return img def _read_ground3d_image(self, index: int, im_file: str) -> np.ndarray: """Read and validate a Ground3D image, failing with a skip-safe exception.""" if self._bad_image_mask[index]: raise Ground3DImageReadError(f"{im_file}: previously marked unreadable or invalid") img = cv2.imread(im_file, self.cv2_flag) return self._validate_ground3d_image(index, im_file, img) def _load_transformed_ground3d_sample(self, index: int) -> dict[str, Any]: """Load, transform, and finalize one Ground3D sample.""" labels = self.transforms(self.get_image_and_label(index)) # Convert labels_3d: numpy/None -> tensor if labels.get("labels_3d") is not None: if isinstance(labels["labels_3d"], np.ndarray): labels["labels_3d"] = torch.from_numpy(labels["labels_3d"].copy()) else: nl = len(labels["bboxes"]) if "bboxes" in labels else 0 # Keep absent 3D supervision as NaN so downstream 3D loss/metrics skip pure 2D objects. labels["labels_3d"] = torch.full((nl, 42), float("nan"), dtype=torch.float32) if self.min_wh > 0 and "bboxes" in labels and len(labels["bboxes"]) > 0: bboxes = labels["bboxes"] _, h, w = labels["img"].shape w_px = bboxes[:, 2] * w h_px = bboxes[:, 3] * h valid = (w_px >= self.min_wh) | (h_px >= self.min_wh) if not valid.all(): labels["bboxes"] = bboxes[valid] labels["cls"] = labels["cls"][valid] if "batch_idx" in labels: labels["batch_idx"] = labels["batch_idx"][valid] if "difficulties" in labels: labels["difficulties"] = labels["difficulties"][valid] if "difficulty_levels" in labels: labels["difficulty_levels"] = labels["difficulty_levels"][valid] labels["labels_3d"] = labels["labels_3d"][valid] edge_start = perf_counter() if self.profile_timing else None self._build_edge_gt(labels) if edge_start is not None and isinstance(labels.get("_profile_timing"), dict): labels["_profile_timing"]["precompute_edge_gt"] = perf_counter() - edge_start return labels def get_image_and_label(self, index): """Load a mono3D sample through an explicit ROI or virtual-camera preprocessing pipeline.""" label_entry = self.label_entries[index] label_file = self._entry_to_label_file(label_entry) im_file = self._entry_to_image_file(label_entry) calib_file = self._entry_to_calib_file(label_entry) label = {"im_file": im_file} timings = {} if self.profile_timing else None t0 = perf_counter() img = self._read_ground3d_image(index, im_file) if timings is not None: timings["read_image"] = perf_counter() - t0 ori_h, ori_w = img.shape[:2] t0 = perf_counter() lb_2d, lb_3d = parse_ground_3d_label_file( label_file, self.class_map, self.difficulty_weights, self.face_3d_classes, self.complete_3d_classes, self.min_wh, ) lb_2d, lb_3d = self._filter_ground3d_labels(lb_2d, lb_3d) if timings is not None: timings["parse_label"] = perf_counter() - t0 has_3d_targets = _ground3d_has_finite_targets(lb_3d) t0 = perf_counter() raw_calib = read_calib_from_path(im_file, image_root=self.image_root, extra_calib_candidates=[calib_file]) if timings is not None: timings["read_calib"] = perf_counter() - t0 if raw_calib is None and has_3d_targets: expected_calib = Path(calib_file) expected_calib = ( expected_calib if expected_calib.name == "camera4.json" else expected_calib.parent / "L2_calib" / "camera4.json" ) LOGGER.error(f"{self.prefix}Missing Ground3D calibration [{index}] {im_file} (expected {expected_calib})") raise Ground3DCalibrationError( f"{im_file}: calibration file not found (expected clip-level camera file at {expected_calib})" ) target_w, target_h = self.imgsz if isinstance(self.imgsz, (list, tuple)) else (self.imgsz, self.imgsz) roi_w, roi_h = self.roi use_virtual_camera = ( raw_calib is not None and 0 < self.virtual_camera_prob <= 1 and random.random() < self.virtual_camera_prob ) use_virtual_camera_zoom = self.augment or self.virtual_camera_val_zoom roi_active = (raw_calib is not None or not has_3d_targets) and (roi_h < ori_h or roi_w < ori_w) sample_ori_shape = (ori_h, ori_w) vp_x = compute_vanishing_point_x(raw_calib, ori_w) vp_y = compute_vanishing_point_y(raw_calib, ori_h) crop_center_x = vp_x if self.crop_center_mode == "vxvy" else ori_w / 2 label["camera_mode"] = "virtual" if use_virtual_camera else "roi" if use_virtual_camera: t0 = perf_counter() simul_calib = compute_simul_calib( raw_calib, (ori_w, ori_h), (target_w, target_h), crop_center_x, vp_y, target_fx=self.virtual_fx, augment=use_virtual_camera_zoom, ) if timings is not None: timings["compute_simul_calib"] = perf_counter() - t0 labels_48 = pack_labels_to_48(lb_2d, lb_3d) t0 = perf_counter() img, labels_48 = apply_simul_transform( img, labels_48, simul_calib, raw_calib, (target_w, target_h), augment=self.augment ) if timings is not None: timings["apply_simul_transform"] = perf_counter() - t0 lb_2d, lb_3d = unpack_labels_from_48(labels_48) final_calib = { "fx": simul_calib["fx"], "fy": simul_calib["fy"], "cx": simul_calib["cx"], "cy": simul_calib["cy"], "depth_scale": simul_calib["depth_scale"], } sample_ori_shape = (target_h, target_w) else: crop_bounds = (0, 0, ori_w, ori_h) roi_crop_time = 0.0 if roi_active: t0 = perf_counter() crop_bounds = compute_centered_roi_bounds(ori_w, ori_h, roi_w, roi_h, crop_center_x, vp_y) crop_x1, crop_y1, crop_x2, crop_y2 = crop_bounds img = img[crop_y1:crop_y2, crop_x1:crop_x2] sample_ori_shape = (crop_y2 - crop_y1, crop_x2 - crop_x1) roi_crop_time = perf_counter() - t0 t0 = perf_counter() calib_for_resize = adjust_calib_for_roi_crop(raw_calib, ori_w, ori_h, crop_bounds) img = _resize_ground3d_image_in_steps(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR) final_calib = build_final_resized_calib( calib_for_resize["focal_u"], calib_for_resize["focal_v"], calib_for_resize["cu"], calib_for_resize["cv"], calib_for_resize["src_w"], calib_for_resize["src_h"], target_w, target_h, self.virtual_fx, distort_coeffs=calib_for_resize["distort_coeffs"], ) if timings is not None: timings["roi_resize"] = perf_counter() - t0 if roi_crop_time > 0: timings["roi_crop"] = roi_crop_time if roi_active: t0 = perf_counter() lb_2d, lb_3d = remap_labels_to_roi(lb_2d, lb_3d, ori_w, ori_h, crop_bounds) if timings is not None: timings["remap_labels_to_roi"] = perf_counter() - t0 t0 = perf_counter() lb_3d = normalize_roi_depth(lb_3d, final_calib["fx"], self.virtual_fx) if timings is not None: timings["normalize_roi_depth"] = perf_counter() - t0 label["ori_shape"] = sample_ori_shape label["resized_shape"] = (target_h, target_w) label["ratio_pad"] = (target_h / sample_ori_shape[0], target_w / sample_ori_shape[1]) label["img"] = img label["calib"] = final_calib if timings is not None: timings["total_sample_build"] = sum(timings.values()) label["_profile_timing"] = timings if self.rect: label["rect_shape"] = self.batch_shapes[self.batch[index]] label.update(lb_2d) label["labels_3d"] = lb_3d return self.update_labels_info(label) def build_transforms(self, hyp=None): """Build transforms: only color-space augmentations for 3D training. No geometric augmentations (Mosaic, RandomPerspective, MixUp, CutMix, RandomFlip) since 3D labels are in camera space, not image space. No LetterBox since image is already resized to target in get_image_and_label. """ transforms = [] if self.augment and hyp: transforms.append(RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v)) transforms.append( Format( bbox_format="xywh", normalize=True, batch_idx=True, bgr=hyp.bgr if self.augment and hyp else 0.0, ) ) return Compose(transforms) def update_labels_info(self, label): """Update label format, preserving labels_3d, calib, and camera mode in the dict.""" labels_3d = label.pop("labels_3d", None) calib = label.pop("calib", None) camera_mode = label.pop("camera_mode", None) label = super().update_labels_info(label) label["labels_3d"] = labels_3d label["calib"] = calib label["camera_mode"] = camera_mode return label @staticmethod def _empty_edge_gt(nl: int) -> dict[str, torch.Tensor]: """Return empty precomputed edge GT tensors aligned with the object dimension.""" return { "edge_faces_points_2d": torch.zeros((nl, 4, 5, 2), dtype=torch.float32), "edge_faces_depths": torch.zeros((nl, 4, 5), dtype=torch.float32), "edge_faces_valid": torch.zeros((nl, 4), dtype=torch.bool), "edge_partial_points_2d": torch.zeros((nl, 5, 2), dtype=torch.float32), "edge_partial_depths": torch.zeros((nl, 5), dtype=torch.float32), "edge_partial_face_type": torch.full((nl,), -1, dtype=torch.long), "edge_partial_valid": torch.zeros((nl,), dtype=torch.bool), } def _build_edge_gt(self, labels: dict[str, Any]) -> None: """Precompute GT edge supervision on CPU inside dataloader workers.""" nl = int(len(labels.get("labels_3d", []))) labels.update(self._empty_edge_gt(nl)) if not self.precompute_edge_gt or nl == 0: return calib = labels.get("calib") if not isinstance(calib, dict): return labels_3d = labels["labels_3d"] cls = labels["cls"] bboxes = labels["bboxes"] _, img_h, img_w = labels["img"].shape labels_3d_np = labels_3d.cpu().numpy() if isinstance(labels_3d, torch.Tensor) else np.asarray(labels_3d) cls_np = cls.view(-1).cpu().numpy().astype(int) if isinstance(cls, torch.Tensor) else np.asarray(cls).reshape(-1).astype(int) bboxes_np = bboxes.cpu().numpy() if isinstance(bboxes, torch.Tensor) else np.asarray(bboxes) xyxy = np.zeros((nl, 4), dtype=np.float32) xyxy[:, 0] = (bboxes_np[:, 0] - bboxes_np[:, 2] / 2) * img_w xyxy[:, 1] = (bboxes_np[:, 1] - bboxes_np[:, 3] / 2) * img_h xyxy[:, 2] = (bboxes_np[:, 0] + bboxes_np[:, 2] / 2) * img_w xyxy[:, 3] = (bboxes_np[:, 1] + bboxes_np[:, 3] / 2) * img_h for obj_idx in range(nl): target_42 = labels_3d_np[obj_idx] cls_id = int(cls_np[obj_idx]) bbox_xyxy = xyxy[obj_idx] partial_edge = decode_cut_partial_side_edge_from_gt( target_42, cls_id, calib, int(img_w), int(img_h), self.face_3d_classes, self.complete_3d_classes, bbox_xyxy=bbox_xyxy, ) if partial_edge is not None: labels["edge_partial_points_2d"][obj_idx] = torch.from_numpy(partial_edge["points_2d"]) labels["edge_partial_depths"][obj_idx] = torch.from_numpy(partial_edge["depths"]) labels["edge_partial_face_type"][obj_idx] = int(partial_edge["face_type"]) labels["edge_partial_valid"][obj_idx] = True for face_type in range(4): decoded = partial_edge if partial_edge is not None and int(partial_edge["face_type"]) == face_type else None if decoded is None: decoded = decode_visible_face_edge_from_gt( target_42, cls_id, calib, int(img_w), int(img_h), self.face_3d_classes, self.complete_3d_classes, face_type=face_type, score_thr=self.face_visibility_score_thresh, bbox_xyxy=bbox_xyxy, ) if decoded is None: continue labels["edge_faces_points_2d"][obj_idx, face_type] = torch.from_numpy(decoded["points_2d"]) labels["edge_faces_depths"][obj_idx, face_type] = torch.from_numpy(decoded["depths"]) labels["edge_faces_valid"][obj_idx, face_type] = True def __getitem__(self, index): """Return transformed label with post-augmentation min_wh filtering for both 2D and 3D.""" if not self.skip_bad_images: return self._load_transformed_ground3d_sample(int(index)) last_error = None start_index = int(index) for offset in range(self.ni): sample_index = (start_index + offset) % self.ni try: return self._load_transformed_ground3d_sample(sample_index) except Ground3DImageReadError as exc: last_error = exc continue raise RuntimeError( f"{self.prefix}Failed to load a valid Ground3D image after checking all {self.ni} samples." ) from last_error class YOLOMultiModalDataset(YOLODataset): """Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support. This class extends YOLODataset to add text information for multi-modal model training, enabling models to process both image and text data. Methods: update_labels_info: Add text information for multi-modal model training. build_transforms: Enhance data transformations with text augmentation. Examples: >>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect") >>> batch = next(iter(dataset)) >>> print(batch.keys()) # Should include 'texts' """ def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs): """Initialize a YOLOMultiModalDataset. Args: data (dict, optional): Dataset configuration dictionary. task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'. *args (Any): Additional positional arguments for the parent class. **kwargs (Any): Additional keyword arguments for the parent class. """ super().__init__(*args, data=data, task=task, **kwargs) def update_labels_info(self, label: dict) -> dict: """Add text information for multi-modal model training. Args: label (dict): Label dictionary containing bboxes, segments, keypoints, etc. Returns: (dict): Updated label dictionary with instances and texts. """ labels = super().update_labels_info(label) # NOTE: some categories are concatenated with its synonyms by `/`. # NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words. labels["texts"] = [v.split("/") for _, v in self.data["names"].items()] return labels def build_transforms(self, hyp: dict | None = None) -> Compose: """Enhance data transformations with optional text augmentation for multi-modal training. Args: hyp (dict, optional): Hyperparameters for transforms. Returns: (Compose): Composed transforms including text augmentation if applicable. """ transforms = super().build_transforms(hyp) if self.augment: # NOTE: hard-coded the args for now. # NOTE: this implementation is different from official yoloe, # the strategy of selecting negative is restricted in one dataset, # while official pre-saved neg embeddings from all datasets at once. transform = RandomLoadText( max_samples=min(self.data["nc"], 80), padding=True, padding_value=self._get_neg_texts(self.category_freq), ) transforms.insert(-1, transform) return transforms @property def category_names(self): """Return category names for the dataset. Returns: (set[str]): Set of class names. """ names = self.data["names"].values() return {n.strip() for name in names for n in name.split("/")} # category names @property def category_freq(self): """Return frequency of each category in the dataset.""" texts = [v.split("/") for v in self.data["names"].values()] category_freq = defaultdict(int) for label in self.labels: for c in label["cls"].squeeze(-1): # to check text = texts[int(c)] for t in text: t = t.strip() category_freq[t] += 1 return category_freq @staticmethod def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]: """Get negative text samples based on frequency threshold.""" threshold = min(max(category_freq.values()), 100) return [k for k, v in category_freq.items() if v >= threshold] class GroundingDataset(YOLODataset): """Dataset class for object detection tasks using annotations from a JSON file in grounding format. This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard YOLO format text files. Attributes: json_file (str): Path to the JSON file containing annotations. Methods: get_img_files: Return empty list as image files are read in get_labels. get_labels: Load annotations from a JSON file and prepare them for training. build_transforms: Configure augmentations for training with optional text loading. Examples: >>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect") >>> len(dataset) # Number of valid images with annotations """ def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs): """Initialize a GroundingDataset for object detection. Args: json_file (str): Path to the JSON file containing annotations. task (str): Must be 'detect' or 'segment' for GroundingDataset. max_samples (int): Maximum number of samples to load for text augmentation. *args (Any): Additional positional arguments for the parent class. **kwargs (Any): Additional keyword arguments for the parent class. """ assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks" self.json_file = json_file self.max_samples = max_samples super().__init__(*args, task=task, data={"channels": 3}, **kwargs) def get_img_files(self, img_path: str) -> list: """The image files would be read in `get_labels` function, return empty list here. Args: img_path (str): Path to the directory containing images. Returns: (list): Empty list as image files are read in get_labels. """ return [] def verify_labels(self, labels: list[dict[str, Any]]) -> None: """Verify the number of instances in the dataset matches expected counts. This method checks if the total number of bounding box instances in the provided labels matches the expected count for known datasets. It performs validation against a predefined set of datasets with known instance counts. Args: labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding box coordinates. Raises: AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset. Notes: For unrecognized datasets (those not in the predefined expected_counts), a warning is logged and verification is skipped. """ expected_counts = { "final_mixed_train_no_coco_segm": 3662412, "final_mixed_train_no_coco": 3681235, "final_flickr_separateGT_train_segm": 638214, "final_flickr_separateGT_train": 640704, } instance_count = sum(label["bboxes"].shape[0] for label in labels) for data_name, count in expected_counts.items(): if data_name in self.json_file: assert instance_count == count, f"'{self.json_file}' has {instance_count} instances, expected {count}." return LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'") def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]: """Load annotations from a JSON file, filter, and normalize bounding boxes for each image. Args: path (Path): Path where to save the cache file. Returns: (dict[str, Any]): Dictionary containing cached labels and related information. """ x = {"labels": []} LOGGER.info("Loading annotation file...") with open(self.json_file) as f: annotations = json.load(f) images = {f"{x['id']:d}": x for x in annotations["images"]} img_to_anns = defaultdict(list) for ann in annotations["annotations"]: img_to_anns[ann["image_id"]].append(ann) for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"): img = images[f"{img_id:d}"] h, w, f = img["height"], img["width"], img["file_name"] im_file = Path(self.img_path) / f if not im_file.exists(): continue self.im_files.append(str(im_file)) bboxes = [] segments = [] cat2id = {} texts = [] for ann in anns: if ann["iscrowd"]: continue box = np.array(ann["bbox"], dtype=np.float32) box[:2] += box[2:] / 2 box[[0, 2]] /= float(w) box[[1, 3]] /= float(h) if box[2] <= 0 or box[3] <= 0: continue caption = img["caption"] cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip() if not cat_name: continue if cat_name not in cat2id: cat2id[cat_name] = len(cat2id) texts.append([cat_name]) cls = cat2id[cat_name] # class box = [cls, *box.tolist()] if box not in bboxes: bboxes.append(box) if ann.get("segmentation") is not None: if len(ann["segmentation"]) == 0: segments.append(box) continue elif len(ann["segmentation"]) > 1: s = merge_multi_segment(ann["segmentation"]) s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist() else: s = [j for i in ann["segmentation"] for j in i] # all segments concatenated s = ( (np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32)) .reshape(-1) .tolist() ) s = [cls, *s] segments.append(s) lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32) if segments: classes = np.array([x[0] for x in segments], dtype=np.float32) segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...) lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) lb = np.array(lb, dtype=np.float32) x["labels"].append( { "im_file": im_file, "shape": (h, w), "cls": lb[:, 0:1], # n, 1 "bboxes": lb[:, 1:], # n, 4 "segments": segments, "normalized": True, "bbox_format": "xywh", "texts": texts, } ) x["hash"] = get_hash(self.json_file) save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return x def get_labels(self) -> list[dict]: """Load labels from cache or generate them from JSON file. Returns: (list[dict]): List of label dictionaries, each containing information about an image and its annotations. """ cache_path = Path(self.json_file).with_suffix(".cache") try: cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file assert cache["version"] == DATASET_CACHE_VERSION # matches current version assert cache["hash"] == get_hash(self.json_file) # identical hash except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError): cache, _ = self.cache_labels(cache_path), False # run cache ops [cache.pop(k) for k in ("hash", "version")] # remove items labels = cache["labels"] self.verify_labels(labels) self.im_files = [str(label["im_file"]) for label in labels] if LOCAL_RANK in {-1, 0}: LOGGER.info(f"Load {self.json_file} from cache file {cache_path}") return labels def build_transforms(self, hyp: dict | None = None) -> Compose: """Configure augmentations for training with optional text loading. Args: hyp (dict, optional): Hyperparameters for transforms. Returns: (Compose): Composed transforms including text augmentation if applicable. """ transforms = super().build_transforms(hyp) if self.augment: # NOTE: hard-coded the args for now. # NOTE: this implementation is different from official yoloe, # the strategy of selecting negative is restricted in one dataset, # while official pre-saved neg embeddings from all datasets at once. transform = RandomLoadText( max_samples=min(self.max_samples, 80), padding=True, padding_value=self._get_neg_texts(self.category_freq), ) transforms.insert(-1, transform) return transforms @property def category_names(self): """Return unique category names from the dataset.""" return {t.strip() for label in self.labels for text in label["texts"] for t in text} @property def category_freq(self): """Return frequency of each category in the dataset.""" category_freq = defaultdict(int) for label in self.labels: for text in label["texts"]: for t in text: t = t.strip() category_freq[t] += 1 return category_freq @staticmethod def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]: """Get negative text samples based on frequency threshold.""" threshold = min(max(category_freq.values()), 100) return [k for k, v in category_freq.items() if v >= threshold] class YOLOConcatDataset(ConcatDataset): """Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation function. Methods: collate_fn: Static method that collates data samples into batches using YOLODataset's collation function. Examples: >>> dataset1 = YOLODataset(...) >>> dataset2 = YOLODataset(...) >>> combined_dataset = YOLOConcatDataset([dataset1, dataset2]) """ @staticmethod def collate_fn(batch: list[dict]) -> dict: """Collate data samples into batches. Args: batch (list[dict]): List of dictionaries containing sample data. Returns: (dict): Collated batch with stacked tensors. """ return YOLODataset.collate_fn(batch) def close_mosaic(self, hyp: dict) -> None: """Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0. Args: hyp (dict): Hyperparameters for transforms. """ for dataset in self.datasets: if not hasattr(dataset, "close_mosaic"): continue dataset.close_mosaic(hyp) # TODO: support semantic segmentation class SemanticDataset(BaseDataset): """Semantic Segmentation Dataset.""" def __init__(self): """Initialize a SemanticDataset object.""" super().__init__() class ClassificationDataset: """Dataset class for image classification tasks wrapping torchvision ImageFolder functionality. This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep learning models, with optional image transformations and caching mechanisms to speed up training. Attributes: cache_ram (bool): Indicates if caching in RAM is enabled. cache_disk (bool): Indicates if caching on disk is enabled. samples (list): A list of lists, each containing the path to an image, its class index, path to its .npy cache file (if caching on disk), and optionally the loaded image array (if caching in RAM). torch_transforms (callable): PyTorch transforms to be applied to the images. root (str): Root directory of the dataset. prefix (str): Prefix for logging and cache filenames. Methods: __getitem__: Return transformed image and class index for the given sample index. __len__: Return the total number of samples in the dataset. verify_images: Verify all images in dataset. """ def __init__(self, root: str, args, augment: bool = False, prefix: str = ""): """Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings. Args: root (str): Path to the dataset directory where images are stored in a class-specific folder structure. args (Namespace): Configuration containing dataset-related settings such as image size, augmentation parameters, and cache settings. augment (bool, optional): Whether to apply augmentations to the dataset. prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification. """ import torchvision # scope for faster 'import ultralytics' # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18 self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True) else: self.base = torchvision.datasets.ImageFolder(root=root) self.samples = self.base.samples self.root = self.base.root # Initialize attributes if augment and args.fraction < 1.0: # reduce training fraction self.samples = self.samples[: round(len(self.samples) * args.fraction)] self.prefix = colorstr(f"{prefix}: ") if prefix else "" self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM if self.cache_ram: LOGGER.warning( "Classification `cache_ram` training has known memory leak in " "https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`." ) self.cache_ram = False self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files self.samples = self.verify_images() # filter out bad images self.samples = [[*list(x), Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) self.torch_transforms = ( classify_augmentations( size=args.imgsz, scale=scale, hflip=args.fliplr, vflip=args.flipud, erasing=args.erasing, auto_augment=args.auto_augment, hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, ) if augment else classify_transforms(size=args.imgsz) ) def __getitem__(self, i: int) -> dict: """Return transformed image and class index for the given sample index. Args: i (int): Index of the sample to retrieve. Returns: (dict): Dictionary containing the image and its class index. """ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image if self.cache_ram: if im is None: # Warning: two separate if statements required here, do not combine this with previous line im = self.samples[i][3] = cv2.imread(f) elif self.cache_disk: if not fn.exists(): # load npy np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) im = np.load(fn) else: # read image im = cv2.imread(f) # BGR # Convert NumPy array to PIL image im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) sample = self.torch_transforms(im) return {"img": sample, "cls": j} def __len__(self) -> int: """Return the total number of samples in the dataset.""" return len(self.samples) def verify_images(self) -> list[tuple]: """Verify all images in dataset. Returns: (list[tuple]): List of valid samples after verification. """ desc = f"{self.prefix}Scanning {self.root}..." path = Path(self.root).with_suffix(".cache") # *.cache file path try: check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds cache = load_dataset_cache_file(path) # attempt to load a *.cache file assert cache["version"] == DATASET_CACHE_VERSION # matches current version assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total if LOCAL_RANK in {-1, 0}: d = f"{desc} {nf} images, {nc} corrupt" TQDM(None, desc=d, total=n, initial=n) if cache["msgs"]: LOGGER.info("\n".join(cache["msgs"])) # display warnings return samples except (FileNotFoundError, AssertionError, AttributeError): # Run scan if *.cache retrieval failed nf, nc, msgs, samples, x = 0, 0, [], [], {} with ThreadPool(NUM_THREADS) as pool: results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix))) pbar = TQDM(results, desc=desc, total=len(self.samples)) for sample, nf_f, nc_f, msg in pbar: if nf_f: samples.append(sample) if msg: msgs.append(msg) nf += nf_f nc += nc_f pbar.desc = f"{desc} {nf} images, {nc} corrupt" pbar.close() if msgs: LOGGER.info("\n".join(msgs)) x["hash"] = get_hash([x[0] for x in self.samples]) x["results"] = nf, nc, len(samples), samples x["msgs"] = msgs # warnings save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) return samples