Files
yolov26_3d/ultralytics/utils/callbacks/tensorboard.py
2026-06-24 09:35:46 +08:00

690 lines
29 KiB
Python
Executable File

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import warnings
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TESTS_RUNNING, colorstr, torch_utils
from ultralytics.utils.torch_utils import smart_inference_mode
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
# Imports below only required if TensorBoard enabled
from copy import deepcopy
from pathlib import Path
import cv2
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
SummaryWriter = None
# Global state for validation logging
_val_sample_indices_to_log = set()
_val_seen_samples = 0
_val_index_offset = 0
_current_epoch = 0
def _reset_val_logging_state() -> None:
"""Reset validation image logging state."""
global _val_sample_indices_to_log, _val_seen_samples, _val_index_offset
_val_sample_indices_to_log = set()
_val_seen_samples = 0
_val_index_offset = 0
def _close_writer() -> None:
"""Flush and close the active TensorBoard writer if one exists."""
global WRITER
if WRITER:
try:
WRITER.flush()
finally:
WRITER.close()
WRITER = None
def _sanitize_tb_tag_component(value: str, max_len: int = 80) -> str:
"""Return a TensorBoard-friendly tag component."""
value = "".join(c if c.isalnum() or c in ("-", "_", ".") else "_" for c in str(value).strip())
value = value.strip("._")
return value[:max_len] if value else "unknown"
def _create_detection_image(images, targets, paths=None, names=None, is_labels=True, conf_thres=0.25):
"""Create detection visualization image for TensorBoard logging.
Args:
images (torch.Tensor): Batch of images (B, C, H, W)
targets (torch.Tensor | np.ndarray): Target labels or predictions
- For labels: (N, 6) [batch_idx, class_id, x, y, w, h] (normalized xywh)
- For predictions: (N, 7) [batch_idx, class_id, x, y, w, h, conf] (normalized xywh)
paths (list): Image paths
names (dict | list): Class names
is_labels (bool): Whether targets are ground truth labels (True) or predictions (False)
conf_thres (float): Confidence threshold for filtering predictions
Returns:
np.ndarray: RGB image with bounding boxes drawn
"""
if not WRITER:
return None
try:
from ultralytics.utils.plotting import Annotator, colors
from ultralytics.utils.ops import xywh2xyxy
import math
from pathlib import Path
# Convert to numpy
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
max_size = 1920 # max image size
max_subplots = 1 # limit to 1 image for cleaner visualization
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots)
ns = np.ceil(bs**0.5) # number of subplots (square)
if np.max(images[0]) <= 1:
images = images * 255 # de-normalise
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)
for i, im in enumerate(images):
if i == max_subplots:
break
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0) # CHW to HWC
# Images are in RGB format (converted from BGR in augmentation pipeline with default bgr=0.0)
# No color conversion needed for TensorBoard
mosaic[y : y + h, x : x + w, :] = im.astype(np.uint8)
# Resize (optional)
scale = max_size / ns / max(h, w)
if scale < 1:
h = math.ceil(scale * h)
w = math.ceil(scale * w)
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)), interpolation=cv2.INTER_AREA)
# Annotate with thin lines
fs = int((h + w) * ns * 0.01) # font size
line_width = max(1, round(fs / 30)) # thin line width
annotator = Annotator(mosaic, line_width=line_width, font_size=fs, pil=True)
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))
if len(targets) > 0:
ti = targets[targets[:, 0] == i] # image targets
boxes = xywh2xyxy(ti[:, 2:6]).T
classes = ti[:, 1].astype("int")
conf = None if is_labels else ti[:, 6] # confidence only for predictions
if boxes.shape[1]:
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
boxes[[0, 2]] *= w # scale to pixels
boxes[[1, 3]] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes *= scale
boxes[[0, 2]] += x
boxes[[1, 3]] += y
for j, box in enumerate(boxes.T.tolist()):
cls = classes[j]
color = colors(cls)
# Get class name
if names:
if isinstance(names, dict):
cls_name = names.get(cls, str(cls))
elif isinstance(names, list):
cls_name = names[cls] if cls < len(names) else str(cls)
else:
cls_name = str(cls)
else:
cls_name = str(cls)
# For predictions, show confidence score; for labels, show class name
if is_labels or (conf is not None and conf[j] > conf_thres):
if not is_labels and conf is not None:
label = f"{cls_name} {conf[j]:.2f}" # Show class and confidence for predictions
else:
label = cls_name # Show class name for ground truth
annotator.box_label(box, label, color=color)
# Get result from Annotator (returns RGB when pil=True)
result = np.asarray(annotator.result())
return result
except Exception as e:
LOGGER.warning(f"{PREFIX}Failed to create detection image: {e}")
return None
def _output_to_target(output, max_det=300):
"""Convert model output to target format for plotting.
Args:
output (list): List of prediction dicts or tensors
- For dict format: [{'bboxes': (N,4), 'conf': (N,), 'cls': (N,), ...}, ...]
- For tensor format: [(N, 6) [x1, y1, x2, y2, conf, cls], ...]
max_det (int): Maximum detections per image
Returns:
np.ndarray: (M, 7) [batch_idx, class_id, x, y, w, h, conf] (normalized xywh)
"""
try:
from ultralytics.utils.ops import xyxy2xywh
targets = []
for i, o in enumerate(output):
# Handle dict format (new postprocess output)
if isinstance(o, dict):
bboxes = o.get('bboxes', torch.empty((0, 4)))
conf = o.get('conf', torch.empty((0,)))
cls = o.get('cls', torch.empty((0,)))
if len(bboxes) == 0:
continue
# Limit detections
if len(bboxes) > max_det:
bboxes = bboxes[:max_det]
conf = conf[:max_det]
cls = cls[:max_det]
# Ensure proper shapes
if conf.dim() == 1:
conf = conf.unsqueeze(-1)
if cls.dim() == 1:
cls = cls.unsqueeze(-1)
# Create batch index
j = torch.full((len(bboxes), 1), i, device=bboxes.device)
# Concatenate: [batch_idx, cls, xywh, conf]
targets.append(torch.cat((j, cls, xyxy2xywh(bboxes), conf), 1))
# Handle tensor format (legacy)
else:
if len(o) == 0:
continue
o = o[:max_det] # limit detections
box, conf, cls = o[:, :4], o[:, 4:5], o[:, 5:6]
j = torch.full((conf.shape[0], 1), i, device=o.device) # batch index
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
if targets:
return torch.cat(targets, 0).cpu().numpy()
else:
return np.empty((0, 7))
except Exception as e:
LOGGER.warning(f"{PREFIX}Failed to convert output to target: {e}")
return np.empty((0, 7))
def _log_scalars(scalars: dict | None, step: int = 0) -> None:
"""Log scalar values to TensorBoard.
Args:
scalars (dict | None): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are
the corresponding scalar values. If None or empty, nothing is logged.
step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
Examples:
Log training metrics
>>> metrics = {"loss": 0.5, "accuracy": 0.95}
>>> _log_scalars(metrics, step=100)
"""
if WRITER and scalars:
for k, v in scalars.items():
WRITER.add_scalar(k, v, step)
@smart_inference_mode()
def _log_tensorboard_graph(trainer) -> None:
"""Log model graph to TensorBoard.
This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
approach for models like RTDETR that may require special handling.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize. Must
have attributes model and args with imgsz.
Notes:
This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
model architectures.
"""
# Input image - use CPU to avoid corrupting CUDA context if tracing fails
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
im = torch.zeros((1, 3, *imgsz), device="cpu", dtype=torch.float32)
# Try simple method first (YOLO)
try:
model = deepcopy(torch_utils.unwrap_model(trainer.model)).cpu().eval()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The input to trace is already a ScriptModule, tracing it is a no-op.*",
category=UserWarning,
)
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
except Exception as e1:
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(torch_utils.unwrap_model(trainer.model)).cpu().eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The input to trace is already a ScriptModule, tracing it is a no-op.*",
category=UserWarning,
)
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e2:
LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure: {e1} -> {e2}")
def on_pretrain_routine_start(trainer) -> None:
"""Initialize TensorBoard logging with SummaryWriter."""
if SummaryWriter:
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer) -> None:
"""Log TensorBoard graph."""
if WRITER:
_log_tensorboard_graph(trainer)
def on_train_epoch_end(trainer) -> None:
"""Log scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
def on_fit_epoch_end(trainer) -> None:
"""Log epoch metrics at end of training epoch."""
_log_scalars(trainer.metrics, trainer.epoch + 1)
def on_val_start(validator) -> None:
"""Initialize validation logging - select samples to log."""
global _val_sample_indices_to_log, _val_seen_samples, _val_index_offset
if WRITER:
dataset = getattr(validator.dataloader, "dataset", None)
total_samples = len(dataset or [])
if total_samples <= 0:
total_samples = len(validator.dataloader)
num_samples = min(50, total_samples)
sample_indices = np.linspace(0, max(total_samples - 1, 0), num_samples, dtype=int).tolist()
sampler = getattr(validator.dataloader, "sampler", None)
if sampler and hasattr(sampler, "_get_rank_indices"):
start_idx, end_idx = sampler._get_rank_indices()
_val_index_offset = start_idx
_val_sample_indices_to_log = {idx - start_idx for idx in sample_indices if start_idx <= idx < end_idx}
else:
_val_index_offset = 0
_val_sample_indices_to_log = set(sample_indices)
_val_seen_samples = 0
def on_val_batch_end(validator) -> None:
"""Log validation sample images to TensorBoard as 2x2 grid if 3D data available."""
global _current_epoch, _val_sample_indices_to_log, _val_seen_samples, _val_index_offset
if not WRITER or not _val_sample_indices_to_log:
return
try:
batch = validator.batch
batch_size = len(batch.get("im_file", [])) or int(batch["img"].shape[0])
batch_start = _val_seen_samples
batch_end = batch_start + batch_size
sample_indices = [i for i in range(batch_start, batch_end) if i in _val_sample_indices_to_log]
_val_seen_samples = batch_end
if not sample_indices:
return
# 2x2 grid [Target 2D | Target 3D] / [Pred 2D | Pred 3D]
for sample_idx in sample_indices:
batch_sample_idx = sample_idx - batch_start
_log_3d_visualization(validator, batch, batch_sample_idx, sample_idx + _val_index_offset)
except Exception as e:
LOGGER.warning(f"{PREFIX}Failed to log validation images: {e}")
def _draw_dashed_rectangle(img, p1, p2, color, thickness=1, dash=8, gap=5):
"""Draw a dashed rectangle on an image."""
x1, y1 = p1
x2, y2 = p2
for x in range(x1, x2, dash + gap):
cv2.line(img, (x, y1), (min(x + dash, x2), y1), color, thickness, cv2.LINE_AA)
cv2.line(img, (x, y2), (min(x + dash, x2), y2), color, thickness, cv2.LINE_AA)
for y in range(y1, y2, dash + gap):
cv2.line(img, (x1, y), (x1, min(y + dash, y2)), color, thickness, cv2.LINE_AA)
cv2.line(img, (x2, y), (x2, min(y + dash, y2)), color, thickness, cv2.LINE_AA)
def _draw_2d_boxes_on_image(img, boxes_xyxy, cls_ids, names=None, confs=None, thickness=1, diff_cls=None):
"""Draw 2D bounding boxes on an image using cv2 (BGR color conventions).
Args:
img: (H, W, 3) uint8 image (BGR).
boxes_xyxy: (N, 4) array of [x1, y1, x2, y2] in pixel coords.
cls_ids: (N,) array of class IDs.
names: Dict or list of class names.
confs: (N,) array of confidences, or None for GT.
thickness: Line thickness.
"""
from ultralytics.utils.plotting import colors as get_color
for i in range(len(boxes_xyxy)):
x1, y1, x2, y2 = [int(v) for v in boxes_xyxy[i]]
cls = int(cls_ids[i])
color = get_color(cls, bgr=True)
is_diff1 = diff_cls is not None and int(diff_cls[i]) == 1
if is_diff1:
_draw_dashed_rectangle(img, (x1, y1), (x2, y2), color, thickness=max(thickness, 1))
else:
cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness, cv2.LINE_AA)
# Label
if names:
cls_name = names.get(cls, str(cls)) if isinstance(names, dict) else (
names[cls] if isinstance(names, list) and cls < len(names) else str(cls))
else:
cls_name = str(cls)
label = f"{cls_name} {confs[i]:.2f}" if confs is not None else cls_name
if diff_cls is not None:
label = f"{label} d{int(diff_cls[i])}"
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), color, -1)
cv2.putText(img, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
def _log_3d_visualization(validator, batch, batch_sample_idx=0, sample_idx=None):
"""Log 2x2 grid visualization: [Target 2D | Target 3D] / [Pred 2D | Pred 3D].
GT depth is restored in decode_3d_target() using calib["depth_scale"].
Prediction depth is restored once upstream in Ground3DDetectionValidator.postprocess().
Both are projected with the original batch calibration.
"""
labels_3d = batch.get("labels_3d")
if labels_3d is None or len(labels_3d) == 0:
return
try:
from ultralytics.utils.ops import xywh2xyxy
from ultralytics.utils.plotting_3d import (
collect_precomputed_edge_points_2d,
decode_3d_prediction,
decode_3d_target,
draw_3d_box,
)
# Get 3D class config from model
model = validator.trainer.model if hasattr(validator, "trainer") and validator.trainer else None
if model is None:
return
if hasattr(model, "module"):
model = model.module
face_3d_classes = getattr(model, "face_3d_classes", set())
complete_3d_classes = getattr(model, "complete_3d_classes", set())
if not face_3d_classes and not complete_3d_classes:
return
images = batch["img"]
batch_idx_t = batch.get("batch_idx")
batch_cls_t = batch.get("cls")
batch_bboxes = batch.get("bboxes")
if batch_idx_t is None or batch_cls_t is None:
return
batch_idx_np = batch_idx_t.cpu().numpy().ravel()
cls_np = batch_cls_t.cpu().numpy().ravel()
labels_3d_np = labels_3d.cpu().numpy()
edge_faces_points_2d = batch.get("edge_faces_points_2d")
edge_faces_valid = batch.get("edge_faces_valid")
edge_faces_points_2d_np = edge_faces_points_2d.cpu().numpy() if edge_faces_points_2d is not None else None
edge_faces_valid_np = edge_faces_valid.cpu().numpy() if edge_faces_valid is not None else None
face_visibility_score_thresh = float(
getattr(validator.args, "face_visibility_score_thresh", DEFAULT_CFG.face_visibility_score_thresh)
)
_, _, img_h, img_w = images.shape
batch_calib = batch.get("calib")
calib = batch_calib[batch_sample_idx] if batch_calib is not None and batch_sample_idx < len(batch_calib) else None
if calib is None:
LOGGER.warning(f"{PREFIX}No calib in batch, skipping 3D visualization")
return
batch_camera_mode = batch.get("camera_mode")
camera_mode = (
str(batch_camera_mode[batch_sample_idx])
if isinstance(batch_camera_mode, (list, tuple)) and batch_sample_idx < len(batch_camera_mode)
else "unknown"
)
epoch = validator.trainer.epoch if validator.training and validator.trainer else 0
global_sample_idx = sample_idx if sample_idx is not None else batch_sample_idx
rank_tag = f"rank{RANK if RANK >= 0 else 0}"
batch_im_files = batch.get("im_file")
image_name = None
if isinstance(batch_im_files, (list, tuple)) and batch_sample_idx < len(batch_im_files):
image_name = Path(str(batch_im_files[batch_sample_idx])).name
image_tag = _sanitize_tb_tag_component(Path(image_name).stem if image_name else "unknown")
# --- Prepare base image ---
im0 = images[batch_sample_idx].cpu().numpy().transpose(1, 2, 0)
im0 = np.ascontiguousarray(im0 * 255, dtype=np.uint8)
im0 = cv2.cvtColor(im0, cv2.COLOR_RGB2BGR)
# Create 4 copies for the 2x2 grid
im_gt_2d = im0.copy()
im_gt_3d = im0.copy()
im_pred_2d = im0.copy()
im_pred_3d = im0.copy()
line_thick = 1
# --- GT: Decode and draw boxes for selected image ---
mask_i = batch_idx_np == batch_sample_idx
if batch_bboxes is not None and mask_i.any():
gt_bboxes_xywh = batch_bboxes[mask_i].cpu().numpy()
gt_cls = cls_np[mask_i]
gt_bboxes_xyxy = xywh2xyxy(torch.from_numpy(gt_bboxes_xywh)).numpy()
gt_bboxes_xyxy[:, [0, 2]] *= img_w
gt_bboxes_xyxy[:, [1, 3]] *= img_h
_draw_2d_boxes_on_image(im_gt_2d, gt_bboxes_xyxy, gt_cls, names=validator.names, thickness=line_thick)
gt_indices = np.where(mask_i)[0]
for local_idx, idx in enumerate(gt_indices):
bbox_xyxy = gt_bboxes_xyxy[local_idx] if batch_bboxes is not None and mask_i.any() else None
d = decode_3d_target(
labels_3d_np[idx],
int(cls_np[idx]),
calib,
img_w,
img_h,
face_3d_classes,
complete_3d_classes,
score_thr=face_visibility_score_thresh,
bbox_xyxy=bbox_xyxy,
)
if d is not None and d.get("corners_3d") is not None:
if edge_faces_points_2d_np is not None and edge_faces_valid_np is not None and idx < len(edge_faces_points_2d_np):
edge_points_2d = collect_precomputed_edge_points_2d(
edge_faces_points_2d_np[idx],
edge_faces_valid_np[idx],
visible_face_types=d.get("visible_face_types", ()),
)
if edge_points_2d is not None:
d = {**d, "edge_points_2d": edge_points_2d}
draw_3d_box(
im_gt_3d,
d["corners_3d"],
calib,
d.get("face_center_2d"),
d.get("face_color"),
edge_points_2d=d.get("edge_points_2d"),
edge_color=(0, 255, 0),
thickness=line_thick,
)
# --- Predictions: Decode and draw boxes for selected image ---
preds_3d_sel = getattr(validator, "_preds_3d_selected", None)
preds_edge_sel = getattr(validator, "_preds_edge_selected", None)
preds_diff_sel = getattr(validator, "_preds_diff_selected", None)
anchors_sel = getattr(validator, "_anchors_selected", None)
strides_sel = getattr(validator, "_strides_selected", None)
preds_2d = validator.pred
if preds_2d is not None and len(preds_2d) > batch_sample_idx:
pred_i = preds_2d[batch_sample_idx]
pred_bboxes = pred_i["bboxes"].cpu().numpy()
pred_cls_np = pred_i["cls"].cpu().numpy()
pred_conf_np = pred_i["conf"].cpu().numpy()
display_conf = max(float(getattr(validator.args, "visualize_conf", 0.25)), float(getattr(validator.args, "conf", 0.0)))
display_mask = pred_conf_np >= display_conf
if display_mask.any():
pred_diff_np = None
if preds_diff_sel is not None:
diff_logits = preds_diff_sel[batch_sample_idx]
if diff_logits.shape[-1] == 1:
diff_prob = diff_logits[..., 0].sigmoid()
diff_thres = float(getattr(validator.args, "diff_thres", 0.7))
pred_diff_np = (diff_prob >= diff_thres).cpu().numpy().astype(np.int64)
elif diff_logits.shape[-1] == 2:
pred_diff_np = diff_logits.softmax(-1).argmax(-1).cpu().numpy().astype(np.int64)
_draw_2d_boxes_on_image(
im_pred_2d,
pred_bboxes[display_mask],
pred_cls_np[display_mask],
names=validator.names,
confs=pred_conf_np[display_mask],
thickness=line_thick,
diff_cls=pred_diff_np[display_mask] if pred_diff_np is not None else None,
)
if preds_3d_sel is not None and anchors_sel is not None and strides_sel is not None:
p3d = preds_3d_sel[batch_sample_idx].cpu().numpy()
pedge = preds_edge_sel[batch_sample_idx].cpu().numpy() if preds_edge_sel is not None else None
anchors_np = anchors_sel[batch_sample_idx].cpu().numpy()
strides_np = strides_sel[batch_sample_idx].cpu().numpy()
for i in np.where(display_mask)[0]:
d = decode_3d_prediction(
p3d[i],
anchors_np[:, i],
float(strides_np[i]),
calib,
img_w,
img_h,
face_3d_classes,
complete_3d_classes,
int(pred_cls_np[i]),
pred_edge_60=pedge[i] if pedge is not None else None,
bbox_xyxy=pred_bboxes[i],
)
if d is not None and d.get("corners_3d") is not None:
draw_3d_box(
im_pred_3d,
d["corners_3d"],
calib,
d.get("face_center_2d"),
d.get("face_color"),
edge_points_2d=d.get("edge_points_2d"),
edge_color=(0, 165, 255),
thickness=line_thick,
)
# --- Add labels to each quadrant ---
font = cv2.FONT_HERSHEY_SIMPLEX
for im_q, label in [(im_gt_2d, "Target 2D"), (im_gt_3d, "Target 3D"),
(im_pred_2d, "Pred 2D"), (im_pred_3d, "Pred 3D")]:
cv2.putText(im_q, label, (10, 40), font, 1.0, (0, 255, 255), 1, cv2.LINE_AA)
# --- Assemble 2x2 grid ---
top_row = np.hstack([im_gt_2d, im_gt_3d])
bot_row = np.hstack([im_pred_2d, im_pred_3d])
grid = np.vstack([top_row, bot_row])
header = f"{image_name or 'unknown'} | idx={global_sample_idx:04d} | {rank_tag} | mode={camera_mode}"
cv2.rectangle(grid, (0, 0), (grid.shape[1], 30), (32, 32, 32), -1)
cv2.putText(grid, header[:200], (10, 21), font, 0.55, (255, 255, 255), 1, cv2.LINE_AA)
# Convert BGR -> RGB for TensorBoard
grid_rgb = cv2.cvtColor(grid, cv2.COLOR_BGR2RGB)
WRITER.add_image(
f"val/{rank_tag}/{camera_mode}/sample{global_sample_idx:04d}_{image_tag}_2d3d",
grid_rgb,
dataformats="HWC",
global_step=epoch,
)
except Exception as e:
LOGGER.warning(f"{PREFIX}Failed to log 3D visualization: {e}")
def on_val_end(validator) -> None:
"""Clean up after validation."""
_reset_val_logging_state()
def on_train_end(trainer) -> None:
"""Flush and close the TensorBoard writer at the end of training."""
_reset_val_logging_state()
_close_writer()
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
"on_val_start": on_val_start,
"on_val_batch_end": on_val_batch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if SummaryWriter
else {}
)