418 lines
14 KiB
Python
418 lines
14 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||
|
|
"""Clean ground 2D dataset labels and corrupt image pairs.
|
||
|
|
|
||
|
|
This script matches the duplicate-label definition used by
|
||
|
|
``ultralytics.data.utils.verify_image_label_ground``:
|
||
|
|
|
||
|
|
[class_id, x_center, y_center, width, height, difficulty1 + difficulty2]
|
||
|
|
|
||
|
|
Rows that map to the same parsed tuple are duplicates; the first raw row is kept.
|
||
|
|
By default the script is a dry run. Pass ``--apply`` to rewrite label files,
|
||
|
|
delete corrupt image/label pairs, and update split list files.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
from collections import OrderedDict
|
||
|
|
from datetime import datetime
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
try:
|
||
|
|
from tqdm import tqdm
|
||
|
|
except ImportError:
|
||
|
|
tqdm = None
|
||
|
|
|
||
|
|
IMG_FORMATS = {
|
||
|
|
"avif",
|
||
|
|
"bmp",
|
||
|
|
"dng",
|
||
|
|
"heic",
|
||
|
|
"heif",
|
||
|
|
"jp2",
|
||
|
|
"jpeg",
|
||
|
|
"jpeg2000",
|
||
|
|
"jpg",
|
||
|
|
"mpo",
|
||
|
|
"png",
|
||
|
|
"tif",
|
||
|
|
"tiff",
|
||
|
|
"webp",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def parse_args() -> argparse.Namespace:
|
||
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
||
|
|
parser.add_argument(
|
||
|
|
"--data",
|
||
|
|
default="ultralytics/cfg/datasets/mono2d_ground.yaml",
|
||
|
|
help="Ground 2D dataset YAML path.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--splits",
|
||
|
|
nargs="+",
|
||
|
|
default=["train", "val"],
|
||
|
|
help="Dataset YAML split keys to scan, for example: train val test.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--apply",
|
||
|
|
action="store_true",
|
||
|
|
help="Actually modify files. Without this flag, only prints what would change.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--no-update-lists",
|
||
|
|
action="store_true",
|
||
|
|
help="Do not remove corrupt entries from split list files.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--no-backup",
|
||
|
|
action="store_true",
|
||
|
|
help="Do not create .bak files before rewriting labels/list files.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--report",
|
||
|
|
default="",
|
||
|
|
help="Optional report file path. Defaults to no report file.",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--progress-interval",
|
||
|
|
type=int,
|
||
|
|
default=100,
|
||
|
|
help="Print progress every N entries when tqdm is not installed. Set 0 to disable fallback progress.",
|
||
|
|
)
|
||
|
|
return parser.parse_args()
|
||
|
|
|
||
|
|
|
||
|
|
def read_lines(path: Path) -> list[str]:
|
||
|
|
return path.read_text(encoding="utf-8").splitlines()
|
||
|
|
|
||
|
|
|
||
|
|
def strip_yaml_comment(value: str) -> str:
|
||
|
|
"""Strip simple YAML comments while preserving path strings."""
|
||
|
|
return value.split("#", 1)[0].strip()
|
||
|
|
|
||
|
|
|
||
|
|
def parse_scalar(value: str) -> object:
|
||
|
|
value = strip_yaml_comment(value)
|
||
|
|
if not value:
|
||
|
|
return None
|
||
|
|
if value.lower() in {"true", "false"}:
|
||
|
|
return value.lower() == "true"
|
||
|
|
if value.startswith("[") and value.endswith("]"):
|
||
|
|
items = [x.strip() for x in value[1:-1].split(",") if x.strip()]
|
||
|
|
parsed = []
|
||
|
|
for item in items:
|
||
|
|
try:
|
||
|
|
parsed.append(float(item))
|
||
|
|
except ValueError:
|
||
|
|
parsed.append(item.strip("'\""))
|
||
|
|
return parsed
|
||
|
|
try:
|
||
|
|
return int(value)
|
||
|
|
except ValueError:
|
||
|
|
try:
|
||
|
|
return float(value)
|
||
|
|
except ValueError:
|
||
|
|
return value.strip("'\"")
|
||
|
|
|
||
|
|
|
||
|
|
def load_ground_yaml(path: Path) -> dict[str, object]:
|
||
|
|
"""Load the small YAML subset used by mono2d_ground.yaml without external dependencies."""
|
||
|
|
data: dict[str, object] = {}
|
||
|
|
current_map: str | None = None
|
||
|
|
for raw_line in read_lines(path):
|
||
|
|
if not raw_line.strip() or raw_line.lstrip().startswith("#"):
|
||
|
|
continue
|
||
|
|
indent = len(raw_line) - len(raw_line.lstrip(" "))
|
||
|
|
line = raw_line.strip()
|
||
|
|
if indent == 0:
|
||
|
|
key, sep, value = line.partition(":")
|
||
|
|
if not sep:
|
||
|
|
continue
|
||
|
|
parsed = parse_scalar(value)
|
||
|
|
if parsed is None:
|
||
|
|
data[key] = {}
|
||
|
|
current_map = key
|
||
|
|
else:
|
||
|
|
data[key] = parsed
|
||
|
|
current_map = None
|
||
|
|
elif current_map and isinstance(data.get(current_map), dict):
|
||
|
|
key, sep, value = line.partition(":")
|
||
|
|
if sep:
|
||
|
|
data[current_map][key.strip()] = parse_scalar(value)
|
||
|
|
return data
|
||
|
|
|
||
|
|
|
||
|
|
def img2label_path(img_path: str) -> str:
|
||
|
|
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"
|
||
|
|
return sb.join(img_path.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt"
|
||
|
|
|
||
|
|
|
||
|
|
def write_lines(path: Path, lines: list[str], backup: bool) -> None:
|
||
|
|
if backup and path.exists():
|
||
|
|
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
|
backup_path = path.with_name(f"{path.name}.bak.{stamp}")
|
||
|
|
backup_path.write_text(path.read_text(encoding="utf-8"), encoding="utf-8")
|
||
|
|
path.write_text("\n".join(lines) + ("\n" if lines else ""), encoding="utf-8")
|
||
|
|
|
||
|
|
|
||
|
|
def split_entries(split_path: Path) -> list[str]:
|
||
|
|
return [line.strip() for line in read_lines(split_path) if line.strip()]
|
||
|
|
|
||
|
|
|
||
|
|
def expand_split_entry(entry: str, split_path: Path) -> str:
|
||
|
|
"""Mirror YOLOGroundDataset.get_img_files list-file expansion."""
|
||
|
|
if entry.startswith("./"):
|
||
|
|
return entry.replace("./", str(split_path.parent) + os.sep, 1)
|
||
|
|
return entry
|
||
|
|
|
||
|
|
|
||
|
|
def resolve_split_path(value: str | None, data_dir: Path) -> Path | None:
|
||
|
|
if not value:
|
||
|
|
return None
|
||
|
|
path = Path(value)
|
||
|
|
return path if path.is_absolute() else data_dir / path
|
||
|
|
|
||
|
|
|
||
|
|
def ground_image_path_from_label_path(path: str, image_root: Path, label_path: bool = False) -> Path:
|
||
|
|
"""Mirror YOLOGroundDataset._ground_image_path_from_label_path."""
|
||
|
|
parts = Path(path).parts
|
||
|
|
detection_root_index = next(
|
||
|
|
(i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))),
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
if detection_root_index is not None and detection_root_index + 1 < len(parts):
|
||
|
|
rel_parts = list(parts[detection_root_index + 1 :])
|
||
|
|
else:
|
||
|
|
rel_path = Path(path) if not Path(path).is_absolute() else Path(*parts[1:])
|
||
|
|
rel_parts = list(rel_path.parts)
|
||
|
|
|
||
|
|
if label_path:
|
||
|
|
image_path = image_root.joinpath(*rel_parts).with_suffix(".jpg")
|
||
|
|
if image_path.exists():
|
||
|
|
return image_path
|
||
|
|
image_rel_parts = ["images" if part == "labels" else part for part in rel_parts]
|
||
|
|
image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg")
|
||
|
|
if image_path.exists():
|
||
|
|
return image_path
|
||
|
|
for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"):
|
||
|
|
candidate = image_path.with_suffix(f".{suffix}")
|
||
|
|
if candidate.exists():
|
||
|
|
return candidate
|
||
|
|
return image_path
|
||
|
|
|
||
|
|
return image_root.joinpath(*rel_parts)
|
||
|
|
|
||
|
|
|
||
|
|
def label_path_for_entry(entry: str) -> Path:
|
||
|
|
suffix = Path(entry).suffix.lower().lstrip(".")
|
||
|
|
if suffix == "txt":
|
||
|
|
return Path(entry)
|
||
|
|
return Path(img2label_path(entry))
|
||
|
|
|
||
|
|
|
||
|
|
def parse_known_label_key(parts: list[str], class_map: dict[str, int]) -> tuple[float, ...] | None:
|
||
|
|
if len(parts) < 7 or parts[0] not in class_map:
|
||
|
|
return None
|
||
|
|
coords = [float(x) for x in parts[1:5]]
|
||
|
|
difficulty = float(parts[5]) + float(parts[6])
|
||
|
|
row = np.array([class_map[parts[0]], *coords, difficulty], dtype=np.float32)
|
||
|
|
return tuple(float(x) for x in row)
|
||
|
|
|
||
|
|
|
||
|
|
def deduplicate_label_file(label_path: Path, class_map: dict[str, int]) -> tuple[int, list[str], list[str]]:
|
||
|
|
"""Return duplicate count, cleaned raw lines, and warnings."""
|
||
|
|
if not label_path.exists():
|
||
|
|
return 0, [], []
|
||
|
|
|
||
|
|
raw_lines = read_lines(label_path)
|
||
|
|
seen: OrderedDict[tuple[float, ...], int] = OrderedDict()
|
||
|
|
keep = [True] * len(raw_lines)
|
||
|
|
warnings = []
|
||
|
|
|
||
|
|
for idx, line in enumerate(raw_lines):
|
||
|
|
stripped = line.strip()
|
||
|
|
if not stripped:
|
||
|
|
continue
|
||
|
|
parts = stripped.split()
|
||
|
|
try:
|
||
|
|
key = parse_known_label_key(parts, class_map)
|
||
|
|
except ValueError as exc:
|
||
|
|
warnings.append(f"{label_path}:{idx + 1}: cannot parse label ({exc})")
|
||
|
|
continue
|
||
|
|
if key is None:
|
||
|
|
continue
|
||
|
|
if key in seen:
|
||
|
|
keep[idx] = False
|
||
|
|
else:
|
||
|
|
seen[key] = idx
|
||
|
|
|
||
|
|
cleaned = [line for line, should_keep in zip(raw_lines, keep) if should_keep]
|
||
|
|
return len(raw_lines) - len(cleaned), cleaned, warnings
|
||
|
|
|
||
|
|
|
||
|
|
def exif_size(img: Image.Image) -> tuple[int, int]:
|
||
|
|
size = img.size
|
||
|
|
if img.format == "JPEG":
|
||
|
|
try:
|
||
|
|
if exif := img.getexif():
|
||
|
|
if exif.get(274, None) in {6, 8}:
|
||
|
|
size = size[1], size[0]
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
return size
|
||
|
|
|
||
|
|
|
||
|
|
def image_corruption_reason(image_path: Path) -> str | None:
|
||
|
|
try:
|
||
|
|
with Image.open(image_path) as im:
|
||
|
|
im.verify()
|
||
|
|
shape = exif_size(im)
|
||
|
|
if shape[0] <= 9 or shape[1] <= 9:
|
||
|
|
return f"image size {shape} <10 pixels"
|
||
|
|
if (im.format or "").lower() not in IMG_FORMATS:
|
||
|
|
return f"invalid image format {im.format}"
|
||
|
|
if image_path.suffix.lower() in {".jpg", ".jpeg"}:
|
||
|
|
with image_path.open("rb") as f:
|
||
|
|
f.seek(-2, 2)
|
||
|
|
if f.read() != b"\xff\xd9":
|
||
|
|
return "corrupt JPEG end marker"
|
||
|
|
except Exception as exc:
|
||
|
|
return str(exc)
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def unlink_if_exists(path: Path) -> bool:
|
||
|
|
if path.exists():
|
||
|
|
path.unlink()
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def progress_iter(items: list[str], desc: str, interval: int):
|
||
|
|
total = len(items)
|
||
|
|
if tqdm is not None:
|
||
|
|
yield from tqdm(items, total=total, desc=desc, unit="entry")
|
||
|
|
return
|
||
|
|
|
||
|
|
for idx, item in enumerate(items, 1):
|
||
|
|
if interval > 0 and (idx == 1 or idx % interval == 0 or idx == total):
|
||
|
|
print(f"{desc}: {idx}/{total}", flush=True)
|
||
|
|
yield item
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> int:
|
||
|
|
args = parse_args()
|
||
|
|
data_file = Path(args.data)
|
||
|
|
data = load_ground_yaml(data_file)
|
||
|
|
data_dir = data_file.parent
|
||
|
|
image_root = Path(data["path"])
|
||
|
|
class_map = data.get("class_map", {})
|
||
|
|
if not class_map:
|
||
|
|
raise SystemExit(f"{data_file} does not contain class_map; this cleaner is for ground 2D labels.")
|
||
|
|
|
||
|
|
update_lists = args.apply and not args.no_update_lists
|
||
|
|
backup = not args.no_backup
|
||
|
|
report_lines = []
|
||
|
|
label_files_seen: set[Path] = set()
|
||
|
|
stats = {
|
||
|
|
"images": 0,
|
||
|
|
"labels": 0,
|
||
|
|
"duplicate_rows": 0,
|
||
|
|
"labels_rewritten": 0,
|
||
|
|
"corrupt_images": 0,
|
||
|
|
"images_removed": 0,
|
||
|
|
"labels_removed": 0,
|
||
|
|
"list_entries_removed": 0,
|
||
|
|
}
|
||
|
|
|
||
|
|
for split in args.splits:
|
||
|
|
split_path = resolve_split_path(data.get(split), data_dir)
|
||
|
|
if split_path is None:
|
||
|
|
continue
|
||
|
|
if not split_path.exists():
|
||
|
|
report_lines.append(f"[missing split] {split}: {split_path}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
entries = split_entries(split_path)
|
||
|
|
kept_entries = []
|
||
|
|
removed_from_split = 0
|
||
|
|
|
||
|
|
for entry in progress_iter(entries, f"Scanning {split}", args.progress_interval):
|
||
|
|
expanded_entry = expand_split_entry(entry, split_path)
|
||
|
|
suffix = Path(expanded_entry).suffix.lower().lstrip(".")
|
||
|
|
if suffix in IMG_FORMATS:
|
||
|
|
image_path = ground_image_path_from_label_path(expanded_entry, image_root, label_path=False)
|
||
|
|
label_path = label_path_for_entry(expanded_entry)
|
||
|
|
elif suffix == "txt":
|
||
|
|
label_path = Path(expanded_entry)
|
||
|
|
image_path = ground_image_path_from_label_path(expanded_entry, image_root, label_path=True)
|
||
|
|
else:
|
||
|
|
report_lines.append(f"[unsupported entry] {split}: {entry}")
|
||
|
|
kept_entries.append(entry)
|
||
|
|
continue
|
||
|
|
|
||
|
|
stats["images"] += 1
|
||
|
|
reason = image_corruption_reason(image_path)
|
||
|
|
if reason:
|
||
|
|
stats["corrupt_images"] += 1
|
||
|
|
removed_from_split += 1
|
||
|
|
report_lines.append(f"[corrupt] {image_path} | {reason}")
|
||
|
|
report_lines.append(f" label: {label_path}")
|
||
|
|
if args.apply:
|
||
|
|
if unlink_if_exists(image_path):
|
||
|
|
stats["images_removed"] += 1
|
||
|
|
if unlink_if_exists(label_path):
|
||
|
|
stats["labels_removed"] += 1
|
||
|
|
continue
|
||
|
|
|
||
|
|
kept_entries.append(entry)
|
||
|
|
if label_path in label_files_seen:
|
||
|
|
continue
|
||
|
|
label_files_seen.add(label_path)
|
||
|
|
if label_path.exists():
|
||
|
|
stats["labels"] += 1
|
||
|
|
dup_count, cleaned_lines, warnings = deduplicate_label_file(label_path, class_map)
|
||
|
|
report_lines.extend(f"[warning] {warning}" for warning in warnings)
|
||
|
|
if dup_count:
|
||
|
|
stats["duplicate_rows"] += dup_count
|
||
|
|
report_lines.append(f"[duplicates] {label_path}: remove {dup_count}")
|
||
|
|
if args.apply:
|
||
|
|
write_lines(label_path, cleaned_lines, backup=backup)
|
||
|
|
stats["labels_rewritten"] += 1
|
||
|
|
|
||
|
|
if update_lists and removed_from_split:
|
||
|
|
write_lines(split_path, kept_entries, backup=backup)
|
||
|
|
stats["list_entries_removed"] += removed_from_split
|
||
|
|
report_lines.append(f"[list updated] {split_path}: removed {removed_from_split} corrupt entries")
|
||
|
|
|
||
|
|
mode = "APPLIED" if args.apply else "DRY RUN"
|
||
|
|
summary = [
|
||
|
|
f"Mode: {mode}",
|
||
|
|
f"Images checked: {stats['images']}",
|
||
|
|
f"Label files checked: {stats['labels']}",
|
||
|
|
f"Duplicate label rows found: {stats['duplicate_rows']}",
|
||
|
|
f"Label files rewritten: {stats['labels_rewritten']}",
|
||
|
|
f"Corrupt images found: {stats['corrupt_images']}",
|
||
|
|
f"Images removed: {stats['images_removed']}",
|
||
|
|
f"Labels removed with corrupt images: {stats['labels_removed']}",
|
||
|
|
f"Split list entries removed: {stats['list_entries_removed']}",
|
||
|
|
]
|
||
|
|
|
||
|
|
output = "\n".join(summary + (["", *report_lines] if report_lines else []))
|
||
|
|
print(output)
|
||
|
|
if args.report:
|
||
|
|
report_path = Path(args.report)
|
||
|
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
report_path.write_text(output + "\n", encoding="utf-8")
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
raise SystemExit(main())
|