#!/usr/bin/env python3 """Run Ground 2D detection inference and visualize predicted difficulty. Predicted difficulty class 1 boxes are drawn with dashed lines. Difficulty class 0 boxes are drawn with solid lines. """ from __future__ import annotations import argparse import csv from pathlib import Path import cv2 import numpy as np import torch from ultralytics import YOLO from ultralytics.data.augment import LetterBox from ultralytics.data.utils import IMG_FORMATS from ultralytics.utils import YAML from ultralytics.utils.nms import TorchNMS from ultralytics.utils.ops import scale_boxes from ultralytics.utils.torch_utils import select_device DEFAULT_WEIGHTS = "runs/detect/train_mono2d_20260429/weights/best.pt" DEFAULT_DATA = "ultralytics/cfg/datasets/mono2d_ground.yaml" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--weights", default=DEFAULT_WEIGHTS, help="Path to Ground 2D checkpoint.") parser.add_argument("--data", default=DEFAULT_DATA, help="Dataset YAML used to resolve val entries.") parser.add_argument("--split", default="val", choices=("train", "val", "test"), help="Dataset split to run.") parser.add_argument("--clip-id", default="", help="Substring filter for clip/image/label paths within the split.") parser.add_argument( "--source", default="", help="Optional image, directory, text file, or video source. Overrides --split.", ) parser.add_argument("--imgsz", type=int, default=704, help="Square inference image size.") parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.") parser.add_argument("--iou", type=float, default=0.7, help="Class-aware NMS IoU threshold.") parser.add_argument("--max-det", type=int, default=300, help="Maximum detections per image after NMS.") parser.add_argument("--diff-thres", type=float, default=0.7, help="Sigmoid threshold for one-logit difficulty output.") parser.add_argument("--eval-iou", type=float, default=0.5, help="IoU threshold for GT matching in difficulty accuracy.") parser.add_argument( "--gt-diff-combine", default="first", choices=("first", "sum", "max"), help="How to combine two 2D label difficulty flags. Ground 2D dataloader uses 'first'.", ) parser.add_argument("--device", default="0", help="Device, e.g. 0 or cpu.") parser.add_argument("--output", default="runs/detect/test_mono2d_difficulty", help="Directory for annotated images.") parser.add_argument("--limit", type=int, default=0, help="Optional maximum number of images to process.") parser.add_argument("--video-stride", type=int, default=1, help="Process every Nth frame for video sources.") return parser.parse_args() def read_lines(path: Path) -> list[str]: with open(path, encoding="utf-8") as f: return [line.strip() for line in f.read().splitlines() if line.strip()] def resolve_relative_entry(entry: str, list_file: Path | None = None) -> Path: path = Path(entry) if path.is_absolute(): return path if list_file is not None and entry.startswith("./"): return list_file.parent / entry[2:] return path def image_path_from_ground_entry(entry: Path, image_root: Path, label_path: bool) -> Path: """Map a Ground 2D list entry to its image path following YOLOGroundDataset rules.""" parts = entry.parts detection_root_index = next( (i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))), None, ) if detection_root_index is not None and detection_root_index + 1 < len(parts): rel_parts = list(parts[detection_root_index + 1 :]) else: rel_parts = list(entry.parts if not entry.is_absolute() else entry.parts[1:]) image_path = image_root.joinpath(*rel_parts) if not label_path: return image_path image_path = image_path.with_suffix(".jpg") if image_path.exists(): return image_path image_rel_parts = ["images" if part == "labels" else part for part in rel_parts] image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg") if image_path.exists(): return image_path for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"): candidate = image_path.with_suffix(f".{suffix}") if candidate.exists(): return candidate return image_path def label_path_from_image_entry(entry: Path) -> Path: parts = ["labels" if part == "images" else part for part in entry.parts] return Path(*parts).with_suffix(".txt") def collect_source_images(source: str) -> list[Path]: source_path = Path(source) if source_path.is_dir(): images = [] for suffix in IMG_FORMATS: images.extend(source_path.rglob(f"*.{suffix}")) return sorted(images) if source_path.suffix.lower().lstrip(".") == "txt": return [resolve_relative_entry(line, source_path) for line in read_lines(source_path)] return [source_path] def is_video_path(path: Path) -> bool: return path.suffix.lower().lstrip(".") in {"mp4", "avi", "mov", "mkv", "webm", "m4v", "mpg", "mpeg"} def collect_dataset_images(data_yaml: str, split: str, clip_id: str = "") -> list[Path]: return [sample[0] for sample in collect_dataset_samples(data_yaml, split, clip_id)] def collect_dataset_samples(data_yaml: str, split: str, clip_id: str = "") -> list[tuple[Path, Path]]: data = YAML.load(data_yaml) image_root = Path(data["path"]) split_value = data.get(split) if not split_value: raise ValueError(f"Dataset YAML has no '{split}' split: {data_yaml}") items = split_value if isinstance(split_value, list) else [split_value] samples: list[tuple[Path, Path]] = [] for item in items: item_path = Path(item) entries = read_lines(item_path) if item_path.is_file() else [str(item_path)] for entry_text in entries: entry = resolve_relative_entry(entry_text, item_path if item_path.is_file() else None) suffix = entry.suffix.lower().lstrip(".") if suffix in IMG_FORMATS: image_path = image_path_from_ground_entry(entry, image_root, label_path=False) label_path = label_path_from_image_entry(entry) elif suffix == "txt": image_path = image_path_from_ground_entry(entry, image_root, label_path=True) label_path = entry else: continue haystack = f"{entry_text} {entry} {image_path} {label_path}" if clip_id and clip_id not in haystack: continue samples.append((image_path, label_path)) return samples def class_names_from_model(model) -> dict[int, str]: names = getattr(model, "names", None) if isinstance(names, dict): return {int(k): str(v) for k, v in names.items()} if isinstance(names, (list, tuple)): return {i: str(v) for i, v in enumerate(names)} return {} def class_map_from_data(data_yaml: str) -> dict[str, int]: data = YAML.load(data_yaml) return {str(k): int(v) for k, v in data.get("class_map", {}).items()} def preprocess_image(img: np.ndarray, imgsz: int, stride: int, device: torch.device) -> tuple[torch.Tensor, np.ndarray]: letterbox = LetterBox(new_shape=(imgsz, imgsz), auto=False, stride=stride) img_resized = letterbox(image=img) tensor = torch.from_numpy(np.ascontiguousarray(img_resized.transpose(2, 0, 1))).to(device) return tensor.float().unsqueeze(0) / 256.0, img_resized def difficulty_from_logits(logits: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]: """Return binary difficulty labels and class-specific confidences for selected predictions.""" if logits.shape[-1] == 1: prob = logits.squeeze(-1).sigmoid() cls = (prob >= threshold).long() conf = torch.where(cls.bool(), prob, 1.0 - prob) return cls, conf if logits.shape[-1] == 2: probs = logits.softmax(-1) cls = logits.argmax(-1).long() return cls, probs.gather(1, cls[:, None]).squeeze(1) raise RuntimeError(f"Expected difficulty logits with 1 or 2 channels, got shape {tuple(logits.shape)}") def predict_one(model, img: np.ndarray, args: argparse.Namespace, device: torch.device, stride: int): im, resized = preprocess_image(img, args.imgsz, stride, device) with torch.inference_mode(): output = model(im) if not isinstance(output, tuple) or len(output) < 2: raise RuntimeError("Model output did not include raw prediction metadata with preds_diff.") det, pred_dict = output one2one = pred_dict.get("one2one", pred_dict) diff_selected = one2one.get("preds_diff_selected") if diff_selected is None: raise RuntimeError( "Model output has no preds_diff_selected. Make sure the checkpoint was trained with the difficulty head." ) det = det[0] diff_selected = diff_selected[0] keep_conf = det[:, 4] > args.conf det = det[keep_conf] diff_selected = diff_selected[keep_conf] if det.numel() == 0: return np.zeros((0, 4), dtype=np.float32), [], [], [], resized keep_nms = TorchNMS.batched_nms(det[:, :4], det[:, 4], det[:, 5], args.iou)[: args.max_det] det = det[keep_nms] diff_selected = diff_selected[keep_nms] boxes = det[:, :4].clone() boxes = scale_boxes(resized.shape[:2], boxes, img.shape[:2]).round().cpu().numpy() scores = det[:, 4].cpu().numpy().tolist() classes = det[:, 5].long().cpu().numpy().tolist() diff_cls, diff_prob = difficulty_from_logits(diff_selected, args.diff_thres) return boxes, scores, classes, list(zip(diff_cls.cpu().numpy().tolist(), diff_prob.cpu().numpy().tolist())), resized def parse_gt_label( label_path: Path, class_map: dict[str, int], img_shape: tuple[int, int], diff_combine: str = "sum", ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Parse Ground 2D labels into xyxy boxes, class ids, and binary difficulty targets.""" if not label_path.exists(): return np.zeros((0, 4), dtype=np.float32), np.zeros(0, dtype=np.int64), np.zeros(0, dtype=np.int64) img_h, img_w = img_shape boxes, classes, diffs = [], [], [] for line in read_lines(label_path): parts = line.split() if len(parts) < 6: continue cls_id = class_map.get(parts[0]) if cls_id is None: continue x, y, w, h = [float(v) for v in parts[1:5]] if len(parts) >= 7: # Match YOLOGroundDataset: diff2 is truncation, so use only diff1 by default. if diff_combine == "first": raw_diff = float(parts[5]) elif diff_combine == "sum": raw_diff = float(parts[5]) + float(parts[6]) else: raw_diff = max(float(parts[5]), float(parts[6])) else: raw_diff = float(parts[5]) x1 = (x - w * 0.5) * img_w y1 = (y - h * 0.5) * img_h x2 = (x + w * 0.5) * img_w y2 = (y + h * 0.5) * img_h boxes.append([x1, y1, x2, y2]) classes.append(cls_id) diffs.append(1 if int(raw_diff) >= 2 else 0) return np.asarray(boxes, dtype=np.float32), np.asarray(classes, dtype=np.int64), np.asarray(diffs, dtype=np.int64) def box_iou_np(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray: if len(boxes1) == 0 or len(boxes2) == 0: return np.zeros((len(boxes1), len(boxes2)), dtype=np.float32) lt = np.maximum(boxes1[:, None, :2], boxes2[None, :, :2]) rb = np.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:]) wh = np.clip(rb - lt, 0, None) inter = wh[:, :, 0] * wh[:, :, 1] area1 = np.clip(boxes1[:, 2] - boxes1[:, 0], 0, None) * np.clip(boxes1[:, 3] - boxes1[:, 1], 0, None) area2 = np.clip(boxes2[:, 2] - boxes2[:, 0], 0, None) * np.clip(boxes2[:, 3] - boxes2[:, 1], 0, None) return inter / np.clip(area1[:, None] + area2[None, :] - inter, 1e-9, None) def match_difficulty_predictions( boxes: np.ndarray, classes: list[int], diffs: list[tuple[int, float]], gt_boxes: np.ndarray, gt_classes: np.ndarray, gt_diffs: np.ndarray, iou_thr: float, ) -> tuple[list[int], dict[int, tuple[int, float]], int, int]: """Return wrong prediction indices, matched GT info by pred index, correct count, and matched count.""" if len(boxes) == 0 or len(gt_boxes) == 0: return [], {}, 0, 0 pred_classes = np.asarray(classes, dtype=np.int64) ious = box_iou_np(gt_boxes, boxes) ious[gt_classes[:, None] != pred_classes[None, :]] = 0.0 pairs = np.argwhere(ious >= iou_thr) if len(pairs) == 0: return [], {}, 0, 0 order = np.argsort(ious[pairs[:, 0], pairs[:, 1]])[::-1] used_gt, used_pred = set(), set() wrong_pred_indices: list[int] = [] match_info: dict[int, tuple[int, float]] = {} correct = matched = 0 for pair_idx in order: gi, pi = int(pairs[pair_idx, 0]), int(pairs[pair_idx, 1]) if gi in used_gt or pi in used_pred: continue used_gt.add(gi) used_pred.add(pi) matched += 1 pred_diff = int(diffs[pi][0]) gt_diff = int(gt_diffs[gi]) match_iou = float(ious[gi, pi]) match_info[pi] = (gt_diff, match_iou) if pred_diff == gt_diff: correct += 1 else: wrong_pred_indices.append(pi) return wrong_pred_indices, match_info, correct, matched def write_prediction_rows( writer, image_id: str, boxes, scores, classes, diffs, names: dict[int, str], frame: int = -1, match_info: dict[int, tuple[int, float]] | None = None, ): match_info = match_info or {} for pi, (box, score, cls_id, (diff_cls, diff_prob)) in enumerate(zip(boxes, scores, classes, diffs)): gt_diff, match_iou = match_info.get(pi, (-1, -1.0)) correct = int(gt_diff == diff_cls) if gt_diff >= 0 else -1 writer.writerow([ image_id, frame, cls_id, names.get(cls_id, str(cls_id)), f"{score:.6f}", diff_cls, f"{diff_prob:.6f}", gt_diff, correct, f"{match_iou:.6f}", *[int(v) for v in box], ]) def draw_dashed_rectangle(img: np.ndarray, p1: tuple[int, int], p2: tuple[int, int], color, thickness=2, dash=10, gap=6): x1, y1 = p1 x2, y2 = p2 for x in range(x1, x2, dash + gap): cv2.line(img, (x, y1), (min(x + dash, x2), y1), color, thickness) cv2.line(img, (x, y2), (min(x + dash, x2), y2), color, thickness) for y in range(y1, y2, dash + gap): cv2.line(img, (x1, y), (x1, min(y + dash, y2)), color, thickness) cv2.line(img, (x2, y), (x2, min(y + dash, y2)), color, thickness) def draw_x_flag(img: np.ndarray, box, color=(0, 255, 255), thickness=3): x1, y1, x2, y2 = [int(v) for v in box] cv2.line(img, (x1, y1), (x2, y2), color, thickness) cv2.line(img, (x1, y2), (x2, y1), color, thickness) def draw_predictions( img: np.ndarray, boxes, scores, classes, diffs, names: dict[int, str], wrong_indices: set[int] | None = None, ) -> np.ndarray: out = img.copy() wrong_indices = wrong_indices or set() for pi, (box, score, cls_id, (diff_cls, diff_prob)) in enumerate(zip(boxes, scores, classes, diffs)): x1, y1, x2, y2 = [int(v) for v in box] color = (0, 0, 255) if diff_cls == 1 else (0, 200, 0) if diff_cls == 1: draw_dashed_rectangle(out, (x1, y1), (x2, y2), color, thickness=2) else: cv2.rectangle(out, (x1, y1), (x2, y2), color, 2) label = f"{names.get(cls_id, str(cls_id))} {score:.2f} diff{diff_cls}:{diff_prob:.2f}" (tw, th), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) y_text = max(y1 - th - baseline - 2, 0) cv2.rectangle(out, (x1, y_text), (x1 + tw + 4, y_text + th + baseline + 4), color, -1) cv2.putText(out, label, (x1 + 2, y_text + th + 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) if pi in wrong_indices: draw_x_flag(out, box) return out def main() -> None: args = parse_args() out_dir = Path(args.output) out_dir.mkdir(parents=True, exist_ok=True) device = select_device(args.device) yolo = YOLO(args.weights) model = yolo.model.to(device).eval() stride = int(getattr(model, "stride", torch.tensor([32])).max()) names = class_names_from_model(model) video_source = Path(args.source) if args.source and is_video_path(Path(args.source)) else None images = [] if video_source else collect_source_images(args.source) if args.source else collect_dataset_images(args.data, args.split, args.clip_id) if args.limit > 0: images = images[: args.limit] if video_source is None and not images: raise FileNotFoundError("No images matched the requested source/split/clip-id.") csv_path = out_dir / "predictions.csv" with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([ "source", "frame", "class_id", "class_name", "conf", "difficulty", "difficulty_prob", "gt_difficulty", "difficulty_correct", "match_iou", "x1", "y1", "x2", "y2", ]) if video_source is not None: cap = cv2.VideoCapture(str(video_source)) if not cap.isOpened(): raise FileNotFoundError(f"Unable to open video source: {video_source}") fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) save_path = out_dir / f"{video_source.stem}_difficulty.mp4" writer_video = cv2.VideoWriter( str(save_path), cv2.VideoWriter_fourcc(*"mp4v"), fps / max(args.video_stride, 1), (width, height), ) frame_idx = -1 processed = 0 while True: ok, frame = cap.read() if not ok: break frame_idx += 1 if frame_idx % max(args.video_stride, 1) != 0: continue boxes, scores, classes, diffs, _ = predict_one(model, frame, args, device, stride) drawn = draw_predictions(frame, boxes, scores, classes, diffs, names) writer_video.write(drawn) write_prediction_rows(writer, str(video_source), boxes, scores, classes, diffs, names, frame=frame_idx) processed += 1 print(f"[frame {frame_idx}] {len(boxes)} detections") if args.limit > 0 and processed >= args.limit: break cap.release() writer_video.release() print(f"Saved annotated video to {save_path}") print(f"Saved prediction CSV to {csv_path}") return class_map = class_map_from_data(args.data) samples = collect_source_images(args.source) if args.source else collect_dataset_samples(args.data, args.split, args.clip_id) if args.source: sample_pairs = [(image_path, None) for image_path in samples] else: sample_pairs = samples if args.limit > 0: sample_pairs = sample_pairs[: args.limit] total_correct = 0 total_matched = 0 total_wrong = 0 for idx, sample in enumerate(sample_pairs, 1): image_path, label_path = sample img = cv2.imread(str(image_path)) if img is None: print(f"[{idx}/{len(sample_pairs)}] skip unreadable: {image_path}") continue boxes, scores, classes, diffs, _ = predict_one(model, img, args, device, stride) wrong_indices: list[int] = [] match_info: dict[int, tuple[int, float]] = {} if label_path is not None: gt_boxes, gt_classes, gt_diffs = parse_gt_label( label_path, class_map, img.shape[:2], diff_combine=args.gt_diff_combine ) wrong_indices, match_info, correct, matched = match_difficulty_predictions( boxes, classes, diffs, gt_boxes, gt_classes, gt_diffs, args.eval_iou ) total_correct += correct total_matched += matched total_wrong += len(wrong_indices) drawn = draw_predictions(img, boxes, scores, classes, diffs, names, wrong_indices=set(wrong_indices)) save_path = out_dir / image_path.name cv2.imwrite(str(save_path), drawn) write_prediction_rows(writer, str(image_path), boxes, scores, classes, diffs, names, match_info=match_info) print( f"[{idx}/{len(sample_pairs)}] {image_path} -> {save_path} " f"({len(boxes)} detections, {len(wrong_indices)} wrong difficulty)" ) if total_matched: acc = total_correct / total_matched print( f"Difficulty accuracy @IoU {args.eval_iou:.2f}: " f"{acc:.4f} ({total_correct}/{total_matched}), wrong={total_wrong}" ) elif not args.source: print(f"Difficulty accuracy @IoU {args.eval_iou:.2f}: no matched predictions") print(f"Saved annotations to {out_dir}") print(f"Saved prediction CSV to {csv_path}") if __name__ == "__main__": main()