Files
HSAP/datasets/dms/scripts/stratified_split.py
Chengfang Lu 7c43b44c57 feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code,
Docker Compose setup, and vendored dataset scaffolds for clone-and-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 16:59:59 +08:00

505 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
按类别分层划分数据集,避免仅按总量随机切分导致 train/val 类别比例失衡。
YOLO 检测:先按「图像所含类别中最稀有类」决定归属,再对各类别依次划分 val。
分类(文件夹按类):每个类别目录内独立划分 train/val或 train/test
用法示例:
# 预览 DDAW 重划分效果(合并现有 train+val 后重分)
python stratified_split.py yolo --root ../gyp/ddaw_1124 --val-ratio 0.1 --dry-run
# 执行划分(会移动 images/labels 下文件)
python stratified_split.py yolo --root ../gyp/ddaw_1124 --val-ratio 0.1 --seed 42
# 分类数据:从 train 按类划出 val
python stratified_split.py classify --root ../gyp/isa_class_0116 --val-ratio 0.1 --dry-run
"""
from __future__ import annotations
import argparse
import random
import shutil
from collections import Counter, defaultdict
from pathlib import Path
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".JPG", ".JPEG", ".PNG"}
def _read_yolo_classes(label_path: Path) -> set[int]:
if not label_path.is_file():
return set()
classes: set[int] = set()
for line in label_path.read_text(encoding="utf-8", errors="ignore").splitlines():
line = line.strip()
if not line:
continue
try:
classes.add(int(line.split()[0]))
except (ValueError, IndexError):
continue
return classes
def _find_image(images_dir: Path, stem: str) -> Path | None:
for ext in IMG_EXTS:
p = images_dir / f"{stem}{ext}"
if p.is_file():
return p
return None
def collect_yolo_samples(root: Path, splits: tuple[str, ...]) -> list[tuple[str, set[int]]]:
samples: list[tuple[str, set[int]]] = []
seen: set[str] = set()
for split in splits:
labels_dir = root / "labels" / split
if not labels_dir.is_dir():
continue
for label_path in labels_dir.glob("*.txt"):
stem = label_path.stem
if stem in seen:
continue
seen.add(stem)
classes = _read_yolo_classes(label_path)
samples.append((stem, classes))
return samples
def val_count_for_class(
n: int,
val_ratio: float,
min_val_per_class: int,
min_train_per_class: int,
rare_class_train_floor: int,
) -> int:
"""该类未分配样本数为 n 时,划入 val 的数量(其余进 train"""
if n <= 0:
return 0
if n <= rare_class_train_floor:
if n <= min_train_per_class:
return 0
return min(min_val_per_class, n - min_train_per_class)
n_val = int(round(n * val_ratio))
if min_val_per_class > 0:
n_val = max(min_val_per_class, n_val)
if min_train_per_class > 0 and n > min_train_per_class:
n_val = min(n_val, n - min_train_per_class)
return max(0, min(n_val, n))
def stratified_assign(
samples: list[tuple[str, set[int]]],
val_ratio: float,
seed: int,
min_val_per_class: int = 1,
min_train_per_class: int = 1,
rare_class_train_floor: int = 5,
) -> dict[str, str]:
"""按类别分层:从稀有类到常见类,为含该类的未分配图像划分 train/val。"""
rng = random.Random(seed)
class_to_stems: dict[int, list[str]] = defaultdict(list)
stem_to_classes: dict[str, set[int]] = {}
for stem, classes in samples:
stem_to_classes[stem] = classes
for c in classes:
class_to_stems[c].append(stem)
no_label = [s for s, c in samples if not c]
assignment: dict[str, str] = {}
classes_sorted = sorted(class_to_stems.keys(), key=lambda c: len(set(class_to_stems[c])))
for c in classes_sorted:
stems = list(dict.fromkeys(class_to_stems[c]))
unassigned = [s for s in stems if s not in assignment]
if not unassigned:
continue
rng.shuffle(unassigned)
n = len(unassigned)
n_val = val_count_for_class(
n, val_ratio, min_val_per_class, min_train_per_class, rare_class_train_floor,
)
for s in unassigned[:n_val]:
assignment[s] = "val"
for s in unassigned[n_val:]:
assignment[s] = "train"
for stem in no_label:
assignment.setdefault(stem, "train")
for stem, _ in samples:
assignment.setdefault(stem, "train")
return assignment
def yolo_class_stats(root: Path, split: str) -> tuple[int, Counter, Counter]:
labels_dir = root / "labels" / split
if not labels_dir.is_dir():
return 0, Counter(), Counter()
inst = Counter()
imgs = Counter()
n_img = 0
for label_path in labels_dir.glob("*.txt"):
n_img += 1
cls_in_img: set[int] = set()
for line in label_path.read_text(encoding="utf-8", errors="ignore").splitlines():
if not line.strip():
continue
try:
c = int(line.split()[0])
except (ValueError, IndexError):
continue
inst[c] += 1
cls_in_img.add(c)
for c in cls_in_img:
imgs[c] += 1
return n_img, inst, imgs
def print_yolo_stats(root: Path, title: str) -> None:
print(f"\n=== {title} ===")
for split in ("train", "val"):
n_img, inst, imgs = yolo_class_stats(root, split)
if n_img == 0:
continue
print(f" [{split}] {n_img} images")
all_cls = sorted(set(inst) | set(imgs))
for c in all_cls:
ratio = imgs[c] / n_img * 100 if n_img else 0
print(
f" cls {c}: instances={inst[c]}, images={imgs[c]} "
f"({imgs[c]}/{n_img}={ratio:.1f}% of split images)"
)
def apply_yolo_split(
root: Path,
assignment: dict[str, str],
pool_splits: tuple[str, ...] = ("train", "val"),
dry_run: bool = False,
) -> None:
"""根据 assignment 将图像与标签移动到 images/{train,val}、labels/{train,val}。"""
for split in ("train", "val"):
(root / "images" / split).mkdir(parents=True, exist_ok=True)
(root / "labels" / split).mkdir(parents=True, exist_ok=True)
# stem -> (image_path, label_path)
located: dict[str, tuple[Path | None, Path | None]] = {}
for split in pool_splits:
labels_dir = root / "labels" / split
images_dir = root / "images" / split
if not labels_dir.is_dir():
continue
for label_path in labels_dir.glob("*.txt"):
stem = label_path.stem
if stem in located:
continue
img = _find_image(images_dir, stem) if images_dir.is_dir() else None
located[stem] = (img, label_path)
moves: list[tuple[Path, Path]] = []
for stem, target_split in assignment.items():
img_src, lab_src = located.get(stem, (None, None))
if lab_src is None:
continue
lab_dst = root / "labels" / target_split / lab_src.name
if lab_src.resolve() != lab_dst.resolve():
moves.append((lab_src, lab_dst))
if img_src is not None:
img_dst = root / "images" / target_split / img_src.name
if img_src.resolve() != img_dst.resolve():
moves.append((img_src, img_dst))
print(f" planned moves: {len(moves)}")
if dry_run:
return
for src, dst in moves:
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists():
dst.unlink()
shutil.move(str(src), str(dst))
def cmd_yolo(args: argparse.Namespace) -> None:
root = Path(args.root).resolve()
if not (root / "images").is_dir():
raise SystemExit(f"not a YOLO dataset root (missing images/): {root}")
pool_splits = tuple(s.strip() for s in args.pool_splits.split(","))
samples = collect_yolo_samples(root, pool_splits)
print(f"pool: {root} samples={len(samples)} val_ratio={args.val_ratio} seed={args.seed}")
print_yolo_stats(root, "before")
assignment = stratified_assign(
samples,
val_ratio=args.val_ratio,
seed=args.seed,
min_val_per_class=args.min_val_per_class,
min_train_per_class=args.min_train_per_class,
rare_class_train_floor=args.rare_class_train_floor,
)
n_val = sum(1 for v in assignment.values() if v == "val")
print(f"\nplanned: train={len(assignment) - n_val} val={n_val}")
# 模拟统计(不写盘)
if args.dry_run:
tmp_counts: dict[str, Counter] = {"train": Counter(), "val": Counter()}
tmp_imgs: dict[str, Counter] = {"train": Counter(), "val": Counter()}
for stem, split in assignment.items():
for split_name in pool_splits:
lab = root / "labels" / split_name / f"{stem}.txt"
if lab.is_file():
classes = _read_yolo_classes(lab)
break
else:
classes = set()
for c in classes:
tmp_imgs[split][c] += 1
for split_name in pool_splits:
lab = root / "labels" / split_name / f"{stem}.txt"
if not lab.is_file():
continue
for line in lab.read_text(encoding="utf-8", errors="ignore").splitlines():
if line.strip():
try:
tmp_counts[split][int(line.split()[0])] += 1
except (ValueError, IndexError):
pass
break
print("\n=== after (simulated) ===")
for sp in ("train", "val"):
n = sum(1 for v in assignment.values() if v == sp)
print(f" [{sp}] {n} images")
for c in sorted(set(tmp_counts[sp]) | set(tmp_imgs[sp])):
print(f" cls {c}: instances={tmp_counts[sp][c]}, images={tmp_imgs[sp][c]}")
print("\n=== per-class val ratio (images with class / all images with class) ===")
print(f" {'cls':>4} {'before':>8} {'after':>8} {'target':>8}")
before_val: Counter[int] = Counter()
before_tot: Counter[int] = Counter()
for split in pool_splits:
_, _, imgs = yolo_class_stats(root, split)
if split == "val":
before_val.update(imgs)
before_tot.update(imgs)
after_tot = Counter()
after_val = Counter()
for stem, split in assignment.items():
for split_name in pool_splits:
lab = root / "labels" / split_name / f"{stem}.txt"
if lab.is_file():
classes = _read_yolo_classes(lab)
break
else:
classes = set()
for c in classes:
after_tot[c] += 1
if split == "val":
after_val[c] += 1
for c in sorted(set(before_tot) | set(after_tot)):
b = before_val[c] / before_tot[c] * 100 if before_tot[c] else 0
a = after_val[c] / after_tot[c] * 100 if after_tot[c] else 0
print(f" {c:4d} {b:7.1f}% {a:7.1f}% {args.val_ratio * 100:7.1f}%")
return
apply_yolo_split(root, assignment, pool_splits=pool_splits, dry_run=False)
print_yolo_stats(root, "after")
def stratified_assign_classify(
class_dirs: list[Path],
val_ratio: float,
seed: int,
min_val_per_class: int,
min_train_per_class: int,
rare_class_train_floor: int,
) -> dict[Path, str]:
"""每个类别目录内独立划分。"""
rng = random.Random(seed)
assignment: dict[Path, str] = {}
for class_dir in sorted(class_dirs):
files = [p for p in class_dir.iterdir() if p.is_file() and p.suffix in IMG_EXTS]
rng.shuffle(files)
n = len(files)
if n == 0:
continue
n_val = val_count_for_class(
n, val_ratio, min_val_per_class, min_train_per_class, rare_class_train_floor,
)
for p in files[:n_val]:
assignment[p] = "val"
for p in files[n_val:]:
assignment[p] = "train"
return assignment
def resplit_classify_root(
root: Path,
val_ratio: float = 0.1,
seed: int = 42,
min_val_per_class: int = 1,
min_train_per_class: int = 1,
rare_class_train_floor: int = 5,
dry_run: bool = False,
) -> dict[str, int]:
"""合并 train+val 按类重分 val保留 test 不动。"""
pooled: dict[str, list[Path]] = defaultdict(list)
for split in ("train", "val"):
sp = root / split
if not sp.is_dir():
continue
for cls_dir in sp.iterdir():
if not cls_dir.is_dir():
continue
for f in cls_dir.iterdir():
if f.is_file() and f.suffix in IMG_EXTS:
pooled[cls_dir.name].append(f)
staging = root / "_resplit_staging"
if staging.exists() and not dry_run:
shutil.rmtree(staging)
staged_dirs: list[Path] = []
for cls, files in sorted(pooled.items()):
seen: dict[str, Path] = {}
for f in files:
seen[f.name] = f
if not seen:
continue
cls_staging = staging / cls
if not dry_run:
cls_staging.mkdir(parents=True, exist_ok=True)
for name, f in seen.items():
dst = cls_staging / name
if dry_run:
staged_dirs.append(cls_staging)
continue
if f.resolve() != dst.resolve():
shutil.move(str(f), str(dst))
if not dry_run:
staged_dirs.append(cls_staging)
if dry_run:
n_tr = n_va = 0
for cls, files in pooled.items():
n = len({f.name for f in files})
n_val = val_count_for_class(
n, val_ratio, min_val_per_class, min_train_per_class, rare_class_train_floor,
)
n_va += n_val
n_tr += n - n_val
return {"train": n_tr, "val": n_va, "dry_run": True}
assignment = stratified_assign_classify(
staged_dirs, val_ratio, seed, min_val_per_class, min_train_per_class, rare_class_train_floor,
)
(root / "train").mkdir(exist_ok=True)
(root / "val").mkdir(exist_ok=True)
n_val = 0
for src_path, sp in assignment.items():
dst = root / sp / src_path.parent.name / src_path.name
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists():
dst.unlink()
shutil.move(str(src_path), str(dst))
if sp == "val":
n_val += 1
if staging.exists():
shutil.rmtree(staging, ignore_errors=True)
n_train = sum(len(list((root / "train" / c).iterdir())) for c in pooled if (root / "train" / c).is_dir())
return {"train": n_train, "val": n_val}
def cmd_classify(args: argparse.Namespace) -> None:
root = Path(args.root).resolve()
src_split = args.src_split
src_dir = root / src_split
if not src_dir.is_dir():
raise SystemExit(f"missing source split dir: {src_dir}")
class_dirs = [d for d in src_dir.iterdir() if d.is_dir()]
files_all = [p for d in class_dirs for p in d.iterdir() if p.is_file() and p.suffix in IMG_EXTS]
print(f"classify: {root} classes={len(class_dirs)} images={len(files_all)}")
assignment = stratified_assign_classify(
class_dirs,
args.val_ratio,
args.seed,
args.min_val_per_class,
args.min_train_per_class,
args.rare_class_train_floor,
)
n_val = sum(1 for v in assignment.values() if v == "val")
print(f"planned: train={len(assignment) - n_val} val={n_val}")
if args.dry_run:
per_cls: dict[str, Counter] = {"train": Counter(), "val": Counter()}
for path, sp in assignment.items():
per_cls[sp][path.parent.name] += 1
print("\n=== per-class counts (simulated) ===")
for cls_name in sorted({p.parent.name for p in assignment}):
tr = per_cls["train"][cls_name]
va = per_cls["val"][cls_name]
tot = tr + va
pct = va / tot * 100 if tot else 0
print(f" {cls_name}: train={tr} val={va} (val%={pct:.1f})")
return
for target in ("train", "val"):
(root / target).mkdir(parents=True, exist_ok=True)
moves = 0
for src_path, target_split in assignment.items():
cls_name = src_path.parent.name
dst_dir = root / target_split / cls_name
dst_dir.mkdir(parents=True, exist_ok=True)
dst = dst_dir / src_path.name
if src_path.resolve() == dst.resolve():
continue
if dst.exists():
dst.unlink()
shutil.move(str(src_path), str(dst))
moves += 1
print(f"done, moved {moves} files into train/val")
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="按类别分层划分 DMS 数据集")
sub = p.add_subparsers(dest="mode", required=True)
py = sub.add_parser("yolo", help="YOLO 检测images/labels 的 train+val")
py.add_argument("--root", required=True, help="数据集根目录,含 images/ labels/")
py.add_argument("--val-ratio", type=float, default=0.1)
py.add_argument("--seed", type=int, default=42)
py.add_argument("--pool-splits", default="train,val", help="合并哪些 split 后重分")
py.add_argument("--min-val-per-class", type=int, default=1)
py.add_argument("--min-train-per-class", type=int, default=1)
py.add_argument("--rare-class-train-floor", type=int, default=5)
py.add_argument("--dry-run", action="store_true")
py.set_defaults(func=cmd_yolo)
pc = sub.add_parser("classify", help="分类:每类文件夹内独立划分")
pc.add_argument("--root", required=True)
pc.add_argument("--src-split", default="train", help="从哪个目录按类采样(如 train")
pc.add_argument("--val-ratio", type=float, default=0.1)
pc.add_argument("--seed", type=int, default=42)
pc.add_argument("--min-val-per-class", type=int, default=1)
pc.add_argument("--min-train-per-class", type=int, default=1)
pc.add_argument("--rare-class-train-floor", type=int, default=5)
pc.add_argument("--dry-run", action="store_true")
pc.set_defaults(func=cmd_classify)
return p
def main() -> None:
args = build_parser().parse_args()
args.func(args)
if __name__ == "__main__":
main()