Files
HSAP/platform/as_platform/data/ingest/dms_coco.py

89 lines
2.8 KiB
Python
Raw Normal View History

"""DMS COCO-format adapter."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from as_platform.data.ingest.base import IngestAdapter, IngestContext, NormalizedDataset
COCO_NAMES = ("instances_train.json", "instances_val.json", "instances_test.json", "annotations.json")
def _read_json(path: Path) -> dict[str, Any] | None:
try:
return json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
class DmsCocoAdapter(IngestAdapter):
format_id = "dms_coco"
projects = ("dms",)
def _find_coco_files(self, root: Path) -> list[Path]:
files: list[Path] = []
for name in COCO_NAMES:
p = root / "annotations" / name
if p.is_file():
files.append(p)
for name in COCO_NAMES:
p = root / name
if p.is_file():
files.append(p)
return files
def can_handle(self, ctx: IngestContext) -> bool:
root = ctx.source_path
return len(self._find_coco_files(root)) > 0
def inspect(self, ctx: IngestContext) -> NormalizedDataset:
root = ctx.source_path
files = self._find_coco_files(root)
split_counts = {"train": 0, "val": 0, "test": 0}
ann_count = 0
categories: set[str] = set()
warnings: list[str] = []
for f in files:
data = _read_json(f)
if not data:
warnings.append(f"failed to parse {f.name}")
continue
images = data.get("images") or []
anns = data.get("annotations") or []
cats = data.get("categories") or []
ann_count += len(anns)
for c in cats:
name = c.get("name")
if isinstance(name, str):
categories.add(name)
lower = f.name.lower()
if "train" in lower:
split_counts["train"] += len(images)
elif "val" in lower:
split_counts["val"] += len(images)
elif "test" in lower:
split_counts["test"] += len(images)
else:
split_counts["train"] += len(images)
return NormalizedDataset(
format_id=self.format_id,
project=ctx.project,
task=ctx.task,
source_path=str(root),
split_counts=split_counts,
sample_count=sum(split_counts.values()),
annotation_count=ann_count,
artifacts=[self._artifact_name(root, f) for f in files],
warnings=warnings,
extra={"categories": sorted(categories)},
)
@staticmethod
def _artifact_name(root: Path, path: Path) -> str:
try:
return str(path.relative_to(root))
except ValueError:
return path.name