Huaxu Sentinel Active Safety Platform with embedded algorithm code, Docker Compose setup, and vendored dataset scaffolds for clone-and-run. Co-authored-by: Cursor <cursoragent@cursor.com>
87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
"""LangChain 风格 Tool 注册(纯 Python + 可选 langchain)。"""
|
||
from __future__ import annotations
|
||
|
||
from typing import Any, Callable
|
||
|
||
from as_platform.audit.queue import submit_approval
|
||
from as_platform.data.core import get_catalog, get_pending_report, load_wf
|
||
from as_platform.jobs.queue import get_job, list_jobs
|
||
|
||
import yaml
|
||
from pathlib import Path
|
||
|
||
WORKSPACE = Path(__file__).resolve().parents[2]
|
||
|
||
|
||
def list_pending_batches() -> dict[str, Any]:
|
||
return get_pending_report()
|
||
|
||
|
||
def get_dataset_catalog() -> dict[str, Any]:
|
||
return get_catalog()
|
||
|
||
|
||
def submit_build_for_batch(task: str, batch: str, pack: str = "dms_v2", submitted_by: str | None = None) -> dict:
|
||
return submit_approval(
|
||
"build_dms",
|
||
{"task": task, "pack": pack, "batch": batch},
|
||
submitted_by=submitted_by,
|
||
note=f"agent build {batch}",
|
||
)
|
||
|
||
|
||
def submit_train_job(project: str, task: str, track: str = "platform", submitted_by: str | None = None) -> dict:
|
||
action = "train_dms" if project == "dms" else "train_lane"
|
||
params: dict[str, Any] = {"track": track}
|
||
if project == "dms":
|
||
params["task"] = task
|
||
return submit_approval(action, params, submitted_by=submitted_by, note=f"agent train {project}/{task}")
|
||
|
||
|
||
def get_job_status(job_id: str) -> dict[str, Any] | None:
|
||
return get_job(job_id)
|
||
|
||
|
||
def get_model_versions(task: str) -> dict[str, Any]:
|
||
root = WORKSPACE / "datasets/dms/manifests/train_versions.yaml"
|
||
if not root.is_file():
|
||
return {}
|
||
data = yaml.safe_load(root.read_text(encoding="utf-8"))
|
||
return data.get(task, {})
|
||
|
||
|
||
TOOL_REGISTRY: dict[str, Callable[..., Any]] = {
|
||
"list_pending_batches": list_pending_batches,
|
||
"get_dataset_catalog": get_dataset_catalog,
|
||
"submit_build_for_batch": submit_build_for_batch,
|
||
"submit_train_job": submit_train_job,
|
||
"get_job_status": get_job_status,
|
||
"get_model_versions": get_model_versions,
|
||
}
|
||
|
||
|
||
def invoke_tool(name: str, **kwargs: Any) -> Any:
|
||
fn = TOOL_REGISTRY.get(name)
|
||
if not fn:
|
||
raise ValueError(f"未知 tool: {name}")
|
||
return fn(**kwargs)
|
||
|
||
|
||
def as_langchain_tools() -> list[Any]:
|
||
try:
|
||
from langchain_core.tools import tool
|
||
except ImportError:
|
||
return []
|
||
|
||
@tool
|
||
def t_list_pending_batches() -> dict:
|
||
"""列出待处理批次与送标状态。"""
|
||
return list_pending_batches()
|
||
|
||
@tool
|
||
def t_get_dataset_catalog() -> dict:
|
||
"""获取 DMS/Lane 数据目录统计。"""
|
||
return get_dataset_catalog()
|
||
|
||
return [t_list_pending_batches, t_get_dataset_catalog]
|