690 lines
29 KiB
Python
Executable File
690 lines
29 KiB
Python
Executable File
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
import warnings
|
|
|
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TESTS_RUNNING, colorstr, torch_utils
|
|
from ultralytics.utils.torch_utils import smart_inference_mode
|
|
|
|
try:
|
|
assert not TESTS_RUNNING # do not log pytest
|
|
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
|
WRITER = None # TensorBoard SummaryWriter instance
|
|
PREFIX = colorstr("TensorBoard: ")
|
|
|
|
# Imports below only required if TensorBoard enabled
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
except (ImportError, AssertionError, TypeError, AttributeError):
|
|
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
|
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
|
|
SummaryWriter = None
|
|
|
|
# Global state for validation logging
|
|
_val_sample_indices_to_log = set()
|
|
_val_seen_samples = 0
|
|
_val_index_offset = 0
|
|
_current_epoch = 0
|
|
|
|
|
|
def _reset_val_logging_state() -> None:
|
|
"""Reset validation image logging state."""
|
|
global _val_sample_indices_to_log, _val_seen_samples, _val_index_offset
|
|
_val_sample_indices_to_log = set()
|
|
_val_seen_samples = 0
|
|
_val_index_offset = 0
|
|
|
|
|
|
def _close_writer() -> None:
|
|
"""Flush and close the active TensorBoard writer if one exists."""
|
|
global WRITER
|
|
if WRITER:
|
|
try:
|
|
WRITER.flush()
|
|
finally:
|
|
WRITER.close()
|
|
WRITER = None
|
|
|
|
|
|
def _sanitize_tb_tag_component(value: str, max_len: int = 80) -> str:
|
|
"""Return a TensorBoard-friendly tag component."""
|
|
value = "".join(c if c.isalnum() or c in ("-", "_", ".") else "_" for c in str(value).strip())
|
|
value = value.strip("._")
|
|
return value[:max_len] if value else "unknown"
|
|
|
|
|
|
def _create_detection_image(images, targets, paths=None, names=None, is_labels=True, conf_thres=0.25):
|
|
"""Create detection visualization image for TensorBoard logging.
|
|
|
|
Args:
|
|
images (torch.Tensor): Batch of images (B, C, H, W)
|
|
targets (torch.Tensor | np.ndarray): Target labels or predictions
|
|
- For labels: (N, 6) [batch_idx, class_id, x, y, w, h] (normalized xywh)
|
|
- For predictions: (N, 7) [batch_idx, class_id, x, y, w, h, conf] (normalized xywh)
|
|
paths (list): Image paths
|
|
names (dict | list): Class names
|
|
is_labels (bool): Whether targets are ground truth labels (True) or predictions (False)
|
|
conf_thres (float): Confidence threshold for filtering predictions
|
|
|
|
Returns:
|
|
np.ndarray: RGB image with bounding boxes drawn
|
|
"""
|
|
if not WRITER:
|
|
return None
|
|
|
|
try:
|
|
from ultralytics.utils.plotting import Annotator, colors
|
|
from ultralytics.utils.ops import xywh2xyxy
|
|
import math
|
|
from pathlib import Path
|
|
|
|
# Convert to numpy
|
|
if isinstance(images, torch.Tensor):
|
|
images = images.cpu().float().numpy()
|
|
if isinstance(targets, torch.Tensor):
|
|
targets = targets.cpu().numpy()
|
|
|
|
max_size = 1920 # max image size
|
|
max_subplots = 1 # limit to 1 image for cleaner visualization
|
|
bs, _, h, w = images.shape # batch size, _, height, width
|
|
bs = min(bs, max_subplots)
|
|
ns = np.ceil(bs**0.5) # number of subplots (square)
|
|
|
|
if np.max(images[0]) <= 1:
|
|
images = images * 255 # de-normalise
|
|
|
|
# Build Image
|
|
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
|
|
for i, im in enumerate(images):
|
|
if i == max_subplots:
|
|
break
|
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
|
im = im.transpose(1, 2, 0) # CHW to HWC
|
|
# Images are in RGB format (converted from BGR in augmentation pipeline with default bgr=0.0)
|
|
# No color conversion needed for TensorBoard
|
|
mosaic[y : y + h, x : x + w, :] = im.astype(np.uint8)
|
|
|
|
# Resize (optional)
|
|
scale = max_size / ns / max(h, w)
|
|
if scale < 1:
|
|
h = math.ceil(scale * h)
|
|
w = math.ceil(scale * w)
|
|
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)), interpolation=cv2.INTER_AREA)
|
|
|
|
# Annotate with thin lines
|
|
fs = int((h + w) * ns * 0.01) # font size
|
|
line_width = max(1, round(fs / 30)) # thin line width
|
|
annotator = Annotator(mosaic, line_width=line_width, font_size=fs, pil=True)
|
|
|
|
for i in range(bs):
|
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
|
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
|
if paths:
|
|
annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))
|
|
|
|
if len(targets) > 0:
|
|
ti = targets[targets[:, 0] == i] # image targets
|
|
boxes = xywh2xyxy(ti[:, 2:6]).T
|
|
classes = ti[:, 1].astype("int")
|
|
conf = None if is_labels else ti[:, 6] # confidence only for predictions
|
|
|
|
if boxes.shape[1]:
|
|
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
|
boxes[[0, 2]] *= w # scale to pixels
|
|
boxes[[1, 3]] *= h
|
|
elif scale < 1: # absolute coords need scale if image scales
|
|
boxes *= scale
|
|
boxes[[0, 2]] += x
|
|
boxes[[1, 3]] += y
|
|
|
|
for j, box in enumerate(boxes.T.tolist()):
|
|
cls = classes[j]
|
|
color = colors(cls)
|
|
|
|
# Get class name
|
|
if names:
|
|
if isinstance(names, dict):
|
|
cls_name = names.get(cls, str(cls))
|
|
elif isinstance(names, list):
|
|
cls_name = names[cls] if cls < len(names) else str(cls)
|
|
else:
|
|
cls_name = str(cls)
|
|
else:
|
|
cls_name = str(cls)
|
|
|
|
# For predictions, show confidence score; for labels, show class name
|
|
if is_labels or (conf is not None and conf[j] > conf_thres):
|
|
if not is_labels and conf is not None:
|
|
label = f"{cls_name} {conf[j]:.2f}" # Show class and confidence for predictions
|
|
else:
|
|
label = cls_name # Show class name for ground truth
|
|
annotator.box_label(box, label, color=color)
|
|
|
|
# Get result from Annotator (returns RGB when pil=True)
|
|
result = np.asarray(annotator.result())
|
|
return result
|
|
|
|
except Exception as e:
|
|
LOGGER.warning(f"{PREFIX}Failed to create detection image: {e}")
|
|
return None
|
|
|
|
|
|
def _output_to_target(output, max_det=300):
|
|
"""Convert model output to target format for plotting.
|
|
|
|
Args:
|
|
output (list): List of prediction dicts or tensors
|
|
- For dict format: [{'bboxes': (N,4), 'conf': (N,), 'cls': (N,), ...}, ...]
|
|
- For tensor format: [(N, 6) [x1, y1, x2, y2, conf, cls], ...]
|
|
max_det (int): Maximum detections per image
|
|
|
|
Returns:
|
|
np.ndarray: (M, 7) [batch_idx, class_id, x, y, w, h, conf] (normalized xywh)
|
|
"""
|
|
try:
|
|
from ultralytics.utils.ops import xyxy2xywh
|
|
|
|
targets = []
|
|
for i, o in enumerate(output):
|
|
# Handle dict format (new postprocess output)
|
|
if isinstance(o, dict):
|
|
bboxes = o.get('bboxes', torch.empty((0, 4)))
|
|
conf = o.get('conf', torch.empty((0,)))
|
|
cls = o.get('cls', torch.empty((0,)))
|
|
|
|
if len(bboxes) == 0:
|
|
continue
|
|
|
|
# Limit detections
|
|
if len(bboxes) > max_det:
|
|
bboxes = bboxes[:max_det]
|
|
conf = conf[:max_det]
|
|
cls = cls[:max_det]
|
|
|
|
# Ensure proper shapes
|
|
if conf.dim() == 1:
|
|
conf = conf.unsqueeze(-1)
|
|
if cls.dim() == 1:
|
|
cls = cls.unsqueeze(-1)
|
|
|
|
# Create batch index
|
|
j = torch.full((len(bboxes), 1), i, device=bboxes.device)
|
|
|
|
# Concatenate: [batch_idx, cls, xywh, conf]
|
|
targets.append(torch.cat((j, cls, xyxy2xywh(bboxes), conf), 1))
|
|
|
|
# Handle tensor format (legacy)
|
|
else:
|
|
if len(o) == 0:
|
|
continue
|
|
o = o[:max_det] # limit detections
|
|
box, conf, cls = o[:, :4], o[:, 4:5], o[:, 5:6]
|
|
j = torch.full((conf.shape[0], 1), i, device=o.device) # batch index
|
|
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
|
|
|
if targets:
|
|
return torch.cat(targets, 0).cpu().numpy()
|
|
else:
|
|
return np.empty((0, 7))
|
|
|
|
except Exception as e:
|
|
LOGGER.warning(f"{PREFIX}Failed to convert output to target: {e}")
|
|
return np.empty((0, 7))
|
|
|
|
|
|
def _log_scalars(scalars: dict | None, step: int = 0) -> None:
|
|
"""Log scalar values to TensorBoard.
|
|
|
|
Args:
|
|
scalars (dict | None): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are
|
|
the corresponding scalar values. If None or empty, nothing is logged.
|
|
step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
|
|
|
|
Examples:
|
|
Log training metrics
|
|
>>> metrics = {"loss": 0.5, "accuracy": 0.95}
|
|
>>> _log_scalars(metrics, step=100)
|
|
"""
|
|
if WRITER and scalars:
|
|
for k, v in scalars.items():
|
|
WRITER.add_scalar(k, v, step)
|
|
|
|
|
|
@smart_inference_mode()
|
|
def _log_tensorboard_graph(trainer) -> None:
|
|
"""Log model graph to TensorBoard.
|
|
|
|
This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
|
|
tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
|
|
approach for models like RTDETR that may require special handling.
|
|
|
|
Args:
|
|
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize. Must
|
|
have attributes model and args with imgsz.
|
|
|
|
Notes:
|
|
This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
|
|
It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
|
|
model architectures.
|
|
"""
|
|
# Input image - use CPU to avoid corrupting CUDA context if tracing fails
|
|
imgsz = trainer.args.imgsz
|
|
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
|
im = torch.zeros((1, 3, *imgsz), device="cpu", dtype=torch.float32)
|
|
|
|
# Try simple method first (YOLO)
|
|
try:
|
|
model = deepcopy(torch_utils.unwrap_model(trainer.model)).cpu().eval()
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="The input to trace is already a ScriptModule, tracing it is a no-op.*",
|
|
category=UserWarning,
|
|
)
|
|
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
|
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
|
return
|
|
except Exception as e1:
|
|
# Fallback to TorchScript export steps (RTDETR)
|
|
try:
|
|
model = deepcopy(torch_utils.unwrap_model(trainer.model)).cpu().eval()
|
|
model = model.fuse(verbose=False)
|
|
for m in model.modules():
|
|
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
|
m.export = True
|
|
m.format = "torchscript"
|
|
model(im) # dry run
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="The input to trace is already a ScriptModule, tracing it is a no-op.*",
|
|
category=UserWarning,
|
|
)
|
|
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
|
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
|
except Exception as e2:
|
|
LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure: {e1} -> {e2}")
|
|
|
|
|
|
def on_pretrain_routine_start(trainer) -> None:
|
|
"""Initialize TensorBoard logging with SummaryWriter."""
|
|
if SummaryWriter:
|
|
try:
|
|
global WRITER
|
|
WRITER = SummaryWriter(str(trainer.save_dir))
|
|
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
|
except Exception as e:
|
|
LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}")
|
|
|
|
|
|
def on_train_start(trainer) -> None:
|
|
"""Log TensorBoard graph."""
|
|
if WRITER:
|
|
_log_tensorboard_graph(trainer)
|
|
|
|
|
|
def on_train_epoch_end(trainer) -> None:
|
|
"""Log scalar statistics at the end of a training epoch."""
|
|
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
|
_log_scalars(trainer.lr, trainer.epoch + 1)
|
|
|
|
|
|
def on_fit_epoch_end(trainer) -> None:
|
|
"""Log epoch metrics at end of training epoch."""
|
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
|
|
|
|
|
def on_val_start(validator) -> None:
|
|
"""Initialize validation logging - select samples to log."""
|
|
global _val_sample_indices_to_log, _val_seen_samples, _val_index_offset
|
|
if WRITER:
|
|
dataset = getattr(validator.dataloader, "dataset", None)
|
|
total_samples = len(dataset or [])
|
|
if total_samples <= 0:
|
|
total_samples = len(validator.dataloader)
|
|
num_samples = min(50, total_samples)
|
|
sample_indices = np.linspace(0, max(total_samples - 1, 0), num_samples, dtype=int).tolist()
|
|
|
|
sampler = getattr(validator.dataloader, "sampler", None)
|
|
if sampler and hasattr(sampler, "_get_rank_indices"):
|
|
start_idx, end_idx = sampler._get_rank_indices()
|
|
_val_index_offset = start_idx
|
|
_val_sample_indices_to_log = {idx - start_idx for idx in sample_indices if start_idx <= idx < end_idx}
|
|
else:
|
|
_val_index_offset = 0
|
|
_val_sample_indices_to_log = set(sample_indices)
|
|
|
|
_val_seen_samples = 0
|
|
|
|
|
|
def on_val_batch_end(validator) -> None:
|
|
"""Log validation sample images to TensorBoard as 2x2 grid if 3D data available."""
|
|
global _current_epoch, _val_sample_indices_to_log, _val_seen_samples, _val_index_offset
|
|
|
|
if not WRITER or not _val_sample_indices_to_log:
|
|
return
|
|
|
|
try:
|
|
batch = validator.batch
|
|
batch_size = len(batch.get("im_file", [])) or int(batch["img"].shape[0])
|
|
batch_start = _val_seen_samples
|
|
batch_end = batch_start + batch_size
|
|
sample_indices = [i for i in range(batch_start, batch_end) if i in _val_sample_indices_to_log]
|
|
_val_seen_samples = batch_end
|
|
if not sample_indices:
|
|
return
|
|
|
|
# 2x2 grid [Target 2D | Target 3D] / [Pred 2D | Pred 3D]
|
|
for sample_idx in sample_indices:
|
|
batch_sample_idx = sample_idx - batch_start
|
|
_log_3d_visualization(validator, batch, batch_sample_idx, sample_idx + _val_index_offset)
|
|
|
|
except Exception as e:
|
|
LOGGER.warning(f"{PREFIX}Failed to log validation images: {e}")
|
|
|
|
|
|
def _draw_dashed_rectangle(img, p1, p2, color, thickness=1, dash=8, gap=5):
|
|
"""Draw a dashed rectangle on an image."""
|
|
x1, y1 = p1
|
|
x2, y2 = p2
|
|
for x in range(x1, x2, dash + gap):
|
|
cv2.line(img, (x, y1), (min(x + dash, x2), y1), color, thickness, cv2.LINE_AA)
|
|
cv2.line(img, (x, y2), (min(x + dash, x2), y2), color, thickness, cv2.LINE_AA)
|
|
for y in range(y1, y2, dash + gap):
|
|
cv2.line(img, (x1, y), (x1, min(y + dash, y2)), color, thickness, cv2.LINE_AA)
|
|
cv2.line(img, (x2, y), (x2, min(y + dash, y2)), color, thickness, cv2.LINE_AA)
|
|
|
|
|
|
def _draw_2d_boxes_on_image(img, boxes_xyxy, cls_ids, names=None, confs=None, thickness=1, diff_cls=None):
|
|
"""Draw 2D bounding boxes on an image using cv2 (BGR color conventions).
|
|
|
|
Args:
|
|
img: (H, W, 3) uint8 image (BGR).
|
|
boxes_xyxy: (N, 4) array of [x1, y1, x2, y2] in pixel coords.
|
|
cls_ids: (N,) array of class IDs.
|
|
names: Dict or list of class names.
|
|
confs: (N,) array of confidences, or None for GT.
|
|
thickness: Line thickness.
|
|
"""
|
|
from ultralytics.utils.plotting import colors as get_color
|
|
|
|
for i in range(len(boxes_xyxy)):
|
|
x1, y1, x2, y2 = [int(v) for v in boxes_xyxy[i]]
|
|
cls = int(cls_ids[i])
|
|
color = get_color(cls, bgr=True)
|
|
is_diff1 = diff_cls is not None and int(diff_cls[i]) == 1
|
|
if is_diff1:
|
|
_draw_dashed_rectangle(img, (x1, y1), (x2, y2), color, thickness=max(thickness, 1))
|
|
else:
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness, cv2.LINE_AA)
|
|
|
|
# Label
|
|
if names:
|
|
cls_name = names.get(cls, str(cls)) if isinstance(names, dict) else (
|
|
names[cls] if isinstance(names, list) and cls < len(names) else str(cls))
|
|
else:
|
|
cls_name = str(cls)
|
|
label = f"{cls_name} {confs[i]:.2f}" if confs is not None else cls_name
|
|
if diff_cls is not None:
|
|
label = f"{label} d{int(diff_cls[i])}"
|
|
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
|
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), color, -1)
|
|
cv2.putText(img, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
|
|
|
|
|
|
def _log_3d_visualization(validator, batch, batch_sample_idx=0, sample_idx=None):
|
|
"""Log 2x2 grid visualization: [Target 2D | Target 3D] / [Pred 2D | Pred 3D].
|
|
|
|
GT depth is restored in decode_3d_target() using calib["depth_scale"].
|
|
Prediction depth is restored once upstream in Ground3DDetectionValidator.postprocess().
|
|
Both are projected with the original batch calibration.
|
|
"""
|
|
labels_3d = batch.get("labels_3d")
|
|
if labels_3d is None or len(labels_3d) == 0:
|
|
return
|
|
|
|
try:
|
|
from ultralytics.utils.ops import xywh2xyxy
|
|
from ultralytics.utils.plotting_3d import (
|
|
collect_precomputed_edge_points_2d,
|
|
decode_3d_prediction,
|
|
decode_3d_target,
|
|
draw_3d_box,
|
|
)
|
|
|
|
# Get 3D class config from model
|
|
model = validator.trainer.model if hasattr(validator, "trainer") and validator.trainer else None
|
|
if model is None:
|
|
return
|
|
if hasattr(model, "module"):
|
|
model = model.module
|
|
face_3d_classes = getattr(model, "face_3d_classes", set())
|
|
complete_3d_classes = getattr(model, "complete_3d_classes", set())
|
|
if not face_3d_classes and not complete_3d_classes:
|
|
return
|
|
|
|
images = batch["img"]
|
|
batch_idx_t = batch.get("batch_idx")
|
|
batch_cls_t = batch.get("cls")
|
|
batch_bboxes = batch.get("bboxes")
|
|
if batch_idx_t is None or batch_cls_t is None:
|
|
return
|
|
|
|
batch_idx_np = batch_idx_t.cpu().numpy().ravel()
|
|
cls_np = batch_cls_t.cpu().numpy().ravel()
|
|
labels_3d_np = labels_3d.cpu().numpy()
|
|
edge_faces_points_2d = batch.get("edge_faces_points_2d")
|
|
edge_faces_valid = batch.get("edge_faces_valid")
|
|
edge_faces_points_2d_np = edge_faces_points_2d.cpu().numpy() if edge_faces_points_2d is not None else None
|
|
edge_faces_valid_np = edge_faces_valid.cpu().numpy() if edge_faces_valid is not None else None
|
|
face_visibility_score_thresh = float(
|
|
getattr(validator.args, "face_visibility_score_thresh", DEFAULT_CFG.face_visibility_score_thresh)
|
|
)
|
|
_, _, img_h, img_w = images.shape
|
|
|
|
batch_calib = batch.get("calib")
|
|
calib = batch_calib[batch_sample_idx] if batch_calib is not None and batch_sample_idx < len(batch_calib) else None
|
|
if calib is None:
|
|
LOGGER.warning(f"{PREFIX}No calib in batch, skipping 3D visualization")
|
|
return
|
|
|
|
batch_camera_mode = batch.get("camera_mode")
|
|
camera_mode = (
|
|
str(batch_camera_mode[batch_sample_idx])
|
|
if isinstance(batch_camera_mode, (list, tuple)) and batch_sample_idx < len(batch_camera_mode)
|
|
else "unknown"
|
|
)
|
|
|
|
epoch = validator.trainer.epoch if validator.training and validator.trainer else 0
|
|
global_sample_idx = sample_idx if sample_idx is not None else batch_sample_idx
|
|
rank_tag = f"rank{RANK if RANK >= 0 else 0}"
|
|
batch_im_files = batch.get("im_file")
|
|
image_name = None
|
|
if isinstance(batch_im_files, (list, tuple)) and batch_sample_idx < len(batch_im_files):
|
|
image_name = Path(str(batch_im_files[batch_sample_idx])).name
|
|
image_tag = _sanitize_tb_tag_component(Path(image_name).stem if image_name else "unknown")
|
|
|
|
# --- Prepare base image ---
|
|
im0 = images[batch_sample_idx].cpu().numpy().transpose(1, 2, 0)
|
|
im0 = np.ascontiguousarray(im0 * 255, dtype=np.uint8)
|
|
im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR)
|
|
|
|
# Create 4 copies for the 2x2 grid
|
|
im_gt_2d = im0.copy()
|
|
im_gt_3d = im0.copy()
|
|
im_pred_2d = im0.copy()
|
|
im_pred_3d = im0.copy()
|
|
|
|
line_thick = 1
|
|
|
|
# --- GT: Decode and draw boxes for selected image ---
|
|
mask_i = batch_idx_np == batch_sample_idx
|
|
if batch_bboxes is not None and mask_i.any():
|
|
gt_bboxes_xywh = batch_bboxes[mask_i].cpu().numpy()
|
|
gt_cls = cls_np[mask_i]
|
|
gt_bboxes_xyxy = xywh2xyxy(torch.from_numpy(gt_bboxes_xywh)).numpy()
|
|
gt_bboxes_xyxy[:, [0, 2]] *= img_w
|
|
gt_bboxes_xyxy[:, [1, 3]] *= img_h
|
|
_draw_2d_boxes_on_image(im_gt_2d, gt_bboxes_xyxy, gt_cls, names=validator.names, thickness=line_thick)
|
|
|
|
gt_indices = np.where(mask_i)[0]
|
|
for local_idx, idx in enumerate(gt_indices):
|
|
bbox_xyxy = gt_bboxes_xyxy[local_idx] if batch_bboxes is not None and mask_i.any() else None
|
|
d = decode_3d_target(
|
|
labels_3d_np[idx],
|
|
int(cls_np[idx]),
|
|
calib,
|
|
img_w,
|
|
img_h,
|
|
face_3d_classes,
|
|
complete_3d_classes,
|
|
score_thr=face_visibility_score_thresh,
|
|
bbox_xyxy=bbox_xyxy,
|
|
)
|
|
if d is not None and d.get("corners_3d") is not None:
|
|
if edge_faces_points_2d_np is not None and edge_faces_valid_np is not None and idx < len(edge_faces_points_2d_np):
|
|
edge_points_2d = collect_precomputed_edge_points_2d(
|
|
edge_faces_points_2d_np[idx],
|
|
edge_faces_valid_np[idx],
|
|
visible_face_types=d.get("visible_face_types", ()),
|
|
)
|
|
if edge_points_2d is not None:
|
|
d = {**d, "edge_points_2d": edge_points_2d}
|
|
draw_3d_box(
|
|
im_gt_3d,
|
|
d["corners_3d"],
|
|
calib,
|
|
d.get("face_center_2d"),
|
|
d.get("face_color"),
|
|
edge_points_2d=d.get("edge_points_2d"),
|
|
edge_color=(0, 255, 0),
|
|
thickness=line_thick,
|
|
)
|
|
|
|
# --- Predictions: Decode and draw boxes for selected image ---
|
|
preds_3d_sel = getattr(validator, "_preds_3d_selected", None)
|
|
preds_edge_sel = getattr(validator, "_preds_edge_selected", None)
|
|
preds_diff_sel = getattr(validator, "_preds_diff_selected", None)
|
|
anchors_sel = getattr(validator, "_anchors_selected", None)
|
|
strides_sel = getattr(validator, "_strides_selected", None)
|
|
preds_2d = validator.pred
|
|
|
|
if preds_2d is not None and len(preds_2d) > batch_sample_idx:
|
|
pred_i = preds_2d[batch_sample_idx]
|
|
pred_bboxes = pred_i["bboxes"].cpu().numpy()
|
|
pred_cls_np = pred_i["cls"].cpu().numpy()
|
|
pred_conf_np = pred_i["conf"].cpu().numpy()
|
|
display_conf = max(float(getattr(validator.args, "visualize_conf", 0.25)), float(getattr(validator.args, "conf", 0.0)))
|
|
display_mask = pred_conf_np >= display_conf
|
|
|
|
if display_mask.any():
|
|
pred_diff_np = None
|
|
if preds_diff_sel is not None:
|
|
diff_logits = preds_diff_sel[batch_sample_idx]
|
|
if diff_logits.shape[-1] == 1:
|
|
diff_prob = diff_logits[..., 0].sigmoid()
|
|
diff_thres = float(getattr(validator.args, "diff_thres", 0.7))
|
|
pred_diff_np = (diff_prob >= diff_thres).cpu().numpy().astype(np.int64)
|
|
elif diff_logits.shape[-1] == 2:
|
|
pred_diff_np = diff_logits.softmax(-1).argmax(-1).cpu().numpy().astype(np.int64)
|
|
_draw_2d_boxes_on_image(
|
|
im_pred_2d,
|
|
pred_bboxes[display_mask],
|
|
pred_cls_np[display_mask],
|
|
names=validator.names,
|
|
confs=pred_conf_np[display_mask],
|
|
thickness=line_thick,
|
|
diff_cls=pred_diff_np[display_mask] if pred_diff_np is not None else None,
|
|
)
|
|
|
|
if preds_3d_sel is not None and anchors_sel is not None and strides_sel is not None:
|
|
p3d = preds_3d_sel[batch_sample_idx].cpu().numpy()
|
|
pedge = preds_edge_sel[batch_sample_idx].cpu().numpy() if preds_edge_sel is not None else None
|
|
anchors_np = anchors_sel[batch_sample_idx].cpu().numpy()
|
|
strides_np = strides_sel[batch_sample_idx].cpu().numpy()
|
|
for i in np.where(display_mask)[0]:
|
|
d = decode_3d_prediction(
|
|
p3d[i],
|
|
anchors_np[:, i],
|
|
float(strides_np[i]),
|
|
calib,
|
|
img_w,
|
|
img_h,
|
|
face_3d_classes,
|
|
complete_3d_classes,
|
|
int(pred_cls_np[i]),
|
|
pred_edge_60=pedge[i] if pedge is not None else None,
|
|
bbox_xyxy=pred_bboxes[i],
|
|
)
|
|
if d is not None and d.get("corners_3d") is not None:
|
|
draw_3d_box(
|
|
im_pred_3d,
|
|
d["corners_3d"],
|
|
calib,
|
|
d.get("face_center_2d"),
|
|
d.get("face_color"),
|
|
edge_points_2d=d.get("edge_points_2d"),
|
|
edge_color=(0, 165, 255),
|
|
thickness=line_thick,
|
|
)
|
|
|
|
# --- Add labels to each quadrant ---
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
for im_q, label in [(im_gt_2d, "Target 2D"), (im_gt_3d, "Target 3D"),
|
|
(im_pred_2d, "Pred 2D"), (im_pred_3d, "Pred 3D")]:
|
|
cv2.putText(im_q, label, (10, 40), font, 1.0, (0, 255, 255), 1, cv2.LINE_AA)
|
|
|
|
# --- Assemble 2x2 grid ---
|
|
top_row = np.hstack([im_gt_2d, im_gt_3d])
|
|
bot_row = np.hstack([im_pred_2d, im_pred_3d])
|
|
grid = np.vstack([top_row, bot_row])
|
|
|
|
header = f"{image_name or 'unknown'} | idx={global_sample_idx:04d} | {rank_tag} | mode={camera_mode}"
|
|
cv2.rectangle(grid, (0, 0), (grid.shape[1], 30), (32, 32, 32), -1)
|
|
cv2.putText(grid, header[:200], (10, 21), font, 0.55, (255, 255, 255), 1, cv2.LINE_AA)
|
|
|
|
# Convert BGR -> RGB for TensorBoard
|
|
grid_rgb = cv2.cvtColor(grid, cv2.COLOR_BGR2RGB)
|
|
|
|
WRITER.add_image(
|
|
f"val/{rank_tag}/{camera_mode}/sample{global_sample_idx:04d}_{image_tag}_2d3d",
|
|
grid_rgb,
|
|
dataformats="HWC",
|
|
global_step=epoch,
|
|
)
|
|
|
|
except Exception as e:
|
|
LOGGER.warning(f"{PREFIX}Failed to log 3D visualization: {e}")
|
|
|
|
|
|
def on_val_end(validator) -> None:
|
|
"""Clean up after validation."""
|
|
_reset_val_logging_state()
|
|
|
|
|
|
def on_train_end(trainer) -> None:
|
|
"""Flush and close the TensorBoard writer at the end of training."""
|
|
_reset_val_logging_state()
|
|
_close_writer()
|
|
|
|
|
|
callbacks = (
|
|
{
|
|
"on_pretrain_routine_start": on_pretrain_routine_start,
|
|
"on_train_start": on_train_start,
|
|
"on_fit_epoch_end": on_fit_epoch_end,
|
|
"on_train_epoch_end": on_train_epoch_end,
|
|
"on_val_start": on_val_start,
|
|
"on_val_batch_end": on_val_batch_end,
|
|
"on_val_end": on_val_end,
|
|
"on_train_end": on_train_end,
|
|
}
|
|
if SummaryWriter
|
|
else {}
|
|
)
|