Files
yolov26_3d/tools/clean_ground2d_dataset.py
2026-06-24 09:35:46 +08:00

418 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Clean ground 2D dataset labels and corrupt image pairs.
This script matches the duplicate-label definition used by
``ultralytics.data.utils.verify_image_label_ground``:
[class_id, x_center, y_center, width, height, difficulty1 + difficulty2]
Rows that map to the same parsed tuple are duplicates; the first raw row is kept.
By default the script is a dry run. Pass ``--apply`` to rewrite label files,
delete corrupt image/label pairs, and update split list files.
"""
from __future__ import annotations
import argparse
import os
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
import numpy as np
from PIL import Image
try:
from tqdm import tqdm
except ImportError:
tqdm = None
IMG_FORMATS = {
"avif",
"bmp",
"dng",
"heic",
"heif",
"jp2",
"jpeg",
"jpeg2000",
"jpg",
"mpo",
"png",
"tif",
"tiff",
"webp",
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--data",
default="ultralytics/cfg/datasets/mono2d_ground.yaml",
help="Ground 2D dataset YAML path.",
)
parser.add_argument(
"--splits",
nargs="+",
default=["train", "val"],
help="Dataset YAML split keys to scan, for example: train val test.",
)
parser.add_argument(
"--apply",
action="store_true",
help="Actually modify files. Without this flag, only prints what would change.",
)
parser.add_argument(
"--no-update-lists",
action="store_true",
help="Do not remove corrupt entries from split list files.",
)
parser.add_argument(
"--no-backup",
action="store_true",
help="Do not create .bak files before rewriting labels/list files.",
)
parser.add_argument(
"--report",
default="",
help="Optional report file path. Defaults to no report file.",
)
parser.add_argument(
"--progress-interval",
type=int,
default=100,
help="Print progress every N entries when tqdm is not installed. Set 0 to disable fallback progress.",
)
return parser.parse_args()
def read_lines(path: Path) -> list[str]:
return path.read_text(encoding="utf-8").splitlines()
def strip_yaml_comment(value: str) -> str:
"""Strip simple YAML comments while preserving path strings."""
return value.split("#", 1)[0].strip()
def parse_scalar(value: str) -> object:
value = strip_yaml_comment(value)
if not value:
return None
if value.lower() in {"true", "false"}:
return value.lower() == "true"
if value.startswith("[") and value.endswith("]"):
items = [x.strip() for x in value[1:-1].split(",") if x.strip()]
parsed = []
for item in items:
try:
parsed.append(float(item))
except ValueError:
parsed.append(item.strip("'\""))
return parsed
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value.strip("'\"")
def load_ground_yaml(path: Path) -> dict[str, object]:
"""Load the small YAML subset used by mono2d_ground.yaml without external dependencies."""
data: dict[str, object] = {}
current_map: str | None = None
for raw_line in read_lines(path):
if not raw_line.strip() or raw_line.lstrip().startswith("#"):
continue
indent = len(raw_line) - len(raw_line.lstrip(" "))
line = raw_line.strip()
if indent == 0:
key, sep, value = line.partition(":")
if not sep:
continue
parsed = parse_scalar(value)
if parsed is None:
data[key] = {}
current_map = key
else:
data[key] = parsed
current_map = None
elif current_map and isinstance(data.get(current_map), dict):
key, sep, value = line.partition(":")
if sep:
data[current_map][key.strip()] = parse_scalar(value)
return data
def img2label_path(img_path: str) -> str:
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"
return sb.join(img_path.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt"
def write_lines(path: Path, lines: list[str], backup: bool) -> None:
if backup and path.exists():
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = path.with_name(f"{path.name}.bak.{stamp}")
backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8")
path.write_text("\n".join(lines) + ("\n" if lines else ""), encoding="utf-8")
def split_entries(split_path: Path) -> list[str]:
return [line.strip() for line in read_lines(split_path) if line.strip()]
def expand_split_entry(entry: str, split_path: Path) -> str:
"""Mirror YOLOGroundDataset.get_img_files list-file expansion."""
if entry.startswith("./"):
return entry.replace("./", str(split_path.parent) + os.sep, 1)
return entry
def resolve_split_path(value: str | None, data_dir: Path) -> Path | None:
if not value:
return None
path = Path(value)
return path if path.is_absolute() else data_dir / path
def ground_image_path_from_label_path(path: str, image_root: Path, label_path: bool = False) -> Path:
"""Mirror YOLOGroundDataset._ground_image_path_from_label_path."""
parts = Path(path).parts
detection_root_index = next(
(i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))),
None,
)
if detection_root_index is not None and detection_root_index + 1 < len(parts):
rel_parts = list(parts[detection_root_index + 1 :])
else:
rel_path = Path(path) if not Path(path).is_absolute() else Path(*parts[1:])
rel_parts = list(rel_path.parts)
if label_path:
image_path = image_root.joinpath(*rel_parts).with_suffix(".jpg")
if image_path.exists():
return image_path
image_rel_parts = ["images" if part == "labels" else part for part in rel_parts]
image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg")
if image_path.exists():
return image_path
for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"):
candidate = image_path.with_suffix(f".{suffix}")
if candidate.exists():
return candidate
return image_path
return image_root.joinpath(*rel_parts)
def label_path_for_entry(entry: str) -> Path:
suffix = Path(entry).suffix.lower().lstrip(".")
if suffix == "txt":
return Path(entry)
return Path(img2label_path(entry))
def parse_known_label_key(parts: list[str], class_map: dict[str, int]) -> tuple[float, ...] | None:
if len(parts) < 7 or parts[0] not in class_map:
return None
coords = [float(x) for x in parts[1:5]]
difficulty = float(parts[5]) + float(parts[6])
row = np.array([class_map[parts[0]], *coords, difficulty], dtype=np.float32)
return tuple(float(x) for x in row)
def deduplicate_label_file(label_path: Path, class_map: dict[str, int]) -> tuple[int, list[str], list[str]]:
"""Return duplicate count, cleaned raw lines, and warnings."""
if not label_path.exists():
return 0, [], []
raw_lines = read_lines(label_path)
seen: OrderedDict[tuple[float, ...], int] = OrderedDict()
keep = [True] * len(raw_lines)
warnings = []
for idx, line in enumerate(raw_lines):
stripped = line.strip()
if not stripped:
continue
parts = stripped.split()
try:
key = parse_known_label_key(parts, class_map)
except ValueError as exc:
warnings.append(f"{label_path}:{idx + 1}: cannot parse label ({exc})")
continue
if key is None:
continue
if key in seen:
keep[idx] = False
else:
seen[key] = idx
cleaned = [line for line, should_keep in zip(raw_lines, keep) if should_keep]
return len(raw_lines) - len(cleaned), cleaned, warnings
def exif_size(img: Image.Image) -> tuple[int, int]:
size = img.size
if img.format == "JPEG":
try:
if exif := img.getexif():
if exif.get(274, None) in {6, 8}:
size = size[1], size[0]
except Exception:
pass
return size
def image_corruption_reason(image_path: Path) -> str | None:
try:
with Image.open(image_path) as im:
im.verify()
shape = exif_size(im)
if shape[0] <= 9 or shape[1] <= 9:
return f"image size {shape} <10 pixels"
if (im.format or "").lower() not in IMG_FORMATS:
return f"invalid image format {im.format}"
if image_path.suffix.lower() in {".jpg", ".jpeg"}:
with image_path.open("rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9":
return "corrupt JPEG end marker"
except Exception as exc:
return str(exc)
return None
def unlink_if_exists(path: Path) -> bool:
if path.exists():
path.unlink()
return True
return False
def progress_iter(items: list[str], desc: str, interval: int):
total = len(items)
if tqdm is not None:
yield from tqdm(items, total=total, desc=desc, unit="entry")
return
for idx, item in enumerate(items, 1):
if interval > 0 and (idx == 1 or idx % interval == 0 or idx == total):
print(f"{desc}: {idx}/{total}", flush=True)
yield item
def main() -> int:
args = parse_args()
data_file = Path(args.data)
data = load_ground_yaml(data_file)
data_dir = data_file.parent
image_root = Path(data["path"])
class_map = data.get("class_map", {})
if not class_map:
raise SystemExit(f"{data_file} does not contain class_map; this cleaner is for ground 2D labels.")
update_lists = args.apply and not args.no_update_lists
backup = not args.no_backup
report_lines = []
label_files_seen: set[Path] = set()
stats = {
"images": 0,
"labels": 0,
"duplicate_rows": 0,
"labels_rewritten": 0,
"corrupt_images": 0,
"images_removed": 0,
"labels_removed": 0,
"list_entries_removed": 0,
}
for split in args.splits:
split_path = resolve_split_path(data.get(split), data_dir)
if split_path is None:
continue
if not split_path.exists():
report_lines.append(f"[missing split] {split}: {split_path}")
continue
entries = split_entries(split_path)
kept_entries = []
removed_from_split = 0
for entry in progress_iter(entries, f"Scanning {split}", args.progress_interval):
expanded_entry = expand_split_entry(entry, split_path)
suffix = Path(expanded_entry).suffix.lower().lstrip(".")
if suffix in IMG_FORMATS:
image_path = ground_image_path_from_label_path(expanded_entry, image_root, label_path=False)
label_path = label_path_for_entry(expanded_entry)
elif suffix == "txt":
label_path = Path(expanded_entry)
image_path = ground_image_path_from_label_path(expanded_entry, image_root, label_path=True)
else:
report_lines.append(f"[unsupported entry] {split}: {entry}")
kept_entries.append(entry)
continue
stats["images"] += 1
reason = image_corruption_reason(image_path)
if reason:
stats["corrupt_images"] += 1
removed_from_split += 1
report_lines.append(f"[corrupt] {image_path} | {reason}")
report_lines.append(f" label: {label_path}")
if args.apply:
if unlink_if_exists(image_path):
stats["images_removed"] += 1
if unlink_if_exists(label_path):
stats["labels_removed"] += 1
continue
kept_entries.append(entry)
if label_path in label_files_seen:
continue
label_files_seen.add(label_path)
if label_path.exists():
stats["labels"] += 1
dup_count, cleaned_lines, warnings = deduplicate_label_file(label_path, class_map)
report_lines.extend(f"[warning] {warning}" for warning in warnings)
if dup_count:
stats["duplicate_rows"] += dup_count
report_lines.append(f"[duplicates] {label_path}: remove {dup_count}")
if args.apply:
write_lines(label_path, cleaned_lines, backup=backup)
stats["labels_rewritten"] += 1
if update_lists and removed_from_split:
write_lines(split_path, kept_entries, backup=backup)
stats["list_entries_removed"] += removed_from_split
report_lines.append(f"[list updated] {split_path}: removed {removed_from_split} corrupt entries")
mode = "APPLIED" if args.apply else "DRY RUN"
summary = [
f"Mode: {mode}",
f"Images checked: {stats['images']}",
f"Label files checked: {stats['labels']}",
f"Duplicate label rows found: {stats['duplicate_rows']}",
f"Label files rewritten: {stats['labels_rewritten']}",
f"Corrupt images found: {stats['corrupt_images']}",
f"Images removed: {stats['images_removed']}",
f"Labels removed with corrupt images: {stats['labels_removed']}",
f"Split list entries removed: {stats['list_entries_removed']}",
]
output = "\n".join(summary + (["", *report_lines] if report_lines else []))
print(output)
if args.report:
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_path.write_text(output + "\n", encoding="utf-8")
return 0
if __name__ == "__main__":
raise SystemExit(main())