Files
HSAP/platform/as_platform/labeling/batch_index.py

295 lines
9.9 KiB
Python
Raw Normal View History

"""批次索引:扫盘结果落库,列表页只查 DB<200ms"""
from __future__ import annotations
import time
from datetime import datetime, timezone
from typing import Any
from as_platform.data.core import get_pending_report, load_wf
from as_platform.db.engine import session_scope
from as_platform.db.models import BatchIndex, LabelingCampaign
from as_platform.labeling.scope import enrich_batch_labels, load_dms_registry
from as_platform.labeling.stage import STAGE_ALIASES, effective_stage
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
def campaign_id_for_row(project: str, task: str, mode: str | None, batch: str, location: str) -> str:
from as_platform.labeling.service import _campaign_id
return _campaign_id(project, task, mode, batch, location)
def _batch_dict_to_fields(b: dict[str, Any], reg: dict) -> dict[str, Any] | None:
if b.get("registry_only") and not b.get("batch"):
return None
row = enrich_batch_labels(b, reg)
if row.get("registry_only"):
pass
raw_stage = row.get("stage")
eff = effective_stage(raw_stage) or raw_stage or "raw_pool"
allowed = (
"raw_pool", "out_for_labeling", "returned", "labeling_submitted",
"in_review", "review_approved", "review_rejected", "ingested",
)
if eff not in allowed and raw_stage not in allowed:
return None
project = row.get("project") or "dms"
task = row.get("task") or ""
mode = row.get("mode")
batch = row.get("batch") or ""
location = row.get("location") or "inbox"
counts = row.get("counts") or {}
cid = campaign_id_for_row(project, task, mode, batch, location)
return {
"campaign_id": cid,
"project": project,
"task": task or None,
"mode": mode,
"batch": batch,
"pack": row.get("pack"),
"location": location,
"stage": eff,
"batch_path": row.get("path"),
"scope_key": row.get("scope_key"),
"engineer": row.get("engineer"),
"format": row.get("format"),
"image_count": int(counts.get("images") or 0),
"label_count": int(counts.get("labels") or 0),
"has_meta": bool(row.get("has_meta")),
"registry_only": bool(row.get("registry_only")),
"next_cli": row.get("next_cli"),
"domain": row.get("domain"),
"domain_label": row.get("domain_label"),
"task_label": row.get("task_label"),
"mode_label": row.get("mode_label"),
"labeling_profile": row.get("labeling_profile"),
"export_default": row.get("export_default"),
"ml_adapter": row.get("ml_adapter"),
"indexed_at": _utcnow(),
}
def upsert_batch_dict(b: dict[str, Any], *, reg: dict | None = None) -> str | None:
"""登记/单批更新时写入索引。"""
reg = reg or load_dms_registry()
fields = _batch_dict_to_fields(b, reg)
if not fields:
return None
cid = fields["campaign_id"]
fields["archived"] = False
with session_scope() as db:
rec = db.get(BatchIndex, cid)
if rec:
for k, v in fields.items():
setattr(rec, k, v)
else:
db.add(BatchIndex(**fields))
db.commit()
return cid
def archive_batch(campaign_id: str) -> dict[str, Any]:
"""软删除:仅从工作台隐藏,不删磁盘数据。"""
with session_scope() as db:
rec = db.get(BatchIndex, campaign_id)
if not rec:
raise FileNotFoundError(campaign_id)
if rec.archived:
return {"ok": True, "campaign_id": campaign_id, "already_archived": True}
if rec.stage != "raw_pool":
raise ValueError("仅「待送标」批次可移除")
camp = db.get(LabelingCampaign, campaign_id)
if camp and camp.status not in ("not_opened", ""):
raise ValueError("已开标的批次不可移除,请先在标注进度中处理")
rec.archived = True
rec.indexed_at = _utcnow()
db.commit()
return {"ok": True, "campaign_id": campaign_id}
def sync_index_stage(campaign_id: str, stage: str) -> None:
eff = effective_stage(stage) or stage
with session_scope() as db:
rec = db.get(BatchIndex, campaign_id)
if rec:
rec.stage = eff
rec.indexed_at = _utcnow()
db.commit()
def index_count() -> int:
with session_scope() as db:
return db.query(BatchIndex).count()
def index_is_empty() -> bool:
return index_count() == 0
def get_batch_by_campaign_id(campaign_id: str) -> dict[str, Any] | None:
with session_scope() as db:
rec = db.get(BatchIndex, campaign_id)
return rec.to_list_row() if rec else None
def rebuild_batch_index(wf: dict | None = None) -> dict[str, Any]:
"""全量扫盘并重建索引(刷新/首次加载/登记后批量同步)。"""
from as_platform.labeling.service import _registry_fallback_batches
t0 = time.perf_counter()
wf = wf or load_wf()
reg = load_dms_registry()
report = get_pending_report(wf)
candidates: list[dict[str, Any]] = list(report.get("batches") or [])
candidates.extend(_registry_fallback_batches(wf, reg))
seen: set[str] = set()
upserted = 0
with session_scope() as db:
for b in candidates:
fields = _batch_dict_to_fields(b, reg)
if not fields:
continue
cid = fields["campaign_id"]
seen.add(cid)
rec = db.get(BatchIndex, cid)
if rec and rec.archived:
continue
if rec:
for k, v in fields.items():
setattr(rec, k, v)
else:
db.add(BatchIndex(**fields))
upserted += 1
if seen:
db.query(BatchIndex).filter(
BatchIndex.campaign_id.notin_(seen),
BatchIndex.archived.is_(False),
).delete(synchronize_session=False)
db.commit()
elapsed_ms = round((time.perf_counter() - t0) * 1000)
return {
"ok": True,
"count": upserted,
"elapsed_ms": elapsed_ms,
"updated_at": report.get("updated_at"),
}
def _expand_stage_filters_for_sql(stage_filters: list[str]) -> list[str]:
"""将筛选阶段展开为索引表中的实际 stage 值(含 review_approved 等别名)。"""
expanded: set[str] = set()
for sf in stage_filters:
expanded.add(sf)
for alias, canonical in STAGE_ALIASES.items():
if canonical == sf:
expanded.add(alias)
return list(expanded)
def _index_query(db, *, stage_filters: list[str], q: str | None):
from sqlalchemy import func, or_
from as_platform.config import IS_POSTGRES
query = db.query(BatchIndex).filter(
BatchIndex.registry_only.is_(False),
BatchIndex.archived.is_(False),
)
if stage_filters:
query = query.filter(BatchIndex.stage.in_(_expand_stage_filters_for_sql(stage_filters)))
text = (q or "").strip()
if text:
pattern = f"%{text}%"
if IS_POSTGRES:
query = query.filter(
or_(
BatchIndex.batch.ilike(pattern),
BatchIndex.task.ilike(pattern),
BatchIndex.project.ilike(pattern),
)
)
else:
lp = pattern.lower()
query = query.filter(
or_(
func.lower(BatchIndex.batch).like(lp),
func.lower(func.coalesce(BatchIndex.task, "")).like(lp),
func.lower(BatchIndex.project).like(lp),
)
)
return query
def list_batches_from_index(
*,
stage: str | None = None,
stages: list[str] | None = None,
offset: int = 0,
limit: int = 20,
q: str | None = None,
) -> dict[str, Any]:
stage_filters = [s.strip() for s in (stages or []) if s and s.strip()]
if stage and stage not in stage_filters:
stage_filters.append(stage)
with session_scope() as db:
base = _index_query(db, stage_filters=stage_filters, q=q)
total = base.count()
recs = (
base.order_by(BatchIndex.indexed_at.desc())
.offset(max(0, offset))
.limit(max(1, limit))
.all()
)
page = [rec.to_list_row() for rec in recs]
latest_indexed = max((rec.indexed_at for rec in recs), default=None)
cids = [r["campaign_id"] for r in page if r.get("campaign_id")]
camp_map: dict[str, dict[str, Any]] = {}
if cids:
with session_scope() as db:
camps = db.query(LabelingCampaign).filter(LabelingCampaign.id.in_(cids)).all()
for c in camps:
camp_map[c.id] = {
"status": c.status,
"assigned_to_user_id": c.assigned_to_user_id,
"assigned_to_name": c.assigned_to_name,
}
progress_stages = frozenset({"out_for_labeling", "in_progress"})
out: list[dict[str, Any]] = []
for row in page:
cid = row.get("campaign_id")
camp = camp_map.get(cid) if cid else None
status = camp["status"] if camp else "not_opened"
if camp:
row["assigned_to_user_id"] = camp["assigned_to_user_id"]
row["assigned_to_name"] = camp["assigned_to_name"]
row["campaign_status"] = status
eff_stage = row.get("stage") or ""
if camp and (status in progress_stages or eff_stage in progress_stages):
try:
from as_platform.labeling.progress import campaign_progress_summary
row.update(campaign_progress_summary(cid))
except Exception:
row.update({"total_tasks": 0, "completed_tasks": 0, "assigned_tasks": 0})
out.append(row)
updated_at = latest_indexed.isoformat() if latest_indexed else None
return {
"items": out,
"total": total,
"offset": offset,
"limit": limit,
"updated_at": updated_at,
"source": "index",
}