Files
yolov26_3d/tools/pdcl_inference/extract_gt_by_group_id.py
2026-06-24 09:35:46 +08:00

659 lines
23 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
按 group_id 提取真值产物(标注 JSON + 标定文件,可选保存图片)。
用法:
python extract_gt_by_group_id.py \
--group_ids <uuid1> <uuid2> ... \
--output /data/output_dir \
[--car_name G1M3_0630] \
[--product_type 2d-3d-association] \
[--camera_id camera4] \
[--image_suffix png] \
[--annotation_stride 2] \
[--skip_images] \
[--workers 4]
或者通过文本文件批量输入(每行: group_uuid [date] [car_name]
python extract_gt_by_group_id.py \
--uuid_file /data/group_ids.txt \
--output /data/output_dir
"""
import argparse
import io
import os
import tarfile
from datetime import datetime, timezone, timedelta
from math import ceil
from multiprocessing import Pool
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
from pdcl_dss import Group
os.environ['STS_UID'] = 'dis-uploader'
os.environ['STS_SECRET_KEY'] = '277310cc09724d315514a79701fecb0f'
BEIJING_TZ = timezone(timedelta(hours=8))
DEFAULT_CAR_NAME = 'unknown'
DEFAULT_DATE_NAME = '00000000000000'
DEFAULT_IMAGE_SUFFIX = 'png'
DEFAULT_ANNOTATION_STRIDE = 1
SUPPORTED_IMAGE_SUFFIXES = {'png', 'jpg', 'jpeg'}
# ---------------------------------------------------------------------------
# 视频解码工具
# ---------------------------------------------------------------------------
def _plane_to_ndarray(plane) -> np.ndarray:
stride = plane.line_size
height = plane.height
width = plane.width
arr = np.frombuffer(plane, dtype=np.uint8)
if stride == width:
return arr.reshape(height, width)
return arr.reshape(height, stride)[:, :width]
def _yuvj420p_to_bgr(frame) -> np.ndarray:
"""将 av YUVJ420P 帧转换为 BGR ndarray。"""
import cv2
y = _plane_to_ndarray(frame.planes[0])
u = _plane_to_ndarray(frame.planes[1])
v = _plane_to_ndarray(frame.planes[2])
h, w = y.shape
uv = np.zeros((h // 2, w), dtype=np.uint8)
uv[:, 0::2] = u
uv[:, 1::2] = v
yuv = np.concatenate([y, uv], axis=0)
return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR_NV12)
def _h265_bytes_to_bgr(frame_bytes: bytes) -> np.ndarray:
import av
container = av.open(io.BytesIO(frame_bytes))
for av_frame in container.decode(video=0):
return _yuvj420p_to_bgr(av_frame)
raise ValueError("H.265 解码失败")
# ---------------------------------------------------------------------------
# PDCL 数据读取
# ---------------------------------------------------------------------------
def get_frame_timestamps(group_uuid: str) -> pd.DataFrame:
"""从 clean_data_archive.tar 中读取 frame_timestamps.csv。"""
with Group(group_uuid) as group:
with group.open('clean_data_archive.tar', mode='rb') as f:
tar = tarfile.open(fileobj=f)
csv_obj = tar.extractfile(tar.getmember('frame_timestamps.csv'))
return pd.read_csv(csv_obj)
def get_group_meta(group_uuid: str) -> Dict:
"""读取 group 元数据。"""
with Group(group_uuid) as group:
return dict(group.meta)
def get_clip_paths(group_uuid: str) -> List[str]:
"""获取 group 下所有 clip 的本地缓存路径。"""
from pdcl_dss import Clip
paths = []
with Group(group_uuid) as group:
for clip_ukey in group.list_clip_ukeys():
clip = Clip(clip_ukey)
files, _ = clip.list_files()
paths.append(clip.get_cache_path(files[0]))
return paths
def _find_index_by_timestamp(df: pd.DataFrame, camera_id: str, ts_ns: int) -> int:
return df[abs(df[camera_id] - ts_ns) < 1e-3].index[0]
def _build_timestamp_index(df: pd.DataFrame, camera_id: str) -> Dict[int, int]:
if camera_id not in df.columns:
return {}
timestamp_index: Dict[int, int] = {}
for json_index, value in df[camera_id].items():
if pd.isna(value):
continue
timestamp_index[int(value)] = int(json_index)
return timestamp_index
def _find_index_by_timestamp_fast(
timestamp_index: Dict[int, int],
df: pd.DataFrame,
camera_id: str,
ts_ns: int,
) -> int:
json_index = timestamp_index.get(int(ts_ns))
if json_index is not None:
return json_index
return _find_index_by_timestamp(df, camera_id, ts_ns)
def _is_missing_car_name(car_name: str) -> bool:
return not car_name or car_name == DEFAULT_CAR_NAME
def _is_missing_date_name(date_name: str) -> bool:
return not date_name or date_name == DEFAULT_DATE_NAME
def _format_ms_to_date_name(timestamp_ms: int) -> str:
dt = datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc).astimezone(BEIJING_TZ)
return dt.strftime('%Y%m%d%H%M%S')
def _get_frame_id_column(camera_id: str) -> str:
return f'{camera_id}_frame_id'
def _is_target_camera_annotation_member(member_name: str, camera_id: str) -> bool:
normalized_name = member_name.replace('\\', '/')
target_fragment = f'/{camera_id}/camera_radar_lidar_jsons/'
return normalized_name.endswith('.json') and target_fragment in f'/{normalized_name}'
def normalize_image_suffix(image_suffix: str) -> str:
normalized = image_suffix.strip().lower().lstrip('.')
if normalized not in SUPPORTED_IMAGE_SUFFIXES:
supported = ', '.join(sorted(SUPPORTED_IMAGE_SUFFIXES))
raise ValueError(f'不支持的 image_suffix: {image_suffix}. 支持: {supported}')
return normalized
def normalize_annotation_stride(annotation_stride: int) -> int:
if annotation_stride < 1:
raise ValueError(f'annotation_stride 必须 >= 1当前为: {annotation_stride}')
return annotation_stride
def _get_frame_timestamp_ns(df: pd.DataFrame, camera_id: str, json_index: int) -> Optional[int]:
try:
value = df.at[json_index, camera_id]
except (KeyError, ValueError):
return None
if pd.isna(value):
return None
return int(value)
def _build_frame_basename(
car_name: str,
date_name: str,
group_uuid: str,
json_index: int,
camera_id: str,
df: pd.DataFrame,
) -> str:
parts = [car_name, date_name, group_uuid, f'{json_index:06d}']
frame_id_col = _get_frame_id_column(camera_id)
if frame_id_col in df.columns:
frame_id = df.at[json_index, frame_id_col]
if not pd.isna(frame_id):
parts.append(str(int(frame_id)))
frame_timestamp_ns = _get_frame_timestamp_ns(df, camera_id, json_index)
if frame_timestamp_ns is not None:
parts.append(str(frame_timestamp_ns))
return '_'.join(parts)
def _infer_car_name(group_meta: Dict) -> Tuple[str, str]:
for key in ('plate_number', 'vehicle_name', 'car_name', 'plateNumber', 'vehicleName', 'carName'):
value = str(group_meta.get(key, '')).strip()
if value:
return value, key
project_name = str(group_meta.get('project_name', '')).strip()
if project_name:
return project_name, 'project_name'
return DEFAULT_CAR_NAME, 'default'
def _infer_date_name(group_meta: Dict, df: pd.DataFrame, camera_id: str) -> Tuple[str, str]:
collection_time = group_meta.get('collection_time')
if isinstance(collection_time, (int, float)) and collection_time > 0:
return _format_ms_to_date_name(int(collection_time)), 'collection_time'
candidate_columns = [camera_id] if camera_id in df.columns else []
candidate_columns.extend(col for col in df.columns if col not in candidate_columns)
for col in candidate_columns:
col_name = str(col)
if col_name.startswith('Unnamed:'):
continue
if any(token in col_name for token in ('frame_id', 'clip_uuid', 'point_count', 'speed')):
continue
values = pd.to_numeric(df[col], errors='coerce').dropna()
if values.empty:
continue
timestamp_ns = int(values.iloc[0])
dt = datetime.fromtimestamp(timestamp_ns / 1e9, tz=timezone.utc).astimezone(BEIJING_TZ)
return dt.strftime('%Y%m%d%H%M%S'), f'frame_timestamps:{col_name}'
return DEFAULT_DATE_NAME, 'default'
def resolve_group_identifiers(
group_uuid: str,
car_name: str,
date_name: str,
df: pd.DataFrame,
camera_id: str,
) -> Tuple[str, str, List[str]]:
"""
仅在 car_name/date 为缺省值时,尝试从 group 元数据和时间戳中自动补全。
"""
resolved_car_name = car_name
resolved_date_name = date_name
messages: List[str] = []
need_car_name = _is_missing_car_name(car_name)
need_date_name = _is_missing_date_name(date_name)
if not need_car_name and not need_date_name:
return resolved_car_name, resolved_date_name, messages
group_meta = get_group_meta(group_uuid)
if need_car_name:
resolved_car_name, source = _infer_car_name(group_meta)
messages.append(f'[{group_uuid}] 自动补全 car_name={resolved_car_name} (source={source})')
if need_date_name:
resolved_date_name, source = _infer_date_name(group_meta, df, camera_id)
messages.append(f'[{group_uuid}] 自动补全 date={resolved_date_name} (source={source})')
return resolved_car_name, resolved_date_name, messages
# ---------------------------------------------------------------------------
# 图片保存
# ---------------------------------------------------------------------------
def save_images(
group_uuid: str,
car_name: str,
date_name: str,
valid_ids: Optional[List[int]],
clip_paths: List[str],
camera_id: str,
df: pd.DataFrame,
timestamp_index: Dict[int, int],
output_path: str,
image_suffix: str,
) -> int:
"""
解析视频流并保存为指定后缀图片。
Args:
valid_ids: 要保存的 json_index 列表;若为 None 则保存所有帧。
Returns:
保存的图片数量。
"""
import cv2
from pdcl_pyclip.decoder_struct import StructDecoder
from pdcl_pyclip.msg_camera import VideoMessage
from pdcl_pyclip.reader import ClipReader
image_dir = os.path.join(output_path, 'images')
os.makedirs(image_dir, exist_ok=True)
valid_set = set(valid_ids) if valid_ids is not None else None
struct_decoder = StructDecoder()
saved = 0
for clip_path in clip_paths:
reader = ClipReader(clip_path)
for schema, channel, msg in reader.iter_messages(topics=[camera_id]):
if schema.encoding != 'struct':
continue
data = struct_decoder.decode(schema, channel, msg)
if not isinstance(data, VideoMessage):
continue
json_index = _find_index_by_timestamp_fast(timestamp_index, df, camera_id, msg.log_time)
if valid_set is not None and json_index not in valid_set:
continue
frame_timestamp_ns = _get_frame_timestamp_ns(df, camera_id, json_index)
if frame_timestamp_ns is not None and msg.log_time != frame_timestamp_ns:
print(
f" [WARN] 帧 {json_index} 时间戳不一致: "
f"msg.log_time={msg.log_time}, df[{camera_id}]={frame_timestamp_ns}"
)
try:
img = _h265_bytes_to_bgr(data.payload)
except Exception as e:
print(f" [WARN] 帧 {json_index} 解码失败: {e}")
continue
fname = (
f'{_build_frame_basename(car_name, date_name, group_uuid, json_index, camera_id, df)}'
f'.{image_suffix}'
)
if not cv2.imwrite(os.path.join(image_dir, fname), img):
print(f" [WARN] 帧 {json_index} 写图失败: {fname}")
continue
saved += 1
return saved
# ---------------------------------------------------------------------------
# 标注 / 标定文件解包
# ---------------------------------------------------------------------------
def extract_annotations_and_calib(
group_uuid: str,
car_name: str,
date_name: str,
product_type: str,
camera_id: str,
df: pd.DataFrame,
output_path: str,
annotation_stride: int = DEFAULT_ANNOTATION_STRIDE,
) -> List[int]:
"""
从真值产物 tar 包中提取标注 JSON 和标定文件。
Returns:
实际保存的有效帧 json_index 列表(仅目标 camera_id 下的帧)。
"""
tar_filename = f'{product_type}_archive.tar'
valid_ids: List[int] = []
candidate_annotation_count = 0
saved_annotation_names = set()
with Group(group_uuid) as group:
if group.get_product(product_type) is None:
print(f" [WARN] Group {group_uuid} 不含产物 {product_type},跳过")
return valid_ids
calib_dir = os.path.join(output_path, 'calib')
annotations_dir = os.path.join(output_path, 'annotations')
os.makedirs(calib_dir, exist_ok=True)
os.makedirs(annotations_dir, exist_ok=True)
with group.open(tar_filename, mode='rb') as f:
with tarfile.open(fileobj=f, mode='r|*') as tar:
for member in tar:
if 'depth_map_fg' in member.name:
continue
file_obj = tar.extractfile(member)
if file_obj is None:
continue
# ── 标定文件 ─────────────────────────────────────────
if 'calib' in member.name:
calib_fname = os.path.basename(member.name)
with open(os.path.join(calib_dir, calib_fname), 'wb') as fw:
fw.write(file_obj.read())
continue
# ── 标注 JSON ─────────────────────────────────────────
if _is_target_camera_annotation_member(member.name, camera_id):
json_fname = os.path.basename(member.name)
try:
json_index = int(json_fname[:-5])
except ValueError:
print(f" [WARN] 无法解析帧号: {member.name}")
continue
candidate_annotation_count += 1
if (candidate_annotation_count - 1) % annotation_stride != 0:
continue
try:
new_fname = (
f'{_build_frame_basename(car_name, date_name, group_uuid, json_index, camera_id, df)}.json'
)
except (KeyError, IndexError, ValueError):
new_fname = f'{car_name}_{date_name}_{group_uuid}_{json_index:06d}.json'
if new_fname in saved_annotation_names:
print(f" [WARN] 重复标注输出名,跳过覆盖: {new_fname} <- {member.name}")
continue
with open(os.path.join(annotations_dir, new_fname), 'wb') as fw:
fw.write(file_obj.read())
saved_annotation_names.add(new_fname)
valid_ids.append(json_index)
print(
f" [INFO] annotations({camera_id}) 保存 {len(valid_ids)} / "
f"{candidate_annotation_count} (annotation_stride={annotation_stride})"
)
return valid_ids
# ---------------------------------------------------------------------------
# 单个 group 的完整处理流程
# ---------------------------------------------------------------------------
def process_group(
group_uuid: str,
car_name: str,
date_name: str,
output_root: str,
product_type: str = '2d-3d-association',
camera_id: str = 'camera4',
image_suffix: str = DEFAULT_IMAGE_SUFFIX,
annotation_stride: int = DEFAULT_ANNOTATION_STRIDE,
skip_images: bool = False,
) -> Tuple[str, str]:
"""
处理单个 group提取标注、标定文件并按需保存对应图片。
Returns:
(group_uuid, status_message)
"""
output_path = os.path.join(output_root, group_uuid)
try:
df = get_frame_timestamps(group_uuid)
car_name, date_name, messages = resolve_group_identifiers(
group_uuid, car_name, date_name, df, camera_id
)
print(
f"[{group_uuid}] 开始处理 "
f"(car={car_name}, date={date_name}, "
f"annotation_stride={annotation_stride}, skip_images={skip_images})"
)
for message in messages:
print(message)
# 1. 提取标注 + 标定,获取有效帧列表
valid_ids = extract_annotations_and_calib(
group_uuid,
car_name,
date_name,
product_type,
camera_id,
df,
output_path,
annotation_stride,
)
if not valid_ids:
print(f"[{group_uuid}] 无有效标注帧,跳过后续处理")
return group_uuid, 'success (no valid frames)'
valid_ids.sort()
if skip_images:
print(f"[{group_uuid}] 共 {len(valid_ids)} 个有效帧,按要求跳过图片保存")
return group_uuid, f'success ({len(valid_ids)} annotations only)'
print(f"[{group_uuid}] 共 {len(valid_ids)} 个有效帧,开始保存图片")
# 2. 保存对应帧图片
timestamp_index = _build_timestamp_index(df, camera_id)
clip_paths = get_clip_paths(group_uuid)
n_saved = save_images(
group_uuid, car_name, date_name,
valid_ids, clip_paths, camera_id, df, timestamp_index, output_path, image_suffix
)
print(f"[{group_uuid}] 完成,保存图片 {n_saved}")
return group_uuid, f'success ({n_saved} images)'
except Exception as e:
import traceback
print(f"[{group_uuid}] 处理失败: {e}\n{traceback.format_exc()}")
return group_uuid, f'failed: {e}'
# ---------------------------------------------------------------------------
# 多进程入口
# ---------------------------------------------------------------------------
def _worker(args):
return process_group(*args)
def run(
tasks: List[Tuple[str, str, str]], # [(group_uuid, car_name, date_name), ...]
output_root: str,
product_type: str,
camera_id: str,
image_suffix: str,
annotation_stride: int,
skip_images: bool,
workers: int,
):
if not tasks:
print('没有需要处理的 group')
return
full_tasks = [
(
uuid,
car,
date,
output_root,
product_type,
camera_id,
image_suffix,
annotation_stride,
skip_images,
)
for uuid, car, date in tasks
]
actual_workers = max(1, min(workers, len(full_tasks)))
if actual_workers == 1:
results = [_worker(t) for t in full_tasks]
else:
chunksize = max(1, ceil(len(full_tasks) / (actual_workers * 4)))
with Pool(processes=actual_workers) as pool:
results = list(pool.imap_unordered(_worker, full_tasks, chunksize=chunksize))
success = sum(1 for _, r in results if r.startswith('success'))
failed = sum(1 for _, r in results if r.startswith('failed'))
print(f"\n完成!总计 {len(results)} 成功 {success} 失败 {failed} workers {actual_workers}")
for uuid, res in results:
if res.startswith('failed'):
print(f" FAILED {uuid}: {res}")
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_uuid_file(filepath: str) -> List[Tuple[str, str, str]]:
"""
解析 UUID 文本文件,每行格式:
group_uuid [date] [car_name]
缺省 date/car_name 时会在处理时自动尝试补全。
"""
tasks = []
seen = set()
duplicate_count = 0
with open(filepath) as f:
for line in f:
parts = line.strip().split()
if not parts:
continue
uuid = parts[0]
date = parts[1] if len(parts) > 1 else DEFAULT_DATE_NAME
car_name = parts[2] if len(parts) > 2 else DEFAULT_CAR_NAME
task = (uuid, car_name, date)
if task in seen:
duplicate_count += 1
continue
seen.add(task)
tasks.append(task)
if duplicate_count:
print(f"[uuid_file] 跳过重复 group_uuid {duplicate_count}")
return tasks
def main():
parser = argparse.ArgumentParser(description='按 group_id 提取 PDCL 真值产物')
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument('--group_ids', nargs='+', metavar='UUID',
help='直接指定一个或多个 group UUID')
src.add_argument('--uuid_file', metavar='FILE',
help='UUID 列表文件(每行: uuid [date] [car_name]')
parser.add_argument('--output', required=True, help='输出根目录')
parser.add_argument('--car_name', default=DEFAULT_CAR_NAME,
help='车辆标识;默认会尝试从 group 元数据自动推断')
parser.add_argument('--date', default=DEFAULT_DATE_NAME,
help='日期字符串;默认会尝试从 collection_time 或时间戳自动推断')
parser.add_argument('--product_type', default='2d-3d-association',
help='真值产物类型(默认: 2d-3d-association')
parser.add_argument('--camera_id', default='camera4', help='相机 topic 名称')
parser.add_argument('--image_suffix', default=DEFAULT_IMAGE_SUFFIX,
help='输出图像后缀,支持 png/jpg/jpeg默认: png')
parser.add_argument('--annotation_stride', type=int, default=DEFAULT_ANNOTATION_STRIDE,
help='标注保存抽样步长1 为全保存2 为 2 抽 1同时图片也会与保留标注对齐')
parser.add_argument('--skip_images', action='store_true',
help='只提取 annotations/calib不解码或保存图片帧')
parser.add_argument('--workers', type=int, default=4, help='并行进程数')
args = parser.parse_args()
image_suffix = normalize_image_suffix(args.image_suffix)
annotation_stride = normalize_annotation_stride(args.annotation_stride)
if args.uuid_file:
tasks = parse_uuid_file(args.uuid_file)
else:
tasks = [(uuid, args.car_name, args.date) for uuid in args.group_ids]
os.makedirs(args.output, exist_ok=True)
run(
tasks,
args.output,
args.product_type,
args.camera_id,
image_suffix,
annotation_stride,
args.skip_images,
args.workers,
)
if __name__ == '__main__':
main()