528 lines
21 KiB
Python
Executable File
528 lines
21 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Run Ground 2D detection inference and visualize predicted difficulty.
|
|
|
|
Predicted difficulty class 1 boxes are drawn with dashed lines. Difficulty class 0 boxes
|
|
are drawn with solid lines.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
|
|
from ultralytics import YOLO
|
|
from ultralytics.data.augment import LetterBox
|
|
from ultralytics.data.utils import IMG_FORMATS
|
|
from ultralytics.utils import YAML
|
|
from ultralytics.utils.nms import TorchNMS
|
|
from ultralytics.utils.ops import scale_boxes
|
|
from ultralytics.utils.torch_utils import select_device
|
|
|
|
|
|
DEFAULT_WEIGHTS = "runs/detect/train_mono2d_20260429/weights/best.pt"
|
|
DEFAULT_DATA = "ultralytics/cfg/datasets/mono2d_ground.yaml"
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--weights", default=DEFAULT_WEIGHTS, help="Path to Ground 2D checkpoint.")
|
|
parser.add_argument("--data", default=DEFAULT_DATA, help="Dataset YAML used to resolve val entries.")
|
|
parser.add_argument("--split", default="val", choices=("train", "val", "test"), help="Dataset split to run.")
|
|
parser.add_argument("--clip-id", default="", help="Substring filter for clip/image/label paths within the split.")
|
|
parser.add_argument(
|
|
"--source",
|
|
default="",
|
|
help="Optional image, directory, text file, or video source. Overrides --split.",
|
|
)
|
|
parser.add_argument("--imgsz", type=int, default=704, help="Square inference image size.")
|
|
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.")
|
|
parser.add_argument("--iou", type=float, default=0.7, help="Class-aware NMS IoU threshold.")
|
|
parser.add_argument("--max-det", type=int, default=300, help="Maximum detections per image after NMS.")
|
|
parser.add_argument("--diff-thres", type=float, default=0.7, help="Sigmoid threshold for one-logit difficulty output.")
|
|
parser.add_argument("--eval-iou", type=float, default=0.5, help="IoU threshold for GT matching in difficulty accuracy.")
|
|
parser.add_argument(
|
|
"--gt-diff-combine",
|
|
default="first",
|
|
choices=("first", "sum", "max"),
|
|
help="How to combine two 2D label difficulty flags. Ground 2D dataloader uses 'first'.",
|
|
)
|
|
parser.add_argument("--device", default="0", help="Device, e.g. 0 or cpu.")
|
|
parser.add_argument("--output", default="runs/detect/test_mono2d_difficulty", help="Directory for annotated images.")
|
|
parser.add_argument("--limit", type=int, default=0, help="Optional maximum number of images to process.")
|
|
parser.add_argument("--video-stride", type=int, default=1, help="Process every Nth frame for video sources.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def read_lines(path: Path) -> list[str]:
|
|
with open(path, encoding="utf-8") as f:
|
|
return [line.strip() for line in f.read().splitlines() if line.strip()]
|
|
|
|
|
|
def resolve_relative_entry(entry: str, list_file: Path | None = None) -> Path:
|
|
path = Path(entry)
|
|
if path.is_absolute():
|
|
return path
|
|
if list_file is not None and entry.startswith("./"):
|
|
return list_file.parent / entry[2:]
|
|
return path
|
|
|
|
|
|
def image_path_from_ground_entry(entry: Path, image_root: Path, label_path: bool) -> Path:
|
|
"""Map a Ground 2D list entry to its image path following YOLOGroundDataset rules."""
|
|
parts = entry.parts
|
|
detection_root_index = next(
|
|
(i for i, part in enumerate(parts) if part.lower().startswith(("detection2d", "2ddetection"))),
|
|
None,
|
|
)
|
|
if detection_root_index is not None and detection_root_index + 1 < len(parts):
|
|
rel_parts = list(parts[detection_root_index + 1 :])
|
|
else:
|
|
rel_parts = list(entry.parts if not entry.is_absolute() else entry.parts[1:])
|
|
|
|
image_path = image_root.joinpath(*rel_parts)
|
|
if not label_path:
|
|
return image_path
|
|
|
|
image_path = image_path.with_suffix(".jpg")
|
|
if image_path.exists():
|
|
return image_path
|
|
|
|
image_rel_parts = ["images" if part == "labels" else part for part in rel_parts]
|
|
image_path = image_root.joinpath(*image_rel_parts).with_suffix(".jpg")
|
|
if image_path.exists():
|
|
return image_path
|
|
|
|
for suffix in ("png", *sorted(IMG_FORMATS - {"jpg", "jpeg", "png"}), "jpeg"):
|
|
candidate = image_path.with_suffix(f".{suffix}")
|
|
if candidate.exists():
|
|
return candidate
|
|
return image_path
|
|
|
|
|
|
def label_path_from_image_entry(entry: Path) -> Path:
|
|
parts = ["labels" if part == "images" else part for part in entry.parts]
|
|
return Path(*parts).with_suffix(".txt")
|
|
|
|
|
|
def collect_source_images(source: str) -> list[Path]:
|
|
source_path = Path(source)
|
|
if source_path.is_dir():
|
|
images = []
|
|
for suffix in IMG_FORMATS:
|
|
images.extend(source_path.rglob(f"*.{suffix}"))
|
|
return sorted(images)
|
|
if source_path.suffix.lower().lstrip(".") == "txt":
|
|
return [resolve_relative_entry(line, source_path) for line in read_lines(source_path)]
|
|
return [source_path]
|
|
|
|
|
|
def is_video_path(path: Path) -> bool:
|
|
return path.suffix.lower().lstrip(".") in {"mp4", "avi", "mov", "mkv", "webm", "m4v", "mpg", "mpeg"}
|
|
|
|
|
|
def collect_dataset_images(data_yaml: str, split: str, clip_id: str = "") -> list[Path]:
|
|
return [sample[0] for sample in collect_dataset_samples(data_yaml, split, clip_id)]
|
|
|
|
|
|
def collect_dataset_samples(data_yaml: str, split: str, clip_id: str = "") -> list[tuple[Path, Path]]:
|
|
data = YAML.load(data_yaml)
|
|
image_root = Path(data["path"])
|
|
split_value = data.get(split)
|
|
if not split_value:
|
|
raise ValueError(f"Dataset YAML has no '{split}' split: {data_yaml}")
|
|
|
|
items = split_value if isinstance(split_value, list) else [split_value]
|
|
samples: list[tuple[Path, Path]] = []
|
|
for item in items:
|
|
item_path = Path(item)
|
|
entries = read_lines(item_path) if item_path.is_file() else [str(item_path)]
|
|
for entry_text in entries:
|
|
entry = resolve_relative_entry(entry_text, item_path if item_path.is_file() else None)
|
|
suffix = entry.suffix.lower().lstrip(".")
|
|
if suffix in IMG_FORMATS:
|
|
image_path = image_path_from_ground_entry(entry, image_root, label_path=False)
|
|
label_path = label_path_from_image_entry(entry)
|
|
elif suffix == "txt":
|
|
image_path = image_path_from_ground_entry(entry, image_root, label_path=True)
|
|
label_path = entry
|
|
else:
|
|
continue
|
|
haystack = f"{entry_text} {entry} {image_path} {label_path}"
|
|
if clip_id and clip_id not in haystack:
|
|
continue
|
|
samples.append((image_path, label_path))
|
|
return samples
|
|
|
|
|
|
def class_names_from_model(model) -> dict[int, str]:
|
|
names = getattr(model, "names", None)
|
|
if isinstance(names, dict):
|
|
return {int(k): str(v) for k, v in names.items()}
|
|
if isinstance(names, (list, tuple)):
|
|
return {i: str(v) for i, v in enumerate(names)}
|
|
return {}
|
|
|
|
|
|
def class_map_from_data(data_yaml: str) -> dict[str, int]:
|
|
data = YAML.load(data_yaml)
|
|
return {str(k): int(v) for k, v in data.get("class_map", {}).items()}
|
|
|
|
|
|
def preprocess_image(img: np.ndarray, imgsz: int, stride: int, device: torch.device) -> tuple[torch.Tensor, np.ndarray]:
|
|
letterbox = LetterBox(new_shape=(imgsz, imgsz), auto=False, stride=stride)
|
|
img_resized = letterbox(image=img)
|
|
tensor = torch.from_numpy(np.ascontiguousarray(img_resized.transpose(2, 0, 1))).to(device)
|
|
return tensor.float().unsqueeze(0) / 256.0, img_resized
|
|
|
|
|
|
def difficulty_from_logits(logits: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Return binary difficulty labels and class-specific confidences for selected predictions."""
|
|
if logits.shape[-1] == 1:
|
|
prob = logits.squeeze(-1).sigmoid()
|
|
cls = (prob >= threshold).long()
|
|
conf = torch.where(cls.bool(), prob, 1.0 - prob)
|
|
return cls, conf
|
|
if logits.shape[-1] == 2:
|
|
probs = logits.softmax(-1)
|
|
cls = logits.argmax(-1).long()
|
|
return cls, probs.gather(1, cls[:, None]).squeeze(1)
|
|
raise RuntimeError(f"Expected difficulty logits with 1 or 2 channels, got shape {tuple(logits.shape)}")
|
|
|
|
|
|
def predict_one(model, img: np.ndarray, args: argparse.Namespace, device: torch.device, stride: int):
|
|
im, resized = preprocess_image(img, args.imgsz, stride, device)
|
|
with torch.inference_mode():
|
|
output = model(im)
|
|
if not isinstance(output, tuple) or len(output) < 2:
|
|
raise RuntimeError("Model output did not include raw prediction metadata with preds_diff.")
|
|
|
|
det, pred_dict = output
|
|
one2one = pred_dict.get("one2one", pred_dict)
|
|
diff_selected = one2one.get("preds_diff_selected")
|
|
if diff_selected is None:
|
|
raise RuntimeError(
|
|
"Model output has no preds_diff_selected. Make sure the checkpoint was trained with the difficulty head."
|
|
)
|
|
|
|
det = det[0]
|
|
diff_selected = diff_selected[0]
|
|
keep_conf = det[:, 4] > args.conf
|
|
det = det[keep_conf]
|
|
diff_selected = diff_selected[keep_conf]
|
|
if det.numel() == 0:
|
|
return np.zeros((0, 4), dtype=np.float32), [], [], [], resized
|
|
|
|
keep_nms = TorchNMS.batched_nms(det[:, :4], det[:, 4], det[:, 5], args.iou)[: args.max_det]
|
|
det = det[keep_nms]
|
|
diff_selected = diff_selected[keep_nms]
|
|
|
|
boxes = det[:, :4].clone()
|
|
boxes = scale_boxes(resized.shape[:2], boxes, img.shape[:2]).round().cpu().numpy()
|
|
scores = det[:, 4].cpu().numpy().tolist()
|
|
classes = det[:, 5].long().cpu().numpy().tolist()
|
|
diff_cls, diff_prob = difficulty_from_logits(diff_selected, args.diff_thres)
|
|
return boxes, scores, classes, list(zip(diff_cls.cpu().numpy().tolist(), diff_prob.cpu().numpy().tolist())), resized
|
|
|
|
|
|
def parse_gt_label(
|
|
label_path: Path,
|
|
class_map: dict[str, int],
|
|
img_shape: tuple[int, int],
|
|
diff_combine: str = "sum",
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
"""Parse Ground 2D labels into xyxy boxes, class ids, and binary difficulty targets."""
|
|
if not label_path.exists():
|
|
return np.zeros((0, 4), dtype=np.float32), np.zeros(0, dtype=np.int64), np.zeros(0, dtype=np.int64)
|
|
|
|
img_h, img_w = img_shape
|
|
boxes, classes, diffs = [], [], []
|
|
for line in read_lines(label_path):
|
|
parts = line.split()
|
|
if len(parts) < 6:
|
|
continue
|
|
cls_id = class_map.get(parts[0])
|
|
if cls_id is None:
|
|
continue
|
|
x, y, w, h = [float(v) for v in parts[1:5]]
|
|
if len(parts) >= 7:
|
|
# Match YOLOGroundDataset: diff2 is truncation, so use only diff1 by default.
|
|
if diff_combine == "first":
|
|
raw_diff = float(parts[5])
|
|
elif diff_combine == "sum":
|
|
raw_diff = float(parts[5]) + float(parts[6])
|
|
else:
|
|
raw_diff = max(float(parts[5]), float(parts[6]))
|
|
else:
|
|
raw_diff = float(parts[5])
|
|
x1 = (x - w * 0.5) * img_w
|
|
y1 = (y - h * 0.5) * img_h
|
|
x2 = (x + w * 0.5) * img_w
|
|
y2 = (y + h * 0.5) * img_h
|
|
boxes.append([x1, y1, x2, y2])
|
|
classes.append(cls_id)
|
|
diffs.append(1 if int(raw_diff) >= 2 else 0)
|
|
return np.asarray(boxes, dtype=np.float32), np.asarray(classes, dtype=np.int64), np.asarray(diffs, dtype=np.int64)
|
|
|
|
|
|
def box_iou_np(boxes1: np.ndarray, boxes2: np.ndarray) -> np.ndarray:
|
|
if len(boxes1) == 0 or len(boxes2) == 0:
|
|
return np.zeros((len(boxes1), len(boxes2)), dtype=np.float32)
|
|
lt = np.maximum(boxes1[:, None, :2], boxes2[None, :, :2])
|
|
rb = np.minimum(boxes1[:, None, 2:], boxes2[None, :, 2:])
|
|
wh = np.clip(rb - lt, 0, None)
|
|
inter = wh[:, :, 0] * wh[:, :, 1]
|
|
area1 = np.clip(boxes1[:, 2] - boxes1[:, 0], 0, None) * np.clip(boxes1[:, 3] - boxes1[:, 1], 0, None)
|
|
area2 = np.clip(boxes2[:, 2] - boxes2[:, 0], 0, None) * np.clip(boxes2[:, 3] - boxes2[:, 1], 0, None)
|
|
return inter / np.clip(area1[:, None] + area2[None, :] - inter, 1e-9, None)
|
|
|
|
|
|
def match_difficulty_predictions(
|
|
boxes: np.ndarray,
|
|
classes: list[int],
|
|
diffs: list[tuple[int, float]],
|
|
gt_boxes: np.ndarray,
|
|
gt_classes: np.ndarray,
|
|
gt_diffs: np.ndarray,
|
|
iou_thr: float,
|
|
) -> tuple[list[int], dict[int, tuple[int, float]], int, int]:
|
|
"""Return wrong prediction indices, matched GT info by pred index, correct count, and matched count."""
|
|
if len(boxes) == 0 or len(gt_boxes) == 0:
|
|
return [], {}, 0, 0
|
|
pred_classes = np.asarray(classes, dtype=np.int64)
|
|
ious = box_iou_np(gt_boxes, boxes)
|
|
ious[gt_classes[:, None] != pred_classes[None, :]] = 0.0
|
|
pairs = np.argwhere(ious >= iou_thr)
|
|
if len(pairs) == 0:
|
|
return [], {}, 0, 0
|
|
order = np.argsort(ious[pairs[:, 0], pairs[:, 1]])[::-1]
|
|
used_gt, used_pred = set(), set()
|
|
wrong_pred_indices: list[int] = []
|
|
match_info: dict[int, tuple[int, float]] = {}
|
|
correct = matched = 0
|
|
for pair_idx in order:
|
|
gi, pi = int(pairs[pair_idx, 0]), int(pairs[pair_idx, 1])
|
|
if gi in used_gt or pi in used_pred:
|
|
continue
|
|
used_gt.add(gi)
|
|
used_pred.add(pi)
|
|
matched += 1
|
|
pred_diff = int(diffs[pi][0])
|
|
gt_diff = int(gt_diffs[gi])
|
|
match_iou = float(ious[gi, pi])
|
|
match_info[pi] = (gt_diff, match_iou)
|
|
if pred_diff == gt_diff:
|
|
correct += 1
|
|
else:
|
|
wrong_pred_indices.append(pi)
|
|
return wrong_pred_indices, match_info, correct, matched
|
|
|
|
|
|
def write_prediction_rows(
|
|
writer,
|
|
image_id: str,
|
|
boxes,
|
|
scores,
|
|
classes,
|
|
diffs,
|
|
names: dict[int, str],
|
|
frame: int = -1,
|
|
match_info: dict[int, tuple[int, float]] | None = None,
|
|
):
|
|
match_info = match_info or {}
|
|
for pi, (box, score, cls_id, (diff_cls, diff_prob)) in enumerate(zip(boxes, scores, classes, diffs)):
|
|
gt_diff, match_iou = match_info.get(pi, (-1, -1.0))
|
|
correct = int(gt_diff == diff_cls) if gt_diff >= 0 else -1
|
|
writer.writerow([
|
|
image_id,
|
|
frame,
|
|
cls_id,
|
|
names.get(cls_id, str(cls_id)),
|
|
f"{score:.6f}",
|
|
diff_cls,
|
|
f"{diff_prob:.6f}",
|
|
gt_diff,
|
|
correct,
|
|
f"{match_iou:.6f}",
|
|
*[int(v) for v in box],
|
|
])
|
|
|
|
|
|
def draw_dashed_rectangle(img: np.ndarray, p1: tuple[int, int], p2: tuple[int, int], color, thickness=2, dash=10, gap=6):
|
|
x1, y1 = p1
|
|
x2, y2 = p2
|
|
for x in range(x1, x2, dash + gap):
|
|
cv2.line(img, (x, y1), (min(x + dash, x2), y1), color, thickness)
|
|
cv2.line(img, (x, y2), (min(x + dash, x2), y2), color, thickness)
|
|
for y in range(y1, y2, dash + gap):
|
|
cv2.line(img, (x1, y), (x1, min(y + dash, y2)), color, thickness)
|
|
cv2.line(img, (x2, y), (x2, min(y + dash, y2)), color, thickness)
|
|
|
|
|
|
def draw_x_flag(img: np.ndarray, box, color=(0, 255, 255), thickness=3):
|
|
x1, y1, x2, y2 = [int(v) for v in box]
|
|
cv2.line(img, (x1, y1), (x2, y2), color, thickness)
|
|
cv2.line(img, (x1, y2), (x2, y1), color, thickness)
|
|
|
|
|
|
def draw_predictions(
|
|
img: np.ndarray,
|
|
boxes,
|
|
scores,
|
|
classes,
|
|
diffs,
|
|
names: dict[int, str],
|
|
wrong_indices: set[int] | None = None,
|
|
) -> np.ndarray:
|
|
out = img.copy()
|
|
wrong_indices = wrong_indices or set()
|
|
for pi, (box, score, cls_id, (diff_cls, diff_prob)) in enumerate(zip(boxes, scores, classes, diffs)):
|
|
x1, y1, x2, y2 = [int(v) for v in box]
|
|
color = (0, 0, 255) if diff_cls == 1 else (0, 200, 0)
|
|
if diff_cls == 1:
|
|
draw_dashed_rectangle(out, (x1, y1), (x2, y2), color, thickness=2)
|
|
else:
|
|
cv2.rectangle(out, (x1, y1), (x2, y2), color, 2)
|
|
label = f"{names.get(cls_id, str(cls_id))} {score:.2f} diff{diff_cls}:{diff_prob:.2f}"
|
|
(tw, th), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
y_text = max(y1 - th - baseline - 2, 0)
|
|
cv2.rectangle(out, (x1, y_text), (x1 + tw + 4, y_text + th + baseline + 4), color, -1)
|
|
cv2.putText(out, label, (x1 + 2, y_text + th + 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
if pi in wrong_indices:
|
|
draw_x_flag(out, box)
|
|
return out
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
out_dir = Path(args.output)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
device = select_device(args.device)
|
|
yolo = YOLO(args.weights)
|
|
model = yolo.model.to(device).eval()
|
|
stride = int(getattr(model, "stride", torch.tensor([32])).max())
|
|
names = class_names_from_model(model)
|
|
|
|
video_source = Path(args.source) if args.source and is_video_path(Path(args.source)) else None
|
|
images = [] if video_source else collect_source_images(args.source) if args.source else collect_dataset_images(args.data, args.split, args.clip_id)
|
|
if args.limit > 0:
|
|
images = images[: args.limit]
|
|
if video_source is None and not images:
|
|
raise FileNotFoundError("No images matched the requested source/split/clip-id.")
|
|
|
|
csv_path = out_dir / "predictions.csv"
|
|
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([
|
|
"source",
|
|
"frame",
|
|
"class_id",
|
|
"class_name",
|
|
"conf",
|
|
"difficulty",
|
|
"difficulty_prob",
|
|
"gt_difficulty",
|
|
"difficulty_correct",
|
|
"match_iou",
|
|
"x1",
|
|
"y1",
|
|
"x2",
|
|
"y2",
|
|
])
|
|
if video_source is not None:
|
|
cap = cv2.VideoCapture(str(video_source))
|
|
if not cap.isOpened():
|
|
raise FileNotFoundError(f"Unable to open video source: {video_source}")
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
save_path = out_dir / f"{video_source.stem}_difficulty.mp4"
|
|
writer_video = cv2.VideoWriter(
|
|
str(save_path),
|
|
cv2.VideoWriter_fourcc(*"mp4v"),
|
|
fps / max(args.video_stride, 1),
|
|
(width, height),
|
|
)
|
|
frame_idx = -1
|
|
processed = 0
|
|
while True:
|
|
ok, frame = cap.read()
|
|
if not ok:
|
|
break
|
|
frame_idx += 1
|
|
if frame_idx % max(args.video_stride, 1) != 0:
|
|
continue
|
|
boxes, scores, classes, diffs, _ = predict_one(model, frame, args, device, stride)
|
|
drawn = draw_predictions(frame, boxes, scores, classes, diffs, names)
|
|
writer_video.write(drawn)
|
|
write_prediction_rows(writer, str(video_source), boxes, scores, classes, diffs, names, frame=frame_idx)
|
|
processed += 1
|
|
print(f"[frame {frame_idx}] {len(boxes)} detections")
|
|
if args.limit > 0 and processed >= args.limit:
|
|
break
|
|
cap.release()
|
|
writer_video.release()
|
|
print(f"Saved annotated video to {save_path}")
|
|
print(f"Saved prediction CSV to {csv_path}")
|
|
return
|
|
|
|
class_map = class_map_from_data(args.data)
|
|
samples = collect_source_images(args.source) if args.source else collect_dataset_samples(args.data, args.split, args.clip_id)
|
|
if args.source:
|
|
sample_pairs = [(image_path, None) for image_path in samples]
|
|
else:
|
|
sample_pairs = samples
|
|
if args.limit > 0:
|
|
sample_pairs = sample_pairs[: args.limit]
|
|
|
|
total_correct = 0
|
|
total_matched = 0
|
|
total_wrong = 0
|
|
for idx, sample in enumerate(sample_pairs, 1):
|
|
image_path, label_path = sample
|
|
img = cv2.imread(str(image_path))
|
|
if img is None:
|
|
print(f"[{idx}/{len(sample_pairs)}] skip unreadable: {image_path}")
|
|
continue
|
|
boxes, scores, classes, diffs, _ = predict_one(model, img, args, device, stride)
|
|
wrong_indices: list[int] = []
|
|
match_info: dict[int, tuple[int, float]] = {}
|
|
if label_path is not None:
|
|
gt_boxes, gt_classes, gt_diffs = parse_gt_label(
|
|
label_path, class_map, img.shape[:2], diff_combine=args.gt_diff_combine
|
|
)
|
|
wrong_indices, match_info, correct, matched = match_difficulty_predictions(
|
|
boxes, classes, diffs, gt_boxes, gt_classes, gt_diffs, args.eval_iou
|
|
)
|
|
total_correct += correct
|
|
total_matched += matched
|
|
total_wrong += len(wrong_indices)
|
|
drawn = draw_predictions(img, boxes, scores, classes, diffs, names, wrong_indices=set(wrong_indices))
|
|
save_path = out_dir / image_path.name
|
|
cv2.imwrite(str(save_path), drawn)
|
|
write_prediction_rows(writer, str(image_path), boxes, scores, classes, diffs, names, match_info=match_info)
|
|
print(
|
|
f"[{idx}/{len(sample_pairs)}] {image_path} -> {save_path} "
|
|
f"({len(boxes)} detections, {len(wrong_indices)} wrong difficulty)"
|
|
)
|
|
if total_matched:
|
|
acc = total_correct / total_matched
|
|
print(
|
|
f"Difficulty accuracy @IoU {args.eval_iou:.2f}: "
|
|
f"{acc:.4f} ({total_correct}/{total_matched}), wrong={total_wrong}"
|
|
)
|
|
elif not args.source:
|
|
print(f"Difficulty accuracy @IoU {args.eval_iou:.2f}: no matched predictions")
|
|
print(f"Saved annotations to {out_dir}")
|
|
print(f"Saved prediction CSV to {csv_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|