Files
yolov26_3d/tools/test_mono2d_difficulty.py
2026-06-24 09:35:46 +08:00

528 lines
21 KiB
Python
Executable File

#!/usr/bin/env python3
"""Run Ground 2D detection inference and visualize predicted difficulty.
Predicted difficulty class 1 boxes are drawn with dashed lines. Difficulty class 0 boxes
are drawn with solid lines.
"""
from __future__ import annotations
import argparse
import csv
from pathlib import Path
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from ultralytics.data.augment import LetterBox
from ultralytics.data.utils import IMG_FORMATS
from ultralytics.utils import YAML
from ultralytics.utils.nms import TorchNMS
from ultralytics.utils.ops import scale_boxes
from ultralytics.utils.torch_utils import select_device
DEFAULT_WEIGHTS = "runs/detect/train_mono2d_20260429/weights/best.pt"
DEFAULT_DATA = "ultralytics/cfg/datasets/mono2d_ground.yaml"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--weights", default=DEFAULT_WEIGHTS, help="Path to Ground 2D checkpoint.")
parser.add_argument("--data", default=DEFAULT_DATA, help="Dataset YAML used to resolve val entries.")
parser.add_argument("--split", default="val", choices=("train", "val", "test"), help="Dataset split to run.")
parser.add_argument("--clip-id", default="", help="Substring filter for clip/image/label paths within the split.")
parser.add_argument(
"--source",
default="",
help="Optional image, directory, text file, or video source. Overrides --split.",
)
parser.add_argument("--imgsz", type=int, default=704, help="Square inference image size.")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.")
parser.add_argument("--iou", type=float, default=0.7, help="Class-aware NMS IoU threshold.")
parser.add_argument("--max-det", type=int, default=300, help="Maximum detections per image after NMS.")
parser.add_argument("--diff-thres", type=float, default=0.7, help="Sigmoid threshold for one-logit difficulty output.")
parser.add_argument("--eval-iou", type=float, default=0.5, help="IoU threshold for GT matching in difficulty accuracy.")
parser.add_argument(
"--gt-diff-combine",
default="first",
choices=("first", "sum", "max"),
help="How to combine two 2D label difficulty flags. Ground 2D dataloader uses 'first'.",
)
parser.add_argument("--device", default="0", help="Device, e.g. 0 or cpu.")
parser.add_argument("--output", default="runs/detect/test_mono2d_difficulty", help="Directory for annotated images.")
parser.add_argument("--limit", type=int, default=0, help="Optional maximum number of images to process.")
parser.add_argument("--video-stride", type=int, default=1, help="Process every Nth frame for video sources.")
return parser.parse_args()
def read_lines(path: Path) -> list[str]:
with open(path, encoding="utf-8") as f:
return [line.strip() for line in f.read().splitlines() if line.strip()]
def resolve_relative_entry(entry: str, list_file: Path | None = None) -> Path:
path = Path(entry)
if path.is_absolute():
return path
if list_file is not None and entry.startswith("./"):
return list_file.parent / entry[2:]
return path
def image_path_from_ground_entry(entry: Path, image_root: Path, label_path: bool) -> Path:
"""Map a Ground 2D list entry to its image path following YOLOGroundDataset rules."""
parts = entry.parts
detection_root_index = next(
(i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))),
None,
)
if detection_root_index is not None and detection_root_index + 1 < len(parts):
rel_parts = list(parts[detection_root_index + 1 :])
else:
rel_parts = list(entry.parts if not entry.is_absolute() else entry.parts[1:])
image_path = image_root.joinpath(*rel_parts)
if not label_path:
return image_path
image_path = image_path.with_suffix(".jpg")
if image_path.exists():
return image_path
image_rel_parts = ["images" if part == "labels" else part for part in rel_parts]
image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg")
if image_path.exists():
return image_path
for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"):
candidate = image_path.with_suffix(f".{suffix}")
if candidate.exists():
return candidate
return image_path
def label_path_from_image_entry(entry: Path) -> Path:
parts = ["labels" if part == "images" else part for part in entry.parts]
return Path(*parts).with_suffix(".txt")
def collect_source_images(source: str) -> list[Path]:
source_path = Path(source)
if source_path.is_dir():
images = []
for suffix in IMG_FORMATS:
images.extend(source_path.rglob(f"*.{suffix}"))
return sorted(images)
if source_path.suffix.lower().lstrip(".") == "txt":
return [resolve_relative_entry(line, source_path) for line in read_lines(source_path)]
return [source_path]
def is_video_path(path: Path) -> bool:
return path.suffix.lower().lstrip(".") in {"mp4", "avi", "mov", "mkv", "webm", "m4v", "mpg", "mpeg"}
def collect_dataset_images(data_yaml: str, split: str, clip_id: str = "") -> list[Path]:
return [sample[0] for sample in collect_dataset_samples(data_yaml, split, clip_id)]
def collect_dataset_samples(data_yaml: str, split: str, clip_id: str = "") -> list[tuple[Path, Path]]:
data = YAML.load(data_yaml)
image_root = Path(data["path"])
split_value = data.get(split)
if not split_value:
raise ValueError(f"Dataset YAML has no '{split}' split: {data_yaml}")
items = split_value if isinstance(split_value, list) else [split_value]
samples: list[tuple[Path, Path]] = []
for item in items:
item_path = Path(item)
entries = read_lines(item_path) if item_path.is_file() else [str(item_path)]
for entry_text in entries:
entry = resolve_relative_entry(entry_text, item_path if item_path.is_file() else None)
suffix = entry.suffix.lower().lstrip(".")
if suffix in IMG_FORMATS:
image_path = image_path_from_ground_entry(entry, image_root, label_path=False)
label_path = label_path_from_image_entry(entry)
elif suffix == "txt":
image_path = image_path_from_ground_entry(entry, image_root, label_path=True)
label_path = entry
else:
continue
haystack = f"{entry_text} {entry} {image_path} {label_path}"
if clip_id and clip_id not in haystack:
continue
samples.append((image_path, label_path))
return samples
def class_names_from_model(model) -> dict[int, str]:
names = getattr(model, "names", None)
if isinstance(names, dict):
return {int(k): str(v) for k, v in names.items()}
if isinstance(names, (list, tuple)):
return {i: str(v) for i, v in enumerate(names)}
return {}
def class_map_from_data(data_yaml: str) -> dict[str, int]:
data = YAML.load(data_yaml)
return {str(k): int(v) for k, v in data.get("class_map", {}).items()}
def preprocess_image(img: np.ndarray, imgsz: int, stride: int, device: torch.device) -> tuple[torch.Tensor, np.ndarray]:
letterbox = LetterBox(new_shape=(imgsz, imgsz), auto=False, stride=stride)
img_resized = letterbox(image=img)
tensor = torch.from_numpy(np.ascontiguousarray(img_resized.transpose(2, 0, 1))).to(device)
return tensor.float().unsqueeze(0) / 256.0, img_resized
def difficulty_from_logits(logits: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
"""Return binary difficulty labels and class-specific confidences for selected predictions."""
if logits.shape[-1] == 1:
prob = logits.squeeze(-1).sigmoid()
cls = (prob >= threshold).long()
conf = torch.where(cls.bool(), prob, 1.0 - prob)
return cls, conf
if logits.shape[-1] == 2:
probs = logits.softmax(-1)
cls = logits.argmax(-1).long()
return cls, probs.gather(1, cls[:, None]).squeeze(1)
raise RuntimeError(f"Expected difficulty logits with 1 or 2 channels, got shape {tuple(logits.shape)}")
def predict_one(model, img: np.ndarray, args: argparse.Namespace, device: torch.device, stride: int):
im, resized = preprocess_image(img, args.imgsz, stride, device)
with torch.inference_mode():
output = model(im)
if not isinstance(output, tuple) or len(output) < 2:
raise RuntimeError("Model output did not include raw prediction metadata with preds_diff.")
det, pred_dict = output
one2one = pred_dict.get("one2one", pred_dict)
diff_selected = one2one.get("preds_diff_selected")
if diff_selected is None:
raise RuntimeError(
"Model output has no preds_diff_selected. Make sure the checkpoint was trained with the difficulty head."
)
det = det[0]
diff_selected = diff_selected[0]
keep_conf = det[:, 4] > args.conf
det = det[keep_conf]
diff_selected = diff_selected[keep_conf]
if det.numel() == 0:
return np.zeros((0, 4), dtype=np.float32), [], [], [], resized
keep_nms = TorchNMS.batched_nms(det[:, :4], det[:, 4], det[:, 5], args.iou)[: args.max_det]
det = det[keep_nms]
diff_selected = diff_selected[keep_nms]
boxes = det[:, :4].clone()
boxes = scale_boxes(resized.shape[:2], boxes, img.shape[:2]).round().cpu().numpy()
scores = det[:, 4].cpu().numpy().tolist()
classes = det[:, 5].long().cpu().numpy().tolist()
diff_cls, diff_prob = difficulty_from_logits(diff_selected, args.diff_thres)
return boxes, scores, classes, list(zip(diff_cls.cpu().numpy().tolist(), diff_prob.cpu().numpy().tolist())), resized
def parse_gt_label(
label_path: Path,
class_map: dict[str, int],
img_shape: tuple[int, int],
diff_combine: str = "sum",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Parse Ground 2D labels into xyxy boxes, class ids, and binary difficulty targets."""
if not label_path.exists():
return np.zeros((0, 4), dtype=np.float32), np.zeros(0, dtype=np.int64), np.zeros(0, dtype=np.int64)
img_h, img_w = img_shape
boxes, classes, diffs = [], [], []
for line in read_lines(label_path):
parts = line.split()
if len(parts) < 6:
continue
cls_id = class_map.get(parts[0])
if cls_id is None:
continue
x, y, w, h = [float(v) for v in parts[1:5]]
if len(parts) >= 7:
# Match YOLOGroundDataset: diff2 is truncation, so use only diff1 by default.
if diff_combine == "first":
raw_diff = float(parts[5])
elif diff_combine == "sum":
raw_diff = float(parts[5]) + float(parts[6])
else:
raw_diff = max(float(parts[5]), float(parts[6]))
else:
raw_diff = float(parts[5])
x1 = (x - w * 0.5) * img_w
y1 = (y - h * 0.5) * img_h
x2 = (x + w * 0.5) * img_w
y2 = (y + h * 0.5) * img_h
boxes.append([x1, y1, x2, y2])
classes.append(cls_id)
diffs.append(1 if int(raw_diff) >= 2 else 0)
return np.asarray(boxes, dtype=np.float32), np.asarray(classes, dtype=np.int64), np.asarray(diffs, dtype=np.int64)
def box_iou_np(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
if len(boxes1) == 0 or len(boxes2) == 0:
return np.zeros((len(boxes1), len(boxes2)), dtype=np.float32)
lt = np.maximum(boxes1[:, None, :2], boxes2[None, :, :2])
rb = np.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:])
wh = np.clip(rb - lt, 0, None)
inter = wh[:, :, 0] * wh[:, :, 1]
area1 = np.clip(boxes1[:, 2] - boxes1[:, 0], 0, None) * np.clip(boxes1[:, 3] - boxes1[:, 1], 0, None)
area2 = np.clip(boxes2[:, 2] - boxes2[:, 0], 0, None) * np.clip(boxes2[:, 3] - boxes2[:, 1], 0, None)
return inter / np.clip(area1[:, None] + area2[None, :] - inter, 1e-9, None)
def match_difficulty_predictions(
boxes: np.ndarray,
classes: list[int],
diffs: list[tuple[int, float]],
gt_boxes: np.ndarray,
gt_classes: np.ndarray,
gt_diffs: np.ndarray,
iou_thr: float,
) -> tuple[list[int], dict[int, tuple[int, float]], int, int]:
"""Return wrong prediction indices, matched GT info by pred index, correct count, and matched count."""
if len(boxes) == 0 or len(gt_boxes) == 0:
return [], {}, 0, 0
pred_classes = np.asarray(classes, dtype=np.int64)
ious = box_iou_np(gt_boxes, boxes)
ious[gt_classes[:, None] != pred_classes[None, :]] = 0.0
pairs = np.argwhere(ious >= iou_thr)
if len(pairs) == 0:
return [], {}, 0, 0
order = np.argsort(ious[pairs[:, 0], pairs[:, 1]])[::-1]
used_gt, used_pred = set(), set()
wrong_pred_indices: list[int] = []
match_info: dict[int, tuple[int, float]] = {}
correct = matched = 0
for pair_idx in order:
gi, pi = int(pairs[pair_idx, 0]), int(pairs[pair_idx, 1])
if gi in used_gt or pi in used_pred:
continue
used_gt.add(gi)
used_pred.add(pi)
matched += 1
pred_diff = int(diffs[pi][0])
gt_diff = int(gt_diffs[gi])
match_iou = float(ious[gi, pi])
match_info[pi] = (gt_diff, match_iou)
if pred_diff == gt_diff:
correct += 1
else:
wrong_pred_indices.append(pi)
return wrong_pred_indices, match_info, correct, matched
def write_prediction_rows(
writer,
image_id: str,
boxes,
scores,
classes,
diffs,
names: dict[int, str],
frame: int = -1,
match_info: dict[int, tuple[int, float]] | None = None,
):
match_info = match_info or {}
for pi, (box, score, cls_id, (diff_cls, diff_prob)) in enumerate(zip(boxes, scores, classes, diffs)):
gt_diff, match_iou = match_info.get(pi, (-1, -1.0))
correct = int(gt_diff == diff_cls) if gt_diff >= 0 else -1
writer.writerow([
image_id,
frame,
cls_id,
names.get(cls_id, str(cls_id)),
f"{score:.6f}",
diff_cls,
f"{diff_prob:.6f}",
gt_diff,
correct,
f"{match_iou:.6f}",
*[int(v) for v in box],
])
def draw_dashed_rectangle(img: np.ndarray, p1: tuple[int, int], p2: tuple[int, int], color, thickness=2, dash=10, gap=6):
x1, y1 = p1
x2, y2 = p2
for x in range(x1, x2, dash + gap):
cv2.line(img, (x, y1), (min(x + dash, x2), y1), color, thickness)
cv2.line(img, (x, y2), (min(x + dash, x2), y2), color, thickness)
for y in range(y1, y2, dash + gap):
cv2.line(img, (x1, y), (x1, min(y + dash, y2)), color, thickness)
cv2.line(img, (x2, y), (x2, min(y + dash, y2)), color, thickness)
def draw_x_flag(img: np.ndarray, box, color=(0, 255, 255), thickness=3):
x1, y1, x2, y2 = [int(v) for v in box]
cv2.line(img, (x1, y1), (x2, y2), color, thickness)
cv2.line(img, (x1, y2), (x2, y1), color, thickness)
def draw_predictions(
img: np.ndarray,
boxes,
scores,
classes,
diffs,
names: dict[int, str],
wrong_indices: set[int] | None = None,
) -> np.ndarray:
out = img.copy()
wrong_indices = wrong_indices or set()
for pi, (box, score, cls_id, (diff_cls, diff_prob)) in enumerate(zip(boxes, scores, classes, diffs)):
x1, y1, x2, y2 = [int(v) for v in box]
color = (0, 0, 255) if diff_cls == 1 else (0, 200, 0)
if diff_cls == 1:
draw_dashed_rectangle(out, (x1, y1), (x2, y2), color, thickness=2)
else:
cv2.rectangle(out, (x1, y1), (x2, y2), color, 2)
label = f"{names.get(cls_id, str(cls_id))} {score:.2f} diff{diff_cls}:{diff_prob:.2f}"
(tw, th), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
y_text = max(y1 - th - baseline - 2, 0)
cv2.rectangle(out, (x1, y_text), (x1 + tw + 4, y_text + th + baseline + 4), color, -1)
cv2.putText(out, label, (x1 + 2, y_text + th + 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
if pi in wrong_indices:
draw_x_flag(out, box)
return out
def main() -> None:
args = parse_args()
out_dir = Path(args.output)
out_dir.mkdir(parents=True, exist_ok=True)
device = select_device(args.device)
yolo = YOLO(args.weights)
model = yolo.model.to(device).eval()
stride = int(getattr(model, "stride", torch.tensor([32])).max())
names = class_names_from_model(model)
video_source = Path(args.source) if args.source and is_video_path(Path(args.source)) else None
images = [] if video_source else collect_source_images(args.source) if args.source else collect_dataset_images(args.data, args.split, args.clip_id)
if args.limit > 0:
images = images[: args.limit]
if video_source is None and not images:
raise FileNotFoundError("No images matched the requested source/split/clip-id.")
csv_path = out_dir / "predictions.csv"
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
"source",
"frame",
"class_id",
"class_name",
"conf",
"difficulty",
"difficulty_prob",
"gt_difficulty",
"difficulty_correct",
"match_iou",
"x1",
"y1",
"x2",
"y2",
])
if video_source is not None:
cap = cv2.VideoCapture(str(video_source))
if not cap.isOpened():
raise FileNotFoundError(f"Unable to open video source: {video_source}")
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
save_path = out_dir / f"{video_source.stem}_difficulty.mp4"
writer_video = cv2.VideoWriter(
str(save_path),
cv2.VideoWriter_fourcc(*"mp4v"),
fps / max(args.video_stride, 1),
(width, height),
)
frame_idx = -1
processed = 0
while True:
ok, frame = cap.read()
if not ok:
break
frame_idx += 1
if frame_idx % max(args.video_stride, 1) != 0:
continue
boxes, scores, classes, diffs, _ = predict_one(model, frame, args, device, stride)
drawn = draw_predictions(frame, boxes, scores, classes, diffs, names)
writer_video.write(drawn)
write_prediction_rows(writer, str(video_source), boxes, scores, classes, diffs, names, frame=frame_idx)
processed += 1
print(f"[frame {frame_idx}] {len(boxes)} detections")
if args.limit > 0 and processed >= args.limit:
break
cap.release()
writer_video.release()
print(f"Saved annotated video to {save_path}")
print(f"Saved prediction CSV to {csv_path}")
return
class_map = class_map_from_data(args.data)
samples = collect_source_images(args.source) if args.source else collect_dataset_samples(args.data, args.split, args.clip_id)
if args.source:
sample_pairs = [(image_path, None) for image_path in samples]
else:
sample_pairs = samples
if args.limit > 0:
sample_pairs = sample_pairs[: args.limit]
total_correct = 0
total_matched = 0
total_wrong = 0
for idx, sample in enumerate(sample_pairs, 1):
image_path, label_path = sample
img = cv2.imread(str(image_path))
if img is None:
print(f"[{idx}/{len(sample_pairs)}] skip unreadable: {image_path}")
continue
boxes, scores, classes, diffs, _ = predict_one(model, img, args, device, stride)
wrong_indices: list[int] = []
match_info: dict[int, tuple[int, float]] = {}
if label_path is not None:
gt_boxes, gt_classes, gt_diffs = parse_gt_label(
label_path, class_map, img.shape[:2], diff_combine=args.gt_diff_combine
)
wrong_indices, match_info, correct, matched = match_difficulty_predictions(
boxes, classes, diffs, gt_boxes, gt_classes, gt_diffs, args.eval_iou
)
total_correct += correct
total_matched += matched
total_wrong += len(wrong_indices)
drawn = draw_predictions(img, boxes, scores, classes, diffs, names, wrong_indices=set(wrong_indices))
save_path = out_dir / image_path.name
cv2.imwrite(str(save_path), drawn)
write_prediction_rows(writer, str(image_path), boxes, scores, classes, diffs, names, match_info=match_info)
print(
f"[{idx}/{len(sample_pairs)}] {image_path} -> {save_path} "
f"({len(boxes)} detections, {len(wrong_indices)} wrong difficulty)"
)
if total_matched:
acc = total_correct / total_matched
print(
f"Difficulty accuracy @IoU {args.eval_iou:.2f}: "
f"{acc:.4f} ({total_correct}/{total_matched}), wrong={total_wrong}"
)
elif not args.source:
print(f"Difficulty accuracy @IoU {args.eval_iou:.2f}: no matched predictions")
print(f"Saved annotations to {out_dir}")
print(f"Saved prediction CSV to {csv_path}")
if __name__ == "__main__":
main()