1866 lines
84 KiB
Python
Executable File
1866 lines
84 KiB
Python
Executable File
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import glob
|
||
import json
|
||
import math
|
||
import os
|
||
import random
|
||
from collections import defaultdict
|
||
from copy import deepcopy
|
||
from itertools import repeat
|
||
from multiprocessing.pool import ThreadPool
|
||
from pathlib import Path
|
||
from time import perf_counter
|
||
from typing import Any
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from PIL import Image
|
||
from torch.utils.data import ConcatDataset
|
||
|
||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr
|
||
from ultralytics.utils.instance import Instances
|
||
from ultralytics.utils.ops import resample_segments, segments2boxes
|
||
from ultralytics.utils.plotting_3d import decode_cut_partial_side_edge_from_gt, decode_visible_face_edge_from_gt
|
||
from ultralytics.utils.torch_utils import TORCHVISION_0_18
|
||
|
||
from .augment import (
|
||
Compose,
|
||
Format,
|
||
LetterBox,
|
||
RandomHSV,
|
||
RandomLoadText,
|
||
classify_augmentations,
|
||
classify_transforms,
|
||
v8_transforms,
|
||
)
|
||
from .base import BaseDataset
|
||
from .converter import merge_multi_segment
|
||
from .ground3d_augment import (
|
||
adjust_calib_for_roi_crop,
|
||
apply_simul_transform,
|
||
build_final_resized_calib,
|
||
compute_centered_roi_bounds,
|
||
compute_simul_calib,
|
||
compute_vanishing_point_x,
|
||
compute_vanishing_point_y,
|
||
normalize_roi_depth,
|
||
pack_labels_to_48,
|
||
parse_ground_3d_label_file,
|
||
read_calib_from_path,
|
||
remap_labels_to_roi,
|
||
unpack_labels_from_48,
|
||
)
|
||
from .utils import (
|
||
FORMATS_HELP_MSG,
|
||
HELP_URL,
|
||
IMG_FORMATS,
|
||
check_file_speeds,
|
||
get_hash,
|
||
img2label_paths,
|
||
load_dataset_cache_file,
|
||
save_dataset_cache_file,
|
||
verify_image,
|
||
verify_image_label,
|
||
)
|
||
|
||
# Ultralytics dataset *.cache version, >= 1.0.0 for Ultralytics YOLO models
|
||
DATASET_CACHE_VERSION = "1.0.3"
|
||
GROUND3D_REQUIRED_DATA_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")
|
||
GROUND3D_VALID_CROP_CENTER_MODES = {"cxvy", "vxvy"}
|
||
|
||
|
||
class Ground3DImageReadError(RuntimeError):
|
||
"""Raised when a Ground3D sample image is unreadable or fails validation."""
|
||
|
||
|
||
class Ground3DCalibrationError(RuntimeError):
|
||
"""Raised when a Ground3D sample is missing required calibration."""
|
||
|
||
|
||
def _resize_ground3d_image_in_steps(
|
||
img: np.ndarray, target_size: tuple[int, int], interpolation: int = cv2.INTER_LINEAR
|
||
) -> np.ndarray:
|
||
"""Resize a Ground3D image with repeated 0.5x downsampling before the final resize."""
|
||
target_w, target_h = int(target_size[0]), int(target_size[1])
|
||
current_h, current_w = img.shape[:2]
|
||
|
||
if (current_w, current_h) == (target_w, target_h):
|
||
return img
|
||
|
||
if target_w < current_w and target_h < current_h:
|
||
while True:
|
||
next_w = math.ceil(current_w * 0.5)
|
||
next_h = math.ceil(current_h * 0.5)
|
||
if next_w < target_w or next_h < target_h:
|
||
break
|
||
|
||
img = cv2.resize(img, (next_w, next_h), interpolation=interpolation)
|
||
current_h, current_w = img.shape[:2]
|
||
if (current_w, current_h) == (target_w, target_h):
|
||
return img
|
||
|
||
return cv2.resize(img, (target_w, target_h), interpolation=interpolation)
|
||
|
||
|
||
def _require_ground3d_data_value(data: dict[str, Any], key: str) -> Any:
|
||
"""Return a required Ground3D config value or raise a clear error."""
|
||
if key not in data or data[key] is None:
|
||
raise ValueError(f"Ground3D dataset config must define '{key}' explicitly; defaults are not allowed.")
|
||
return data[key]
|
||
|
||
|
||
def _validate_ground3d_pair(value: Any, key: str) -> tuple[int, int]:
|
||
"""Validate a Ground3D size-like config field and return it as an integer pair."""
|
||
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
||
raise ValueError(f"Ground3D dataset config field '{key}' must be a 2-item list/tuple, but got: {value!r}")
|
||
pair = (int(value[0]), int(value[1]))
|
||
if pair[0] <= 0 or pair[1] <= 0:
|
||
raise ValueError(f"Ground3D dataset config field '{key}' must contain positive values, but got: {value!r}")
|
||
return pair
|
||
|
||
|
||
def _validate_ground3d_crop_center_mode(value: Any) -> str:
|
||
"""Validate the Ground3D crop-center mode."""
|
||
if value not in GROUND3D_VALID_CROP_CENTER_MODES:
|
||
valid_modes = ", ".join(sorted(GROUND3D_VALID_CROP_CENTER_MODES))
|
||
raise ValueError(
|
||
f"Ground3D dataset config field 'crop_center_mode' must be one of {{{valid_modes}}}, but got: {value!r}"
|
||
)
|
||
return str(value)
|
||
|
||
|
||
def _validate_ground3d_positive_float(value: Any, key: str) -> float:
|
||
"""Validate a positive numeric Ground3D config field."""
|
||
value = float(value)
|
||
if value <= 0:
|
||
raise ValueError(f"Ground3D dataset config field '{key}' must be positive, but got: {value!r}")
|
||
return value
|
||
|
||
|
||
def _validate_ground3d_float(value: Any, key: str) -> float:
|
||
"""Validate a numeric Ground3D config field."""
|
||
try:
|
||
return float(value)
|
||
except (TypeError, ValueError) as e:
|
||
raise ValueError(f"Ground3D dataset config field '{key}' must be numeric, but got: {value!r}") from e
|
||
|
||
|
||
def _ground3d_has_finite_targets(labels_3d: np.ndarray | None) -> bool:
|
||
"""Return whether a parsed 3D label array contains any real 3D supervision."""
|
||
if labels_3d is None:
|
||
return False
|
||
labels_3d = np.asarray(labels_3d)
|
||
return labels_3d.size > 0 and np.isfinite(labels_3d).any()
|
||
|
||
|
||
class YOLODataset(BaseDataset):
|
||
"""Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||
|
||
This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
|
||
(OBB) tasks using the YOLO format.
|
||
|
||
Attributes:
|
||
use_segments (bool): Indicates if segmentation masks should be used.
|
||
use_keypoints (bool): Indicates if keypoints should be used for pose estimation.
|
||
use_obb (bool): Indicates if oriented bounding boxes should be used.
|
||
data (dict): Dataset configuration dictionary.
|
||
|
||
Methods:
|
||
cache_labels: Cache dataset labels, check images and read shapes.
|
||
get_labels: Return list of label dictionaries for YOLO training.
|
||
build_transforms: Build and append transforms to the list.
|
||
close_mosaic: Disable mosaic, copy_paste, mixup and cutmix augmentations and build transformations.
|
||
update_labels_info: Update label format for different tasks.
|
||
collate_fn: Collate data samples into batches.
|
||
|
||
Examples:
|
||
>>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
|
||
>>> dataset.get_labels()
|
||
"""
|
||
|
||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||
"""Initialize the YOLODataset.
|
||
|
||
Args:
|
||
data (dict, optional): Dataset configuration dictionary.
|
||
task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
|
||
*args (Any): Additional positional arguments for the parent class.
|
||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||
"""
|
||
self.use_segments = task == "segment"
|
||
self.use_keypoints = task == "pose"
|
||
self.use_obb = task == "obb"
|
||
self.data = data
|
||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||
super().__init__(*args, channels=self.data.get("channels", 3), use_yuv444=self.data.get("use_yuv444", False), **kwargs)
|
||
|
||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
|
||
"""Cache dataset labels, check images and read shapes.
|
||
|
||
Args:
|
||
path (Path): Path where to save the cache file.
|
||
|
||
Returns:
|
||
(dict): Dictionary containing cached labels and related information.
|
||
"""
|
||
x = {"labels": []}
|
||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||
total = len(self.im_files)
|
||
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
||
raise ValueError(
|
||
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||
)
|
||
with ThreadPool(NUM_THREADS) as pool:
|
||
results = pool.imap(
|
||
func=verify_image_label,
|
||
iterable=zip(
|
||
self.im_files,
|
||
self.label_files,
|
||
repeat(self.prefix),
|
||
repeat(self.use_keypoints),
|
||
repeat(len(self.data["names"])),
|
||
repeat(nkpt),
|
||
repeat(ndim),
|
||
repeat(self.single_cls),
|
||
),
|
||
)
|
||
pbar = TQDM(results, desc=desc, total=total)
|
||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||
nm += nm_f
|
||
nf += nf_f
|
||
ne += ne_f
|
||
nc += nc_f
|
||
if im_file:
|
||
x["labels"].append(
|
||
{
|
||
"im_file": im_file,
|
||
"shape": shape,
|
||
"cls": lb[:, 0:1], # n, 1
|
||
"bboxes": lb[:, 1:], # n, 4
|
||
"segments": segments,
|
||
"keypoints": keypoint,
|
||
"normalized": True,
|
||
"bbox_format": "xywh",
|
||
}
|
||
)
|
||
if msg:
|
||
msgs.append(msg)
|
||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||
pbar.close()
|
||
|
||
if msgs:
|
||
LOGGER.info("\n".join(msgs))
|
||
if nf == 0:
|
||
LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}")
|
||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||
x["msgs"] = msgs # warnings
|
||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||
return x
|
||
|
||
def get_labels(self) -> list[dict]:
|
||
"""Return list of label dictionaries for YOLO training.
|
||
|
||
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
|
||
|
||
Returns:
|
||
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
||
"""
|
||
self.label_files = img2label_paths(self.im_files)
|
||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||
try:
|
||
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||
|
||
# Display cache
|
||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||
if exists and LOCAL_RANK in {-1, 0}:
|
||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||
if cache["msgs"]:
|
||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||
|
||
# Read cache
|
||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||
labels = cache["labels"]
|
||
if not labels:
|
||
raise RuntimeError(
|
||
f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}"
|
||
)
|
||
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
||
|
||
# Check if the dataset is all boxes or all segments
|
||
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
||
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
||
if len_segments and len_boxes != len_segments:
|
||
LOGGER.warning(
|
||
f"Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
||
)
|
||
for lb in labels:
|
||
lb["segments"] = []
|
||
if len_cls == 0:
|
||
LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
|
||
return labels
|
||
|
||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||
"""Build and append transforms to the list.
|
||
|
||
Args:
|
||
hyp (dict, optional): Hyperparameters for transforms.
|
||
|
||
Returns:
|
||
(Compose): Composed transforms.
|
||
"""
|
||
if self.augment:
|
||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
|
||
transforms = v8_transforms(self, self.imgsz, hyp)
|
||
else:
|
||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||
transforms.append(
|
||
Format(
|
||
bbox_format="xywh",
|
||
normalize=True,
|
||
return_mask=self.use_segments,
|
||
return_keypoint=self.use_keypoints,
|
||
return_obb=self.use_obb,
|
||
batch_idx=True,
|
||
mask_ratio=hyp.mask_ratio,
|
||
mask_overlap=hyp.overlap_mask,
|
||
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
|
||
)
|
||
)
|
||
return transforms
|
||
|
||
def close_mosaic(self, hyp: dict) -> None:
|
||
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
||
|
||
Args:
|
||
hyp (dict): Hyperparameters for transforms.
|
||
"""
|
||
hyp.mosaic = 0.0
|
||
hyp.copy_paste = 0.0
|
||
hyp.mixup = 0.0
|
||
hyp.cutmix = 0.0
|
||
self.transforms = self.build_transforms(hyp)
|
||
|
||
def update_labels_info(self, label: dict) -> dict:
|
||
"""Update label format for different tasks.
|
||
|
||
Args:
|
||
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
||
|
||
Returns:
|
||
(dict): Updated label dictionary with instances.
|
||
|
||
Notes:
|
||
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||
Can also support classification and semantic segmentation by adding or removing dict keys there.
|
||
"""
|
||
bboxes = label.pop("bboxes")
|
||
segments = label.pop("segments", [])
|
||
keypoints = label.pop("keypoints", None)
|
||
bbox_format = label.pop("bbox_format")
|
||
normalized = label.pop("normalized")
|
||
|
||
# NOTE: do NOT resample oriented boxes
|
||
segment_resamples = 100 if self.use_obb else 1000
|
||
if len(segments) > 0:
|
||
# make sure segments interpolate correctly if original length is greater than segment_resamples
|
||
max_len = max(len(s) for s in segments)
|
||
segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
|
||
# list[np.array(segment_resamples, 2)] * num_samples
|
||
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
||
else:
|
||
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||
return label
|
||
|
||
@staticmethod
|
||
def collate_fn(batch: list[dict]) -> dict:
|
||
"""Collate data samples into batches.
|
||
|
||
Args:
|
||
batch (list[dict]): List of dictionaries containing sample data.
|
||
|
||
Returns:
|
||
(dict): Collated batch with stacked tensors.
|
||
"""
|
||
new_batch = {}
|
||
batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
|
||
keys = batch[0].keys()
|
||
values = list(zip(*[list(b.values()) for b in batch]))
|
||
for i, k in enumerate(keys):
|
||
value = values[i]
|
||
if k in {"img", "text_feats", "sem_masks"}:
|
||
value = torch.stack(value, 0)
|
||
elif k == "visuals":
|
||
value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
|
||
elif k == "camera_mode":
|
||
value = tuple(value)
|
||
if k in {
|
||
"masks",
|
||
"keypoints",
|
||
"bboxes",
|
||
"cls",
|
||
"segments",
|
||
"obb",
|
||
"difficulties",
|
||
"difficulty_levels",
|
||
"labels_3d",
|
||
"edge_faces_points_2d",
|
||
"edge_faces_depths",
|
||
"edge_faces_valid",
|
||
"edge_partial_points_2d",
|
||
"edge_partial_depths",
|
||
"edge_partial_face_type",
|
||
"edge_partial_valid",
|
||
}:
|
||
value = torch.cat(value, 0)
|
||
new_batch[k] = value
|
||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||
for i in range(len(new_batch["batch_idx"])):
|
||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||
return new_batch
|
||
|
||
|
||
class YOLOGroundDataset(YOLODataset):
|
||
"""Dataset class for loading ground 2D detection labels with difficulty scores.
|
||
|
||
This class extends YOLODataset to support custom annotation format with:
|
||
- String class names mapped to numeric IDs via class_map
|
||
- Difficulty scores for each bounding box
|
||
- Optional minimum box size filtering
|
||
- Optional YUV444 color space support
|
||
|
||
Attributes:
|
||
class_map (dict): Mapping from class names to class IDs.
|
||
min_wh (float): Minimum box width/height in pixels for filtering.
|
||
use_yuv444 (bool): Whether to use YUV444 color space.
|
||
|
||
Examples:
|
||
>>> data = {"class_map": {"car": 0, "pedestrian": 1}, "min_wh": 2.0}
|
||
>>> dataset = YOLOGroundDataset(img_path="path/to/images", data=data, task="detect")
|
||
"""
|
||
|
||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||
"""Initialize the YOLOGroundDataset with class mapping and difficulty support.
|
||
|
||
Args:
|
||
data (dict, optional): Dataset configuration with 'class_map' and optional 'min_wh', 'use_yuv444'.
|
||
task (str): Task type, should be 'detect' for ground dataset.
|
||
*args (Any): Additional positional arguments for the parent class.
|
||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||
"""
|
||
self.class_map = data.get("class_map", {})
|
||
self.min_wh = data.get("min_wh", 2.0)
|
||
self.use_yuv444 = data.get("use_yuv444", False)
|
||
self.difficulty_weights = data.get("difficulty_weights", [1.0])
|
||
super().__init__(*args, data=data, task=task, **kwargs)
|
||
|
||
def get_img_files(self, img_path: str | list[str]) -> list[str]:
|
||
"""Read image files while preserving source paths for label lookup.
|
||
|
||
Ground 2D train/val list files may live under one ``Detection2D_*`` root while the corresponding images should
|
||
be loaded from the YAML ``path`` root. The original list paths are retained for label lookup.
|
||
"""
|
||
try:
|
||
files = []
|
||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||
p = Path(p)
|
||
if p.is_dir():
|
||
files += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
||
elif p.is_file():
|
||
with open(p, encoding="utf-8") as t:
|
||
lines = t.read().strip().splitlines()
|
||
parent = str(p.parent) + os.sep
|
||
files += [x.replace("./", parent) if x.startswith("./") else x for x in lines]
|
||
else:
|
||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||
entries = sorted(x.replace("/", os.sep) for x in files if x.strip())
|
||
assert entries, f"{self.prefix}No entries found in {img_path}."
|
||
except Exception as e:
|
||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||
if self.fraction < 1:
|
||
entries = entries[: round(len(entries) * self.fraction)]
|
||
|
||
self.ground_label_files = []
|
||
im_files = []
|
||
for f in entries:
|
||
suffix = Path(f).suffix.lower().lstrip(".")
|
||
if suffix in IMG_FORMATS:
|
||
self.ground_label_files.append(img2label_paths([f])[0])
|
||
im_files.append(self._ground_image_path_from_label_path(f))
|
||
elif suffix == "txt":
|
||
self.ground_label_files.append(f)
|
||
im_files.append(self._ground_image_path_from_label_path(f, label_path=True))
|
||
else:
|
||
raise FileNotFoundError(f"{self.prefix}Unsupported ground list entry: {f}. {FORMATS_HELP_MSG}")
|
||
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
|
||
check_file_speeds(im_files, prefix=self.prefix)
|
||
return im_files
|
||
|
||
def _ground_image_path_from_label_path(self, im_file: str, label_path: bool = False) -> str:
|
||
"""Map a list-file path from a ``Detection2D_*``/``2Ddetection_*`` root onto the configured image root."""
|
||
image_root = Path(self.data["path"])
|
||
parts = Path(im_file).parts
|
||
detection_root_index = next(
|
||
(i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))),
|
||
None,
|
||
)
|
||
if detection_root_index is not None and detection_root_index + 1 < len(parts):
|
||
rel_parts = list(parts[detection_root_index + 1 :])
|
||
else:
|
||
rel_path = Path(im_file) if not Path(im_file).is_absolute() else Path(*parts[1:])
|
||
rel_parts = list(rel_path.parts)
|
||
|
||
if label_path:
|
||
image_path = image_root.joinpath(*rel_parts).with_suffix(".jpg")
|
||
if image_path.exists():
|
||
return str(image_path)
|
||
image_rel_parts = ["images" if part == "labels" else part for part in rel_parts]
|
||
image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg")
|
||
if image_path.exists():
|
||
return str(image_path)
|
||
for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"):
|
||
candidate = image_path.with_suffix(f".{suffix}")
|
||
if candidate.exists():
|
||
return str(candidate)
|
||
return str(image_path)
|
||
|
||
return str(image_root.joinpath(*rel_parts))
|
||
|
||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
|
||
"""Cache dataset labels with difficulty scores, check images and read shapes.
|
||
|
||
Args:
|
||
path (Path): Path where to save the cache file.
|
||
|
||
Returns:
|
||
(dict): Dictionary containing cached labels and related information.
|
||
"""
|
||
from ultralytics.data.utils import verify_image_label_ground
|
||
|
||
x = {"labels": []}
|
||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||
total = len(self.im_files)
|
||
with ThreadPool(NUM_THREADS) as pool:
|
||
results = pool.imap(
|
||
func=verify_image_label_ground,
|
||
iterable=zip(
|
||
self.im_files,
|
||
self.label_files,
|
||
repeat(self.prefix),
|
||
repeat(self.class_map),
|
||
),
|
||
)
|
||
pbar = TQDM(results, desc=desc, total=total)
|
||
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||
nm += nm_f
|
||
nf += nf_f
|
||
ne += ne_f
|
||
nc += nc_f
|
||
if im_file:
|
||
# Map raw difficulty values to loss weights using difficulty_weights lookup
|
||
dw = self.difficulty_weights
|
||
if len(lb) > 0:
|
||
raw_diff = lb[:, 5].astype(int).clip(0, len(dw) - 1)
|
||
loss_weights = np.array([dw[d] for d in raw_diff], dtype=np.float32).reshape(-1, 1)
|
||
difficulty_levels = raw_diff.astype(np.int64).reshape(-1, 1)
|
||
else:
|
||
loss_weights = np.zeros((0, 1), dtype=np.float32)
|
||
difficulty_levels = np.zeros((0, 1), dtype=np.int64)
|
||
|
||
x["labels"].append(
|
||
{
|
||
"im_file": im_file,
|
||
"shape": shape,
|
||
"cls": lb[:, 0:1], # n, 1
|
||
"bboxes": lb[:, 1:5], # n, 4
|
||
"difficulties": loss_weights, # n, 1 (pre-computed loss weights)
|
||
"difficulty_levels": difficulty_levels, # n, 1 raw class target in [0, 3]
|
||
"segments": segments,
|
||
"keypoints": None,
|
||
"normalized": True,
|
||
"bbox_format": "xywh",
|
||
}
|
||
)
|
||
if msg:
|
||
msgs.append(msg)
|
||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||
pbar.close()
|
||
|
||
if msgs:
|
||
LOGGER.info("\n".join(msgs))
|
||
if nf == 0:
|
||
LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}")
|
||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||
x["msgs"] = msgs # warnings
|
||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||
return x
|
||
|
||
def get_labels(self) -> list[dict]:
|
||
"""Return ground labels derived from the original train/val list paths."""
|
||
self.label_files = getattr(self, "ground_label_files", img2label_paths(self.im_files))
|
||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||
try:
|
||
cache, exists = load_dataset_cache_file(cache_path), True
|
||
assert cache["version"] == DATASET_CACHE_VERSION
|
||
assert cache["hash"] == get_hash(self.label_files + self.im_files)
|
||
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
||
cache, exists = self.cache_labels(cache_path), False
|
||
|
||
nf, nm, ne, nc, n = cache.pop("results")
|
||
if exists and LOCAL_RANK in {-1, 0}:
|
||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||
TQDM(None, desc=self.prefix + d, total=n, initial=n)
|
||
if cache["msgs"]:
|
||
LOGGER.info("\n".join(cache["msgs"]))
|
||
|
||
[cache.pop(k) for k in ("hash", "version", "msgs")]
|
||
labels = cache["labels"]
|
||
if not labels:
|
||
raise RuntimeError(
|
||
f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}"
|
||
)
|
||
self.im_files = [lb["im_file"] for lb in labels]
|
||
if sum(len(lb["cls"]) for lb in labels) == 0:
|
||
LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
|
||
return labels
|
||
|
||
def update_labels_info(self, label: dict) -> dict:
|
||
"""Update label format with difficulty scores.
|
||
|
||
Args:
|
||
label (dict): Label dictionary containing bboxes, difficulties, etc.
|
||
|
||
Returns:
|
||
(dict): Updated label dictionary with instances including difficulties.
|
||
"""
|
||
bboxes = label.pop("bboxes")
|
||
difficulties = label.pop("difficulties", None)
|
||
difficulty_levels = label.pop("difficulty_levels", None)
|
||
segments = label.pop("segments", [])
|
||
keypoints = label.pop("keypoints", None)
|
||
bbox_format = label.pop("bbox_format")
|
||
normalized = label.pop("normalized")
|
||
|
||
# Convert empty segments list to numpy array (required by Instances.denormalize)
|
||
if len(segments) == 0:
|
||
segments = np.zeros((0, 1000, 2), dtype=np.float32)
|
||
|
||
# Create Instances with difficulties
|
||
label["instances"] = Instances(
|
||
bboxes,
|
||
segments,
|
||
keypoints,
|
||
difficulties=difficulties,
|
||
difficulty_levels=difficulty_levels,
|
||
bbox_format=bbox_format,
|
||
normalized=normalized,
|
||
)
|
||
return label
|
||
|
||
def __getitem__(self, index: int) -> dict:
|
||
"""Return transformed label with post-augmentation min_wh filtering.
|
||
|
||
Filters out boxes whose width and height are both smaller than min_wh pixels after all augmentations
|
||
(resize, crop, mosaic, etc.) have been applied, ensuring the filter operates on final pixel sizes.
|
||
"""
|
||
labels = self.transforms(self.get_image_and_label(index))
|
||
if self.min_wh > 0 and "bboxes" in labels and len(labels["bboxes"]) > 0:
|
||
bboxes = labels["bboxes"] # normalized xywh after Format
|
||
_, h, w = labels["img"].shape # CHW
|
||
w_px = bboxes[:, 2] * w
|
||
h_px = bboxes[:, 3] * h
|
||
valid = (w_px >= self.min_wh) | (h_px >= self.min_wh)
|
||
if not valid.all():
|
||
labels["bboxes"] = bboxes[valid]
|
||
labels["cls"] = labels["cls"][valid]
|
||
if "batch_idx" in labels:
|
||
labels["batch_idx"] = labels["batch_idx"][valid]
|
||
if "difficulties" in labels:
|
||
labels["difficulties"] = labels["difficulties"][valid]
|
||
if "difficulty_levels" in labels:
|
||
labels["difficulty_levels"] = labels["difficulty_levels"][valid]
|
||
return labels
|
||
|
||
|
||
class YOLOGround3DDataset(YOLOGroundDataset):
|
||
"""Dataset for ground 3D detection backed by GT list files and on-the-fly label loading.
|
||
|
||
Extends YOLOGroundDataset to support joint 2D+3D detection with:
|
||
- On-the-fly parsing of 19-col (complete_3d) and 51-col (face_3d) label files
|
||
- A simple GT-manifest list used for dataset length and sample lookup
|
||
- Only color-space augmentations (3D labels are in camera space)
|
||
- Virtual camera / fisheye augmentation support
|
||
|
||
Attributes:
|
||
face_3d_classes (set): Class IDs with 4-face annotations.
|
||
complete_3d_classes (set): Class IDs with whole-box 3D only.
|
||
norm_scales_3d (dict): Normalization scales for 3D loss computation.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
img_path: str | list[str],
|
||
imgsz: int = 640,
|
||
cache: bool | str = False,
|
||
augment: bool = True,
|
||
hyp: dict[str, Any] = DEFAULT_CFG,
|
||
prefix: str = "",
|
||
rect: bool = False,
|
||
batch_size: int = 16,
|
||
stride: int = 32,
|
||
pad: float = 0.5,
|
||
single_cls: bool = False,
|
||
classes: list[int] | None = None,
|
||
fraction: float = 1.0,
|
||
data: dict | None = None,
|
||
task: str = "detect",
|
||
):
|
||
init_start = perf_counter()
|
||
data = data or {}
|
||
self.use_segments = task == "segment"
|
||
self.use_keypoints = task == "pose"
|
||
self.use_obb = task == "obb"
|
||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||
missing_keys = [key for key in GROUND3D_REQUIRED_DATA_KEYS if key not in data or data[key] is None]
|
||
if missing_keys:
|
||
missing = ", ".join(missing_keys)
|
||
raise ValueError(f"Ground3D dataset config must define required field(s) explicitly: {missing}.")
|
||
|
||
path_value = str(_require_ground3d_data_value(data, "path")).strip() # 检查 3D dataset 必需配置是否缺失。依赖文件顶部 GROUND3D_REQUIRED_DATA_KEYS = ("path", "ori_img_size", "roi", "virtual_fx", "virtual_camera_prob", "crop_center_mode")。
|
||
if not path_value:
|
||
raise ValueError("Ground3D dataset config field 'path' must be a non-empty string.")
|
||
|
||
normalized_ground3d_cfg = {
|
||
"path": str(Path(path_value).expanduser()), # 读取 data["path"],也就是数据根目录。
|
||
"ori_img_size": list(_validate_ground3d_pair(_require_ground3d_data_value(data, "ori_img_size"), "ori_img_size")), #
|
||
"roi": list(_validate_ground3d_pair(_require_ground3d_data_value(data, "roi"), "roi")),
|
||
"virtual_fx": _validate_ground3d_positive_float(_require_ground3d_data_value(data, "virtual_fx"), "virtual_fx"),
|
||
"virtual_camera_prob": _validate_ground3d_float(
|
||
_require_ground3d_data_value(data, "virtual_camera_prob"), "virtual_camera_prob"
|
||
),
|
||
"crop_center_mode": _validate_ground3d_crop_center_mode(_require_ground3d_data_value(data, "crop_center_mode")),
|
||
}
|
||
|
||
self.data = {**data, **normalized_ground3d_cfg}
|
||
self.class_map = data.get("class_map", {}) # class_map:字符串类别名到 class id 的映射。3D label 里类别可能是名字,需要映射成数字类别
|
||
self.min_wh = data.get("min_wh", 2.0) # min_wh:过滤过小 2D 框的阈值,默认 2 像素
|
||
self.use_yuv444 = data.get("use_yuv444", False)
|
||
self.difficulty_weights = data.get("difficulty_weights", [1.0])
|
||
self.face_3d_classes = set(data.get("face_3d_classes", []))
|
||
self.complete_3d_classes = set(data.get("complete_3d_classes", []))
|
||
self.norm_scales_3d = data.get("norm_scales_3d", {}) # 3D loss 归一化尺度,例如 depth、尺寸、offset 等归一化参数
|
||
self.image_root = Path(self.data["path"]) # 数据根目录 Path 对象,后续从 label 相对路径推导 image/calib 路径。
|
||
self.ori_img_size = self.data["ori_img_size"]
|
||
self.roi = self.data["roi"]
|
||
self.virtual_fx = self.data["virtual_fx"]
|
||
self.virtual_camera_prob = self.data["virtual_camera_prob"]
|
||
self.crop_center_mode = self.data["crop_center_mode"]
|
||
self.virtual_camera_val_zoom = data.get("virtual_camera_val_zoom", False) # 验证阶段是否也启用 virtual camera zoom 相关逻辑。默认关闭
|
||
self._ground3d_shape = (int(self.ori_img_size[1]), int(self.ori_img_size[0]))
|
||
self._include_class = None
|
||
self.face_visibility_score_thresh = float(
|
||
getattr(hyp, "face_visibility_score_thresh", DEFAULT_CFG.face_visibility_score_thresh) # 可见面分数阈值,从 hyp 里取,取不到就用 DEFAULT_CFG。后续构建 edge GT / face GT 时判断某个 face 是否有效。
|
||
)
|
||
self.precompute_edge_gt = float(getattr(hyp, "edge_aux_loss_gain", 1.0)) > 0
|
||
self.profile_timing = bool(getattr(hyp, "batch_timing", False))
|
||
self.skip_bad_images = bool(data.get("skip_bad_images", True))
|
||
self.strict_image_shape = bool(data.get("strict_image_shape", "ori_img_size" in data))
|
||
self.min_image_side = max(int(data.get("min_image_side", 10)), 1)
|
||
|
||
self.img_path = img_path
|
||
self.imgsz = imgsz
|
||
self.augment = augment
|
||
self.single_cls = single_cls
|
||
self.prefix = prefix
|
||
self.fraction = fraction
|
||
self.channels = self.data.get("channels", 3)
|
||
self.cv2_flag = cv2.IMREAD_GRAYSCALE if self.channels == 1 else cv2.IMREAD_COLOR # 根据通道数决定 OpenCV 按灰度还是彩色读取。Ground 3D 通常是 3 通道,所以一般是 cv2.IMREAD_COLOR
|
||
|
||
self.label_entries = self._load_ground3d_label_entries(self.img_path)
|
||
self.labels = self.label_entries
|
||
self.update_labels(include_class=classes)
|
||
self.ni = len(self.labels)
|
||
self._bad_image_mask = np.zeros(self.ni, dtype=bool)
|
||
self._warned_bad_image_indices: set[int] = set()
|
||
|
||
self.rect = rect
|
||
self.batch_size = batch_size
|
||
self.stride = stride
|
||
self.pad = pad
|
||
if self.rect:
|
||
assert self.batch_size is not None
|
||
self.set_rectangle()
|
||
|
||
self.buffer = []
|
||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 # 训练增强时最多缓存 min(样本数, batch*8, 1000) 张;验证或不增强时为 0
|
||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||
self.npy_files = []
|
||
self.cache = None
|
||
if cache:
|
||
LOGGER.warning(f"{self.prefix}Ground3D dataset ignores cache={cache} and loads samples on-the-fly")
|
||
|
||
self.transforms = self.build_transforms(hyp=hyp)
|
||
if LOCAL_RANK in {-1, 0}:
|
||
LOGGER.info(
|
||
f"{self.prefix}Ground3D dataset ready with {self.ni:,} samples in {perf_counter() - init_start:.1f}s"
|
||
)
|
||
|
||
def _load_ground3d_label_entries(self, img_path: str | list[str]) -> list[tuple[str, str]]:
|
||
"""Load Ground3D GT list files and return `(list_root, rel_label)` entries. 说明返回二元组:list_root 是 list 文件所在目录,rel_label 是 list 里写的 label 相对路径"""
|
||
items = img_path if isinstance(img_path, list) else [img_path]
|
||
label_entries = []
|
||
log_progress = LOCAL_RANK in {-1, 0}
|
||
scan_start = perf_counter()
|
||
total_entries = 0
|
||
|
||
if log_progress:
|
||
LOGGER.info(f"{self.prefix}Initializing Ground3D dataset from {len(items)} GT list file(s)...")
|
||
|
||
try:
|
||
for item in items:
|
||
item = Path(item)
|
||
if not item.is_file():
|
||
raise FileNotFoundError(f"{self.prefix}Ground3D dataset expects GT list files, but got: {item}")
|
||
if log_progress:
|
||
LOGGER.info(f"{self.prefix}Reading GT list: {item}")
|
||
|
||
with open(item, encoding="utf-8") as f:
|
||
file_has_entries = False
|
||
for line in f:
|
||
entry = line.strip()
|
||
if not entry or entry.lstrip().startswith("#"):
|
||
continue
|
||
if Path(entry).suffix.lower() != ".txt":
|
||
raise ValueError(f"{self.prefix}Ground3D GT list entries must be label .txt paths, but got: {entry}")
|
||
|
||
file_has_entries = True
|
||
label_entries.append((str(item.parent.resolve()), entry))
|
||
total_entries += 1
|
||
if log_progress and total_entries % 100_000 == 0:
|
||
LOGGER.info(
|
||
f"{self.prefix}Parsed {total_entries:,} GT entries in {perf_counter() - scan_start:.1f}s"
|
||
)
|
||
|
||
if not file_has_entries:
|
||
raise FileNotFoundError(f"{self.prefix}No GT entries found in {item}.")
|
||
|
||
assert label_entries, f"{self.prefix}No GT entries found in {img_path}."
|
||
except Exception as e:
|
||
raise FileNotFoundError(f"{self.prefix}Error loading Ground3D GT list from {img_path}\n{HELP_URL}") from e
|
||
|
||
if self.fraction < 1:
|
||
n = round(len(label_entries) * self.fraction)
|
||
label_entries = label_entries[:n]
|
||
if log_progress:
|
||
LOGGER.info(
|
||
f"{self.prefix}Applied dataset fraction={self.fraction:g}: keeping {len(label_entries):,}/{total_entries:,} entries"
|
||
)
|
||
|
||
if log_progress:
|
||
LOGGER.info(
|
||
f"{self.prefix}Loaded {len(label_entries):,} Ground3D GT entries from {len(items)} list file(s) "
|
||
f"in {perf_counter() - scan_start:.1f}s"
|
||
)
|
||
LOGGER.info(f"{self.prefix}Checking sampled image access speed...")
|
||
sample_entries = random.sample(label_entries, min(5, len(label_entries)))
|
||
check_file_speeds([self._entry_to_image_file(entry) for entry in sample_entries], prefix=self.prefix)
|
||
return label_entries # label 路径里的 labels 会被换成 images,后缀优先 .png,找不到再尝试 .jpg。
|
||
|
||
@staticmethod
|
||
def _label_rel_to_image_rel(rel_label: Path) -> Path:
|
||
"""Convert a relative Ground3D label path to its matching image path."""
|
||
parts = list(rel_label.parts)
|
||
if "labels" in parts:
|
||
parts[len(parts) - 1 - parts[::-1].index("labels")] = "images"
|
||
rel_label = Path(*parts)
|
||
return rel_label.with_suffix(".png")
|
||
|
||
@staticmethod
|
||
def _entry_to_rel_label(label_entry: tuple[str, str]) -> Path:
|
||
"""Return the relative label path stored in a Ground3D GT entry."""
|
||
return Path(label_entry[1])
|
||
|
||
def _entry_to_label_file(self, label_entry: tuple[str, str]) -> str:
|
||
"""Resolve a Ground3D GT entry to its absolute label file path."""
|
||
list_root, _ = label_entry
|
||
return str((Path(list_root) / self._entry_to_rel_label(label_entry)).resolve())
|
||
|
||
@staticmethod
|
||
def _label_rel_to_calib_rel(rel_label: Path) -> Path:
|
||
"""Convert a relative Ground3D label path to its matching calibration path."""
|
||
parts = list(rel_label.parts)
|
||
if "labels" in parts:
|
||
parts[len(parts) - 1 - parts[::-1].index("labels")] = "calib"
|
||
rel_label = Path(*parts)
|
||
return rel_label.with_suffix(".json")
|
||
|
||
def _entry_to_calib_file(self, label_entry: tuple[str, str]) -> str:
|
||
"""Resolve a Ground3D GT entry to its absolute calibration file path under the label root."""
|
||
list_root, _ = label_entry
|
||
return str((Path(list_root) / self._label_rel_to_calib_rel(self._entry_to_rel_label(label_entry))).resolve())
|
||
|
||
def _entry_to_image_file(self, label_entry: tuple[str, str]) -> str:
|
||
"""Resolve a Ground3D GT entry to its absolute image file path."""
|
||
rel_image = self._label_rel_to_image_rel(self._entry_to_rel_label(label_entry))
|
||
image_file = (self.image_root / rel_image).resolve()
|
||
if image_file.exists():
|
||
return str(image_file)
|
||
|
||
jpg_image_file = image_file.with_suffix(".jpg")
|
||
if jpg_image_file.exists():
|
||
return str(jpg_image_file)
|
||
|
||
return str(image_file)
|
||
|
||
def update_labels(self, include_class: list[int] | None) -> None:
|
||
"""Record class filtering rules and apply them lazily on parsed Ground3D labels."""
|
||
self._include_class = np.array(include_class, dtype=np.int64).reshape(1, -1) if include_class is not None else None
|
||
|
||
def set_rectangle(self) -> None:
|
||
"""Configure rectangular batching from the fixed Ground3D source shape."""
|
||
if self.ni == 0:
|
||
self.batch = np.zeros(0, dtype=int)
|
||
self.batch_shapes = np.zeros((0, 2), dtype=int)
|
||
return
|
||
|
||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)
|
||
nb = bi[-1] + 1
|
||
ori_h, ori_w = self._ground3d_shape
|
||
ar = ori_h / ori_w if ori_w else 1.0
|
||
|
||
if ar < 1:
|
||
base_shape = [ar, 1.0]
|
||
elif ar > 1:
|
||
base_shape = [1.0, 1.0 / ar]
|
||
else:
|
||
base_shape = [1.0, 1.0]
|
||
|
||
imgsz = np.array(self.imgsz if isinstance(self.imgsz, (list, tuple)) else [self.imgsz, self.imgsz], dtype=np.float32)
|
||
shapes = np.tile(np.array(base_shape, dtype=np.float32), (nb, 1))
|
||
self.batch_shapes = np.ceil(shapes * imgsz / self.stride + self.pad).astype(int) * self.stride
|
||
self.batch = bi
|
||
|
||
def _filter_ground3d_labels(self, lb_2d: dict[str, Any], lb_3d: np.ndarray | None) -> tuple[dict[str, Any], np.ndarray | None]:
|
||
"""Apply class selection lazily to on-the-fly Ground3D labels."""
|
||
if self._include_class is not None and len(lb_2d["cls"]):
|
||
keep = (lb_2d["cls"] == self._include_class).any(1)
|
||
lb_2d = {
|
||
**lb_2d,
|
||
"cls": lb_2d["cls"][keep],
|
||
"bboxes": lb_2d["bboxes"][keep],
|
||
"difficulties": lb_2d["difficulties"][keep],
|
||
"difficulty_levels": lb_2d["difficulty_levels"][keep],
|
||
"segments": [lb_2d["segments"][si] for si, idx in enumerate(keep.tolist()) if idx] if lb_2d["segments"] else [],
|
||
"keypoints": lb_2d["keypoints"][keep] if lb_2d["keypoints"] is not None else None,
|
||
}
|
||
if lb_3d is not None:
|
||
lb_3d = lb_3d[keep]
|
||
|
||
if self.single_cls and len(lb_2d["cls"]):
|
||
lb_2d = {**lb_2d, "cls": lb_2d["cls"].copy()}
|
||
lb_2d["cls"][:, 0] = 0
|
||
|
||
return lb_2d, lb_3d
|
||
|
||
def _mark_bad_image(self, index: int, im_file: str, reason: str) -> None:
|
||
"""Remember bad images so later epochs skip them immediately."""
|
||
self._bad_image_mask[index] = True
|
||
if index not in self._warned_bad_image_indices:
|
||
LOGGER.warning(f"{self.prefix}Skipping bad Ground3D image [{index}] {im_file}: {reason}")
|
||
self._warned_bad_image_indices.add(index)
|
||
|
||
def _validate_ground3d_image(self, index: int, im_file: str, img: np.ndarray | None) -> np.ndarray:
|
||
"""Validate a decoded Ground3D image before calibration and label work."""
|
||
if self._bad_image_mask[index]:
|
||
raise Ground3DImageReadError(f"{im_file}: previously marked unreadable or invalid")
|
||
if img is None:
|
||
self._mark_bad_image(index, im_file, "cv2.imread() returned None")
|
||
raise Ground3DImageReadError(f"{im_file}: cv2.imread() returned None")
|
||
if not isinstance(img, np.ndarray):
|
||
self._mark_bad_image(index, im_file, f"decoder returned {type(img).__name__}, expected ndarray")
|
||
raise Ground3DImageReadError(f"{im_file}: decoder returned non-array image data")
|
||
|
||
shape = tuple(int(v) for v in img.shape[:2])
|
||
if self.channels == 1:
|
||
valid_channels = img.ndim == 2
|
||
expected_shape = "HxW"
|
||
else:
|
||
valid_channels = img.ndim == 3 and img.shape[2] == self.channels
|
||
expected_shape = f"HxWx{self.channels}"
|
||
if not valid_channels:
|
||
self._mark_bad_image(index, im_file, f"unexpected decoded shape {img.shape}, expected {expected_shape}")
|
||
raise Ground3DImageReadError(f"{im_file}: unexpected decoded shape {img.shape}")
|
||
if min(shape) < self.min_image_side:
|
||
self._mark_bad_image(index, im_file, f"decoded image is too small: {shape}")
|
||
raise Ground3DImageReadError(f"{im_file}: decoded image is too small: {shape}")
|
||
if self.strict_image_shape and shape != self._ground3d_shape:
|
||
self._mark_bad_image(
|
||
index,
|
||
im_file,
|
||
f"decoded shape {shape} does not match expected {self._ground3d_shape}",
|
||
)
|
||
raise Ground3DImageReadError(f"{im_file}: decoded shape {shape} does not match expected {self._ground3d_shape}")
|
||
return img
|
||
|
||
def _read_ground3d_image(self, index: int, im_file: str) -> np.ndarray:
|
||
"""Read and validate a Ground3D image, failing with a skip-safe exception."""
|
||
if self._bad_image_mask[index]:
|
||
raise Ground3DImageReadError(f"{im_file}: previously marked unreadable or invalid")
|
||
img = cv2.imread(im_file, self.cv2_flag)
|
||
return self._validate_ground3d_image(index, im_file, img)
|
||
|
||
def _load_transformed_ground3d_sample(self, index: int) -> dict[str, Any]:
|
||
"""Load, transform, and finalize one Ground3D sample."""
|
||
labels = self.transforms(self.get_image_and_label(index))
|
||
|
||
# Convert labels_3d: numpy/None -> tensor
|
||
if labels.get("labels_3d") is not None:
|
||
if isinstance(labels["labels_3d"], np.ndarray):
|
||
labels["labels_3d"] = torch.from_numpy(labels["labels_3d"].copy())
|
||
else:
|
||
nl = len(labels["bboxes"]) if "bboxes" in labels else 0
|
||
# Keep absent 3D supervision as NaN so downstream 3D loss/metrics skip pure 2D objects.
|
||
labels["labels_3d"] = torch.full((nl, 42), float("nan"), dtype=torch.float32)
|
||
|
||
if self.min_wh > 0 and "bboxes" in labels and len(labels["bboxes"]) > 0:
|
||
bboxes = labels["bboxes"]
|
||
_, h, w = labels["img"].shape
|
||
w_px = bboxes[:, 2] * w
|
||
h_px = bboxes[:, 3] * h
|
||
valid = (w_px >= self.min_wh) | (h_px >= self.min_wh)
|
||
if not valid.all():
|
||
labels["bboxes"] = bboxes[valid]
|
||
labels["cls"] = labels["cls"][valid]
|
||
if "batch_idx" in labels:
|
||
labels["batch_idx"] = labels["batch_idx"][valid]
|
||
if "difficulties" in labels:
|
||
labels["difficulties"] = labels["difficulties"][valid]
|
||
if "difficulty_levels" in labels:
|
||
labels["difficulty_levels"] = labels["difficulty_levels"][valid]
|
||
labels["labels_3d"] = labels["labels_3d"][valid]
|
||
edge_start = perf_counter() if self.profile_timing else None
|
||
self._build_edge_gt(labels)
|
||
if edge_start is not None and isinstance(labels.get("_profile_timing"), dict):
|
||
labels["_profile_timing"]["precompute_edge_gt"] = perf_counter() - edge_start
|
||
return labels
|
||
|
||
def get_image_and_label(self, index):
|
||
"""Load a mono3D sample through an explicit ROI or virtual-camera preprocessing pipeline."""
|
||
label_entry = self.label_entries[index]
|
||
label_file = self._entry_to_label_file(label_entry)
|
||
im_file = self._entry_to_image_file(label_entry)
|
||
calib_file = self._entry_to_calib_file(label_entry)
|
||
label = {"im_file": im_file}
|
||
timings = {} if self.profile_timing else None
|
||
|
||
t0 = perf_counter()
|
||
img = self._read_ground3d_image(index, im_file)
|
||
if timings is not None:
|
||
timings["read_image"] = perf_counter() - t0
|
||
|
||
ori_h, ori_w = img.shape[:2]
|
||
t0 = perf_counter()
|
||
lb_2d, lb_3d = parse_ground_3d_label_file(
|
||
label_file,
|
||
self.class_map,
|
||
self.difficulty_weights,
|
||
self.face_3d_classes,
|
||
self.complete_3d_classes,
|
||
self.min_wh,
|
||
)
|
||
lb_2d, lb_3d = self._filter_ground3d_labels(lb_2d, lb_3d)
|
||
if timings is not None:
|
||
timings["parse_label"] = perf_counter() - t0
|
||
has_3d_targets = _ground3d_has_finite_targets(lb_3d)
|
||
t0 = perf_counter()
|
||
raw_calib = read_calib_from_path(im_file, image_root=self.image_root, extra_calib_candidates=[calib_file])
|
||
if timings is not None:
|
||
timings["read_calib"] = perf_counter() - t0
|
||
if raw_calib is None and has_3d_targets:
|
||
expected_calib = Path(calib_file)
|
||
expected_calib = (
|
||
expected_calib if expected_calib.name == "camera4.json" else expected_calib.parent / "L2_calib" / "camera4.json"
|
||
)
|
||
LOGGER.error(f"{self.prefix}Missing Ground3D calibration [{index}] {im_file} (expected {expected_calib})")
|
||
raise Ground3DCalibrationError(
|
||
f"{im_file}: calibration file not found (expected clip-level camera file at {expected_calib})"
|
||
)
|
||
|
||
target_w, target_h = self.imgsz if isinstance(self.imgsz, (list, tuple)) else (self.imgsz, self.imgsz)
|
||
roi_w, roi_h = self.roi
|
||
use_virtual_camera = (
|
||
raw_calib is not None and 0 < self.virtual_camera_prob <= 1 and random.random() < self.virtual_camera_prob
|
||
)
|
||
use_virtual_camera_zoom = self.augment or self.virtual_camera_val_zoom
|
||
|
||
roi_active = (raw_calib is not None or not has_3d_targets) and (roi_h < ori_h or roi_w < ori_w)
|
||
sample_ori_shape = (ori_h, ori_w)
|
||
|
||
vp_x = compute_vanishing_point_x(raw_calib, ori_w)
|
||
vp_y = compute_vanishing_point_y(raw_calib, ori_h)
|
||
crop_center_x = vp_x if self.crop_center_mode == "vxvy" else ori_w / 2
|
||
|
||
label["camera_mode"] = "virtual" if use_virtual_camera else "roi"
|
||
|
||
if use_virtual_camera:
|
||
t0 = perf_counter()
|
||
simul_calib = compute_simul_calib(
|
||
raw_calib,
|
||
(ori_w, ori_h),
|
||
(target_w, target_h),
|
||
crop_center_x,
|
||
vp_y,
|
||
target_fx=self.virtual_fx,
|
||
augment=use_virtual_camera_zoom,
|
||
)
|
||
if timings is not None:
|
||
timings["compute_simul_calib"] = perf_counter() - t0
|
||
labels_48 = pack_labels_to_48(lb_2d, lb_3d)
|
||
t0 = perf_counter()
|
||
img, labels_48 = apply_simul_transform(
|
||
img, labels_48, simul_calib, raw_calib, (target_w, target_h), augment=self.augment
|
||
)
|
||
if timings is not None:
|
||
timings["apply_simul_transform"] = perf_counter() - t0
|
||
lb_2d, lb_3d = unpack_labels_from_48(labels_48)
|
||
final_calib = {
|
||
"fx": simul_calib["fx"],
|
||
"fy": simul_calib["fy"],
|
||
"cx": simul_calib["cx"],
|
||
"cy": simul_calib["cy"],
|
||
"depth_scale": simul_calib["depth_scale"],
|
||
}
|
||
sample_ori_shape = (target_h, target_w)
|
||
else:
|
||
crop_bounds = (0, 0, ori_w, ori_h)
|
||
roi_crop_time = 0.0
|
||
if roi_active:
|
||
t0 = perf_counter()
|
||
crop_bounds = compute_centered_roi_bounds(ori_w, ori_h, roi_w, roi_h, crop_center_x, vp_y)
|
||
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bounds
|
||
img = img[crop_y1:crop_y2, crop_x1:crop_x2]
|
||
sample_ori_shape = (crop_y2 - crop_y1, crop_x2 - crop_x1)
|
||
roi_crop_time = perf_counter() - t0
|
||
|
||
t0 = perf_counter()
|
||
calib_for_resize = adjust_calib_for_roi_crop(raw_calib, ori_w, ori_h, crop_bounds)
|
||
img = _resize_ground3d_image_in_steps(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
||
final_calib = build_final_resized_calib(
|
||
calib_for_resize["focal_u"],
|
||
calib_for_resize["focal_v"],
|
||
calib_for_resize["cu"],
|
||
calib_for_resize["cv"],
|
||
calib_for_resize["src_w"],
|
||
calib_for_resize["src_h"],
|
||
target_w,
|
||
target_h,
|
||
self.virtual_fx,
|
||
distort_coeffs=calib_for_resize["distort_coeffs"],
|
||
)
|
||
if timings is not None:
|
||
timings["roi_resize"] = perf_counter() - t0
|
||
if roi_crop_time > 0:
|
||
timings["roi_crop"] = roi_crop_time
|
||
if roi_active:
|
||
t0 = perf_counter()
|
||
lb_2d, lb_3d = remap_labels_to_roi(lb_2d, lb_3d, ori_w, ori_h, crop_bounds)
|
||
if timings is not None:
|
||
timings["remap_labels_to_roi"] = perf_counter() - t0
|
||
t0 = perf_counter()
|
||
lb_3d = normalize_roi_depth(lb_3d, final_calib["fx"], self.virtual_fx)
|
||
if timings is not None:
|
||
timings["normalize_roi_depth"] = perf_counter() - t0
|
||
|
||
label["ori_shape"] = sample_ori_shape
|
||
label["resized_shape"] = (target_h, target_w)
|
||
label["ratio_pad"] = (target_h / sample_ori_shape[0], target_w / sample_ori_shape[1])
|
||
label["img"] = img
|
||
label["calib"] = final_calib
|
||
if timings is not None:
|
||
timings["total_sample_build"] = sum(timings.values())
|
||
label["_profile_timing"] = timings
|
||
if self.rect:
|
||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||
|
||
label.update(lb_2d)
|
||
label["labels_3d"] = lb_3d
|
||
return self.update_labels_info(label)
|
||
|
||
def build_transforms(self, hyp=None):
|
||
"""Build transforms: only color-space augmentations for 3D training.
|
||
|
||
No geometric augmentations (Mosaic, RandomPerspective, MixUp, CutMix, RandomFlip)
|
||
since 3D labels are in camera space, not image space.
|
||
No LetterBox since image is already resized to target in get_image_and_label.
|
||
"""
|
||
transforms = []
|
||
if self.augment and hyp:
|
||
transforms.append(RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v))
|
||
transforms.append(
|
||
Format(
|
||
bbox_format="xywh",
|
||
normalize=True,
|
||
batch_idx=True,
|
||
bgr=hyp.bgr if self.augment and hyp else 0.0,
|
||
)
|
||
)
|
||
return Compose(transforms)
|
||
|
||
def update_labels_info(self, label):
|
||
"""Update label format, preserving labels_3d, calib, and camera mode in the dict."""
|
||
labels_3d = label.pop("labels_3d", None)
|
||
calib = label.pop("calib", None)
|
||
camera_mode = label.pop("camera_mode", None)
|
||
label = super().update_labels_info(label)
|
||
label["labels_3d"] = labels_3d
|
||
label["calib"] = calib
|
||
label["camera_mode"] = camera_mode
|
||
return label
|
||
|
||
@staticmethod
|
||
def _empty_edge_gt(nl: int) -> dict[str, torch.Tensor]:
|
||
"""Return empty precomputed edge GT tensors aligned with the object dimension."""
|
||
return {
|
||
"edge_faces_points_2d": torch.zeros((nl, 4, 5, 2), dtype=torch.float32),
|
||
"edge_faces_depths": torch.zeros((nl, 4, 5), dtype=torch.float32),
|
||
"edge_faces_valid": torch.zeros((nl, 4), dtype=torch.bool),
|
||
"edge_partial_points_2d": torch.zeros((nl, 5, 2), dtype=torch.float32),
|
||
"edge_partial_depths": torch.zeros((nl, 5), dtype=torch.float32),
|
||
"edge_partial_face_type": torch.full((nl,), -1, dtype=torch.long),
|
||
"edge_partial_valid": torch.zeros((nl,), dtype=torch.bool),
|
||
}
|
||
|
||
def _build_edge_gt(self, labels: dict[str, Any]) -> None:
|
||
"""Precompute GT edge supervision on CPU inside dataloader workers."""
|
||
nl = int(len(labels.get("labels_3d", [])))
|
||
labels.update(self._empty_edge_gt(nl))
|
||
if not self.precompute_edge_gt or nl == 0:
|
||
return
|
||
|
||
calib = labels.get("calib")
|
||
if not isinstance(calib, dict):
|
||
return
|
||
|
||
labels_3d = labels["labels_3d"]
|
||
cls = labels["cls"]
|
||
bboxes = labels["bboxes"]
|
||
_, img_h, img_w = labels["img"].shape
|
||
|
||
labels_3d_np = labels_3d.cpu().numpy() if isinstance(labels_3d, torch.Tensor) else np.asarray(labels_3d)
|
||
cls_np = cls.view(-1).cpu().numpy().astype(int) if isinstance(cls, torch.Tensor) else np.asarray(cls).reshape(-1).astype(int)
|
||
bboxes_np = bboxes.cpu().numpy() if isinstance(bboxes, torch.Tensor) else np.asarray(bboxes)
|
||
|
||
xyxy = np.zeros((nl, 4), dtype=np.float32)
|
||
xyxy[:, 0] = (bboxes_np[:, 0] - bboxes_np[:, 2] / 2) * img_w
|
||
xyxy[:, 1] = (bboxes_np[:, 1] - bboxes_np[:, 3] / 2) * img_h
|
||
xyxy[:, 2] = (bboxes_np[:, 0] + bboxes_np[:, 2] / 2) * img_w
|
||
xyxy[:, 3] = (bboxes_np[:, 1] + bboxes_np[:, 3] / 2) * img_h
|
||
|
||
for obj_idx in range(nl):
|
||
target_42 = labels_3d_np[obj_idx]
|
||
cls_id = int(cls_np[obj_idx])
|
||
bbox_xyxy = xyxy[obj_idx]
|
||
|
||
partial_edge = decode_cut_partial_side_edge_from_gt(
|
||
target_42,
|
||
cls_id,
|
||
calib,
|
||
int(img_w),
|
||
int(img_h),
|
||
self.face_3d_classes,
|
||
self.complete_3d_classes,
|
||
bbox_xyxy=bbox_xyxy,
|
||
)
|
||
if partial_edge is not None:
|
||
labels["edge_partial_points_2d"][obj_idx] = torch.from_numpy(partial_edge["points_2d"])
|
||
labels["edge_partial_depths"][obj_idx] = torch.from_numpy(partial_edge["depths"])
|
||
labels["edge_partial_face_type"][obj_idx] = int(partial_edge["face_type"])
|
||
labels["edge_partial_valid"][obj_idx] = True
|
||
|
||
for face_type in range(4):
|
||
decoded = partial_edge if partial_edge is not None and int(partial_edge["face_type"]) == face_type else None
|
||
if decoded is None:
|
||
decoded = decode_visible_face_edge_from_gt(
|
||
target_42,
|
||
cls_id,
|
||
calib,
|
||
int(img_w),
|
||
int(img_h),
|
||
self.face_3d_classes,
|
||
self.complete_3d_classes,
|
||
face_type=face_type,
|
||
score_thr=self.face_visibility_score_thresh,
|
||
bbox_xyxy=bbox_xyxy,
|
||
)
|
||
if decoded is None:
|
||
continue
|
||
labels["edge_faces_points_2d"][obj_idx, face_type] = torch.from_numpy(decoded["points_2d"])
|
||
labels["edge_faces_depths"][obj_idx, face_type] = torch.from_numpy(decoded["depths"])
|
||
labels["edge_faces_valid"][obj_idx, face_type] = True
|
||
|
||
def __getitem__(self, index):
|
||
"""Return transformed label with post-augmentation min_wh filtering for both 2D and 3D."""
|
||
if not self.skip_bad_images:
|
||
return self._load_transformed_ground3d_sample(int(index))
|
||
|
||
last_error = None
|
||
start_index = int(index)
|
||
for offset in range(self.ni):
|
||
sample_index = (start_index + offset) % self.ni
|
||
try:
|
||
return self._load_transformed_ground3d_sample(sample_index)
|
||
except Ground3DImageReadError as exc:
|
||
last_error = exc
|
||
continue
|
||
|
||
raise RuntimeError(
|
||
f"{self.prefix}Failed to load a valid Ground3D image after checking all {self.ni} samples."
|
||
) from last_error
|
||
|
||
|
||
class YOLOMultiModalDataset(YOLODataset):
|
||
"""Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
|
||
|
||
This class extends YOLODataset to add text information for multi-modal model training, enabling models to process
|
||
both image and text data.
|
||
|
||
Methods:
|
||
update_labels_info: Add text information for multi-modal model training.
|
||
build_transforms: Enhance data transformations with text augmentation.
|
||
|
||
Examples:
|
||
>>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
|
||
>>> batch = next(iter(dataset))
|
||
>>> print(batch.keys()) # Should include 'texts'
|
||
"""
|
||
|
||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||
"""Initialize a YOLOMultiModalDataset.
|
||
|
||
Args:
|
||
data (dict, optional): Dataset configuration dictionary.
|
||
task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
|
||
*args (Any): Additional positional arguments for the parent class.
|
||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||
"""
|
||
super().__init__(*args, data=data, task=task, **kwargs)
|
||
|
||
def update_labels_info(self, label: dict) -> dict:
|
||
"""Add text information for multi-modal model training.
|
||
|
||
Args:
|
||
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
||
|
||
Returns:
|
||
(dict): Updated label dictionary with instances and texts.
|
||
"""
|
||
labels = super().update_labels_info(label)
|
||
# NOTE: some categories are concatenated with its synonyms by `/`.
|
||
# NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
|
||
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
||
|
||
return labels
|
||
|
||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||
"""Enhance data transformations with optional text augmentation for multi-modal training.
|
||
|
||
Args:
|
||
hyp (dict, optional): Hyperparameters for transforms.
|
||
|
||
Returns:
|
||
(Compose): Composed transforms including text augmentation if applicable.
|
||
"""
|
||
transforms = super().build_transforms(hyp)
|
||
if self.augment:
|
||
# NOTE: hard-coded the args for now.
|
||
# NOTE: this implementation is different from official yoloe,
|
||
# the strategy of selecting negative is restricted in one dataset,
|
||
# while official pre-saved neg embeddings from all datasets at once.
|
||
transform = RandomLoadText(
|
||
max_samples=min(self.data["nc"], 80),
|
||
padding=True,
|
||
padding_value=self._get_neg_texts(self.category_freq),
|
||
)
|
||
transforms.insert(-1, transform)
|
||
return transforms
|
||
|
||
@property
|
||
def category_names(self):
|
||
"""Return category names for the dataset.
|
||
|
||
Returns:
|
||
(set[str]): Set of class names.
|
||
"""
|
||
names = self.data["names"].values()
|
||
return {n.strip() for name in names for n in name.split("/")} # category names
|
||
|
||
@property
|
||
def category_freq(self):
|
||
"""Return frequency of each category in the dataset."""
|
||
texts = [v.split("/") for v in self.data["names"].values()]
|
||
category_freq = defaultdict(int)
|
||
for label in self.labels:
|
||
for c in label["cls"].squeeze(-1): # to check
|
||
text = texts[int(c)]
|
||
for t in text:
|
||
t = t.strip()
|
||
category_freq[t] += 1
|
||
return category_freq
|
||
|
||
@staticmethod
|
||
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
|
||
"""Get negative text samples based on frequency threshold."""
|
||
threshold = min(max(category_freq.values()), 100)
|
||
return [k for k, v in category_freq.items() if v >= threshold]
|
||
|
||
|
||
class GroundingDataset(YOLODataset):
|
||
"""Dataset class for object detection tasks using annotations from a JSON file in grounding format.
|
||
|
||
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard
|
||
YOLO format text files.
|
||
|
||
Attributes:
|
||
json_file (str): Path to the JSON file containing annotations.
|
||
|
||
Methods:
|
||
get_img_files: Return empty list as image files are read in get_labels.
|
||
get_labels: Load annotations from a JSON file and prepare them for training.
|
||
build_transforms: Configure augmentations for training with optional text loading.
|
||
|
||
Examples:
|
||
>>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
|
||
>>> len(dataset) # Number of valid images with annotations
|
||
"""
|
||
|
||
def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
|
||
"""Initialize a GroundingDataset for object detection.
|
||
|
||
Args:
|
||
json_file (str): Path to the JSON file containing annotations.
|
||
task (str): Must be 'detect' or 'segment' for GroundingDataset.
|
||
max_samples (int): Maximum number of samples to load for text augmentation.
|
||
*args (Any): Additional positional arguments for the parent class.
|
||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||
"""
|
||
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
|
||
self.json_file = json_file
|
||
self.max_samples = max_samples
|
||
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
|
||
|
||
def get_img_files(self, img_path: str) -> list:
|
||
"""The image files would be read in `get_labels` function, return empty list here.
|
||
|
||
Args:
|
||
img_path (str): Path to the directory containing images.
|
||
|
||
Returns:
|
||
(list): Empty list as image files are read in get_labels.
|
||
"""
|
||
return []
|
||
|
||
def verify_labels(self, labels: list[dict[str, Any]]) -> None:
|
||
"""Verify the number of instances in the dataset matches expected counts.
|
||
|
||
This method checks if the total number of bounding box instances in the provided labels matches the expected
|
||
count for known datasets. It performs validation against a predefined set of datasets with known instance
|
||
counts.
|
||
|
||
Args:
|
||
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset
|
||
annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding
|
||
box coordinates.
|
||
|
||
Raises:
|
||
AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset.
|
||
|
||
Notes:
|
||
For unrecognized datasets (those not in the predefined expected_counts),
|
||
a warning is logged and verification is skipped.
|
||
"""
|
||
expected_counts = {
|
||
"final_mixed_train_no_coco_segm": 3662412,
|
||
"final_mixed_train_no_coco": 3681235,
|
||
"final_flickr_separateGT_train_segm": 638214,
|
||
"final_flickr_separateGT_train": 640704,
|
||
}
|
||
|
||
instance_count = sum(label["bboxes"].shape[0] for label in labels)
|
||
for data_name, count in expected_counts.items():
|
||
if data_name in self.json_file:
|
||
assert instance_count == count, f"'{self.json_file}' has {instance_count} instances, expected {count}."
|
||
return
|
||
LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
|
||
|
||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
|
||
"""Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
|
||
|
||
Args:
|
||
path (Path): Path where to save the cache file.
|
||
|
||
Returns:
|
||
(dict[str, Any]): Dictionary containing cached labels and related information.
|
||
"""
|
||
x = {"labels": []}
|
||
LOGGER.info("Loading annotation file...")
|
||
with open(self.json_file) as f:
|
||
annotations = json.load(f)
|
||
images = {f"{x['id']:d}": x for x in annotations["images"]}
|
||
img_to_anns = defaultdict(list)
|
||
for ann in annotations["annotations"]:
|
||
img_to_anns[ann["image_id"]].append(ann)
|
||
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
|
||
img = images[f"{img_id:d}"]
|
||
h, w, f = img["height"], img["width"], img["file_name"]
|
||
im_file = Path(self.img_path) / f
|
||
if not im_file.exists():
|
||
continue
|
||
self.im_files.append(str(im_file))
|
||
bboxes = []
|
||
segments = []
|
||
cat2id = {}
|
||
texts = []
|
||
for ann in anns:
|
||
if ann["iscrowd"]:
|
||
continue
|
||
box = np.array(ann["bbox"], dtype=np.float32)
|
||
box[:2] += box[2:] / 2
|
||
box[[0, 2]] /= float(w)
|
||
box[[1, 3]] /= float(h)
|
||
if box[2] <= 0 or box[3] <= 0:
|
||
continue
|
||
|
||
caption = img["caption"]
|
||
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
|
||
if not cat_name:
|
||
continue
|
||
|
||
if cat_name not in cat2id:
|
||
cat2id[cat_name] = len(cat2id)
|
||
texts.append([cat_name])
|
||
cls = cat2id[cat_name] # class
|
||
box = [cls, *box.tolist()]
|
||
if box not in bboxes:
|
||
bboxes.append(box)
|
||
if ann.get("segmentation") is not None:
|
||
if len(ann["segmentation"]) == 0:
|
||
segments.append(box)
|
||
continue
|
||
elif len(ann["segmentation"]) > 1:
|
||
s = merge_multi_segment(ann["segmentation"])
|
||
s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
|
||
else:
|
||
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
||
s = (
|
||
(np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
|
||
.reshape(-1)
|
||
.tolist()
|
||
)
|
||
s = [cls, *s]
|
||
segments.append(s)
|
||
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||
|
||
if segments:
|
||
classes = np.array([x[0] for x in segments], dtype=np.float32)
|
||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
|
||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
||
lb = np.array(lb, dtype=np.float32)
|
||
|
||
x["labels"].append(
|
||
{
|
||
"im_file": im_file,
|
||
"shape": (h, w),
|
||
"cls": lb[:, 0:1], # n, 1
|
||
"bboxes": lb[:, 1:], # n, 4
|
||
"segments": segments,
|
||
"normalized": True,
|
||
"bbox_format": "xywh",
|
||
"texts": texts,
|
||
}
|
||
)
|
||
x["hash"] = get_hash(self.json_file)
|
||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||
return x
|
||
|
||
def get_labels(self) -> list[dict]:
|
||
"""Load labels from cache or generate them from JSON file.
|
||
|
||
Returns:
|
||
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
||
"""
|
||
cache_path = Path(self.json_file).with_suffix(".cache")
|
||
try:
|
||
cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||
assert cache["hash"] == get_hash(self.json_file) # identical hash
|
||
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
||
cache, _ = self.cache_labels(cache_path), False # run cache ops
|
||
[cache.pop(k) for k in ("hash", "version")] # remove items
|
||
labels = cache["labels"]
|
||
self.verify_labels(labels)
|
||
self.im_files = [str(label["im_file"]) for label in labels]
|
||
if LOCAL_RANK in {-1, 0}:
|
||
LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
|
||
return labels
|
||
|
||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||
"""Configure augmentations for training with optional text loading.
|
||
|
||
Args:
|
||
hyp (dict, optional): Hyperparameters for transforms.
|
||
|
||
Returns:
|
||
(Compose): Composed transforms including text augmentation if applicable.
|
||
"""
|
||
transforms = super().build_transforms(hyp)
|
||
if self.augment:
|
||
# NOTE: hard-coded the args for now.
|
||
# NOTE: this implementation is different from official yoloe,
|
||
# the strategy of selecting negative is restricted in one dataset,
|
||
# while official pre-saved neg embeddings from all datasets at once.
|
||
transform = RandomLoadText(
|
||
max_samples=min(self.max_samples, 80),
|
||
padding=True,
|
||
padding_value=self._get_neg_texts(self.category_freq),
|
||
)
|
||
transforms.insert(-1, transform)
|
||
return transforms
|
||
|
||
@property
|
||
def category_names(self):
|
||
"""Return unique category names from the dataset."""
|
||
return {t.strip() for label in self.labels for text in label["texts"] for t in text}
|
||
|
||
@property
|
||
def category_freq(self):
|
||
"""Return frequency of each category in the dataset."""
|
||
category_freq = defaultdict(int)
|
||
for label in self.labels:
|
||
for text in label["texts"]:
|
||
for t in text:
|
||
t = t.strip()
|
||
category_freq[t] += 1
|
||
return category_freq
|
||
|
||
@staticmethod
|
||
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
|
||
"""Get negative text samples based on frequency threshold."""
|
||
threshold = min(max(category_freq.values()), 100)
|
||
return [k for k, v in category_freq.items() if v >= threshold]
|
||
|
||
|
||
class YOLOConcatDataset(ConcatDataset):
|
||
"""Dataset as a concatenation of multiple datasets.
|
||
|
||
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation
|
||
function.
|
||
|
||
Methods:
|
||
collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.
|
||
|
||
Examples:
|
||
>>> dataset1 = YOLODataset(...)
|
||
>>> dataset2 = YOLODataset(...)
|
||
>>> combined_dataset = YOLOConcatDataset([dataset1, dataset2])
|
||
"""
|
||
|
||
@staticmethod
|
||
def collate_fn(batch: list[dict]) -> dict:
|
||
"""Collate data samples into batches.
|
||
|
||
Args:
|
||
batch (list[dict]): List of dictionaries containing sample data.
|
||
|
||
Returns:
|
||
(dict): Collated batch with stacked tensors.
|
||
"""
|
||
return YOLODataset.collate_fn(batch)
|
||
|
||
def close_mosaic(self, hyp: dict) -> None:
|
||
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
||
|
||
Args:
|
||
hyp (dict): Hyperparameters for transforms.
|
||
"""
|
||
for dataset in self.datasets:
|
||
if not hasattr(dataset, "close_mosaic"):
|
||
continue
|
||
dataset.close_mosaic(hyp)
|
||
|
||
|
||
# TODO: support semantic segmentation
|
||
class SemanticDataset(BaseDataset):
|
||
"""Semantic Segmentation Dataset."""
|
||
|
||
def __init__(self):
|
||
"""Initialize a SemanticDataset object."""
|
||
super().__init__()
|
||
|
||
|
||
class ClassificationDataset:
|
||
"""Dataset class for image classification tasks wrapping torchvision ImageFolder functionality.
|
||
|
||
This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
|
||
handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
|
||
to speed up training.
|
||
|
||
Attributes:
|
||
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||
cache_disk (bool): Indicates if caching on disk is enabled.
|
||
samples (list): A list of lists, each containing the path to an image, its class index, path to its .npy cache
|
||
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||
root (str): Root directory of the dataset.
|
||
prefix (str): Prefix for logging and cache filenames.
|
||
|
||
Methods:
|
||
__getitem__: Return transformed image and class index for the given sample index.
|
||
__len__: Return the total number of samples in the dataset.
|
||
verify_images: Verify all images in dataset.
|
||
"""
|
||
|
||
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
||
"""Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
|
||
|
||
Args:
|
||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||
parameters, and cache settings.
|
||
augment (bool, optional): Whether to apply augmentations to the dataset.
|
||
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification.
|
||
"""
|
||
import torchvision # scope for faster 'import ultralytics'
|
||
|
||
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
||
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
|
||
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
|
||
else:
|
||
self.base = torchvision.datasets.ImageFolder(root=root)
|
||
self.samples = self.base.samples
|
||
self.root = self.base.root
|
||
|
||
# Initialize attributes
|
||
if augment and args.fraction < 1.0: # reduce training fraction
|
||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||
if self.cache_ram:
|
||
LOGGER.warning(
|
||
"Classification `cache_ram` training has known memory leak in "
|
||
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
|
||
)
|
||
self.cache_ram = False
|
||
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||
self.samples = self.verify_images() # filter out bad images
|
||
self.samples = [[*list(x), Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||
self.torch_transforms = (
|
||
classify_augmentations(
|
||
size=args.imgsz,
|
||
scale=scale,
|
||
hflip=args.fliplr,
|
||
vflip=args.flipud,
|
||
erasing=args.erasing,
|
||
auto_augment=args.auto_augment,
|
||
hsv_h=args.hsv_h,
|
||
hsv_s=args.hsv_s,
|
||
hsv_v=args.hsv_v,
|
||
)
|
||
if augment
|
||
else classify_transforms(size=args.imgsz)
|
||
)
|
||
|
||
def __getitem__(self, i: int) -> dict:
|
||
"""Return transformed image and class index for the given sample index.
|
||
|
||
Args:
|
||
i (int): Index of the sample to retrieve.
|
||
|
||
Returns:
|
||
(dict): Dictionary containing the image and its class index.
|
||
"""
|
||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||
if self.cache_ram:
|
||
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||
im = self.samples[i][3] = cv2.imread(f)
|
||
elif self.cache_disk:
|
||
if not fn.exists(): # load npy
|
||
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||
im = np.load(fn)
|
||
else: # read image
|
||
im = cv2.imread(f) # BGR
|
||
# Convert NumPy array to PIL image
|
||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||
sample = self.torch_transforms(im)
|
||
return {"img": sample, "cls": j}
|
||
|
||
def __len__(self) -> int:
|
||
"""Return the total number of samples in the dataset."""
|
||
return len(self.samples)
|
||
|
||
def verify_images(self) -> list[tuple]:
|
||
"""Verify all images in dataset.
|
||
|
||
Returns:
|
||
(list[tuple]): List of valid samples after verification.
|
||
"""
|
||
desc = f"{self.prefix}Scanning {self.root}..."
|
||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||
|
||
try:
|
||
check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds
|
||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||
if LOCAL_RANK in {-1, 0}:
|
||
d = f"{desc} {nf} images, {nc} corrupt"
|
||
TQDM(None, desc=d, total=n, initial=n)
|
||
if cache["msgs"]:
|
||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||
return samples
|
||
|
||
except (FileNotFoundError, AssertionError, AttributeError):
|
||
# Run scan if *.cache retrieval failed
|
||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||
with ThreadPool(NUM_THREADS) as pool:
|
||
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||
for sample, nf_f, nc_f, msg in pbar:
|
||
if nf_f:
|
||
samples.append(sample)
|
||
if msg:
|
||
msgs.append(msg)
|
||
nf += nf_f
|
||
nc += nc_f
|
||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||
pbar.close()
|
||
if msgs:
|
||
LOGGER.info("\n".join(msgs))
|
||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||
x["results"] = nf, nc, len(samples), samples
|
||
x["msgs"] = msgs # warnings
|
||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||
return samples
|