Files
HSAP/platform/as_platform/agents/tools.py
Chengfang Lu 7c43b44c57 feat: initial HSAP platform
Huaxu Sentinel Active Safety Platform with embedded algorithm code,
Docker Compose setup, and vendored dataset scaffolds for clone-and-run.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-25 16:59:59 +08:00

87 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""LangChain 风格 Tool 注册(纯 Python + 可选 langchain"""
from __future__ import annotations
from typing import Any, Callable
from as_platform.audit.queue import submit_approval
from as_platform.data.core import get_catalog, get_pending_report, load_wf
from as_platform.jobs.queue import get_job, list_jobs
import yaml
from pathlib import Path
WORKSPACE = Path(__file__).resolve().parents[2]
def list_pending_batches() -> dict[str, Any]:
return get_pending_report()
def get_dataset_catalog() -> dict[str, Any]:
return get_catalog()
def submit_build_for_batch(task: str, batch: str, pack: str = "dms_v2", submitted_by: str | None = None) -> dict:
return submit_approval(
"build_dms",
{"task": task, "pack": pack, "batch": batch},
submitted_by=submitted_by,
note=f"agent build {batch}",
)
def submit_train_job(project: str, task: str, track: str = "platform", submitted_by: str | None = None) -> dict:
action = "train_dms" if project == "dms" else "train_lane"
params: dict[str, Any] = {"track": track}
if project == "dms":
params["task"] = task
return submit_approval(action, params, submitted_by=submitted_by, note=f"agent train {project}/{task}")
def get_job_status(job_id: str) -> dict[str, Any] | None:
return get_job(job_id)
def get_model_versions(task: str) -> dict[str, Any]:
root = WORKSPACE / "datasets/dms/manifests/train_versions.yaml"
if not root.is_file():
return {}
data = yaml.safe_load(root.read_text(encoding="utf-8"))
return data.get(task, {})
TOOL_REGISTRY: dict[str, Callable[..., Any]] = {
"list_pending_batches": list_pending_batches,
"get_dataset_catalog": get_dataset_catalog,
"submit_build_for_batch": submit_build_for_batch,
"submit_train_job": submit_train_job,
"get_job_status": get_job_status,
"get_model_versions": get_model_versions,
}
def invoke_tool(name: str, **kwargs: Any) -> Any:
fn = TOOL_REGISTRY.get(name)
if not fn:
raise ValueError(f"未知 tool: {name}")
return fn(**kwargs)
def as_langchain_tools() -> list[Any]:
try:
from langchain_core.tools import tool
except ImportError:
return []
@tool
def t_list_pending_batches() -> dict:
"""列出待处理批次与送标状态。"""
return list_pending_batches()
@tool
def t_get_dataset_catalog() -> dict:
"""获取 DMS/Lane 数据目录统计。"""
return get_dataset_catalog()
return [t_list_pending_batches, t_get_dataset_catalog]