#!/usr/bin/env python3 # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license """Clean ground 2D dataset labels and corrupt image pairs. This script matches the duplicate-label definition used by ``ultralytics.data.utils.verify_image_label_ground``: [class_id, x_center, y_center, width, height, difficulty1 + difficulty2] Rows that map to the same parsed tuple are duplicates; the first raw row is kept. By default the script is a dry run. Pass ``--apply`` to rewrite label files, delete corrupt image/label pairs, and update split list files. """ from __future__ import annotations import argparse import os from collections import OrderedDict from datetime import datetime from pathlib import Path import numpy as np from PIL import Image try: from tqdm import tqdm except ImportError: tqdm = None IMG_FORMATS = { "avif", "bmp", "dng", "heic", "heif", "jp2", "jpeg", "jpeg2000", "jpg", "mpo", "png", "tif", "tiff", "webp", } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--data", default="ultralytics/cfg/datasets/mono2d_ground.yaml", help="Ground 2D dataset YAML path.", ) parser.add_argument( "--splits", nargs="+", default=["train", "val"], help="Dataset YAML split keys to scan, for example: train val test.", ) parser.add_argument( "--apply", action="store_true", help="Actually modify files. Without this flag, only prints what would change.", ) parser.add_argument( "--no-update-lists", action="store_true", help="Do not remove corrupt entries from split list files.", ) parser.add_argument( "--no-backup", action="store_true", help="Do not create .bak files before rewriting labels/list files.", ) parser.add_argument( "--report", default="", help="Optional report file path. Defaults to no report file.", ) parser.add_argument( "--progress-interval", type=int, default=100, help="Print progress every N entries when tqdm is not installed. Set 0 to disable fallback progress.", ) return parser.parse_args() def read_lines(path: Path) -> list[str]: return path.read_text(encoding="utf-8").splitlines() def strip_yaml_comment(value: str) -> str: """Strip simple YAML comments while preserving path strings.""" return value.split("#", 1)[0].strip() def parse_scalar(value: str) -> object: value = strip_yaml_comment(value) if not value: return None if value.lower() in {"true", "false"}: return value.lower() == "true" if value.startswith("[") and value.endswith("]"): items = [x.strip() for x in value[1:-1].split(",") if x.strip()] parsed = [] for item in items: try: parsed.append(float(item)) except ValueError: parsed.append(item.strip("'\"")) return parsed try: return int(value) except ValueError: try: return float(value) except ValueError: return value.strip("'\"") def load_ground_yaml(path: Path) -> dict[str, object]: """Load the small YAML subset used by mono2d_ground.yaml without external dependencies.""" data: dict[str, object] = {} current_map: str | None = None for raw_line in read_lines(path): if not raw_line.strip() or raw_line.lstrip().startswith("#"): continue indent = len(raw_line) - len(raw_line.lstrip(" ")) line = raw_line.strip() if indent == 0: key, sep, value = line.partition(":") if not sep: continue parsed = parse_scalar(value) if parsed is None: data[key] = {} current_map = key else: data[key] = parsed current_map = None elif current_map and isinstance(data.get(current_map), dict): key, sep, value = line.partition(":") if sep: data[current_map][key.strip()] = parse_scalar(value) return data def img2label_path(img_path: str) -> str: sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" return sb.join(img_path.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" def write_lines(path: Path, lines: list[str], backup: bool) -> None: if backup and path.exists(): stamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = path.with_name(f"{path.name}.bak.{stamp}") backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8") path.write_text("\n".join(lines) + ("\n" if lines else ""), encoding="utf-8") def split_entries(split_path: Path) -> list[str]: return [line.strip() for line in read_lines(split_path) if line.strip()] def expand_split_entry(entry: str, split_path: Path) -> str: """Mirror YOLOGroundDataset.get_img_files list-file expansion.""" if entry.startswith("./"): return entry.replace("./", str(split_path.parent) + os.sep, 1) return entry def resolve_split_path(value: str | None, data_dir: Path) -> Path | None: if not value: return None path = Path(value) return path if path.is_absolute() else data_dir / path def ground_image_path_from_label_path(path: str, image_root: Path, label_path: bool = False) -> Path: """Mirror YOLOGroundDataset._ground_image_path_from_label_path.""" parts = Path(path).parts detection_root_index = next( (i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))), None, ) if detection_root_index is not None and detection_root_index + 1 < len(parts): rel_parts = list(parts[detection_root_index + 1 :]) else: rel_path = Path(path) if not Path(path).is_absolute() else Path(*parts[1:]) rel_parts = list(rel_path.parts) if label_path: image_path = image_root.joinpath(*rel_parts).with_suffix(".jpg") if image_path.exists(): return image_path image_rel_parts = ["images" if part == "labels" else part for part in rel_parts] image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg") if image_path.exists(): return image_path for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"): candidate = image_path.with_suffix(f".{suffix}") if candidate.exists(): return candidate return image_path return image_root.joinpath(*rel_parts) def label_path_for_entry(entry: str) -> Path: suffix = Path(entry).suffix.lower().lstrip(".") if suffix == "txt": return Path(entry) return Path(img2label_path(entry)) def parse_known_label_key(parts: list[str], class_map: dict[str, int]) -> tuple[float, ...] | None: if len(parts) < 7 or parts[0] not in class_map: return None coords = [float(x) for x in parts[1:5]] difficulty = float(parts[5]) + float(parts[6]) row = np.array([class_map[parts[0]], *coords, difficulty], dtype=np.float32) return tuple(float(x) for x in row) def deduplicate_label_file(label_path: Path, class_map: dict[str, int]) -> tuple[int, list[str], list[str]]: """Return duplicate count, cleaned raw lines, and warnings.""" if not label_path.exists(): return 0, [], [] raw_lines = read_lines(label_path) seen: OrderedDict[tuple[float, ...], int] = OrderedDict() keep = [True] * len(raw_lines) warnings = [] for idx, line in enumerate(raw_lines): stripped = line.strip() if not stripped: continue parts = stripped.split() try: key = parse_known_label_key(parts, class_map) except ValueError as exc: warnings.append(f"{label_path}:{idx + 1}: cannot parse label ({exc})") continue if key is None: continue if key in seen: keep[idx] = False else: seen[key] = idx cleaned = [line for line, should_keep in zip(raw_lines, keep) if should_keep] return len(raw_lines) - len(cleaned), cleaned, warnings def exif_size(img: Image.Image) -> tuple[int, int]: size = img.size if img.format == "JPEG": try: if exif := img.getexif(): if exif.get(274, None) in {6, 8}: size = size[1], size[0] except Exception: pass return size def image_corruption_reason(image_path: Path) -> str | None: try: with Image.open(image_path) as im: im.verify() shape = exif_size(im) if shape[0] <= 9 or shape[1] <= 9: return f"image size {shape} <10 pixels" if (im.format or "").lower() not in IMG_FORMATS: return f"invalid image format {im.format}" if image_path.suffix.lower() in {".jpg", ".jpeg"}: with image_path.open("rb") as f: f.seek(-2, 2) if f.read() != b"\xff\xd9": return "corrupt JPEG end marker" except Exception as exc: return str(exc) return None def unlink_if_exists(path: Path) -> bool: if path.exists(): path.unlink() return True return False def progress_iter(items: list[str], desc: str, interval: int): total = len(items) if tqdm is not None: yield from tqdm(items, total=total, desc=desc, unit="entry") return for idx, item in enumerate(items, 1): if interval > 0 and (idx == 1 or idx % interval == 0 or idx == total): print(f"{desc}: {idx}/{total}", flush=True) yield item def main() -> int: args = parse_args() data_file = Path(args.data) data = load_ground_yaml(data_file) data_dir = data_file.parent image_root = Path(data["path"]) class_map = data.get("class_map", {}) if not class_map: raise SystemExit(f"{data_file} does not contain class_map; this cleaner is for ground 2D labels.") update_lists = args.apply and not args.no_update_lists backup = not args.no_backup report_lines = [] label_files_seen: set[Path] = set() stats = { "images": 0, "labels": 0, "duplicate_rows": 0, "labels_rewritten": 0, "corrupt_images": 0, "images_removed": 0, "labels_removed": 0, "list_entries_removed": 0, } for split in args.splits: split_path = resolve_split_path(data.get(split), data_dir) if split_path is None: continue if not split_path.exists(): report_lines.append(f"[missing split] {split}: {split_path}") continue entries = split_entries(split_path) kept_entries = [] removed_from_split = 0 for entry in progress_iter(entries, f"Scanning {split}", args.progress_interval): expanded_entry = expand_split_entry(entry, split_path) suffix = Path(expanded_entry).suffix.lower().lstrip(".") if suffix in IMG_FORMATS: image_path = ground_image_path_from_label_path(expanded_entry, image_root, label_path=False) label_path = label_path_for_entry(expanded_entry) elif suffix == "txt": label_path = Path(expanded_entry) image_path = ground_image_path_from_label_path(expanded_entry, image_root, label_path=True) else: report_lines.append(f"[unsupported entry] {split}: {entry}") kept_entries.append(entry) continue stats["images"] += 1 reason = image_corruption_reason(image_path) if reason: stats["corrupt_images"] += 1 removed_from_split += 1 report_lines.append(f"[corrupt] {image_path} | {reason}") report_lines.append(f" label: {label_path}") if args.apply: if unlink_if_exists(image_path): stats["images_removed"] += 1 if unlink_if_exists(label_path): stats["labels_removed"] += 1 continue kept_entries.append(entry) if label_path in label_files_seen: continue label_files_seen.add(label_path) if label_path.exists(): stats["labels"] += 1 dup_count, cleaned_lines, warnings = deduplicate_label_file(label_path, class_map) report_lines.extend(f"[warning] {warning}" for warning in warnings) if dup_count: stats["duplicate_rows"] += dup_count report_lines.append(f"[duplicates] {label_path}: remove {dup_count}") if args.apply: write_lines(label_path, cleaned_lines, backup=backup) stats["labels_rewritten"] += 1 if update_lists and removed_from_split: write_lines(split_path, kept_entries, backup=backup) stats["list_entries_removed"] += removed_from_split report_lines.append(f"[list updated] {split_path}: removed {removed_from_split} corrupt entries") mode = "APPLIED" if args.apply else "DRY RUN" summary = [ f"Mode: {mode}", f"Images checked: {stats['images']}", f"Label files checked: {stats['labels']}", f"Duplicate label rows found: {stats['duplicate_rows']}", f"Label files rewritten: {stats['labels_rewritten']}", f"Corrupt images found: {stats['corrupt_images']}", f"Images removed: {stats['images_removed']}", f"Labels removed with corrupt images: {stats['labels_removed']}", f"Split list entries removed: {stats['list_entries_removed']}", ] output = "\n".join(summary + (["", *report_lines] if report_lines else [])) print(output) if args.report: report_path = Path(args.report) report_path.parent.mkdir(parents=True, exist_ok=True) report_path.write_text(output + "\n", encoding="utf-8") return 0 if __name__ == "__main__": raise SystemExit(main())