From 071a25168a234da8ab0bf7acc4a849ad50a70d20 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 4 May 2026 12:01:52 -0500 Subject: [PATCH 001/119] Add pluggable execution backend for Parsl, EnsembleLauncher, and Globus Compute Introduce a unified execution module with an abstract ExecutionBackend interface and TaskSpec model, supporting four backends: local (ProcessPoolExecutor), Parsl, EnsembleLauncher, and Globus Compute. Includes config factory with resolution order (args > env > config.toml), HPC configs loader, comprehensive tests, and pytest --run-globus-compute option for live endpoint tests. --- config.toml | 18 + pyproject.toml | 7 + src/chemgraph/execution/__init__.py | 33 + src/chemgraph/execution/base.py | 144 +++ src/chemgraph/execution/config.py | 163 +++ .../execution/ensemble_launcher_backend.py | 199 ++++ .../execution/globus_compute_backend.py | 131 +++ src/chemgraph/execution/local_backend.py | 119 ++ src/chemgraph/execution/parsl_backend.py | 122 ++ src/chemgraph/execution/utils.py | 175 +++ src/chemgraph/hpc_configs/__init__.py | 1 + src/chemgraph/hpc_configs/loader.py | 65 ++ src/chemgraph/hpc_configs/local_parsl.py | 60 + src/chemgraph/mcp/graspa_mcp_hpc.py | 124 ++ src/chemgraph/mcp/mace_mcp_hpc.py | 178 +++ src/chemgraph/mcp/xanes_mcp_hpc.py | 227 ++++ tests/conftest.py | 22 +- tests/test_execution.py | 1017 +++++++++++++++++ 18 files changed, 2799 insertions(+), 6 deletions(-) create mode 100644 src/chemgraph/execution/__init__.py create mode 100644 src/chemgraph/execution/base.py create mode 100644 src/chemgraph/execution/config.py create mode 100644 src/chemgraph/execution/ensemble_launcher_backend.py create mode 100644 src/chemgraph/execution/globus_compute_backend.py create mode 100644 src/chemgraph/execution/local_backend.py create mode 100644 src/chemgraph/execution/parsl_backend.py create mode 100644 src/chemgraph/execution/utils.py create mode 100644 src/chemgraph/hpc_configs/__init__.py create mode 100644 src/chemgraph/hpc_configs/loader.py create mode 100644 src/chemgraph/hpc_configs/local_parsl.py create mode 100644 src/chemgraph/mcp/graspa_mcp_hpc.py create mode 100644 src/chemgraph/mcp/mace_mcp_hpc.py create mode 100644 src/chemgraph/mcp/xanes_mcp_hpc.py create mode 100644 tests/test_execution.py diff --git a/config.toml b/config.toml index 49319a6f..6295bc7d 100644 --- a/config.toml +++ b/config.toml @@ -91,6 +91,24 @@ enable_function_calling = true enable_parallel = false num_workers = 2 +[execution] +backend = "local" +system = "local" + +[execution.local] +max_workers = 4 + +[execution.parsl] +worker_init = "export TMPDIR=/tmp" + +[execution.ensemble_launcher] +comm_name = "async_zmq" +task_executor_name = "async_processpool" +nlevels = 0 + +[execution.globus_compute] +endpoint_id = "" + [environments.development] model = "gpt-4o-mini" verbose = true diff --git a/pyproject.toml b/pyproject.toml index ee72fd8e..692b0388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,12 @@ ui = [ parsl = [ "parsl", ] +ensemble_launcher = [ + "ensemble-launcher", +] +globus_compute = [ + "globus-compute-sdk", +] xanes = [ "mp-api; python_version >= '3.11'", "parsl" @@ -104,6 +110,7 @@ skip-magic-trailing-comma = false # Ensure Black-style formatting [tool.pytest.ini_options] markers = [ "llm: marks tests as requiring LLM API access (run with --run-llm)", + "globus_compute: marks tests requiring a live Globus Compute endpoint (run with --run-globus-compute)", "asyncio: marks async tests", ] filterwarnings = [ diff --git a/src/chemgraph/execution/__init__.py b/src/chemgraph/execution/__init__.py new file mode 100644 index 00000000..0fd6709b --- /dev/null +++ b/src/chemgraph/execution/__init__.py @@ -0,0 +1,33 @@ +"""Pluggable execution backends for ChemGraph HPC workloads. + +This package provides a backend-agnostic interface for submitting +computational tasks to different workflow managers (Parsl, +EnsembleLauncher, Globus Compute, local process pool). + +Quick start +----------- +>>> from chemgraph.execution import get_backend, TaskSpec +>>> backend = get_backend() # reads config.toml / env vars +>>> future = backend.submit(TaskSpec( +... task_id="test-1", +... task_type="python", +... callable=my_function, +... kwargs={"param": 42}, +... )) +>>> result = future.result() +>>> backend.shutdown() + +See Also +-------- +:mod:`chemgraph.execution.base` -- abstract classes +:mod:`chemgraph.execution.config` -- factory function +""" + +from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.config import get_backend + +__all__ = [ + "ExecutionBackend", + "TaskSpec", + "get_backend", +] diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py new file mode 100644 index 00000000..e7dc338b --- /dev/null +++ b/src/chemgraph/execution/base.py @@ -0,0 +1,144 @@ +"""Abstract base classes for execution backends. + +This module defines the ``ExecutionBackend`` protocol and the ``TaskSpec`` +data model that all workflow managers (Parsl, EnsembleLauncher, local +process pool, etc.) must implement. Downstream code (MCP servers, tools) +only depends on these abstractions -- never on a concrete backend. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from concurrent.futures import Future +from typing import Any, Callable, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field + +logger = logging.getLogger(__name__) + + +class TaskSpec(BaseModel): + """Specification for a single unit of work to submit to a backend. + + Supports two execution modes: + + * **python** -- run a Python callable (``callable(*args, **kwargs)``) + * **shell** -- run a shell command string + + Resource hints (``num_nodes``, ``processes_per_node``, ``gpus_per_task``) + are advisory; backends may ignore hints they do not support. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + task_id: str = Field( + description="Unique identifier for this task within the batch." + ) + task_type: Literal["python", "shell"] = Field( + default="python", + description="Execution mode: 'python' for a callable, 'shell' for a command.", + ) + + # ── Python task fields ────────────────────────────────────────────── + callable: Optional[Callable[..., Any]] = Field( + default=None, + description="Python callable to execute (required when task_type='python').", + ) + args: tuple = Field( + default=(), + description="Positional arguments forwarded to the callable.", + ) + kwargs: dict = Field( + default_factory=dict, + description="Keyword arguments forwarded to the callable.", + ) + + # ── Shell task fields ─────────────────────────────────────────────── + command: Optional[str] = Field( + default=None, + description="Shell command to execute (required when task_type='shell').", + ) + working_dir: Optional[str] = Field( + default=None, + description="Working directory for the shell command.", + ) + stdout: Optional[str] = Field( + default=None, + description="Path to capture stdout (shell tasks).", + ) + stderr: Optional[str] = Field( + default=None, + description="Path to capture stderr (shell tasks).", + ) + + # ── Resource hints ────────────────────────────────────────────────── + num_nodes: int = Field( + default=1, + description="Number of compute nodes requested.", + ) + processes_per_node: int = Field( + default=1, + description="Number of processes (ranks) per node.", + ) + gpus_per_task: int = Field( + default=0, + description="Number of GPUs requested per task.", + ) + + +class ExecutionBackend(ABC): + """Abstract interface that every workflow-manager adapter must implement. + + Lifecycle + --------- + 1. ``initialize(system, **kwargs)`` -- start the backend + 2. ``submit(task)`` / ``submit_batch(tasks)`` -- dispatch work + 3. ``shutdown()`` -- release resources + + The class also supports the context-manager protocol (``with`` statement). + """ + + def __init__(self) -> None: + self._initialized: bool = False + + @abstractmethod + def initialize(self, system: str = "local", **kwargs: Any) -> None: + """Prepare the backend for accepting work. + + Parameters + ---------- + system : str + Target HPC system name (e.g. ``"polaris"``, ``"aurora"``, + ``"local"``). Backends may use this to load system-specific + configurations. + **kwargs + Backend-specific options (worker_init, run_dir, etc.). + """ + + @abstractmethod + def submit(self, task: TaskSpec) -> Future: + """Submit a single task and return a ``concurrent.futures.Future``. + + The future resolves to whatever the callable/command returns. + """ + + def submit_batch(self, tasks: list[TaskSpec]) -> list[Future]: + """Submit multiple tasks, returning futures in submission order. + + The default implementation simply loops over ``submit()``. + Backends may override this for optimized batch submission. + """ + return [self.submit(t) for t in tasks] + + @abstractmethod + def shutdown(self) -> None: + """Release all resources held by the backend.""" + + # ── Context-manager protocol ──────────────────────────────────────── + + def __enter__(self) -> ExecutionBackend: + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001 + self.shutdown() diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py new file mode 100644 index 00000000..71d3de90 --- /dev/null +++ b/src/chemgraph/execution/config.py @@ -0,0 +1,163 @@ +"""Execution backend configuration and factory. + +Reads the ``[execution]`` section from ``config.toml`` (or env-var +overrides) and returns an initialised :class:`ExecutionBackend` instance. + +Environment variables +--------------------- +``CHEMGRAPH_EXECUTION_BACKEND`` + Override the backend name (``"parsl"``, ``"ensemble_launcher"``, + ``"globus_compute"``, ``"local"``). +``COMPUTE_SYSTEM`` + Override the target HPC system (``"polaris"``, ``"aurora"``, + ``"local"``). +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from chemgraph.execution.base import ExecutionBackend + +logger = logging.getLogger(__name__) + +# Supported backend names (keep in sync with the ``elif`` chain below) +SUPPORTED_BACKENDS = ("parsl", "ensemble_launcher", "globus_compute", "local") + + +def _load_execution_config(config_path: Optional[str] = None) -> dict[str, Any]: + """Read the ``[execution]`` table from ``config.toml``. + + Returns an empty dict if the section is missing or the file is not + found, so callers always get sensible defaults. + """ + if config_path is None: + # Walk upward from CWD to find config.toml (same heuristic the + # rest of ChemGraph uses). + candidate = Path.cwd() / "config.toml" + if candidate.is_file(): + config_path = str(candidate) + else: + # Try the repo root (two levels up from this file). + repo_root = Path(__file__).resolve().parents[3] + candidate = repo_root / "config.toml" + if candidate.is_file(): + config_path = str(candidate) + + if config_path is None: + return {} + + try: + import toml + + full_config = toml.load(config_path) + return full_config.get("execution", {}) + except Exception as exc: # noqa: BLE001 + logger.warning("Could not read [execution] from %s: %s", config_path, exc) + return {} + + +def get_backend( + config_path: Optional[str] = None, + backend_name: Optional[str] = None, + system: Optional[str] = None, + **kwargs: Any, +) -> ExecutionBackend: + """Create and initialise an :class:`ExecutionBackend`. + + Resolution order for ``backend_name``: + + 1. Explicit ``backend_name`` argument + 2. ``CHEMGRAPH_EXECUTION_BACKEND`` environment variable + 3. ``config.toml`` ``[execution] backend`` key + 4. ``"local"`` (safe fallback) + + Resolution order for ``system``: + + 1. Explicit ``system`` argument + 2. ``COMPUTE_SYSTEM`` environment variable + 3. ``config.toml`` ``[execution] system`` key + 4. ``"local"`` + + Parameters + ---------- + config_path : str, optional + Path to ``config.toml``. Auto-detected when omitted. + backend_name : str, optional + Force a specific backend. + system : str, optional + Target HPC system name. + **kwargs + Extra keyword arguments forwarded to + :meth:`ExecutionBackend.initialize`. + + Returns + ------- + ExecutionBackend + A ready-to-use backend instance. + """ + cfg = _load_execution_config(config_path) + + # -- resolve backend name ------------------------------------------------- + resolved_backend = ( + backend_name + or os.getenv("CHEMGRAPH_EXECUTION_BACKEND") + or cfg.get("backend", "local") + ) + resolved_backend = resolved_backend.lower().strip() + + if resolved_backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"Unknown execution backend '{resolved_backend}'. " + f"Supported: {', '.join(SUPPORTED_BACKENDS)}" + ) + + # -- resolve system ------------------------------------------------------- + resolved_system = ( + system or os.getenv("COMPUTE_SYSTEM") or cfg.get("system", "local") + ) + + # -- merge backend-specific config ---------------------------------------- + backend_cfg = cfg.get(resolved_backend, {}) + merged_kwargs = {**backend_cfg, **kwargs} + + # -- instantiate ---------------------------------------------------------- + logger.info( + "Creating execution backend '%s' for system '%s'", + resolved_backend, + resolved_system, + ) + + if resolved_backend == "parsl": + from chemgraph.execution.parsl_backend import ParslBackend + + backend = ParslBackend() + + elif resolved_backend == "ensemble_launcher": + from chemgraph.execution.ensemble_launcher_backend import ( + EnsembleLauncherBackend, + ) + + backend = EnsembleLauncherBackend() + + elif resolved_backend == "globus_compute": + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + + elif resolved_backend == "local": + from chemgraph.execution.local_backend import LocalBackend + + backend = LocalBackend() + + else: + # Should be unreachable thanks to the validation above. + raise ValueError(f"Unsupported backend: {resolved_backend}") + + backend.initialize(system=resolved_system, **merged_kwargs) + return backend diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py new file mode 100644 index 00000000..23462f5b --- /dev/null +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -0,0 +1,199 @@ +"""EnsembleLauncher execution backend. + +Wraps `EnsembleLauncher `_ +to conform to the :class:`ExecutionBackend` interface. Uses the +cluster-mode API (``EnsembleLauncher.start()`` + ``ClusterClient``) so +that tasks can be submitted dynamically. + +EnsembleLauncher must be installed separately +(``pip install chemgraphagent[ensemble_launcher]``). +""" + +from __future__ import annotations + +import logging +import os +import socket +import time +import uuid +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class EnsembleLauncherBackend(ExecutionBackend): + """Execution backend that delegates work to EnsembleLauncher. + + The backend starts an EnsembleLauncher orchestrator in cluster mode + and submits tasks through a :class:`ClusterClient`. + + Configuration + ------------- + The following ``kwargs`` are accepted by :meth:`initialize`: + + ``comm_name`` : str + Communication backend (``"zmq"``, ``"async_zmq"``, ``"multiprocessing"``). + Default: ``"async_zmq"``. + ``task_executor_name`` : str + Task executor (``"multiprocessing"``, ``"mpi"``, + ``"async_processpool"``). Default: ``"async_processpool"``. + ``nlevels`` : int + Hierarchy depth. Default: ``0`` (single-node). + ``max_workers`` : int + Number of CPUs to expose. Default: ``os.cpu_count()``. + ``checkpoint_dir`` : str + Directory for orchestrator checkpoint files. Auto-generated + when omitted. + ``nodes`` : list[str] + List of compute node hostnames. Defaults to ``[hostname]``. + ``startup_delay`` : float + Seconds to wait after ``el.start()`` for the orchestrator to be + ready. Default: ``2.0``. + """ + + def __init__(self) -> None: + super().__init__() + self._el = None + self._client = None + self._checkpoint_dir: str | None = None + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + try: + from ensemble_launcher import EnsembleLauncher + from ensemble_launcher.config import LauncherConfig, SystemConfig + from ensemble_launcher.orchestrator import ClusterClient + except ImportError as exc: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) from exc + + # -- extract parameters ------------------------------------------------ + comm_name = kwargs.get("comm_name", "async_zmq") + task_executor = kwargs.get("task_executor_name", "async_processpool") + nlevels = kwargs.get("nlevels", 0) + ncpus = kwargs.get("max_workers", os.cpu_count() or 4) + checkpoint_dir = kwargs.get( + "checkpoint_dir", + os.path.join(os.getcwd(), f".el_ckpt_{uuid.uuid4().hex[:8]}"), + ) + nodes = kwargs.get("nodes", [socket.gethostname()]) + startup_delay = kwargs.get("startup_delay", 2.0) + + self._checkpoint_dir = checkpoint_dir + + # -- configure --------------------------------------------------------- + system_config = SystemConfig( + name=system, + ncpus=ncpus, + cpus=list(range(ncpus)), + ) + + launcher_config = LauncherConfig( + task_executor_name=task_executor, + comm_name=comm_name, + nlevels=nlevels, + cluster=True, + checkpoint_dir=checkpoint_dir, + ) + + # -- start orchestrator ------------------------------------------------ + self._el = EnsembleLauncher( + ensemble_file={}, + system_config=system_config, + launcher_config=launcher_config, + Nodes=nodes, + ) + self._el.start() + time.sleep(startup_delay) + + # -- connect client ---------------------------------------------------- + self._client = ClusterClient(checkpoint_dir=checkpoint_dir) + self._client.start() + + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized (system='%s', " + "comm='%s', executor='%s', nodes=%s)", + system, + comm_name, + task_executor, + nodes, + ) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._client is None: + raise RuntimeError( + "EnsembleLauncherBackend is not initialized. " + "Call initialize() first." + ) + + from ensemble_launcher.ensemble import Task as ELTask + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + el_task = ELTask( + task_id=task.task_id, + nnodes=task.num_nodes, + ppn=task.processes_per_node, + executable=task.callable, + args=task.args or (), + kwargs=task.kwargs or {}, + ) + return self._client.submit(el_task) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + el_task = ELTask( + task_id=task.task_id, + nnodes=task.num_nodes, + ppn=task.processes_per_node, + cmd_template=task.command, + ) + return self._client.submit(el_task) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + self._initialized = False + client_ok = True + if self._client is not None: + try: + self._client.teardown() + self._client = None + except Exception: + client_ok = False + logger.warning( + "Error tearing down EnsembleLauncher client.", exc_info=True + ) + + el_ok = True + if self._el is not None: + try: + self._el.stop() + self._el = None + except Exception: + el_ok = False + logger.warning( + "Error stopping EnsembleLauncher orchestrator.", exc_info=True + ) + + if client_ok and el_ok: + logger.info("EnsembleLauncherBackend shut down.") + else: + logger.warning( + "EnsembleLauncherBackend partially shut down. " + "Call shutdown() again to retry failed teardown." + ) diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py new file mode 100644 index 00000000..0c2a9634 --- /dev/null +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -0,0 +1,131 @@ +"""Globus Compute execution backend. + +Wraps the `Globus Compute SDK `_ +to conform to the :class:`ExecutionBackend` interface. Python tasks are +dispatched via :meth:`Executor.submit` and shell tasks via +:class:`ShellFunction`. + +Unlike the Parsl and EnsembleLauncher backends, Globus Compute does **not** +require an active PBS/Slurm allocation at submit time. A persistent +Globus Compute *endpoint* daemon running on the HPC login node +automatically provisions and manages batch jobs as tasks arrive. + +**Prerequisites** + +1. Install the SDK: ``pip install chemgraphagent[globus_compute]`` +2. On the HPC system, configure and start an endpoint:: + + globus-compute-endpoint configure chemgraph-polaris + globus-compute-endpoint start chemgraph-polaris + # -> prints the endpoint UUID + +3. Set ``endpoint_id`` in ``config.toml`` or pass it to + :func:`~chemgraph.execution.config.get_backend`. +""" + +from __future__ import annotations + +import logging +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class GlobusComputeBackend(ExecutionBackend): + """Execution backend that delegates work to Globus Compute. + + Configuration + ------------- + The following ``kwargs`` are accepted by :meth:`initialize`: + + ``endpoint_id`` : str **required** + UUID of the Globus Compute endpoint to submit tasks to. + ``amqp_port`` : int, optional + Port for the AMQP result-streaming connection. Defaults to the + SDK default (5671). Set to ``443`` if outbound 5671 is blocked. + """ + + def __init__(self) -> None: + super().__init__() + self._executor = None + + # ── lifecycle ──────────────────────────────────────────────────────── + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + try: + from globus_compute_sdk import Executor + except ImportError as exc: + raise ImportError( + "globus-compute-sdk is required for the GlobusComputeBackend. " + "Install it with: pip install chemgraphagent[globus_compute]" + ) from exc + + endpoint_id = kwargs.get("endpoint_id") + if not endpoint_id: + raise ValueError( + "GlobusComputeBackend requires an 'endpoint_id'. " + "Set it in config.toml under [execution.globus_compute] " + "or pass it directly to get_backend()." + ) + + executor_kwargs: dict[str, Any] = {"endpoint_id": endpoint_id} + + amqp_port = kwargs.get("amqp_port") + if amqp_port is not None: + executor_kwargs["amqp_port"] = int(amqp_port) + + self._executor = Executor(**executor_kwargs) + self._initialized = True + logger.info( + "GlobusComputeBackend initialized (system='%s', endpoint='%s')", + system, + endpoint_id, + ) + + # ── task submission ───────────────────────────────────────────────── + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._executor is None: + raise RuntimeError( + "GlobusComputeBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + # Executor.submit() returns a ComputeFuture (a + # concurrent.futures.Future subclass), fully compatible + # with asyncio.wrap_future() used by gather_futures(). + return self._executor.submit(task.callable, *task.args, **task.kwargs) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + from globus_compute_sdk import ShellFunction + + shell_fn = ShellFunction(task.command) + return self._executor.submit(shell_fn) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + # ── teardown ──────────────────────────────────────────────────────── + + def shutdown(self) -> None: + if self._executor is not None: + try: + self._executor.shutdown() + logger.info("GlobusComputeBackend shut down.") + except Exception: + logger.warning("Error during Globus Compute shutdown.", exc_info=True) + self._executor = None + self._initialized = False diff --git a/src/chemgraph/execution/local_backend.py b/src/chemgraph/execution/local_backend.py new file mode 100644 index 00000000..c6a66abe --- /dev/null +++ b/src/chemgraph/execution/local_backend.py @@ -0,0 +1,119 @@ +"""Local execution backend using ``concurrent.futures.ProcessPoolExecutor``. + +Ideal for development, testing, and single-node runs where no HPC +workflow manager is needed. Requires zero external dependencies beyond +the Python standard library. +""" + +from __future__ import annotations + +import logging +import subprocess +from concurrent.futures import Future, ProcessPoolExecutor +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + +# Default number of worker processes (can be overridden via config). +_DEFAULT_MAX_WORKERS = 4 + + +def _run_shell_task( + command: str, + working_dir: str | None, + stdout_path: str | None, + stderr_path: str | None, +) -> int: + """Execute a shell command in a child process. + + Returns the process exit code. stdout/stderr are captured to + files when paths are provided. + """ + import contextlib + + with ( + open(stdout_path, "w") if stdout_path else contextlib.nullcontext() as stdout_fh, + open(stderr_path, "w") if stderr_path else contextlib.nullcontext() as stderr_fh, + ): + result = subprocess.run( + command, + shell=True, + cwd=working_dir, + stdout=stdout_fh, + stderr=stderr_fh, + check=True, + ) + return result.returncode + + +def _run_python_task( + fn: Any, # Callable -- typed as Any for pickling + args: tuple, + kwargs: dict, +) -> Any: + """Execute a Python callable in a child process.""" + return fn(*args, **kwargs) + + +class LocalBackend(ExecutionBackend): + """Execution backend backed by :class:`ProcessPoolExecutor`. + + Configuration + ------------- + ``max_workers`` : int + Maximum number of concurrent worker processes (default: 4). + """ + + def __init__(self) -> None: + super().__init__() + self._pool: ProcessPoolExecutor | None = None + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + max_workers = kwargs.get("max_workers", _DEFAULT_MAX_WORKERS) + self._pool = ProcessPoolExecutor(max_workers=max_workers) + self._initialized = True + logger.info( + "LocalBackend initialized with %d workers", max_workers + ) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._pool is None: + raise RuntimeError( + "LocalBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + return self._pool.submit( + _run_python_task, task.callable, task.args, task.kwargs + ) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + return self._pool.submit( + _run_shell_task, + task.command, + task.working_dir, + task.stdout, + task.stderr, + ) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + if self._pool is not None: + logger.info("Shutting down LocalBackend process pool.") + self._pool.shutdown(wait=True) + self._pool = None + self._initialized = False diff --git a/src/chemgraph/execution/parsl_backend.py b/src/chemgraph/execution/parsl_backend.py new file mode 100644 index 00000000..f2e4fe37 --- /dev/null +++ b/src/chemgraph/execution/parsl_backend.py @@ -0,0 +1,122 @@ +"""Parsl execution backend. + +Wraps `Parsl `_ to conform to the +:class:`ExecutionBackend` interface. Python tasks are dispatched via +``@python_app`` and shell tasks via ``@bash_app``. + +Parsl must be installed separately (``pip install chemgraphagent[parsl]``). +""" + +from __future__ import annotations + +import logging +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class ParslBackend(ExecutionBackend): + """Execution backend that delegates work to Parsl. + + Configuration + ------------- + The ``system`` argument passed to :meth:`initialize` is forwarded to + :func:`chemgraph.hpc_configs.loader.load_parsl_config` which returns + the appropriate ``parsl.config.Config``. + + Extra ``kwargs`` are forwarded to the config loader (e.g. + ``worker_init``). + """ + + def __init__(self) -> None: + super().__init__() + self._python_app = None + self._bash_app = None + + def initialize(self, system: str = "polaris", **kwargs: Any) -> None: + try: + import parsl + from parsl import bash_app, python_app + except ImportError as exc: + raise ImportError( + "Parsl is required for the ParslBackend. " + "Install it with: pip install chemgraphagent[parsl]" + ) from exc + + from chemgraph.hpc_configs.loader import load_parsl_config + + run_dir = kwargs.pop("run_dir", None) + worker_init = kwargs.pop("worker_init", None) + + # Build kwargs for the config loader + loader_kwargs: dict[str, Any] = {} + if run_dir is not None: + loader_kwargs["run_dir"] = run_dir + if worker_init is not None: + loader_kwargs["worker_init"] = worker_init + + config = load_parsl_config(system, **loader_kwargs) + parsl.load(config) + + # Create generic app wrappers ------------------------------------------ + # These are created once and reused for all submitted tasks. + + @python_app + def _generic_python_app(fn, args, kwargs): + """Execute an arbitrary callable on a Parsl worker.""" + return fn(*args, **kwargs) + + @bash_app + def _generic_bash_app(command, stdout=None, stderr=None): + """Execute a shell command string on a Parsl worker.""" + return command + + self._python_app = _generic_python_app + self._bash_app = _generic_bash_app + + self._initialized = True + logger.info("ParslBackend initialized for system '%s'", system) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized: + raise RuntimeError( + "ParslBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + return self._python_app(task.callable, task.args, task.kwargs) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + bash_kwargs: dict[str, Any] = {"command": task.command} + if task.stdout: + bash_kwargs["stdout"] = task.stdout + if task.stderr: + bash_kwargs["stderr"] = task.stderr + return self._bash_app(**bash_kwargs) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + if self._initialized: + try: + import parsl + + parsl.clear() + logger.info("ParslBackend shut down.") + except Exception: + logger.warning("Error during Parsl shutdown.", exc_info=True) + self._initialized = False diff --git a/src/chemgraph/execution/utils.py b/src/chemgraph/execution/utils.py new file mode 100644 index 00000000..70759a71 --- /dev/null +++ b/src/chemgraph/execution/utils.py @@ -0,0 +1,175 @@ +"""Shared utilities for ensemble execution in MCP servers. + +Consolidates patterns that were previously duplicated across +``graspa_mcp_parsl.py``, ``xanes_mcp_parsl.py``, and +``mace_mcp_parsl.py``: + +* Structure file resolution from directory or file list +* Async future gathering with error handling +* JSONL result writing +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from concurrent.futures import Future +from pathlib import Path +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +def resolve_structure_files( + input_source: str | list[str], + extensions: set[str] | None = None, +) -> tuple[list[Path], Path]: + """Resolve a directory path or file list into a list of structure files. + + Parameters + ---------- + input_source : str or list[str] + Either a directory path (all matching files will be collected) + or an explicit list of file paths. + extensions : set[str], optional + File extensions to include when scanning a directory (e.g. + ``{".cif", ".xyz"}``). If *None*, all files are included. + + Returns + ------- + structure_files : list[Path] + Sorted list of resolved file paths. + output_dir : Path + The parent directory (useful for placing output files). + + Raises + ------ + ValueError + If no files are found or if listed files do not exist. + """ + structure_files: list[Path] = [] + output_dir: Path = Path.cwd() + + if isinstance(input_source, list): + structure_files = [Path(p) for p in input_source] + missing = [p for p in structure_files if not p.exists()] + if missing: + raise ValueError(f"The following input files are missing: {missing}") + if structure_files: + output_dir = structure_files[0].parent + else: + input_dir = Path(input_source) + if not input_dir.is_dir(): + raise ValueError(f"'{input_dir}' is not a valid directory.") + + if extensions: + structure_files = sorted( + p for p in input_dir.iterdir() if p.suffix in extensions + ) + else: + structure_files = sorted(p for p in input_dir.iterdir() if p.is_file()) + + output_dir = input_dir + + if not structure_files: + raise ValueError("No structure files found to simulate.") + + return structure_files, output_dir + + +async def gather_futures( + pending: list[tuple[dict, Future]], + post_fn: Optional[Callable[[dict, Any], dict]] = None, +) -> list[dict]: + """Await a list of ``(metadata, future)`` pairs concurrently. + + Each future is converted to an asyncio-awaitable via + :func:`asyncio.wrap_future` and gathered concurrently. + + Parameters + ---------- + pending : list[tuple[dict, Future]] + Each element is ``(task_metadata_dict, concurrent_futures_Future)``. + post_fn : callable, optional + If provided, called as ``post_fn(metadata, result)`` after a + successful future resolution. Must return a ``dict`` to include + in the results list. When *None*, the raw result is merged with + metadata. + + Returns + ------- + list[dict] + One result dict per task (successful or failed). + """ + + async def _wait(meta: dict, fut: Future) -> dict: + try: + result = await asyncio.wrap_future(fut) + if post_fn is not None: + return post_fn(meta, result) + # Default: merge metadata with result (if result is a dict) + if isinstance(result, dict): + merged = {**meta, **result} + merged.setdefault("status", "success") + return merged + return {**meta, "result": result, "status": "success"} + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": str(e), + } + + return list( + await asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) + ) + + +def write_results_jsonl( + results: list[dict], + output_path: Path, + append: bool = True, +) -> tuple[int, int]: + """Write results to a JSONL file and return (success_count, total_count). + + Parameters + ---------- + results : list[dict] + Each dict should contain a ``"status"`` key. + output_path : Path + Path to the JSONL file. + append : bool + If *True* (default), append to an existing file. + + Returns + ------- + success_count : int + total_count : int + """ + mode = "a" if append else "w" + success_count = 0 + + with open(output_path, mode, encoding="utf-8") as f: + for res in results: + if res.get("status") == "success": + success_count += 1 + f.write(json.dumps(res) + "\n") + + return success_count, len(results) + + +def make_per_structure_output( + struct_path: Path, + base_output: Path, +) -> Path: + """Generate a per-structure output filename. + + Given ``struct_path = "/data/MOF-5.cif"`` and + ``base_output = "/results/output.json"``, returns + ``"/results/MOF-5_output.json"``. + """ + base_suffix = base_output.suffix or ".json" + base_stem = base_output.stem + return base_output.with_name(f"{struct_path.stem}_{base_stem}{base_suffix}") diff --git a/src/chemgraph/hpc_configs/__init__.py b/src/chemgraph/hpc_configs/__init__.py new file mode 100644 index 00000000..32d8bc92 --- /dev/null +++ b/src/chemgraph/hpc_configs/__init__.py @@ -0,0 +1 @@ +"""HPC configuration factories for workflow managers.""" diff --git a/src/chemgraph/hpc_configs/loader.py b/src/chemgraph/hpc_configs/loader.py new file mode 100644 index 00000000..4a25d5e7 --- /dev/null +++ b/src/chemgraph/hpc_configs/loader.py @@ -0,0 +1,65 @@ +"""Unified loader for HPC-specific Parsl configurations. + +This consolidates the ``load_parsl_config()`` function that was +previously duplicated across ``graspa_mcp_parsl.py`` and +``xanes_mcp_parsl.py``. +""" + +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + + +def load_parsl_config(system_name: str, run_dir: str | None = None, **kwargs): + """Dynamically import and return a Parsl ``Config`` for the given HPC system. + + Parameters + ---------- + system_name : str + Target system name. Supported: ``"local"``, ``"polaris"``, + ``"aurora"``. + run_dir : str, optional + Parsl run directory. Defaults to the current working directory. + **kwargs + Extra keyword arguments forwarded to the system-specific + config factory (e.g. ``worker_init``, ``max_workers``). + + Returns + ------- + parsl.config.Config + A ready-to-use Parsl configuration object. + + Raises + ------ + ValueError + If *system_name* is not recognised. + """ + system_name = system_name.lower().strip() + if run_dir is None: + run_dir = os.getcwd() + + logger.info("Loading Parsl config for system: %s", system_name) + + if system_name == "local": + from chemgraph.hpc_configs.local_parsl import get_local_config + + return get_local_config(run_dir=run_dir, **kwargs) + + elif system_name == "polaris": + from chemgraph.hpc_configs.polaris_parsl import get_polaris_config + + return get_polaris_config(run_dir=run_dir, **kwargs) + + elif system_name == "aurora": + from chemgraph.hpc_configs.aurora_parsl import get_aurora_config + + return get_aurora_config(run_dir=run_dir, **kwargs) + + else: + raise ValueError( + f"Unknown HPC system: '{system_name}'. " + f"Supported systems: local, polaris, aurora" + ) diff --git a/src/chemgraph/hpc_configs/local_parsl.py b/src/chemgraph/hpc_configs/local_parsl.py new file mode 100644 index 00000000..b4c05f01 --- /dev/null +++ b/src/chemgraph/hpc_configs/local_parsl.py @@ -0,0 +1,60 @@ +"""Local Parsl configuration for development and single-node runs. + +Uses ``HighThroughputExecutor`` with a ``LocalProvider`` (no MPI +launcher, no PBS/Slurm dependency). Suitable for laptops, CI runners, +and single-node workstations where the Parsl backend is desired but no +HPC scheduler is available. +""" + +from __future__ import annotations + +import logging +import os + +from parsl.config import Config +from parsl.executors import HighThroughputExecutor +from parsl.providers import LocalProvider + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_WORKERS = 4 + + +def get_local_config( + run_dir: str | None = None, + max_workers: int = _DEFAULT_MAX_WORKERS, + worker_init: str = "export TMPDIR=/tmp", +) -> Config: + """Generate a Parsl configuration for local execution. + + Parameters + ---------- + run_dir : str, optional + Parsl run directory. Defaults to the current working directory. + max_workers : int, optional + Maximum number of concurrent workers. Default: 4. + worker_init : str, optional + Shell commands executed on each worker before tasks. + """ + if run_dir is None: + run_dir = os.getcwd() + + logger.info("Creating local Parsl config with %d workers", max_workers) + + config = Config( + executors=[ + HighThroughputExecutor( + label="local_htex", + max_workers_per_node=max_workers, + provider=LocalProvider( + init_blocks=1, + min_blocks=0, + max_blocks=1, + worker_init=worker_init, + ), + ), + ], + run_dir=run_dir, + ) + + return config diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py new file mode 100644 index 00000000..9ee276bc --- /dev/null +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -0,0 +1,124 @@ +"""Backend-agnostic gRASPA MCP server. + +Replaces ``graspa_mcp_parsl.py`` by using the :mod:`chemgraph.execution` +abstraction layer. The execution backend (Parsl, EnsembleLauncher, +local) is selected at startup via ``config.toml`` or the +``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +""" + +import logging +from pathlib import Path + +from mcp.server.fastmcp import FastMCP + +from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, + write_results_jsonl, +) +from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble + +logger = logging.getLogger(__name__) + +# ── Initialise execution backend ──────────────────────────────────────── +backend = get_backend() + +# ── MCP server ────────────────────────────────────────────────────────── +mcp = FastMCP( + name="ChemGraph Graspa Tools", + instructions=""" + You expose tools for running graspa simulations and reading their results. + The available tools are: + 1. run_graspa_ensemble: run graspa calculations over all structures in a + directory using the configured execution backend. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are written to the output files + defined in the schemas. + - When returning paths, use absolute paths. + - Energies are in eV and wall times are in seconds. + """, +) + + +def _run_graspa_single(job: dict) -> dict: + """Execute a single gRASPA simulation (runs on the worker).""" + from chemgraph.schemas.graspa_schema import graspa_input_schema + from chemgraph.tools.graspa_tools import run_graspa_core + + params = graspa_input_schema(**job) if isinstance(job, dict) else job + return run_graspa_core(params) + + +@mcp.tool( + name="run_graspa_ensemble", + description="Run an ensemble of graspa calculations for multiple input files.", +) +async def run_graspa_ensemble( + params: graspa_input_schema_ensemble, +): + """Run an ensemble of gRASPA calculations over all structure files + using the configured execution backend. + + Parameters + ---------- + params : graspa_input_schema_ensemble + Input parameters for the ensemble of gRASPA calculations. + """ + structure_files, output_dir = resolve_structure_files( + params.input_structures, + extensions={".cif"}, + ) + + # Base output file name + base_output = Path(params.output_result_file).resolve() + + pending_tasks = [] + + for struct_path in structure_files: + mof_name = struct_path.stem + for condition in params.conditions: + per_struct_output = make_per_structure_output(struct_path, base_output) + job = { + "input_structure_file": str(struct_path), + "output_result_file": str(per_struct_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } + + task = TaskSpec( + task_id=f"graspa_{mof_name}_{condition.temperature}K_{condition.pressure}Pa", + task_type="python", + callable=_run_graspa_single, + kwargs={"job": job}, + ) + fut = backend.submit(task) + + task_meta = { + "structure": mof_name, + "temperature": condition.temperature, + "pressure": condition.pressure, + } + pending_tasks.append((task_meta, fut)) + + results = await gather_futures(pending_tasks) + + summary_log_path = output_dir / "simulation_results.jsonl" + success_count, total_count = write_results_jsonl(results, summary_log_path) + + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) + + +if __name__ == "__main__": + run_mcp_server(mcp, default_port=9001) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py new file mode 100644 index 00000000..eba86858 --- /dev/null +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -0,0 +1,178 @@ +"""Backend-agnostic MACE MCP server. + +Replaces ``mace_mcp_parsl.py`` by using the :mod:`chemgraph.execution` +abstraction layer. The execution backend (Parsl, EnsembleLauncher, +local) is selected at startup via ``config.toml`` or the +``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. + +Key improvements over the original: +- No hardcoded Polaris config or user-specific conda paths. +- Ensemble tool is now async (non-blocking event loop). +- Uses shared utilities for structure resolution and result gathering. +""" + +import json +import logging +from pathlib import Path + +from mcp.server.fastmcp import FastMCP + +from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, +) +from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.tools.parsl_tools import ( + mace_input_schema, + mace_input_schema_ensemble, + run_mace_core, +) + +logger = logging.getLogger(__name__) + +# ── Initialise execution backend ──────────────────────────────────────── +backend = get_backend() + +# ── MCP server ────────────────────────────────────────────────────────── +mcp = FastMCP( + name="ChemGraph MACE Tools", + instructions=""" + You expose tools for running MACE simulations and reading their results. + The available tools are: + 1. run_mace_single: run a single MACE calculation using the specified + input schema. + 2. run_mace_ensemble: run MACE calculations over all structures in a + directory using the configured execution backend. + 3. extract_output_json: load simulation results from a JSON file. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are written to the output files + defined in the schemas. + - When returning paths, use absolute paths. + - Energies are in eV and wall times are in seconds. + """, +) + + +def _run_mace_single(job: dict) -> dict: + """Execute a single MACE simulation (runs on the worker).""" + from chemgraph.tools.parsl_tools import mace_input_schema, run_mace_core + + params = mace_input_schema(**job) if isinstance(job, dict) else job + return run_mace_core(params) + + +@mcp.tool( + name="run_mace_single", + description="Run a single MACE calculation", +) +def run_mace_single(params: mace_input_schema): + return run_mace_core(params) + + +def _mace_post_fn(meta: dict, result) -> dict: + """Post-process a completed MACE task.""" + status = result.get("status", "unknown") if isinstance(result, dict) else "success" + energy = result.get("single_point_energy") if isinstance(result, dict) else None + return { + "structure": meta["structure"], + "output_result_file": meta["output_result_file"], + "status": status, + "single_point_energy": energy, + "raw_result": result, + } + + +@mcp.tool( + name="run_mace_ensemble", + description="Run an ensemble of MACE calculations", +) +async def run_mace_ensemble(params: mace_input_schema_ensemble): + """Run an ensemble of MACE calculations over all structure files in a + directory using the configured execution backend. + + Parameters + ---------- + params : mace_input_schema_ensemble + Input parameters for the ensemble of MACE calculations. + + Returns + ------- + dict + Summary of all jobs with minimal per-job results. + """ + structure_files, _output_dir = resolve_structure_files( + params.input_structure_directory, + ) + + # Base output file name used as a pattern for per-structure outputs + base_output = Path(params.output_result_file) + + pending_tasks = [] + for struct_path in structure_files: + per_struct_output = make_per_structure_output(struct_path, base_output) + + job = { + "input_structure_file": str(struct_path), + "output_result_file": str(per_struct_output), + "driver": params.driver, + "model": params.model, + "device": params.device, + "temperature": params.temperature, + "pressure": params.pressure, + "fmax": params.fmax, + "steps": params.steps, + "optimizer": params.optimizer, + } + + task = TaskSpec( + task_id=f"mace_{struct_path.stem}", + task_type="python", + callable=_run_mace_single, + kwargs={"job": job}, + ) + fut = backend.submit(task) + + task_meta = { + "structure": struct_path.name, + "output_result_file": str(per_struct_output), + } + pending_tasks.append((task_meta, fut)) + + results = await gather_futures(pending_tasks, post_fn=_mace_post_fn) + + return { + "status": "success", + "n_structures": len(structure_files), + "results": results, + } + + +@mcp.tool( + name="extract_output_json", + description="Load output from a JSON file.", +) +def extract_output_json(json_file: str) -> dict: + """Load simulation results from a JSON file produced by run_ase. + + Parameters + ---------- + json_file : str + Path to the JSON file containing ASE simulation results. + + Returns + ------- + dict + Parsed results from the JSON file. + """ + with open(json_file, "r") as f: + data = json.load(f) + return data + + +if __name__ == "__main__": + run_mcp_server(mcp, default_port=9004) diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py new file mode 100644 index 00000000..3ed81fa7 --- /dev/null +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -0,0 +1,227 @@ +"""Backend-agnostic XANES/FDMNES MCP server. + +Replaces ``xanes_mcp_parsl.py`` by using the :mod:`chemgraph.execution` +abstraction layer. The execution backend (Parsl, EnsembleLauncher, +local) is selected at startup via ``config.toml`` or the +``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +""" + +import logging +from pathlib import Path + +from mcp.server.fastmcp import FastMCP + +from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.utils import ( + gather_futures, + resolve_structure_files, + write_results_jsonl, +) +from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.schemas.xanes_schema import ( + mp_query_schema, + xanes_input_schema, + xanes_input_schema_ensemble, +) + +logger = logging.getLogger(__name__) + +# ── Initialise execution backend ──────────────────────────────────────── +backend = get_backend() + +# ── MCP server ────────────────────────────────────────────────────────── +mcp = FastMCP( + name="ChemGraph XANES Tools", + instructions=""" + You expose tools for running XANES/FDMNES simulations. + The available tools are: + 1. run_xanes_single: run a single FDMNES calculation for one structure. + 2. run_xanes_ensemble: run FDMNES calculations over multiple structures + using the configured execution backend. + 3. fetch_mp_structures: fetch optimized structures from Materials Project. + 4. plot_xanes: generate normalized XANES plots for completed calculations. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are in the output directories. + - When returning paths, use absolute paths. + - Energies are in eV. + """, +) + + +@mcp.tool( + name="run_xanes_single", + description="Run a single XANES/FDMNES calculation for one input structure.", +) +def run_xanes_single(params: xanes_input_schema): + """Run a single FDMNES calculation using the core engine.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + return run_xanes_core(params) + + +def _xanes_post_fn(meta: dict, _result) -> dict: + """Post-process a completed FDMNES task: extract convergence data.""" + from chemgraph.tools.xanes_tools import extract_conv + + try: + conv_data = extract_conv(meta["run_dir"]) + return { + **meta, + "status": "success", + "n_conv_files": len(conv_data), + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"Post-processing failed: {e}", + } + + +@mcp.tool( + name="run_xanes_ensemble", + description="Run an ensemble of XANES/FDMNES calculations using the configured backend.", +) +async def run_xanes_ensemble(params: xanes_input_schema_ensemble): + """Run ensemble XANES calculations over all structure files. + + For each structure file: + 1. Reads the structure via ASE. + 2. Creates FDMNES input files in a per-structure subdirectory. + 3. Submits a shell task to run FDMNES. + 4. Gathers results and writes a JSONL summary log. + + Parameters + ---------- + params : xanes_input_schema_ensemble + Input parameters for the ensemble calculation. + """ + from ase.io import read as ase_read + + from chemgraph.tools.xanes_tools import write_fdmnes_input + + structure_files, output_dir = resolve_structure_files( + params.input_structures, + extensions={".cif", ".xyz", ".poscar"}, + ) + + # Create a batch runs directory + runs_dir = output_dir / "fdmnes_batch_runs" + runs_dir.mkdir(parents=True, exist_ok=True) + + fdmnes_exe = params.fdmnes_exe + + pending_tasks = [] + + for i, struct_path in enumerate(structure_files): + run_dir = runs_dir / f"run_{i}" + run_dir.mkdir(parents=True, exist_ok=True) + + # Read structure and write FDMNES inputs + atoms = ase_read(str(struct_path)) + z_abs = ( + params.z_absorber + if params.z_absorber is not None + else int(max(atoms.get_atomic_numbers())) + ) + + write_fdmnes_input( + ase_atoms=atoms, + z_absorber=z_abs, + input_file_dir=run_dir, + radius=params.radius, + magnetism=params.magnetism, + ) + + # Submit shell task + task = TaskSpec( + task_id=f"xanes_{struct_path.stem}_{i}", + task_type="shell", + command=f'cd "{run_dir}" && "{fdmnes_exe}"', + working_dir=str(run_dir), + stdout=str(run_dir / "fdmnes_stdout.txt"), + stderr=str(run_dir / "fdmnes_stderr.txt"), + ) + fut = backend.submit(task) + + task_meta = { + "structure": struct_path.name, + "run_dir": str(run_dir), + "z_absorber": z_abs, + } + pending_tasks.append((task_meta, fut)) + + results = await gather_futures(pending_tasks, post_fn=_xanes_post_fn) + + summary_log_path = output_dir / "xanes_results.jsonl" + success_count, total_count = write_results_jsonl(results, summary_log_path) + + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) + + +@mcp.tool( + name="fetch_mp_structures", + description="Fetch optimized structures from Materials Project.", +) +def fetch_mp_structures(params: mp_query_schema): + """Fetch structures from Materials Project and save as CIF files and pickle database.""" + from chemgraph.tools.xanes_tools import ( + _get_data_dir, + fetch_materials_project_data, + ) + + data_dir = _get_data_dir() + result = fetch_materials_project_data(params, data_dir) + return { + "status": "success", + "n_structures": result["n_structures"], + "chemsys": params.chemsys, + "output_dir": str(data_dir), + "structure_files": result["structure_files"], + "pickle_file": result["pickle_file"], + } + + +@mcp.tool( + name="plot_xanes", + description="Generate normalized XANES plots for completed FDMNES calculations.", +) +def plot_xanes(runs_dir: str): + """Generate XANES plots for all completed runs in a directory. + + Parameters + ---------- + runs_dir : str + Path to the ``fdmnes_batch_runs`` directory containing ``run_*`` + subdirectories with FDMNES outputs. + """ + from chemgraph.tools.xanes_tools import ( + _get_data_dir, + plot_xanes_results, + ) + + runs_path = Path(runs_dir) + if not runs_path.is_dir(): + raise ValueError(f"'{runs_dir}' is not a valid directory.") + + data_dir = _get_data_dir() + result = plot_xanes_results(data_dir, runs_path) + return { + "status": "success", + "n_plots": result["n_plots"], + "n_failed": result["n_failed"], + "plot_files": result["plot_files"], + "failed": result["failed"], + } + + +if __name__ == "__main__": + run_mcp_server(mcp, default_port=9007) diff --git a/tests/conftest.py b/tests/conftest.py index 083d138e..0de3313d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,12 +27,22 @@ def pytest_addoption(parser): parser.addoption( "--run-llm", action="store_true", default=False, help="run tests that call LLM APIs" ) + parser.addoption( + "--run-globus-compute", action="store_true", default=False, + help="run tests that require a live Globus Compute endpoint" + ) def pytest_collection_modifyitems(config, items): - if config.getoption("--run-llm"): - # --run-llm given in cli: do not skip llm tests - return - skip_llm = pytest.mark.skip(reason="need --run-llm option to run") + skip_llm = None + if not config.getoption("--run-llm"): + skip_llm = pytest.mark.skip(reason="need --run-llm option to run") + + skip_globus = None + if not config.getoption("--run-globus-compute"): + skip_globus = pytest.mark.skip(reason="need --run-globus-compute option to run") + for item in items: - if "llm" in item.keywords: - item.add_marker(skip_llm) \ No newline at end of file + if skip_llm and "llm" in item.keywords: + item.add_marker(skip_llm) + if skip_globus and "globus_compute" in item.keywords: + item.add_marker(skip_globus) \ No newline at end of file diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 00000000..5f1617bc --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,1017 @@ +"""Tests for the chemgraph.execution abstraction layer. + +Tests cover: +- TaskSpec validation +- LocalBackend: python and shell tasks +- GlobusComputeBackend: python and shell tasks (mocked SDK) +- Backend factory (get_backend) +- Shared utilities: resolve_structure_files, gather_futures, write_results_jsonl +""" + +import asyncio +import json +import os +import sys +import tempfile +from concurrent.futures import Future +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.local_backend import LocalBackend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, + write_results_jsonl, +) + + +# ── TaskSpec tests ────────────────────────────────────────────────────── + + +class TestTaskSpec: + def test_python_task_minimal(self): + spec = TaskSpec(task_id="t1", task_type="python", callable=abs, args=(42,)) + assert spec.task_id == "t1" + assert spec.task_type == "python" + assert spec.callable is abs + assert spec.args == (42,) + + def test_shell_task_minimal(self): + spec = TaskSpec(task_id="t2", task_type="shell", command="echo hello") + assert spec.task_type == "shell" + assert spec.command == "echo hello" + + def test_defaults(self): + spec = TaskSpec(task_id="t3") + assert spec.task_type == "python" + assert spec.callable is None + assert spec.args == () + assert spec.kwargs == {} + assert spec.num_nodes == 1 + assert spec.processes_per_node == 1 + assert spec.gpus_per_task == 0 + + +# ── LocalBackend tests ────────────────────────────────────────────────── + + +def _square(x): + return x * x + + +def _add(a, b): + return a + b + + +def _failing_fn(): + raise ValueError("intentional test error") + + +class TestLocalBackend: + def test_python_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=2) + try: + task = TaskSpec( + task_id="sq", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + assert fut.result(timeout=10) == 49 + finally: + backend.shutdown() + + def test_python_task_kwargs(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=2) + try: + task = TaskSpec( + task_id="add", + task_type="python", + callable=_add, + kwargs={"a": 3, "b": 5}, + ) + assert backend.submit(task).result(timeout=10) == 8 + finally: + backend.shutdown() + + def test_shell_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as f: + stdout_path = f.name + + task = TaskSpec( + task_id="echo", + task_type="shell", + command="echo hello_world", + stdout=stdout_path, + ) + fut = backend.submit(task) + fut.result(timeout=10) + + with open(stdout_path) as f: + assert "hello_world" in f.read() + finally: + backend.shutdown() + os.unlink(stdout_path) + + def test_submit_batch(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=4) + try: + tasks = [ + TaskSpec( + task_id=f"sq_{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(5) + ] + futures = backend.submit_batch(tasks) + assert len(futures) == 5 + results = [f.result(timeout=10) for f in futures] + assert results == [0, 1, 4, 9, 16] + finally: + backend.shutdown() + + def test_failing_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec( + task_id="fail", + task_type="python", + callable=_failing_fn, + ) + fut = backend.submit(task) + with pytest.raises(ValueError, match="intentional test error"): + fut.result(timeout=10) + finally: + backend.shutdown() + + def test_context_manager(self): + with LocalBackend() as backend: + backend.initialize(system="local", max_workers=1) + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + assert backend.submit(task).result(timeout=10) == 9 + + def test_not_initialized_raises(self): + backend = LocalBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_python_task_missing_callable(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + finally: + backend.shutdown() + + def test_shell_task_missing_command(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + finally: + backend.shutdown() + + +# ── GlobusComputeBackend tests ────────────────────────────────────────── + + +def _make_mock_gc_modules(): + """Create mock globus_compute_sdk module and its classes.""" + mock_sdk = MagicMock() + + # Mock Executor: instances track submit calls and return Futures + mock_executor_instance = MagicMock() + mock_future = Future() + mock_future.set_result(42) + mock_executor_instance.submit.return_value = mock_future + mock_sdk.Executor.return_value = mock_executor_instance + + # Mock ShellFunction + mock_shell_fn_instance = MagicMock() + mock_sdk.ShellFunction.return_value = mock_shell_fn_instance + + return mock_sdk, mock_executor_instance + + +class TestGlobusComputeBackend: + def _patch_and_import(self, mock_sdk): + """Patch globus_compute_sdk into sys.modules and import the backend.""" + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + # Force re-import to pick up the mock + import importlib + + import chemgraph.execution.globus_compute_backend as gc_mod + + importlib.reload(gc_mod) + return gc_mod.GlobusComputeBackend + + def test_initialize_success(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(system="polaris", endpoint_id="test-uuid-1234") + + assert backend._initialized is True + mock_sdk.Executor.assert_called_once_with(endpoint_id="test-uuid-1234") + + def test_initialize_with_amqp_port(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize( + system="polaris", + endpoint_id="test-uuid", + amqp_port=443, + ) + + mock_sdk.Executor.assert_called_once_with( + endpoint_id="test-uuid", amqp_port=443 + ) + + def test_initialize_missing_endpoint_id(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ValueError, match="endpoint_id"): + backend.initialize(system="polaris") + + def test_initialize_empty_endpoint_id(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ValueError, match="endpoint_id"): + backend.initialize(system="polaris", endpoint_id="") + + def test_initialize_import_error(self): + """Verify helpful error when globus-compute-sdk is not installed.""" + with patch.dict(sys.modules, {"globus_compute_sdk": None}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ImportError, match="globus-compute-sdk"): + backend.initialize(endpoint_id="test-uuid") + + def test_submit_python_task(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="py1", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + + assert isinstance(fut, Future) + mock_executor.submit.assert_called_once_with(_square, 7) + + def test_submit_python_task_with_kwargs(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="py2", + task_type="python", + callable=_add, + args=(3,), + kwargs={"b": 5}, + ) + backend.submit(task) + + mock_executor.submit.assert_called_once_with(_add, 3, b=5) + + def test_submit_shell_task(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="sh1", + task_type="shell", + command="echo hello", + ) + backend.submit(task) + + # ShellFunction should be constructed with the command + mock_sdk.ShellFunction.assert_called_once_with("echo hello") + # And then submitted via the executor + shell_fn_instance = mock_sdk.ShellFunction.return_value + mock_executor.submit.assert_called_once_with(shell_fn_instance) + + def test_submit_not_initialized(self): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_submit_python_missing_callable(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + + def test_submit_shell_missing_command(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + + def test_shutdown(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + assert backend._initialized is True + + backend.shutdown() + + assert backend._initialized is False + assert backend._executor is None + mock_executor.shutdown.assert_called_once() + + def test_shutdown_idempotent(self): + """Calling shutdown() when not initialized should not raise.""" + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.shutdown() # should be a no-op + assert backend._initialized is False + + def test_context_manager(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + with GlobusComputeBackend() as backend: + backend.initialize(endpoint_id="test-uuid") + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + backend.submit(task) + + # After exiting context, shutdown should have been called + mock_executor.shutdown.assert_called_once() + + +class TestGetBackendGlobusCompute: + def test_factory_creates_globus_compute_backend(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend( + backend_name="globus_compute", + endpoint_id="factory-test-uuid", + ) + try: + assert isinstance(backend, GlobusComputeBackend) + assert backend._initialized is True + finally: + backend.shutdown() + + def test_factory_via_env_var(self): + mock_sdk, _ = _make_mock_gc_modules() + with ( + patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}), + patch.dict( + os.environ, + {"CHEMGRAPH_EXECUTION_BACKEND": "globus_compute"}, + ), + ): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend(endpoint_id="env-test-uuid") + try: + assert isinstance(backend, GlobusComputeBackend) + finally: + backend.shutdown() + + +# ── Factory tests ─────────────────────────────────────────────────────── + + +class TestGetBackend: + def test_local_backend_via_env(self): + with patch.dict(os.environ, {"CHEMGRAPH_EXECUTION_BACKEND": "local"}): + from chemgraph.execution.config import get_backend + + backend = get_backend() + try: + assert isinstance(backend, LocalBackend) + finally: + backend.shutdown() + + def test_explicit_backend_name(self): + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="local", max_workers=2) + try: + assert isinstance(backend, LocalBackend) + finally: + backend.shutdown() + + def test_unsupported_backend_raises(self): + from chemgraph.execution.config import get_backend + + with pytest.raises(ValueError, match="Unknown execution backend"): + get_backend(backend_name="nonexistent") + + +# ── Utility tests ─────────────────────────────────────────────────────── + + +class TestResolveStructureFiles: + def test_from_directory(self, tmp_path): + for name in ["a.cif", "b.cif", "c.txt"]: + (tmp_path / name).write_text("dummy") + + files, out_dir = resolve_structure_files(str(tmp_path), extensions={".cif"}) + assert len(files) == 2 + assert out_dir == tmp_path + assert all(f.suffix == ".cif" for f in files) + + def test_from_file_list(self, tmp_path): + paths = [] + for name in ["x.xyz", "y.xyz"]: + p = tmp_path / name + p.write_text("dummy") + paths.append(str(p)) + + files, out_dir = resolve_structure_files(paths) + assert len(files) == 2 + assert out_dir == tmp_path + + def test_missing_file_raises(self, tmp_path): + with pytest.raises(ValueError, match="missing"): + resolve_structure_files([str(tmp_path / "noexist.cif")]) + + def test_empty_dir_raises(self, tmp_path): + with pytest.raises(ValueError, match="No structure files"): + resolve_structure_files(str(tmp_path), extensions={".cif"}) + + def test_invalid_dir_raises(self): + with pytest.raises(ValueError, match="not a valid directory"): + resolve_structure_files("/nonexistent/path") + + +class TestMakePerStructureOutput: + def test_basic(self): + result = make_per_structure_output( + Path("/data/MOF-5.cif"), + Path("/results/output.json"), + ) + assert result == Path("/results/MOF-5_output.json") + + def test_no_suffix(self): + result = make_per_structure_output( + Path("/data/struct.xyz"), + Path("/results/result"), + ) + assert result == Path("/results/struct_result.json") + + +class TestGatherFutures: + @pytest.mark.asyncio + async def test_successful_futures(self): + loop = asyncio.get_event_loop() + + def _make_resolved(val): + f = Future() + f.set_result(val) + return f + + pending = [ + ({"name": "a"}, _make_resolved({"status": "success", "energy": -1.0})), + ({"name": "b"}, _make_resolved({"status": "success", "energy": -2.0})), + ] + results = await gather_futures(pending) + assert len(results) == 2 + assert results[0]["name"] == "a" + assert results[0]["energy"] == -1.0 + + @pytest.mark.asyncio + async def test_failed_future(self): + f = Future() + f.set_exception(RuntimeError("boom")) + + pending = [({"name": "fail"}, f)] + results = await gather_futures(pending) + assert results[0]["status"] == "failure" + assert results[0]["error_type"] == "RuntimeError" + assert "boom" in results[0]["message"] + + @pytest.mark.asyncio + async def test_with_post_fn(self): + f = Future() + f.set_result(42) + + def post(meta, result): + return {**meta, "doubled": result * 2, "status": "success"} + + results = await gather_futures([({"id": "x"}, f)], post_fn=post) + assert results[0]["doubled"] == 84 + + +class TestWriteResultsJsonl: + def test_write_and_count(self, tmp_path): + results = [ + {"status": "success", "value": 1}, + {"status": "failure", "error": "bad"}, + {"status": "success", "value": 2}, + ] + path = tmp_path / "results.jsonl" + success, total = write_results_jsonl(results, path) + assert success == 2 + assert total == 3 + + lines = path.read_text().strip().split("\n") + assert len(lines) == 3 + assert json.loads(lines[0])["value"] == 1 + + def test_append_mode(self, tmp_path): + path = tmp_path / "results.jsonl" + write_results_jsonl([{"status": "success"}], path) + write_results_jsonl([{"status": "success"}], path, append=True) + + lines = path.read_text().strip().split("\n") + assert len(lines) == 2 + + +# ── Layer 2: GlobusComputeBackend unit-test gap coverage ──────────────── + + +class TestGlobusComputeBackendGaps: + """Additional mocked tests covering gaps in the original test suite.""" + + def test_submit_unsupported_task_type(self): + """The else branch in submit() should raise for unknown task_type.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="bad_type", + task_type="python", + callable=_square, + args=(1,), + ) + # Bypass Pydantic validation to force an invalid task_type + object.__setattr__(task, "task_type", "mpi") + + with pytest.raises(ValueError, match="unsupported task_type"): + backend.submit(task) + + def test_submit_batch_delegates(self): + """submit_batch (inherited from base) should call submit() N times.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + tasks = [ + TaskSpec( + task_id=f"t{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(3) + ] + futures = backend.submit_batch(tasks) + + assert len(futures) == 3 + assert mock_executor.submit.call_count == 3 + + def test_amqp_port_string_coercion(self): + """amqp_port from config.toml arrives as a string; must be coerced to int.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid", amqp_port="443") + + mock_sdk.Executor.assert_called_once_with( + endpoint_id="test-uuid", amqp_port=443 + ) + + def test_shutdown_executor_raises(self): + """If executor.shutdown() raises, the error is swallowed and state resets.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + mock_executor.shutdown.side_effect = RuntimeError("connection lost") + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + # Should NOT raise + backend.shutdown() + + assert backend._initialized is False + assert backend._executor is None + + +class TestGetBackendGlobusComputeGaps: + """Additional factory tests for config merging and TOML-driven creation.""" + + def test_factory_kwargs_override_config(self, tmp_path): + """Explicit kwargs should override values from config.toml.""" + config_file = tmp_path / "config.toml" + config_file.write_text( + "[execution]\n" + 'backend = "globus_compute"\n\n' + "[execution.globus_compute]\n" + 'endpoint_id = "config-uuid"\n' + "amqp_port = 5671\n" + ) + + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend( + config_path=str(config_file), + endpoint_id="kwarg-uuid", + ) + try: + assert isinstance(backend, GlobusComputeBackend) + # kwarg-uuid should win over config-uuid; amqp_port from config + mock_sdk.Executor.assert_called_once_with( + endpoint_id="kwarg-uuid", + amqp_port=5671, + ) + finally: + backend.shutdown() + + def test_factory_config_toml_driven(self, tmp_path): + """get_backend() with only a config.toml path should work end-to-end.""" + config_file = tmp_path / "config.toml" + config_file.write_text( + "[execution]\n" + 'backend = "globus_compute"\n\n' + "[execution.globus_compute]\n" + 'endpoint_id = "toml-uuid"\n' + ) + + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend(config_path=str(config_file)) + try: + assert isinstance(backend, GlobusComputeBackend) + assert backend._initialized is True + mock_sdk.Executor.assert_called_once_with(endpoint_id="toml-uuid") + finally: + backend.shutdown() + + +# ── Layer 3: Globus Compute integration tests (real endpoint) ─────────── + + +@pytest.fixture +def globus_backend(): + """Provide an initialized GlobusComputeBackend connected to a real endpoint. + + Skips the test if GLOBUS_COMPUTE_ENDPOINT_ID is not set or the SDK is + not installed. + """ + endpoint_id = os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID") + if not endpoint_id: + pytest.skip("GLOBUS_COMPUTE_ENDPOINT_ID env var not set") + + try: + from chemgraph.execution.config import get_backend + except ImportError: + pytest.skip("chemgraph.execution not available") + + try: + backend = get_backend( + backend_name="globus_compute", endpoint_id=endpoint_id + ) + except ImportError: + pytest.skip("globus-compute-sdk not installed") + + yield backend + backend.shutdown() + + +def _gc_double(x): + """Trivial function for Globus Compute integration tests.""" + return x * 2 + + +def _gc_square(x): + """Square function for Globus Compute integration tests.""" + return x * x + + +def _gc_identity(x): + """Identity function for Globus Compute integration tests.""" + return x + + +@pytest.mark.globus_compute +class TestGlobusComputeIntegration: + """Integration tests that submit work to a real Globus Compute endpoint. + + These are skipped by default. Run with:: + + GLOBUS_COMPUTE_ENDPOINT_ID= pytest --run-globus-compute -k Integration + """ + + def test_python_task_roundtrip(self, globus_backend): + """Submit a trivial Python callable and verify the result.""" + task = TaskSpec( + task_id="roundtrip", + task_type="python", + callable=_gc_double, + args=(21,), + ) + fut = globus_backend.submit(task) + result = fut.result(timeout=120) + assert result == 42 + + def test_shell_task_roundtrip(self, globus_backend): + """Submit a shell command and verify the output.""" + task = TaskSpec( + task_id="shell_rt", + task_type="shell", + command="echo hello_globus", + ) + fut = globus_backend.submit(task) + result = fut.result(timeout=120) + # ShellFunction returns a ShellResult; stdout should contain the string + assert "hello_globus" in str(result) + + def test_batch_submission(self, globus_backend): + """Submit a batch of tasks and verify all results.""" + tasks = [ + TaskSpec( + task_id=f"batch_{i}", + task_type="python", + callable=_gc_square, + args=(i,), + ) + for i in range(5) + ] + futures = globus_backend.submit_batch(tasks) + results = [f.result(timeout=120) for f in futures] + assert results == [0, 1, 4, 9, 16] + + @pytest.mark.asyncio + async def test_gather_futures_with_real_endpoint(self, globus_backend): + """Verify gather_futures works with real ComputeFuture objects.""" + tasks = [ + TaskSpec( + task_id=f"gf_{i}", + task_type="python", + callable=_gc_identity, + args=(i,), + ) + for i in range(3) + ] + futs = globus_backend.submit_batch(tasks) + pending = [({"index": i}, f) for i, f in enumerate(futs)] + + results = await gather_futures(pending) + assert len(results) == 3 + assert all("index" in r for r in results) + + +# ── Layer 4: Edge-case and error-handling tests ───────────────────────── + + +class TestGlobusComputeEdgeCases: + """Mocked tests for error paths and edge conditions.""" + + def test_submit_after_shutdown(self): + """Submitting after shutdown() should raise RuntimeError.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + backend.shutdown() + + task = TaskSpec(task_id="late", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_double_initialize(self): + """Calling initialize() twice should succeed and create a new executor.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="uuid-1") + backend.initialize(endpoint_id="uuid-2") + + assert backend._initialized is True + assert mock_sdk.Executor.call_count == 2 + backend.shutdown() + + def test_context_manager_with_exception(self): + """shutdown() must be called even when the body raises.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + with pytest.raises(ValueError, match="intentional"): + with GlobusComputeBackend() as backend: + backend.initialize(endpoint_id="test-uuid") + raise ValueError("intentional") + + mock_executor.shutdown.assert_called_once() + + def test_executor_submit_raises_propagates(self): + """Errors from executor.submit() should propagate to the caller.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + mock_executor.submit.side_effect = RuntimeError("endpoint unavailable") + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="err", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="endpoint unavailable"): + backend.submit(task) + + def test_submit_with_resource_hints(self): + """Resource hints are advisory and should not break submission.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="hints", + task_type="python", + callable=_square, + args=(5,), + num_nodes=4, + processes_per_node=32, + gpus_per_task=4, + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + # Resource hints should NOT be passed to executor.submit + mock_executor.submit.assert_called_once_with(_square, 5) + + def test_failed_future_result(self): + """A future that resolves to an exception should be retrievable.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + failed_future = Future() + failed_future.set_exception(RuntimeError("task exploded")) + mock_executor.submit.return_value = failed_future + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="fail", callable=_square, args=(1,)) + fut = backend.submit(task) + + with pytest.raises(RuntimeError, match="task exploded"): + fut.result(timeout=5) From 487874046f1d07624b967b4266c68b7d91d0ff1e Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 4 May 2026 12:02:08 -0500 Subject: [PATCH 002/119] Fix unreachable code in aurora_parsl and EnsembleLauncher shutdown state Remove dead num_nodes=1 after raise in aurora_parsl.py and fix misleading error message. Set _initialized=False at start of EnsembleLauncherBackend.shutdown() to prevent submitting to a partially torn-down backend. --- src/chemgraph/hpc_configs/aurora_parsl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/hpc_configs/aurora_parsl.py b/src/chemgraph/hpc_configs/aurora_parsl.py index f2a0f761..ece27183 100644 --- a/src/chemgraph/hpc_configs/aurora_parsl.py +++ b/src/chemgraph/hpc_configs/aurora_parsl.py @@ -22,9 +22,9 @@ def get_aurora_config( node_list = f.readlines() num_nodes = len(node_list) else: - # Fallback for testing/local runs without PBS - raise ValueError("Warning: PBS_NODEFILE not found. Defaulting to 1 node.") - num_nodes = 1 + raise ValueError( + "PBS_NODEFILE not found. Cannot determine node count for Aurora." + ) config = Config( executors=[ From a8feddbe020c4417c2a886d8a47de628e74047a6 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 14 May 2026 13:42:27 -0500 Subject: [PATCH 003/119] Update Globus config --- src/chemgraph/execution/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index 71d3de90..80b3c458 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -124,6 +124,12 @@ def get_backend( backend_cfg = cfg.get(resolved_backend, {}) merged_kwargs = {**backend_cfg, **kwargs} + # Globus Compute: fall back to GLOBUS_COMPUTE_ENDPOINT_ID env var + if resolved_backend == "globus_compute" and "endpoint_id" not in merged_kwargs: + env_id = os.getenv("GLOBUS_COMPUTE_ENDPOINT_ID") + if env_id: + merged_kwargs["endpoint_id"] = env_id + # -- instantiate ---------------------------------------------------------- logger.info( "Creating execution backend '%s' for system '%s'", From 3144c6a7ff2605e9efd6f7b13bda6caad25adae5 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 14 May 2026 13:43:00 -0500 Subject: [PATCH 004/119] Add inline structure for file transferring between local and globus remote --- src/chemgraph/mcp/mace_mcp_hpc.py | 88 +++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index eba86858..dc966c49 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -11,8 +11,10 @@ - Uses shared utilities for structure resolution and result gathering. """ +import asyncio import json import logging +import os from pathlib import Path from mcp.server.fastmcp import FastMCP @@ -27,7 +29,6 @@ from chemgraph.tools.parsl_tools import ( mace_input_schema, mace_input_schema_ensemble, - run_mace_core, ) logger = logging.getLogger(__name__) @@ -59,19 +60,85 @@ def _run_mace_single(job: dict) -> dict: - """Execute a single MACE simulation (runs on the worker).""" + """Execute a single MACE simulation (runs on the worker). + + When the ``job`` dict contains an ``inline_structure`` key (with + ``numbers``, ``positions``, and optional ``cell``/``pbc``), the + structure is materialised as a temporary XYZ file on the worker + filesystem before running MACE. This allows local-agent / + remote-worker workflows where the original file only exists on the + submitting machine. + """ + import os + import tempfile + from chemgraph.tools.parsl_tools import mace_input_schema, run_mace_core + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, job.get("output_result_file", "output.json") + ) + params = mace_input_schema(**job) if isinstance(job, dict) else job - return run_mace_core(params) + result = run_mace_core(params) + + # Embed full output JSON when running with inline structure so the + # caller does not need to read a file on the remote filesystem. + if inline is not None: + out_file = job.get("output_result_file", "") + if os.path.isfile(out_file): + import json as _json + + with open(out_file, "r") as fh: + result["full_output"] = _json.load(fh) + + return result @mcp.tool( name="run_mace_single", description="Run a single MACE calculation", ) -def run_mace_single(params: mace_input_schema): - return run_mace_core(params) +async def run_mace_single(params: mace_input_schema): + """Run a single MACE calculation using the configured execution backend.""" + job = params.model_dump() + + # Read the local structure file and embed it so the job is + # self-contained and can run on any worker (local or remote). + input_file = job.get("input_structure_file") + if input_file and os.path.isfile(input_file): + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(input_file) + atomsdata = atoms_to_atomsdata(atoms) + job["inline_structure"] = atomsdata.model_dump() + + task = TaskSpec( + task_id="mace_single", + task_type="python", + callable=_run_mace_single, + kwargs={"job": job}, + ) + fut = backend.submit(task) + return await asyncio.wrap_future(fut) def _mace_post_fn(meta: dict, result) -> dict: @@ -129,6 +196,17 @@ async def run_mace_ensemble(params: mace_input_schema_ensemble): "optimizer": params.optimizer, } + # Embed structure data so the job works on remote workers that + # cannot access the local filesystem. + if struct_path.is_file(): + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(struct_path)) + atomsdata = atoms_to_atomsdata(atoms) + job["inline_structure"] = atomsdata.model_dump() + task = TaskSpec( task_id=f"mace_{struct_path.stem}", task_type="python", From e032fb81d92f3fcc85abd3109140bbb89292bd76 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 14 May 2026 18:09:04 -0500 Subject: [PATCH 005/119] Add async job tracking for Globus Compute MCP tools When backend=globus_compute, MCP tools now return immediately after submitting jobs to the remote HPC endpoint instead of blocking until completion. A new JobTracker tracks submitted futures across tool calls, and new MCP tools (check_job_status, get_job_results, list_jobs, cancel_job) let the LLM agent poll for progress and retrieve results. Non-Globus backends (local, Parsl, EnsembleLauncher) are unchanged and continue to block until results are ready. Key changes: - Add is_async_remote property to ExecutionBackend (True for Globus) - Add check_endpoint_status() health check to GlobusComputeBackend - Add JobTracker with batch registration, status, results, cleanup - Add submit_or_gather() utility that branches on backend type - Add optional timeout parameter to gather_futures() - Add register_job_tools() to wire job tools into any MCP server - Integrate tracker into MACE, XANES, and gRASPA MCP servers --- src/chemgraph/execution/__init__.py | 2 + src/chemgraph/execution/base.py | 8 + .../execution/globus_compute_backend.py | 30 ++ src/chemgraph/execution/job_tracker.py | 296 +++++++++++++ src/chemgraph/execution/utils.py | 75 +++- src/chemgraph/mcp/graspa_mcp_hpc.py | 37 +- src/chemgraph/mcp/job_tools.py | 107 +++++ src/chemgraph/mcp/mace_mcp_hpc.py | 39 +- src/chemgraph/mcp/xanes_mcp_hpc.py | 36 +- tests/test_job_tracker.py | 394 ++++++++++++++++++ 10 files changed, 994 insertions(+), 30 deletions(-) create mode 100644 src/chemgraph/execution/job_tracker.py create mode 100644 src/chemgraph/mcp/job_tools.py create mode 100644 tests/test_job_tracker.py diff --git a/src/chemgraph/execution/__init__.py b/src/chemgraph/execution/__init__.py index 0fd6709b..bd6d0ccf 100644 --- a/src/chemgraph/execution/__init__.py +++ b/src/chemgraph/execution/__init__.py @@ -25,9 +25,11 @@ from chemgraph.execution.base import ExecutionBackend, TaskSpec from chemgraph.execution.config import get_backend +from chemgraph.execution.job_tracker import JobTracker __all__ = [ "ExecutionBackend", + "JobTracker", "TaskSpec", "get_backend", ] diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py index e7dc338b..ccfb4f2d 100644 --- a/src/chemgraph/execution/base.py +++ b/src/chemgraph/execution/base.py @@ -102,6 +102,14 @@ class ExecutionBackend(ABC): def __init__(self) -> None: self._initialized: bool = False + @property + def is_async_remote(self) -> bool: + """Whether this backend submits to a remote queue where jobs may + take minutes to hours. When ``True``, MCP tools should return + immediately after submission and provide separate status/result + retrieval tools instead of blocking until completion.""" + return False + @abstractmethod def initialize(self, system: str = "local", **kwargs: Any) -> None: """Prepare the backend for accepting work. diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index 0c2a9634..2ec2bba1 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -51,6 +51,11 @@ class GlobusComputeBackend(ExecutionBackend): def __init__(self) -> None: super().__init__() self._executor = None + self._endpoint_id: str | None = None + + @property + def is_async_remote(self) -> bool: + return True # ── lifecycle ──────────────────────────────────────────────────────── @@ -77,6 +82,7 @@ def initialize(self, system: str = "local", **kwargs: Any) -> None: if amqp_port is not None: executor_kwargs["amqp_port"] = int(amqp_port) + self._endpoint_id = endpoint_id self._executor = Executor(**executor_kwargs) self._initialized = True logger.info( @@ -118,6 +124,30 @@ def submit(self, task: TaskSpec) -> Future: f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." ) + # ── health check ──────────────────────────────────────────────────── + + def check_endpoint_status(self) -> dict: + """Check the status of the configured Globus Compute endpoint. + + Returns a dict with ``endpoint_id`` and ``status`` fields. + Useful as a pre-flight check before submitting tasks. + """ + try: + from globus_compute_sdk import Client + + client = Client() + status = client.get_endpoint_status(self._endpoint_id) + return { + "endpoint_id": self._endpoint_id, + "status": status, + } + except Exception as e: + return { + "endpoint_id": self._endpoint_id, + "status": "error", + "error": str(e), + } + # ── teardown ──────────────────────────────────────────────────────── def shutdown(self) -> None: diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py new file mode 100644 index 00000000..87b473c0 --- /dev/null +++ b/src/chemgraph/execution/job_tracker.py @@ -0,0 +1,296 @@ +"""In-memory job tracker for async remote execution backends. + +Tracks ``concurrent.futures.Future`` objects returned by +:meth:`ExecutionBackend.submit` so that MCP tools can return +immediately after submission and provide separate status / result +retrieval endpoints. + +Each MCP server process creates its own ``JobTracker`` instance +(mirroring the existing ``backend = get_backend()`` pattern). +""" + +from __future__ import annotations + +import logging +import threading +import uuid +from concurrent.futures import Future +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class TrackedTask: + """A single task within a tracked batch.""" + + task_id: str + meta: dict + future: Future + result: Optional[dict] = None + + +@dataclass +class TrackedBatch: + """A group of tasks submitted together.""" + + batch_id: str + tool_name: str + submitted_at: datetime + tasks: list[TrackedTask] = field(default_factory=list) + post_fn: Optional[Callable[[dict, Any], dict]] = None + + +class JobTracker: + """Track submitted job batches and their futures. + + Thread-safe: all public methods acquire an internal lock. + """ + + def __init__(self) -> None: + self._batches: dict[str, TrackedBatch] = {} + self._lock = threading.Lock() + + # ── registration ─────────────────────────────────────────────────── + + def register_batch( + self, + tool_name: str, + pending_tasks: list[tuple[dict, Future]], + post_fn: Optional[Callable[[dict, Any], dict]] = None, + ) -> str: + """Register a batch of submitted tasks and return a batch ID. + + Parameters + ---------- + tool_name : str + Name of the MCP tool that submitted the batch. + pending_tasks : list[tuple[dict, Future]] + Each element is ``(metadata_dict, future)``. + post_fn : callable, optional + Post-processing function applied when collecting results. + Called as ``post_fn(metadata, raw_result) -> dict``. + + Returns + ------- + str + A UUID batch identifier. + """ + batch_id = uuid.uuid4().hex[:12] + tracked = [ + TrackedTask( + task_id=meta.get("task_id", meta.get("structure", f"task_{i}")), + meta=meta, + future=fut, + ) + for i, (meta, fut) in enumerate(pending_tasks) + ] + batch = TrackedBatch( + batch_id=batch_id, + tool_name=tool_name, + submitted_at=datetime.now(timezone.utc), + tasks=tracked, + post_fn=post_fn, + ) + with self._lock: + self._batches[batch_id] = batch + + logger.info( + "Registered batch '%s' (%s) with %d tasks", + batch_id, + tool_name, + len(tracked), + ) + return batch_id + + # ── status ───────────────────────────────────────────────────────── + + def get_status(self, batch_id: str) -> dict: + """Return the current status of a batch. + + Returns + ------- + dict + Keys: ``batch_id``, ``tool_name``, ``submitted_at``, + ``status``, ``total_tasks``, ``completed_tasks``, + ``failed_tasks``, ``pending_tasks``, ``progress_pct``. + """ + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + total = len(batch.tasks) + done = 0 + failed = 0 + + for t in batch.tasks: + if t.future.done(): + done += 1 + # Cache the result on first check + if t.result is None: + try: + raw = t.future.result(timeout=0) + if batch.post_fn is not None: + t.result = batch.post_fn(t.meta, raw) + elif isinstance(raw, dict): + merged = {**t.meta, **raw} + merged.setdefault("status", "success") + t.result = merged + else: + t.result = { + **t.meta, + "result": raw, + "status": "success", + } + except Exception as e: + t.result = { + **t.meta, + "status": "failure", + "error_type": type(e).__name__, + "message": str(e), + } + if t.result.get("status") == "failure": + failed += 1 + + pending = total - done + if pending == total: + status = "pending" + elif pending > 0: + status = "running" + elif failed == total: + status = "failed" + elif failed > 0: + status = "partial" + else: + status = "completed" + + return { + "batch_id": batch_id, + "tool_name": batch.tool_name, + "submitted_at": batch.submitted_at.isoformat(), + "status": status, + "total_tasks": total, + "completed_tasks": done - failed, + "failed_tasks": failed, + "pending_tasks": pending, + "progress_pct": round(done / total * 100, 1) if total else 0.0, + } + + # ── results ──────────────────────────────────────────────────────── + + def get_results( + self, batch_id: str, include_partial: bool = False + ) -> dict: + """Collect results from a batch. + + Parameters + ---------- + batch_id : str + The batch identifier. + include_partial : bool + If ``True``, return results for completed tasks even if some + are still pending. If ``False`` (default) and the batch is + not fully resolved, return a status message instead. + + Returns + ------- + dict + Contains ``status``, ``results`` list, and summary counts. + """ + status_info = self.get_status(batch_id) + if "error" in status_info: + return status_info + + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + if not include_partial and status_info["pending_tasks"] > 0: + return { + **status_info, + "message": ( + f"{status_info['pending_tasks']} of " + f"{status_info['total_tasks']} tasks still pending. " + f"Call check_job_status('{batch_id}') to monitor, " + f"or use include_partial=True to get partial results." + ), + } + + results = [] + for t in batch.tasks: + if t.result is not None: + results.append(t.result) + + return { + **status_info, + "results": results, + } + + # ── listing ──────────────────────────────────────────────────────── + + def list_batches(self) -> list[dict]: + """Return a summary of all tracked batches.""" + with self._lock: + batch_ids = list(self._batches.keys()) + + summaries = [] + for bid in batch_ids: + summaries.append(self.get_status(bid)) + return summaries + + # ── cancellation ─────────────────────────────────────────────────── + + def cancel_batch(self, batch_id: str) -> dict: + """Attempt to cancel pending tasks in a batch. + + Returns a dict with the number of successfully cancelled tasks. + Note: ``Future.cancel()`` only succeeds if the task has not yet + started executing. + """ + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + cancelled = 0 + already_done = 0 + for t in batch.tasks: + if t.future.done(): + already_done += 1 + elif t.future.cancel(): + cancelled += 1 + + return { + "batch_id": batch_id, + "cancelled": cancelled, + "already_done": already_done, + "could_not_cancel": len(batch.tasks) - cancelled - already_done, + } + + # ── cleanup ──────────────────────────────────────────────────────── + + def cleanup(self, max_age_hours: float = 24) -> int: + """Remove completed batches older than *max_age_hours*. + + Returns the number of batches removed. + """ + now = datetime.now(timezone.utc) + to_remove: list[str] = [] + + with self._lock: + for bid, batch in self._batches.items(): + age_hours = (now - batch.submitted_at).total_seconds() / 3600 + if age_hours > max_age_hours and all( + t.future.done() for t in batch.tasks + ): + to_remove.append(bid) + for bid in to_remove: + del self._batches[bid] + + if to_remove: + logger.info("Cleaned up %d old batches", len(to_remove)) + return len(to_remove) diff --git a/src/chemgraph/execution/utils.py b/src/chemgraph/execution/utils.py index 70759a71..ba941fd6 100644 --- a/src/chemgraph/execution/utils.py +++ b/src/chemgraph/execution/utils.py @@ -16,7 +16,11 @@ import logging from concurrent.futures import Future from pathlib import Path -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + from chemgraph.execution.base import ExecutionBackend + from chemgraph.execution.job_tracker import JobTracker logger = logging.getLogger(__name__) @@ -81,6 +85,7 @@ def resolve_structure_files( async def gather_futures( pending: list[tuple[dict, Future]], post_fn: Optional[Callable[[dict, Any], dict]] = None, + timeout: Optional[float] = None, ) -> list[dict]: """Await a list of ``(metadata, future)`` pairs concurrently. @@ -96,11 +101,20 @@ async def gather_futures( successful future resolution. Must return a ``dict`` to include in the results list. When *None*, the raw result is merged with metadata. + timeout : float, optional + Maximum seconds to wait for all futures to resolve. If the + timeout expires, an :class:`asyncio.TimeoutError` is raised. + When *None* (default), wait indefinitely. Returns ------- list[dict] One result dict per task (successful or failed). + + Raises + ------ + asyncio.TimeoutError + If *timeout* is set and exceeded before all futures complete. """ async def _wait(meta: dict, fut: Future) -> dict: @@ -122,9 +136,62 @@ async def _wait(meta: dict, fut: Future) -> dict: "message": str(e), } - return list( - await asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) - ) + coro = asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) + if timeout is not None: + return list(await asyncio.wait_for(coro, timeout=timeout)) + return list(await coro) + + +async def submit_or_gather( + backend: ExecutionBackend, + pending: list[tuple[dict, Future]], + tracker: JobTracker, + tool_name: str, + post_fn: Optional[Callable[[dict, Any], dict]] = None, +) -> dict: + """Gather results or register for async tracking, depending on the backend. + + When ``backend.is_async_remote`` is ``True``, the pending futures are + registered with the *tracker* and a submission confirmation is + returned immediately (non-blocking). Otherwise, results are gathered + synchronously via :func:`gather_futures`. + + Parameters + ---------- + backend : ExecutionBackend + The active execution backend. + pending : list[tuple[dict, Future]] + Each element is ``(metadata_dict, future)``. + tracker : JobTracker + The job tracker instance to register batches with. + tool_name : str + Name of the MCP tool submitting the batch. + post_fn : callable, optional + Post-processing function for results. + + Returns + ------- + dict + Either ``{"status": "submitted", "batch_id": ..., ...}`` for + async backends, or ``{"status": "completed", "results": ...}`` + for synchronous backends. + """ + if backend.is_async_remote: + batch_id = tracker.register_batch(tool_name, pending, post_fn=post_fn) + return { + "status": "submitted", + "batch_id": batch_id, + "n_tasks": len(pending), + "message": ( + f"Submitted {len(pending)} task(s) to remote HPC endpoint. " + f"Use check_job_status(batch_id='{batch_id}') to monitor " + f"progress, and get_job_results(batch_id='{batch_id}') to " + f"retrieve results once complete." + ), + } + + results = await gather_futures(pending, post_fn=post_fn) + return {"status": "completed", "results": results} def write_results_jsonl( diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index 9ee276bc..87eeb231 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -12,12 +12,14 @@ from mcp.server.fastmcp import FastMCP from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.job_tracker import JobTracker from chemgraph.execution.utils import ( - gather_futures, make_per_structure_output, resolve_structure_files, + submit_or_gather, write_results_jsonl, ) +from chemgraph.mcp.job_tools import register_job_tools from chemgraph.mcp.server_utils import run_mcp_server from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble @@ -25,6 +27,7 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() +tracker = JobTracker() # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -34,6 +37,10 @@ The available tools are: 1. run_graspa_ensemble: run graspa calculations over all structures in a directory using the configured execution backend. + 2. check_job_status: check progress of a submitted HPC job batch. + 3. get_job_results: retrieve results from a completed job batch. + 4. list_jobs: list all tracked job batches. + 5. cancel_job: cancel pending tasks in a job batch. Guidelines: - Use each tool only when its input schema matches the user request. @@ -42,8 +49,11 @@ defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling get_job_results. """, ) +register_job_tools(mcp, tracker, backend) def _run_graspa_single(job: dict) -> dict: @@ -108,17 +118,24 @@ async def run_graspa_ensemble( } pending_tasks.append((task_meta, fut)) - results = await gather_futures(pending_tasks) - - summary_log_path = output_dir / "simulation_results.jsonl" - success_count, total_count = write_results_jsonl(results, summary_log_path) - - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." + result = await submit_or_gather( + backend, pending_tasks, tracker, "run_graspa_ensemble", ) + if result["status"] == "completed": + summary_log_path = output_dir / "simulation_results.jsonl" + success_count, total_count = write_results_jsonl( + result["results"], summary_log_path, + ) + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) + + # Async remote: return submission confirmation + return result + if __name__ == "__main__": run_mcp_server(mcp, default_port=9001) diff --git a/src/chemgraph/mcp/job_tools.py b/src/chemgraph/mcp/job_tools.py new file mode 100644 index 00000000..6974aef1 --- /dev/null +++ b/src/chemgraph/mcp/job_tools.py @@ -0,0 +1,107 @@ +"""Shared MCP tools for job status tracking and result retrieval. + +Call :func:`register_job_tools` to add ``check_job_status``, +``get_job_results``, ``list_jobs``, ``cancel_job``, and (optionally) +``check_endpoint_status`` to any :class:`~mcp.server.fastmcp.FastMCP` +server instance. + +These tools are only useful when the execution backend is async-remote +(e.g. Globus Compute), but are registered unconditionally so the LLM +agent always has a consistent tool surface. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from mcp.server.fastmcp import FastMCP + + from chemgraph.execution.base import ExecutionBackend + from chemgraph.execution.job_tracker import JobTracker + + +def register_job_tools( + mcp: FastMCP, + tracker: JobTracker, + backend: ExecutionBackend, +) -> None: + """Register job-management MCP tools on *mcp*. + + Parameters + ---------- + mcp : FastMCP + The MCP server to register tools on. + tracker : JobTracker + The job tracker for this server process. + backend : ExecutionBackend + The active execution backend (used for endpoint health checks). + """ + + @mcp.tool( + name="check_job_status", + description=( + "Check the status of a previously submitted HPC job batch. " + "Returns progress information including how many tasks are " + "complete, failed, or still pending. Use this to poll " + "long-running remote compute jobs." + ), + ) + def check_job_status(batch_id: str) -> dict: + """Check the status of a submitted job batch.""" + return tracker.get_status(batch_id) + + @mcp.tool( + name="get_job_results", + description=( + "Retrieve results from a completed (or partially completed) " + "HPC job batch. By default, returns results only when all " + "tasks are done. Set include_partial=True to get results " + "for tasks that have finished so far." + ), + ) + def get_job_results( + batch_id: str, + include_partial: bool = False, + ) -> dict: + """Retrieve results from a job batch.""" + return tracker.get_results(batch_id, include_partial=include_partial) + + @mcp.tool( + name="list_jobs", + description=( + "List all tracked job batches with their current status. " + "Shows batch IDs, tool names, submission times, and progress." + ), + ) + def list_jobs() -> list[dict]: + """List all tracked job batches.""" + batches = tracker.list_batches() + if not batches: + return [{"message": "No job batches tracked."}] + return batches + + @mcp.tool( + name="cancel_job", + description=( + "Cancel pending tasks in a job batch. Only tasks that have " + "not yet started executing can be cancelled." + ), + ) + def cancel_job(batch_id: str) -> dict: + """Cancel pending tasks in a job batch.""" + return tracker.cancel_batch(batch_id) + + if backend.is_async_remote and hasattr(backend, "check_endpoint_status"): + + @mcp.tool( + name="check_endpoint_status", + description=( + "Check whether the remote HPC compute endpoint is " + "reachable and accepting tasks. Use this as a pre-flight " + "check before submitting jobs." + ), + ) + def check_endpoint_status() -> dict: + """Check the remote compute endpoint status.""" + return backend.check_endpoint_status() diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index dc966c49..a664a1e7 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -20,11 +20,13 @@ from mcp.server.fastmcp import FastMCP from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.job_tracker import JobTracker from chemgraph.execution.utils import ( - gather_futures, make_per_structure_output, resolve_structure_files, + submit_or_gather, ) +from chemgraph.mcp.job_tools import register_job_tools from chemgraph.mcp.server_utils import run_mcp_server from chemgraph.tools.parsl_tools import ( mace_input_schema, @@ -35,6 +37,7 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() +tracker = JobTracker() # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -47,6 +50,10 @@ 2. run_mace_ensemble: run MACE calculations over all structures in a directory using the configured execution backend. 3. extract_output_json: load simulation results from a JSON file. + 4. check_job_status: check progress of a submitted HPC job batch. + 5. get_job_results: retrieve results from a completed job batch. + 6. list_jobs: list all tracked job batches. + 7. cancel_job: cancel pending tasks in a job batch. Guidelines: - Use each tool only when its input schema matches the user request. @@ -55,8 +62,11 @@ defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling get_job_results. """, ) +register_job_tools(mcp, tracker, backend) def _run_mace_single(job: dict) -> dict: @@ -138,6 +148,13 @@ async def run_mace_single(params: mace_input_schema): kwargs={"job": job}, ) fut = backend.submit(task) + + if backend.is_async_remote: + task_meta = {"task_id": "mace_single"} + return await submit_or_gather( + backend, [(task_meta, fut)], tracker, "run_mace_single" + ) + return await asyncio.wrap_future(fut) @@ -221,13 +238,21 @@ async def run_mace_ensemble(params: mace_input_schema_ensemble): } pending_tasks.append((task_meta, fut)) - results = await gather_futures(pending_tasks, post_fn=_mace_post_fn) + result = await submit_or_gather( + backend, pending_tasks, tracker, "run_mace_ensemble", + post_fn=_mace_post_fn, + ) - return { - "status": "success", - "n_structures": len(structure_files), - "results": results, - } + if result["status"] == "completed": + return { + "status": "success", + "n_structures": len(structure_files), + "results": result["results"], + } + + # Async remote: return submission confirmation + result["n_structures"] = len(structure_files) + return result @mcp.tool( diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 3ed81fa7..4b0d219b 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -12,11 +12,13 @@ from mcp.server.fastmcp import FastMCP from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.job_tracker import JobTracker from chemgraph.execution.utils import ( - gather_futures, resolve_structure_files, + submit_or_gather, write_results_jsonl, ) +from chemgraph.mcp.job_tools import register_job_tools from chemgraph.mcp.server_utils import run_mcp_server from chemgraph.schemas.xanes_schema import ( mp_query_schema, @@ -28,6 +30,7 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() +tracker = JobTracker() # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -40,6 +43,10 @@ using the configured execution backend. 3. fetch_mp_structures: fetch optimized structures from Materials Project. 4. plot_xanes: generate normalized XANES plots for completed calculations. + 5. check_job_status: check progress of a submitted HPC job batch. + 6. get_job_results: retrieve results from a completed job batch. + 7. list_jobs: list all tracked job batches. + 8. cancel_job: cancel pending tasks in a job batch. Guidelines: - Use each tool only when its input schema matches the user request. @@ -47,8 +54,11 @@ - Keep responses compact -- full results are in the output directories. - When returning paths, use absolute paths. - Energies are in eV. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling get_job_results. """, ) +register_job_tools(mcp, tracker, backend) @mcp.tool( @@ -155,16 +165,24 @@ async def run_xanes_ensemble(params: xanes_input_schema_ensemble): } pending_tasks.append((task_meta, fut)) - results = await gather_futures(pending_tasks, post_fn=_xanes_post_fn) + result = await submit_or_gather( + backend, pending_tasks, tracker, "run_xanes_ensemble", + post_fn=_xanes_post_fn, + ) - summary_log_path = output_dir / "xanes_results.jsonl" - success_count, total_count = write_results_jsonl(results, summary_log_path) + if result["status"] == "completed": + summary_log_path = output_dir / "xanes_results.jsonl" + success_count, total_count = write_results_jsonl( + result["results"], summary_log_path, + ) + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." - ) + # Async remote: return submission confirmation + return result @mcp.tool( diff --git a/tests/test_job_tracker.py b/tests/test_job_tracker.py new file mode 100644 index 00000000..cee3d081 --- /dev/null +++ b/tests/test_job_tracker.py @@ -0,0 +1,394 @@ +"""Tests for the JobTracker and submit_or_gather utilities.""" + +import asyncio +from concurrent.futures import Future +from unittest.mock import MagicMock + +import pytest + +from chemgraph.execution.job_tracker import JobTracker +from chemgraph.execution.utils import gather_futures, submit_or_gather + + +# ── Helpers ──────────────────────────────────────────────────────────── + + +def _make_done_future(result): + """Create a Future that is already resolved with *result*.""" + fut = Future() + fut.set_result(result) + return fut + + +def _make_failed_future(exc): + """Create a Future that is already resolved with an exception.""" + fut = Future() + fut.set_exception(exc) + return fut + + +def _make_pending_future(): + """Create a Future that is not yet resolved.""" + return Future() + + +# ── JobTracker.register_batch ────────────────────────────────────────── + + +class TestRegisterBatch: + def test_returns_batch_id(self): + tracker = JobTracker() + fut = _make_pending_future() + batch_id = tracker.register_batch( + "test_tool", [({"key": "val"}, fut)] + ) + assert isinstance(batch_id, str) + assert len(batch_id) == 12 + + def test_stores_tasks(self): + tracker = JobTracker() + futs = [_make_pending_future() for _ in range(3)] + pending = [({"idx": i}, f) for i, f in enumerate(futs)] + batch_id = tracker.register_batch("test_tool", pending) + + status = tracker.get_status(batch_id) + assert status["total_tasks"] == 3 + + def test_multiple_batches_unique_ids(self): + tracker = JobTracker() + ids = set() + for _ in range(10): + bid = tracker.register_batch( + "tool", [({"x": 1}, _make_pending_future())] + ) + ids.add(bid) + assert len(ids) == 10 + + +# ── JobTracker.get_status ────────────────────────────────────────────── + + +class TestGetStatus: + def test_all_pending(self): + tracker = JobTracker() + pending = [({"i": i}, _make_pending_future()) for i in range(3)] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "pending" + assert status["total_tasks"] == 3 + assert status["completed_tasks"] == 0 + assert status["pending_tasks"] == 3 + assert status["progress_pct"] == 0.0 + + def test_all_completed(self): + tracker = JobTracker() + pending = [ + ({"i": i}, _make_done_future({"val": i})) for i in range(3) + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + assert status["completed_tasks"] == 3 + assert status["failed_tasks"] == 0 + assert status["pending_tasks"] == 0 + assert status["progress_pct"] == 100.0 + + def test_partial_done(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 0})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "running" + assert status["completed_tasks"] == 1 + assert status["pending_tasks"] == 1 + assert status["progress_pct"] == 50.0 + + def test_all_failed(self): + tracker = JobTracker() + pending = [ + ({"i": i}, _make_failed_future(ValueError(f"err_{i}"))) + for i in range(2) + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "failed" + assert status["failed_tasks"] == 2 + + def test_mixed_success_and_failure(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 0})), + ({"i": 1}, _make_failed_future(RuntimeError("boom"))), + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "partial" + assert status["completed_tasks"] == 1 + assert status["failed_tasks"] == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + status = tracker.get_status("nonexistent") + assert "error" in status + + def test_with_post_fn(self): + def post_fn(meta, result): + return {"custom": True, "status": "success", **meta} + + tracker = JobTracker() + pending = [({"i": 0}, _make_done_future({"raw": 1}))] + batch_id = tracker.register_batch("tool", pending, post_fn=post_fn) + + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + + +# ── JobTracker.get_results ───────────────────────────────────────────── + + +class TestGetResults: + def test_returns_results_when_complete(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_done_future({"val": 20})), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id) + assert "results" in result + assert len(result["results"]) == 2 + + def test_blocks_when_pending_and_partial_false(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id, include_partial=False) + assert "results" not in result + assert "message" in result + assert "still pending" in result["message"] + + def test_returns_partial_when_requested(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id, include_partial=True) + assert "results" in result + assert len(result["results"]) == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + result = tracker.get_results("nonexistent") + assert "error" in result + + +# ── JobTracker.list_batches ──────────────────────────────────────────── + + +class TestListBatches: + def test_empty(self): + tracker = JobTracker() + assert tracker.list_batches() == [] + + def test_multiple_batches(self): + tracker = JobTracker() + tracker.register_batch("tool_a", [({"x": 1}, _make_pending_future())]) + tracker.register_batch("tool_b", [({"x": 2}, _make_done_future(42))]) + + batches = tracker.list_batches() + assert len(batches) == 2 + tool_names = {b["tool_name"] for b in batches} + assert tool_names == {"tool_a", "tool_b"} + + +# ── JobTracker.cancel_batch ──────────────────────────────────────────── + + +class TestCancelBatch: + def test_cancel_pending(self): + tracker = JobTracker() + fut = _make_pending_future() + batch_id = tracker.register_batch("tool", [({"i": 0}, fut)]) + + result = tracker.cancel_batch(batch_id) + # Future.cancel() may or may not succeed depending on state, + # but the call should not raise + assert "batch_id" in result + + def test_cancel_already_done(self): + tracker = JobTracker() + fut = _make_done_future({"val": 1}) + batch_id = tracker.register_batch("tool", [({"i": 0}, fut)]) + + result = tracker.cancel_batch(batch_id) + assert result["already_done"] == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + result = tracker.cancel_batch("nonexistent") + assert "error" in result + + +# ── JobTracker.cleanup ───────────────────────────────────────────────── + + +class TestCleanup: + def test_removes_old_completed(self): + tracker = JobTracker() + batch_id = tracker.register_batch( + "tool", [({"i": 0}, _make_done_future(1))] + ) + + # Force the submitted_at to be old + batch = tracker._batches[batch_id] + from datetime import timedelta + + batch.submitted_at -= timedelta(hours=25) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 1 + assert tracker.list_batches() == [] + + def test_keeps_recent(self): + tracker = JobTracker() + tracker.register_batch("tool", [({"i": 0}, _make_done_future(1))]) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 0 + assert len(tracker.list_batches()) == 1 + + def test_keeps_pending(self): + tracker = JobTracker() + batch_id = tracker.register_batch( + "tool", [({"i": 0}, _make_pending_future())] + ) + + batch = tracker._batches[batch_id] + from datetime import timedelta + + batch.submitted_at -= timedelta(hours=25) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 0 + + +# ── gather_futures with timeout ──────────────────────────────────────── + + +class TestGatherFuturesTimeout: + def test_completes_within_timeout(self): + pending = [ + ({"i": 0}, _make_done_future({"val": 1})), + ({"i": 1}, _make_done_future({"val": 2})), + ] + results = asyncio.get_event_loop().run_until_complete( + gather_futures(pending, timeout=5.0) + ) + assert len(results) == 2 + + def test_timeout_raises(self): + pending = [({"i": 0}, _make_pending_future())] + with pytest.raises(asyncio.TimeoutError): + asyncio.get_event_loop().run_until_complete( + gather_futures(pending, timeout=0.1) + ) + + def test_no_timeout_default(self): + pending = [({"i": 0}, _make_done_future(42))] + results = asyncio.get_event_loop().run_until_complete( + gather_futures(pending) + ) + assert len(results) == 1 + + +# ── submit_or_gather ─────────────────────────────────────────────────── + + +class TestSubmitOrGather: + def test_sync_backend_returns_completed(self): + backend = MagicMock() + backend.is_async_remote = False + + tracker = JobTracker() + pending = [({"i": 0}, _make_done_future({"val": 10}))] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + assert result["status"] == "completed" + assert "results" in result + assert len(result["results"]) == 1 + + def test_async_backend_returns_submitted(self): + backend = MagicMock() + backend.is_async_remote = True + + tracker = JobTracker() + pending = [({"i": 0}, _make_pending_future())] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + assert result["status"] == "submitted" + assert "batch_id" in result + assert result["n_tasks"] == 1 + assert "check_job_status" in result["message"] + + def test_async_backend_batch_trackable(self): + backend = MagicMock() + backend.is_async_remote = True + + tracker = JobTracker() + fut = _make_done_future({"val": 99}) + pending = [({"i": 0}, fut)] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + batch_id = result["batch_id"] + + # Verify the batch is tracked and status works + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + + # Verify results can be retrieved + results = tracker.get_results(batch_id) + assert "results" in results + assert len(results["results"]) == 1 + + def test_async_backend_with_post_fn(self): + backend = MagicMock() + backend.is_async_remote = True + + def post_fn(meta, result): + return {"processed": True, "status": "success"} + + tracker = JobTracker() + fut = _make_done_future({"raw": 1}) + pending = [({"i": 0}, fut)] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather( + backend, pending, tracker, "test_tool", post_fn=post_fn, + ) + ) + batch_id = result["batch_id"] + + results = tracker.get_results(batch_id) + assert results["results"][0]["processed"] is True From ae7963d3d5bbcadef1a18871c99c7fcc9c8335f6 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Thu, 21 May 2026 16:43:22 -0500 Subject: [PATCH 006/119] Modified the EL backend implemenations, and added a EL backend test --- src/chemgraph/execution/base.py | 3 +- src/chemgraph/execution/config.py | 8 + .../execution/ensemble_launcher_backend.py | 149 ++++++++---- tests/test_execution.py | 221 +++++++++++++++++- 4 files changed, 325 insertions(+), 56 deletions(-) diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py index ccfb4f2d..c182b4cf 100644 --- a/src/chemgraph/execution/base.py +++ b/src/chemgraph/execution/base.py @@ -11,7 +11,7 @@ import logging from abc import ABC, abstractmethod from concurrent.futures import Future -from typing import Any, Callable, Literal, Optional +from typing import Any, Callable, Dict, Literal, Optional from pydantic import BaseModel, ConfigDict, Field @@ -85,6 +85,7 @@ class TaskSpec(BaseModel): default=0, description="Number of GPUs requested per task.", ) + env: Dict[str, str] = Field(default_factory=dict) class ExecutionBackend(ABC): diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index 80b3c458..60921f92 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -144,10 +144,18 @@ def get_backend( elif resolved_backend == "ensemble_launcher": from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, EnsembleLauncherBackend, + get_launcher_config, ) backend = EnsembleLauncherBackend() + assert system in SYSTEM_CONFIG_REGISTRY, ( + f"Unknown system: only know {SYSTEM_CONFIG_REGISTRY.keys()}" + ) + merged_kwargs = {} + merged_kwargs["system_config"] = SYSTEM_CONFIG_REGISTRY[system] + merged_kwargs["launcher_config"] = get_launcher_config(**backend_cfg) elif resolved_backend == "globus_compute": from chemgraph.execution.globus_compute_backend import ( diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 23462f5b..f78d41b1 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -13,17 +13,86 @@ import logging import os -import socket import time import uuid from concurrent.futures import Future -from typing import Any +from typing import List, Literal, Optional, Union from chemgraph.execution.base import ExecutionBackend, TaskSpec +try: + from ensemble_launcher import EnsembleLauncher + from ensemble_launcher.config import ( + LauncherConfig, + MPIConfig, + PolicyConfig, + SystemConfig, + ) + from ensemble_launcher.helper_functions import get_nodes + from ensemble_launcher.orchestrator import ClusterClient +except ImportError as exc: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) from exc + logger = logging.getLogger(__name__) +def get_local_system_config(): + system_config = SystemConfig( + name="local", + ncpus=os.cpu_count(), + cpus=list(range(os.cpu_count())), + ) + return system_config + + +def get_polaris_system_config(): + system_config = SystemConfig( + name="polaris", + ncpus=32, + cpus=list(range(32)), + ngpus=4, + gpus=list(range(4)), + ) + return system_config + + +def get_aurora_system_config(): + system_config = SystemConfig( + name="aurora", + ncpus=102, + cpus=list(range(1, 52)) + list(range(53, 104)), + ngpus=12, + gpus=list(range(12)), + ) + return system_config + + +def get_launcher_config( + task_executor_name: Union[str, List] = "async_processpool", + child_executor_policy: str = "fixed_leafs_children_policy", + policy_config: Optional[PolicyConfig] = None, + checkpoint_dir=f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}", + mpi_flavour: Literal["test", "mpich"] = "test", +): + if policy_config is None: + policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) + return LauncherConfig( + child_executor_name="async_mpi", + task_executor_name=task_executor_name, + return_stdout=True, + worker_logs=True, + master_logs=True, + children_scheduler_policy=child_executor_policy, + policy_config=policy_config, + cluster=True, + checkpoint_dir=checkpoint_dir, + mpi_config=MPIConfig(flavor=mpi_flavour), + ) + + class EnsembleLauncherBackend(ExecutionBackend): """Execution backend that delegates work to EnsembleLauncher. @@ -60,75 +129,42 @@ def __init__(self) -> None: self._client = None self._checkpoint_dir: str | None = None - def initialize(self, system: str = "local", **kwargs: Any) -> None: - try: - from ensemble_launcher import EnsembleLauncher - from ensemble_launcher.config import LauncherConfig, SystemConfig - from ensemble_launcher.orchestrator import ClusterClient - except ImportError as exc: - raise ImportError( - "EnsembleLauncher is required for the EnsembleLauncherBackend. " - "Install it with: pip install ensemble-launcher" - ) from exc - - # -- extract parameters ------------------------------------------------ - comm_name = kwargs.get("comm_name", "async_zmq") - task_executor = kwargs.get("task_executor_name", "async_processpool") - nlevels = kwargs.get("nlevels", 0) - ncpus = kwargs.get("max_workers", os.cpu_count() or 4) - checkpoint_dir = kwargs.get( - "checkpoint_dir", - os.path.join(os.getcwd(), f".el_ckpt_{uuid.uuid4().hex[:8]}"), - ) - nodes = kwargs.get("nodes", [socket.gethostname()]) - startup_delay = kwargs.get("startup_delay", 2.0) - - self._checkpoint_dir = checkpoint_dir - - # -- configure --------------------------------------------------------- - system_config = SystemConfig( - name=system, - ncpus=ncpus, - cpus=list(range(ncpus)), - ) - - launcher_config = LauncherConfig( - task_executor_name=task_executor, - comm_name=comm_name, - nlevels=nlevels, - cluster=True, - checkpoint_dir=checkpoint_dir, - ) + def initialize( + self, + system: str, + system_config: SystemConfig, + launcher_config: LauncherConfig, + startup_delay: float = 1.0, + ) -> None: + os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) # -- start orchestrator ------------------------------------------------ self._el = EnsembleLauncher( ensemble_file={}, system_config=system_config, launcher_config=launcher_config, - Nodes=nodes, ) self._el.start() time.sleep(startup_delay) # -- connect client ---------------------------------------------------- - self._client = ClusterClient(checkpoint_dir=checkpoint_dir) + self._client = ClusterClient(checkpoint_dir=launcher_config.checkpoint_dir) self._client.start() self._initialized = True logger.info( "EnsembleLauncherBackend initialized (system='%s', " "comm='%s', executor='%s', nodes=%s)", - system, - comm_name, - task_executor, - nodes, + system_config.name, + launcher_config.comm_name, + launcher_config.task_executor_name, + len(self._el.nodes), ) def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._client is None: raise RuntimeError( - "EnsembleLauncherBackend is not initialized. " - "Call initialize() first." + "EnsembleLauncherBackend is not initialized. Call initialize() first." ) from ensemble_launcher.ensemble import Task as ELTask @@ -145,6 +181,7 @@ def submit(self, task: TaskSpec) -> Future: executable=task.callable, args=task.args or (), kwargs=task.kwargs or {}, + env=task.env, ) return self._client.submit(el_task) @@ -157,7 +194,8 @@ def submit(self, task: TaskSpec) -> Future: task_id=task.task_id, nnodes=task.num_nodes, ppn=task.processes_per_node, - cmd_template=task.command, + executable=task.command, + env=task.env, ) return self._client.submit(el_task) @@ -197,3 +235,14 @@ def shutdown(self) -> None: "EnsembleLauncherBackend partially shut down. " "Call shutdown() again to retry failed teardown." ) + + +SYSTEM_CONFIG_REGISTRY = { + "local": get_local_system_config(), + "aurora": get_aurora_system_config(), + "polaris": get_polaris_system_config(), +} + +if __name__ == "__main__": + el_backend = EnsembleLauncherBackend() + el_backend.initialize() diff --git a/tests/test_execution.py b/tests/test_execution.py index 5f1617bc..c662547c 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -19,7 +19,7 @@ import pytest -from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.base import TaskSpec from chemgraph.execution.local_backend import LocalBackend from chemgraph.execution.utils import ( gather_futures, @@ -28,7 +28,6 @@ write_results_jsonl, ) - # ── TaskSpec tests ────────────────────────────────────────────────────── @@ -199,6 +198,220 @@ def test_shell_task_missing_command(self): backend.shutdown() +# ── EnsembleLauncherBackend tests ────────────────────────────────────────── + + +class TestELBackend: + @classmethod + def setup_class(cls): + project_root = str(Path(__file__).resolve().parent.parent) + existing = os.environ.get("PYTHONPATH", "") + os.environ["PYTHONPATH"] = ( + f"{project_root}:{existing}" if existing else project_root + ) + + def test_python_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="sq", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + assert fut.result(timeout=10) == 49 + finally: + backend.shutdown() + + def test_python_task_kwargs(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="add", + task_type="python", + callable=_add, + kwargs={"a": 3, "b": 5}, + ) + assert backend.submit(task).result(timeout=10) == 8 + finally: + backend.shutdown() + + def test_shell_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="echo", + task_type="shell", + command="echo hello_world", + ) + fut = backend.submit(task) + result = fut.result(timeout=10) + assert result is not None + finally: + backend.shutdown() + + def test_submit_batch(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + tasks = [ + TaskSpec( + task_id=f"sq_{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(5) + ] + futures = backend.submit_batch(tasks) + assert len(futures) == 5 + results = [f.result(timeout=10) for f in futures] + assert results == [0, 1, 4, 9, 16] + finally: + backend.shutdown() + + def test_failing_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="fail", + task_type="python", + callable=_failing_fn, + ) + fut = backend.submit(task) + with pytest.raises(Exception, match="intentional test error"): + fut.result(timeout=10) + finally: + backend.shutdown() + + def test_context_manager(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + with EnsembleLauncherBackend() as backend: + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + assert backend.submit(task).result(timeout=10) == 9 + + def test_not_initialized_raises(self): + from chemgraph.execution.ensemble_launcher_backend import ( + EnsembleLauncherBackend, + ) + + backend = EnsembleLauncherBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_python_task_missing_callable(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + finally: + backend.shutdown() + + def test_shell_task_missing_command(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + finally: + backend.shutdown() + + # ── GlobusComputeBackend tests ────────────────────────────────────────── @@ -807,9 +1020,7 @@ def globus_backend(): pytest.skip("chemgraph.execution not available") try: - backend = get_backend( - backend_name="globus_compute", endpoint_id=endpoint_id - ) + backend = get_backend(backend_name="globus_compute", endpoint_id=endpoint_id) except ImportError: pytest.skip("globus-compute-sdk not installed") From 1febc5ac90c5c42583d02065d57becf02d29efda Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Fri, 22 May 2026 12:46:29 -0500 Subject: [PATCH 007/119] Add CGFastMCP backend framework, EL client-only mode, and pickle fix - Add CGFastMCP: FastMCP subclass with integrated execution backend, lazy init, built-in job tools, @tool() and @ensemble_tool() decorators - Refactor EnsembleLauncherBackend with client-only mode (shared orchestrator via checkpoint_dir) and managed mode - Update get_backend() to route client_only vs managed EL initialization - Rewrite mace_mcp_hpc.py to use CGFastMCP decorators - Clean up parsl_tools.py: remove dead code, use stdlib logging - Fix __main__ pickle issue via _fix_module_for_pickle + sys.modules alias - Add client-only mode demo cell to notebook 3 Co-Authored-By: Claude Opus 4.6 --- notebooks/3_Demo_using_MCP.ipynb | 303 ++++++++++------ src/chemgraph/execution/config.py | 20 +- .../execution/ensemble_launcher_backend.py | 164 +++++---- src/chemgraph/mcp/cg_fastmcp.py | 339 ++++++++++++++++++ src/chemgraph/mcp/mace_mcp_hpc.py | 261 ++------------ src/chemgraph/tools/parsl_tools.py | 24 +- 6 files changed, 694 insertions(+), 417 deletions(-) create mode 100644 src/chemgraph/mcp/cg_fastmcp.py diff --git a/notebooks/3_Demo_using_MCP.ipynb b/notebooks/3_Demo_using_MCP.ipynb index ce37b46d..caf11cb0 100644 --- a/notebooks/3_Demo_using_MCP.ipynb +++ b/notebooks/3_Demo_using_MCP.ipynb @@ -2,190 +2,269 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "3b97dfba-13c9-49a4-bdce-efd5900dcafa", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tpham2/work/projects/ChemGraph/env/chemgraph_env/lib/python3.10/site-packages/google/api_core/_python_version_support.py:266: FutureWarning: You are using a Python version (3.10.19) which Google will stop supporting in new releases of google.api_core once it reaches its end of life (2026-10-04). Please upgrade to the latest Python version, or at least Python 3.11, to continue receiving updates for google.api_core past that date.\n", - " warnings.warn(message, FutureWarning)\n", - "WARNING:root:fairchem is not installed. .\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "2026-01-22 11:50:08,686 - chemgraph.models.openai - INFO - OpenAI API key not found in environment variables.\n" + "Done creating client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "INFO:chemgraph.models.openai:OpenAI API key not found in environment variables.\n" - ] - }, - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Please enter your OpenAI API key: ········\n" + "2026-05-22 12:34:00,370 - chemgraph.graphs.single_agent - INFO - Constructing single agent graph\n", + "2026-05-22 12:34:00,372 - chemgraph.graphs.single_agent - INFO - Graph construction completed\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2026-01-22 11:50:10,594 - chemgraph.models.openai - INFO - Loading OpenAI model: gpt-4o-mini\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:chemgraph.models.openai:Loading OpenAI model: gpt-4o-mini\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2026-01-22 11:50:10,710 - chemgraph.models.openai - INFO - Requested model: gpt-4o-mini\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:chemgraph.models.openai:Requested model: gpt-4o-mini\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2026-01-22 11:50:10,711 - chemgraph.models.openai - INFO - OpenAI model loaded successfully\n" + "Done getting tools\n", + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "Run a mace calculations with the same file, use energy for driver and small model. a cif file are located at /Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "Tool Calls:\n", + " run_mace_single (chatcmpl-tool-a42c48d32a55e54d)\n", + " Call ID: chatcmpl-tool-a42c48d32a55e54d\n", + " Args:\n", + " params: {'input_structure_file': '/Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif', 'driver': 'energy', 'model': 'small'}\n", + "=================================\u001b[1m Tool Message \u001b[0m=================================\n", + "Name: run_mace_single\n", + "\n", + "{\n", + " \"status\": \"success\",\n", + " \"message\": \"Simulation completed. Results saved to /Users/hari/projects/ChemGraph/notebooks/output.json\",\n", + " \"single_point_energy\": -295.75144320599975,\n", + " \"unit\": \"eV\"\n", + "}\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "\n", + "The MACE single‑point energy calculation completed successfully.\n", + "\n", + "**Result**\n", + "- **Energy:** -295.75144320599975 eV \n", + "- **Output file:** `/Users/hari/projects/ChemGraph/notebooks/output.json`\n", + "\n", + "If you need any other properties (e.g., forces, charge distribution) or would like to run additional calculations (geometry optimization, vibrational analysis, etc.), just let me know!\n", + "Done\n" ] - }, + } + ], + "source": [ + "import subprocess, time, os\n", + "from langchain_mcp_adapters.client import MultiServerMCPClient\n", + "from chemgraph.agent.llm_agent import ChemGraph\n", + "\n", + "prompt_single = \"Run a mace calculations with the same file, use energy for driver and small model. a cif file are located at /Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif\"\n", + "\n", + "os.environ[\"ALCF_ACCESS_TOKEN\"]=\" None: super().__init__() - self._el = None - self._client = None - self._checkpoint_dir: str | None = None + self._orchestrator: Optional[EnsembleLauncher] = None + self._client: Optional[ClusterClient] = None def initialize( self, - system: str, - system_config: SystemConfig, - launcher_config: LauncherConfig, - startup_delay: float = 1.0, + system: str = "local", + *, + client_only: bool = False, + checkpoint_dir: Optional[str] = None, + node_id: str = "global", + system_config: Optional[SystemConfig] = None, + launcher_config: Optional[LauncherConfig] = None, + startup_delay: float = 10.0, + **kwargs, ) -> None: + """Prepare the backend for accepting work. + + Parameters + ---------- + client_only : bool + When ``True``, connect to a running orchestrator via + *checkpoint_dir* — no orchestrator is started. + checkpoint_dir : str + Path to the orchestrator's checkpoint directory. Required + when *client_only* is ``True``. + node_id : str + Orchestrator node to connect to (default ``"global"``). + system_config, launcher_config + Required for **managed** mode (``client_only=False``). + The backend starts its own orchestrator with these. + startup_delay : float + Seconds to wait for the orchestrator to become ready + (managed mode only). + """ + if client_only: + # -- client-only mode ---------------------------------------------- + if checkpoint_dir is None: + raise ValueError( + "client_only=True requires a checkpoint_dir pointing " + "to a running orchestrator." + ) + self._client = ClusterClient( + checkpoint_dir=checkpoint_dir, node_id=node_id + ) + self._client.start() + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized in client-only mode " + "(checkpoint_dir='%s', node_id='%s')", + checkpoint_dir, + node_id, + ) + else: + # -- managed mode: start orchestrator first ------------------------ + if system_config is None or launcher_config is None: + raise ValueError( + "Managed mode requires system_config and launcher_config " + "(or set client_only=True with a checkpoint_dir)." + ) + os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) + self._orchestrator = EnsembleLauncher( + ensemble_file={}, + system_config=system_config, + launcher_config=launcher_config, + ) + self._orchestrator.start() + time.sleep(startup_delay) - os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) - # -- start orchestrator ------------------------------------------------ - self._el = EnsembleLauncher( - ensemble_file={}, - system_config=system_config, - launcher_config=launcher_config, - ) - self._el.start() - time.sleep(startup_delay) - - # -- connect client ---------------------------------------------------- - self._client = ClusterClient(checkpoint_dir=launcher_config.checkpoint_dir) - self._client.start() - - self._initialized = True - logger.info( - "EnsembleLauncherBackend initialized (system='%s', " - "comm='%s', executor='%s', nodes=%s)", - system_config.name, - launcher_config.comm_name, - launcher_config.task_executor_name, - len(self._el.nodes), - ) + self._client = ClusterClient( + checkpoint_dir=launcher_config.checkpoint_dir, + node_id=node_id, + ) + self._client.start() + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized in managed mode " + "(system='%s', comm='%s', executor='%s', nodes=%s)", + system_config.name, + launcher_config.comm_name, + launcher_config.task_executor_name, + len(self._orchestrator.nodes), + ) def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._client is None: @@ -217,18 +252,19 @@ def shutdown(self) -> None: "Error tearing down EnsembleLauncher client.", exc_info=True ) - el_ok = True - if self._el is not None: + orchestrator_ok = True + if self._orchestrator is not None: try: - self._el.stop() - self._el = None + self._orchestrator.stop() + self._orchestrator = None except Exception: - el_ok = False + orchestrator_ok = False logger.warning( - "Error stopping EnsembleLauncher orchestrator.", exc_info=True + "Error stopping EnsembleLauncher orchestrator.", + exc_info=True, ) - if client_ok and el_ok: + if client_ok and orchestrator_ok: logger.info("EnsembleLauncherBackend shut down.") else: logger.warning( diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py new file mode 100644 index 00000000..e25b23f9 --- /dev/null +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -0,0 +1,339 @@ +"""Backend-aware FastMCP subclass for ChemGraph. + +:class:`CGFastMCP` extends :class:`FastMCP` with an execution backend. +Tools registered via :meth:`tool` are automatically submitted to the +backend as :class:`~chemgraph.execution.base.TaskSpec` instances — +the tool author writes a plain function and the framework handles +submission, future resolution, and async job tracking. + +Tools that do **not** need the backend (e.g. JSON loaders, plotting +utilities) should be registered with :meth:`add_tool` (inherited from +FastMCP) which bypasses the backend wrapper entirely. +""" + +import asyncio +import functools +import inspect +import logging +from typing import Any, Callable, Dict, Optional + +from mcp.server.fastmcp import FastMCP +from mcp.types import ToolAnnotations + +logger = logging.getLogger(__name__) + + +class CGFastMCP(FastMCP): + """FastMCP with an integrated execution backend. + + Parameters + ---------- + **kwargs + Forwarded to :class:`FastMCP` (``name``, ``instructions``, etc.). + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._backend = None + self._tracker = None + self._backend_kwargs: Optional[dict[str, Any]] = None + + # ── Backend lifecycle ─────────────────────────────────────────────── + + def init_backend(self, **kwargs: Any) -> None: + """Register backend configuration for lazy initialisation. + + The backend is not created until the first tool invocation, + so the MCP server can start accepting connections immediately. + All keyword arguments are forwarded to + :func:`~chemgraph.execution.config.get_backend`. + """ + self._backend_kwargs = kwargs + self._register_job_tools() + logger.info("CGFastMCP backend configured (lazy init).") + + def _ensure_backend(self) -> None: + """Create the backend on first use.""" + if self._backend is not None: + return + if self._backend_kwargs is None: + raise RuntimeError( + "Backend not configured. Call init_backend() first." + ) + from chemgraph.execution import JobTracker, get_backend + + self._backend = get_backend(**self._backend_kwargs) + self._tracker = JobTracker() + logger.info( + "CGFastMCP backend initialised: %s", type(self._backend).__name__ + ) + + def shutdown_backend(self) -> None: + """Shut down the execution backend and release resources.""" + if self._backend is not None: + try: + self._backend.shutdown() + except Exception: + logger.warning("Error during backend shutdown.", exc_info=True) + self._backend = None + self._tracker = None + self._backend_kwargs = None + logger.info("CGFastMCP backend shut down.") + + # ── Job tracking tools ───────────────────────────────────────────── + + def _register_job_tools(self) -> None: + """Register job-management tools (status, results, cancel).""" + + @self.add_tool + def check_job_status(batch_id: str) -> dict: + """Check the status of a submitted job batch.""" + self._ensure_backend() + return self._tracker.get_status(batch_id) + + @self.add_tool + def get_job_results( + batch_id: str, include_partial: bool = False + ) -> dict: + """Retrieve results from a completed job batch.""" + self._ensure_backend() + return self._tracker.get_results( + batch_id, include_partial=include_partial + ) + + @self.add_tool + def list_jobs() -> list[dict]: + """List all tracked job batches.""" + self._ensure_backend() + batches = self._tracker.list_batches() + if not batches: + return [{"message": "No job batches tracked."}] + return batches + + @self.add_tool + def cancel_job(batch_id: str) -> dict: + """Cancel pending tasks in a job batch.""" + self._ensure_backend() + return self._tracker.cancel_batch(batch_id) + + @self.add_tool + def check_endpoint_status() -> dict: + """Check whether the remote compute endpoint is reachable.""" + self._ensure_backend() + if hasattr(self._backend, "check_endpoint_status"): + return self._backend.check_endpoint_status() + return {"status": "not_applicable", + "message": "This backend does not support endpoint status checks."} + + # ── Internal helpers ────────────────────────────────────────────── + + @staticmethod + def _fix_module_for_pickle(fn: Callable) -> None: + """Ensure *fn* is picklable when the MCP server runs as ``__main__``.""" + if fn.__module__ == "__main__": + import sys + + spec = getattr(sys.modules.get("__main__"), "__spec__", None) + if spec and spec.name: + fn.__module__ = spec.name + if spec.name not in sys.modules: + sys.modules[spec.name] = sys.modules["__main__"] + + # ── Tool registration ─────────────────────────────────────────────── + + def tool( + self, + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + structured_output: Optional[bool] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a tool that runs on the execution backend. + + Same calling convention as :meth:`FastMCP.tool` — **parens are + required** (``@mcp.tool()``, not ``@mcp.tool``). + + The additional parameters (``num_nodes``, ``processes_per_node``, + ``gpus_per_task``, ``env``, ``working_dir``) are forwarded to the + :class:`~chemgraph.execution.base.TaskSpec` that wraps the + decorated function when it is invoked. + + Parameters + ---------- + name, title, description, annotations, structured_output + Passed through to :meth:`FastMCP.add_tool`. + num_nodes : int + Number of compute nodes (default ``1``). + processes_per_node : int + Processes per node (default ``1``). + gpus_per_task : int + GPUs per task (default ``0``). + env : dict, optional + Extra environment variables for the worker. + working_dir : str, optional + Working directory for the task. + """ + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if title is not None: + fastmcp_kwargs["title"] = title + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + if structured_output is not None: + fastmcp_kwargs["structured_output"] = structured_output + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + def decorator(fn: Callable) -> Callable: + wrapper = self._make_backend_wrapper(fn, task_spec_kwargs) + self.add_tool(wrapper, **fastmcp_kwargs) + return fn + + return decorator + + # ── Ensemble tool registration ───────────────────────────────────── + + def ensemble_tool( + self, + name: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a fan-out tool that submits ``list[params]`` to the backend. + + Decorates ``fn(params: Schema) -> result``. The MCP tool schema + becomes ``list[Schema]`` — the LLM provides a list of jobs and + the framework submits each as a + :class:`~chemgraph.execution.base.TaskSpec`, then gathers results + via :func:`~chemgraph.execution.utils.submit_or_gather`. + + Parameters + ---------- + name, description, annotations + Passed through to :meth:`FastMCP.add_tool`. + num_nodes, processes_per_node, gpus_per_task, env, working_dir + Forwarded to :class:`~chemgraph.execution.base.TaskSpec`. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + + def decorator(fn: Callable) -> Callable: + self._fix_module_for_pickle(fn) + sig = inspect.signature(fn) + param = list(sig.parameters.values())[0] + param_type = param.annotation + + async def wrapper(params): + self._ensure_backend() + pending = [] + for i, p in enumerate(params): + task = TaskSpec( + task_id=f"{fn.__name__}_{i}", + task_type="python", + callable=fn, + kwargs={param.name: p}, + **task_spec_kwargs, + ) + fut = self._backend.submit(task) + pending.append(({"index": i}, fut)) + + return await submit_or_gather( + self._backend, + pending, + self._tracker, + name or fn.__name__, + ) + + wrapper.__name__ = name or fn.__name__ + wrapper.__doc__ = fn.__doc__ + wrapper.__module__ = fn.__module__ + wrapper.__qualname__ = fn.__qualname__ + + new_param = inspect.Parameter( + "params", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=list[param_type], + ) + wrapper.__signature__ = inspect.Signature( + parameters=[new_param] + ) + + self.add_tool(wrapper, **fastmcp_kwargs) + return fn + + return decorator + + # ── Internal ──────────────────────────────────────────────────────── + + def _make_backend_wrapper( + self, fn: Callable, task_spec_kwargs: dict[str, Any] + ) -> Callable: + """Build an async wrapper that submits *fn* to the backend.""" + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + self._fix_module_for_pickle(fn) + + @functools.wraps(fn) + async def wrapper(**kwargs: Any) -> Any: + self._ensure_backend() + task = TaskSpec( + task_id=fn.__name__, + task_type="python", + callable=fn, + kwargs=kwargs, + **task_spec_kwargs, + ) + fut = self._backend.submit(task) + + if self._backend.is_async_remote: + return await submit_or_gather( + self._backend, + [({"task_id": fn.__name__}, fut)], + self._tracker, + fn.__name__, + ) + + return await asyncio.wrap_future(fut) + + return wrapper diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index a664a1e7..70496186 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -1,46 +1,18 @@ """Backend-agnostic MACE MCP server. -Replaces ``mace_mcp_parsl.py`` by using the :mod:`chemgraph.execution` -abstraction layer. The execution backend (Parsl, EnsembleLauncher, -local) is selected at startup via ``config.toml`` or the -``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP` so that tool +functions are plain computation — the framework handles backend +submission, future resolution, and async job tracking. -Key improvements over the original: -- No hardcoded Polaris config or user-specific conda paths. -- Ensemble tool is now async (non-blocking event loop). -- Uses shared utilities for structure resolution and result gathering. +Nothing is initialised at import time so that worker subprocesses +(e.g. EnsembleLauncher) can safely re-import this module. """ -import asyncio -import json -import logging -import os -from pathlib import Path +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.schemas.mace_parsl_schema import mace_input_schema +from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core -from mcp.server.fastmcp import FastMCP - -from chemgraph.execution import TaskSpec, get_backend -from chemgraph.execution.job_tracker import JobTracker -from chemgraph.execution.utils import ( - make_per_structure_output, - resolve_structure_files, - submit_or_gather, -) -from chemgraph.mcp.job_tools import register_job_tools -from chemgraph.mcp.server_utils import run_mcp_server -from chemgraph.tools.parsl_tools import ( - mace_input_schema, - mace_input_schema_ensemble, -) - -logger = logging.getLogger(__name__) - -# ── Initialise execution backend ──────────────────────────────────────── -backend = get_backend() -tracker = JobTracker() - -# ── MCP server ────────────────────────────────────────────────────────── -mcp = FastMCP( +mcp = CGFastMCP( name="ChemGraph MACE Tools", instructions=""" You expose tools for running MACE simulations and reading their results. @@ -66,216 +38,45 @@ check_job_status to poll for progress before calling get_job_results. """, ) -register_job_tools(mcp, tracker, backend) - - -def _run_mace_single(job: dict) -> dict: - """Execute a single MACE simulation (runs on the worker). - - When the ``job`` dict contains an ``inline_structure`` key (with - ``numbers``, ``positions``, and optional ``cell``/``pbc``), the - structure is materialised as a temporary XYZ file on the worker - filesystem before running MACE. This allows local-agent / - remote-worker workflows where the original file only exists on the - submitting machine. - """ - import os - import tempfile - - from chemgraph.tools.parsl_tools import mace_input_schema, run_mace_core - - inline = job.pop("inline_structure", None) - if inline is not None: - from ase import Atoms - from ase.io import write as ase_write - - atoms = Atoms( - numbers=inline["numbers"], - positions=inline["positions"], - cell=inline.get("cell"), - pbc=inline.get("pbc"), - ) - tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") - xyz_path = os.path.join(tmpdir, "structure.xyz") - ase_write(xyz_path, atoms) - job["input_structure_file"] = xyz_path - - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - tmpdir, job.get("output_result_file", "output.json") - ) - - params = mace_input_schema(**job) if isinstance(job, dict) else job - result = run_mace_core(params) - - # Embed full output JSON when running with inline structure so the - # caller does not need to read a file on the remote filesystem. - if inline is not None: - out_file = job.get("output_result_file", "") - if os.path.isfile(out_file): - import json as _json - - with open(out_file, "r") as fh: - result["full_output"] = _json.load(fh) - - return result @mcp.tool( name="run_mace_single", description="Run a single MACE calculation", ) -async def run_mace_single(params: mace_input_schema): - """Run a single MACE calculation using the configured execution backend.""" - job = params.model_dump() - - # Read the local structure file and embed it so the job is - # self-contained and can run on any worker (local or remote). - input_file = job.get("input_structure_file") - if input_file and os.path.isfile(input_file): - from ase.io import read as ase_read - - from chemgraph.tools.ase_core import atoms_to_atomsdata - - atoms = ase_read(input_file) - atomsdata = atoms_to_atomsdata(atoms) - job["inline_structure"] = atomsdata.model_dump() - - task = TaskSpec( - task_id="mace_single", - task_type="python", - callable=_run_mace_single, - kwargs={"job": job}, - ) - fut = backend.submit(task) +def run_mace_single(params: mace_input_schema): + """Run a single MACE calculation on the execution backend.""" + import sys - if backend.is_async_remote: - task_meta = {"task_id": "mace_single"} - return await submit_or_gather( - backend, [(task_meta, fut)], tracker, "run_mace_single" - ) + old_stdout = sys.stdout + sys.stdout = sys.stderr + try: + return run_mace_core(params) + finally: + sys.stdout = old_stdout - return await asyncio.wrap_future(fut) - -def _mace_post_fn(meta: dict, result) -> dict: - """Post-process a completed MACE task.""" - status = result.get("status", "unknown") if isinstance(result, dict) else "success" - energy = result.get("single_point_energy") if isinstance(result, dict) else None - return { - "structure": meta["structure"], - "output_result_file": meta["output_result_file"], - "status": status, - "single_point_energy": energy, - "raw_result": result, - } - - -@mcp.tool( +@mcp.ensemble_tool( name="run_mace_ensemble", - description="Run an ensemble of MACE calculations", + description="Run an ensemble of MACE calculations for multiple inputs.", ) -async def run_mace_ensemble(params: mace_input_schema_ensemble): - """Run an ensemble of MACE calculations over all structure files in a - directory using the configured execution backend. - - Parameters - ---------- - params : mace_input_schema_ensemble - Input parameters for the ensemble of MACE calculations. - - Returns - ------- - dict - Summary of all jobs with minimal per-job results. - """ - structure_files, _output_dir = resolve_structure_files( - params.input_structure_directory, - ) - - # Base output file name used as a pattern for per-structure outputs - base_output = Path(params.output_result_file) - - pending_tasks = [] - for struct_path in structure_files: - per_struct_output = make_per_structure_output(struct_path, base_output) - - job = { - "input_structure_file": str(struct_path), - "output_result_file": str(per_struct_output), - "driver": params.driver, - "model": params.model, - "device": params.device, - "temperature": params.temperature, - "pressure": params.pressure, - "fmax": params.fmax, - "steps": params.steps, - "optimizer": params.optimizer, - } - - # Embed structure data so the job works on remote workers that - # cannot access the local filesystem. - if struct_path.is_file(): - from ase.io import read as ase_read +def _run_mace_worker(params: mace_input_schema): + return run_mace_core(params) - from chemgraph.tools.ase_core import atoms_to_atomsdata - atoms = ase_read(str(struct_path)) - atomsdata = atoms_to_atomsdata(atoms) - job["inline_structure"] = atomsdata.model_dump() - - task = TaskSpec( - task_id=f"mace_{struct_path.stem}", - task_type="python", - callable=_run_mace_single, - kwargs={"job": job}, - ) - fut = backend.submit(task) - - task_meta = { - "structure": struct_path.name, - "output_result_file": str(per_struct_output), - } - pending_tasks.append((task_meta, fut)) - - result = await submit_or_gather( - backend, pending_tasks, tracker, "run_mace_ensemble", - post_fn=_mace_post_fn, - ) - - if result["status"] == "completed": - return { - "status": "success", - "n_structures": len(structure_files), - "results": result["results"], - } - - # Async remote: return submission confirmation - result["n_structures"] = len(structure_files) - return result - - -@mcp.tool( +mcp.add_tool( + extract_output_json, name="extract_output_json", description="Load output from a JSON file.", ) -def extract_output_json(json_file: str) -> dict: - """Load simulation results from a JSON file produced by run_ase. - Parameters - ---------- - json_file : str - Path to the JSON file containing ASE simulation results. - Returns - ------- - dict - Parsed results from the JSON file. - """ - with open(json_file, "r") as f: - data = json.load(f) - return data +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + mcp.init_backend() -if __name__ == "__main__": - run_mcp_server(mcp, default_port=9004) + try: + run_mcp_server(mcp, default_port=9004) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/tools/parsl_tools.py b/src/chemgraph/tools/parsl_tools.py index 9c5f887a..2657af1a 100644 --- a/src/chemgraph/tools/parsl_tools.py +++ b/src/chemgraph/tools/parsl_tools.py @@ -6,23 +6,25 @@ from __future__ import annotations -from chemgraph.tools.ase_core import run_ase_core +import logging + from chemgraph.schemas.ase_input import ASEInputSchema from chemgraph.schemas.mace_parsl_schema import ( mace_input_schema, - mace_input_schema_ensemble, mace_output_schema, ) +from chemgraph.tools.ase_core import run_ase_core # Re-export schemas so existing ``from chemgraph.tools.parsl_tools import …`` # statements continue to work. __all__ = [ "mace_input_schema", - "mace_input_schema_ensemble", "mace_output_schema", "run_mace_core", + "extract_output_json", ] +logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Core execution — delegates to the unified implementation @@ -64,5 +66,17 @@ def run_mace_core(params: mace_input_schema) -> dict: dict Simulation result payload. """ - ase_params = _mace_input_to_ase_input(params) - return run_ase_core(ase_params) + try: + ase_params = _mace_input_to_ase_input(params) + return run_ase_core(ase_params) + except Exception as e: + print(f"Running ase failed with error:{e}") + return None + + +def extract_output_json(json_file: str) -> dict: + """Load simulation results from a JSON file produced by run_ase.""" + import json + + with open(json_file, "r") as f: + return json.load(f) From d60a1e1b3c51b089f56788d3c18cadc2b4bd4220 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:26:51 -0500 Subject: [PATCH 008/119] Fix PR #127 blockers: silent failure, decorator IndexError, hard EL import - parsl_tools.run_mace_core: stop swallowing exceptions and returning None. run_ase_core already returns a structured failure dict on simulation errors, and programmer errors should propagate. - cg_fastmcp.ensemble_tool: raise TypeError with a clear message when the decorated function does not have exactly one parameter, instead of crashing with IndexError at decoration time. - ensemble_launcher_backend: soft-import ensemble_launcher and defer the failure to construction / call time. SYSTEM_CONFIG_REGISTRY is now a lazy view backed by builder functions so the module loads cleanly without EL installed, restoring the deferred-error behaviour callers of chemgraph.execution.config expected. --- .../execution/ensemble_launcher_backend.py | 70 ++++++++++++++----- src/chemgraph/mcp/cg_fastmcp.py | 9 ++- src/chemgraph/tools/parsl_tools.py | 8 +-- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index e3863d38..56210d97 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -30,16 +30,31 @@ ) from ensemble_launcher.helper_functions import get_nodes from ensemble_launcher.orchestrator import ClusterClient -except ImportError as exc: - raise ImportError( - "EnsembleLauncher is required for the EnsembleLauncherBackend. " - "Install it with: pip install ensemble-launcher" - ) from exc + + _ENSEMBLE_LAUNCHER_AVAILABLE = True +except ImportError: + EnsembleLauncher = None + LauncherConfig = None + MPIConfig = None + PolicyConfig = None + SystemConfig = None + get_nodes = None + ClusterClient = None + _ENSEMBLE_LAUNCHER_AVAILABLE = False logger = logging.getLogger(__name__) +def _require_ensemble_launcher() -> None: + if not _ENSEMBLE_LAUNCHER_AVAILABLE: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) + + def get_local_system_config(): + _require_ensemble_launcher() system_config = SystemConfig( name="local", ncpus=os.cpu_count(), @@ -49,6 +64,7 @@ def get_local_system_config(): def get_polaris_system_config(): + _require_ensemble_launcher() system_config = SystemConfig( name="polaris", ncpus=32, @@ -60,6 +76,7 @@ def get_polaris_system_config(): def get_aurora_system_config(): + _require_ensemble_launcher() system_config = SystemConfig( name="aurora", ncpus=102, @@ -73,10 +90,11 @@ def get_aurora_system_config(): def get_launcher_config( task_executor_name: Union[str, List] = "async_processpool", child_executor_policy: str = "fixed_leafs_children_policy", - policy_config: Optional[PolicyConfig] = None, + policy_config=None, checkpoint_dir=f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}", mpi_flavour: Literal["test", "mpich"] = "test", ): + _require_ensemble_launcher() if policy_config is None: policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) return LauncherConfig( @@ -112,9 +130,10 @@ class EnsembleLauncherBackend(ExecutionBackend): """ def __init__(self) -> None: + _require_ensemble_launcher() super().__init__() - self._orchestrator: Optional[EnsembleLauncher] = None - self._client: Optional[ClusterClient] = None + self._orchestrator = None + self._client = None def initialize( self, @@ -123,8 +142,8 @@ def initialize( client_only: bool = False, checkpoint_dir: Optional[str] = None, node_id: str = "global", - system_config: Optional[SystemConfig] = None, - launcher_config: Optional[LauncherConfig] = None, + system_config=None, + launcher_config=None, startup_delay: float = 10.0, **kwargs, ) -> None: @@ -273,12 +292,29 @@ def shutdown(self) -> None: ) -SYSTEM_CONFIG_REGISTRY = { - "local": get_local_system_config(), - "aurora": get_aurora_system_config(), - "polaris": get_polaris_system_config(), +_SYSTEM_CONFIG_BUILDERS = { + "local": get_local_system_config, + "aurora": get_aurora_system_config, + "polaris": get_polaris_system_config, } -if __name__ == "__main__": - el_backend = EnsembleLauncherBackend() - el_backend.initialize() + +class _LazyRegistry: + """Built-on-first-access mapping of system name -> SystemConfig. + + Avoids importing ``ensemble_launcher`` at module load time. + """ + + def __contains__(self, key: str) -> bool: + return key in _SYSTEM_CONFIG_BUILDERS + + def __getitem__(self, key: str): + if key not in _SYSTEM_CONFIG_BUILDERS: + raise KeyError(key) + return _SYSTEM_CONFIG_BUILDERS[key]() + + def keys(self): + return _SYSTEM_CONFIG_BUILDERS.keys() + + +SYSTEM_CONFIG_REGISTRY = _LazyRegistry() diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index e25b23f9..a343fa3a 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -260,7 +260,14 @@ def ensemble_tool( def decorator(fn: Callable) -> Callable: self._fix_module_for_pickle(fn) sig = inspect.signature(fn) - param = list(sig.parameters.values())[0] + params = list(sig.parameters.values()) + if len(params) != 1: + raise TypeError( + f"@ensemble_tool expects a function with exactly one " + f"parameter (the per-item schema), got {len(params)} " + f"on {fn.__qualname__}." + ) + param = params[0] param_type = param.annotation async def wrapper(params): diff --git a/src/chemgraph/tools/parsl_tools.py b/src/chemgraph/tools/parsl_tools.py index 2657af1a..1d83ba61 100644 --- a/src/chemgraph/tools/parsl_tools.py +++ b/src/chemgraph/tools/parsl_tools.py @@ -66,12 +66,8 @@ def run_mace_core(params: mace_input_schema) -> dict: dict Simulation result payload. """ - try: - ase_params = _mace_input_to_ase_input(params) - return run_ase_core(ase_params) - except Exception as e: - print(f"Running ase failed with error:{e}") - return None + ase_params = _mace_input_to_ase_input(params) + return run_ase_core(ase_params) def extract_output_json(json_file: str) -> dict: From 4387d13c378899d455646b446516fb69d6bb5552 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:33:52 -0500 Subject: [PATCH 009/119] Add JobTracker persistence and Globus task UUID round-trip - persist_file parameter: when set, batch metadata and Globus Compute task UUIDs are written to JSON after registration and after results are cached, and loaded on init. Allows MCP servers to recover job state across restarts. - TrackedTask.globus_task_id and TrackedTask.future are both optional; loaded-from-disk batches have no in-memory Future and are queried via the Globus Compute Client directly in get_status. - Lazy Globus Compute Client with a separate gc_lock for thread safety. - _wait_for_globus_task_ids polls each ComputeFuture briefly after submission to capture the Globus task_id assigned asynchronously by the Executor background thread. - cancel_batch / cleanup_old_batches handle the no-future case. --- src/chemgraph/execution/job_tracker.py | 237 +++++++++++++++++++++++-- 1 file changed, 226 insertions(+), 11 deletions(-) diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py index 87b473c0..23f6c837 100644 --- a/src/chemgraph/execution/job_tracker.py +++ b/src/chemgraph/execution/job_tracker.py @@ -7,16 +7,23 @@ Each MCP server process creates its own ``JobTracker`` instance (mirroring the existing ``backend = get_backend()`` pattern). + +When a *persist_file* is provided, batch metadata and Globus Compute +task UUIDs are written to a JSON file so that a future session can +reload them and query Globus Compute directly for results. """ from __future__ import annotations +import json import logging import threading +import time import uuid from concurrent.futures import Future from dataclasses import dataclass, field from datetime import datetime, timezone +from pathlib import Path from typing import Any, Callable, Optional logger = logging.getLogger(__name__) @@ -28,7 +35,8 @@ class TrackedTask: task_id: str meta: dict - future: Future + future: Optional[Future] = None + globus_task_id: Optional[str] = None result: Optional[dict] = None @@ -47,11 +55,112 @@ class JobTracker: """Track submitted job batches and their futures. Thread-safe: all public methods acquire an internal lock. + + Parameters + ---------- + persist_file : Path or str, optional + Path to a JSON file for persisting batch metadata across + sessions. When set, batches are saved after registration and + after results are cached. On init, existing batches are loaded. """ - def __init__(self) -> None: + def __init__(self, persist_file: Optional[Path | str] = None) -> None: self._batches: dict[str, TrackedBatch] = {} self._lock = threading.Lock() + self._gc_lock = threading.Lock() + self._persist_file = Path(persist_file) if persist_file else None + self._gc_client = None # lazily initialised Globus Compute Client + + if self._persist_file is not None: + self._load() + + # ── Globus Compute client (lazy) ────────────────────────────────── + + def _get_gc_client(self): + """Return a Globus Compute ``Client`` (created once, reused).""" + if self._gc_client is not None: + return self._gc_client + with self._gc_lock: + if self._gc_client is None: + try: + from globus_compute_sdk import Client + + self._gc_client = Client() + except Exception: + logger.warning( + "Could not create Globus Compute Client", + exc_info=True, + ) + return None + return self._gc_client + + # ── persistence ─────────────────────────────────────────────────── + + def _save(self) -> None: + """Write current batch metadata to *persist_file*.""" + if self._persist_file is None: + return + + data: dict[str, Any] = {} + with self._lock: + for bid, batch in self._batches.items(): + data[bid] = { + "tool_name": batch.tool_name, + "submitted_at": batch.submitted_at.isoformat(), + "tasks": [ + { + "task_id": t.task_id, + "meta": t.meta, + "globus_task_id": t.globus_task_id, + "result": t.result, + } + for t in batch.tasks + ], + } + + self._persist_file.parent.mkdir(parents=True, exist_ok=True) + tmp = self._persist_file.with_suffix(".tmp") + with open(tmp, "w") as f: + json.dump(data, f, indent=2) + tmp.replace(self._persist_file) + + def _load(self) -> None: + """Load batch metadata from *persist_file* (if it exists).""" + if self._persist_file is None or not self._persist_file.is_file(): + return + + try: + with open(self._persist_file) as f: + data = json.load(f) + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Could not load job tracker state: %s", exc) + return + + with self._lock: + for bid, info in data.items(): + if bid in self._batches: + continue # don't overwrite live batches + + tasks = [ + TrackedTask( + task_id=t["task_id"], + meta=t.get("meta", {}), + future=None, + globus_task_id=t.get("globus_task_id"), + result=t.get("result"), + ) + for t in info.get("tasks", []) + ] + self._batches[bid] = TrackedBatch( + batch_id=bid, + tool_name=info["tool_name"], + submitted_at=datetime.fromisoformat(info["submitted_at"]), + tasks=tasks, + ) + + logger.info( + "Loaded %d batches from %s", len(data), self._persist_file + ) # ── registration ─────────────────────────────────────────────────── @@ -103,13 +212,61 @@ def register_batch( tool_name, len(tracked), ) + + # Wait briefly for the Executor background thread to set task_ids + # on the ComputeFutures. Typically takes ~1-2 s; we cap at 3 s + # so the MCP tool response isn't delayed excessively. + self._wait_for_globus_task_ids(tracked, timeout=3.0) + self._save() return batch_id + def _wait_for_globus_task_ids( + self, tasks: list[TrackedTask], timeout: float = 3.0 + ) -> None: + """Wait up to *timeout* seconds for Globus ``task_id`` to appear + on each ComputeFuture, then store them for persistence.""" + deadline = time.monotonic() + timeout + pending = [t for t in tasks if t.future is not None and t.globus_task_id is None] + + while pending and time.monotonic() < deadline: + still_pending = [] + for t in pending: + gc_id = getattr(t.future, "task_id", None) + if gc_id is not None: + t.globus_task_id = str(gc_id) + else: + still_pending.append(t) + pending = still_pending + if pending: + time.sleep(0.25) + + if pending: + logger.debug( + "%d tasks did not receive a Globus task_id within %.1fs", + len(pending), + timeout, + ) + + def _try_capture_globus_task_ids(self, tasks: list[TrackedTask]) -> bool: + """Non-blocking: extract ``task_id`` from any ComputeFuture that + has one available. Returns True if any new IDs were captured.""" + captured = False + for t in tasks: + if t.globus_task_id is None and t.future is not None: + gc_id = getattr(t.future, "task_id", None) + if gc_id is not None: + t.globus_task_id = str(gc_id) + captured = True + return captured + # ── status ───────────────────────────────────────────────────────── def get_status(self, batch_id: str) -> dict: """Return the current status of a batch. + For tasks loaded from disk (no in-memory ``Future``), queries + Globus Compute directly if a ``globus_task_id`` is available. + Returns ------- dict @@ -125,11 +282,16 @@ def get_status(self, batch_id: str) -> dict: total = len(batch.tasks) done = 0 failed = 0 + # Lazily capture Globus Compute task UUIDs (set asynchronously + # by the Executor background thread after submission). + dirty = self._try_capture_globus_task_ids(batch.tasks) for t in batch.tasks: - if t.future.done(): - done += 1 - # Cache the result on first check + task_done = False + + # --- live future path --- + if t.future is not None and t.future.done(): + task_done = True if t.result is None: try: raw = t.future.result(timeout=0) @@ -152,8 +314,55 @@ def get_status(self, batch_id: str) -> dict: "error_type": type(e).__name__, "message": str(e), } - if t.result.get("status") == "failure": - failed += 1 + dirty = True + + # --- loaded-from-disk path (no future, use Globus client) --- + elif t.future is None and t.result is None and t.globus_task_id: + gc = self._get_gc_client() + if gc is not None: + try: + task_info = gc.get_task(t.globus_task_id) + if not task_info.get("pending", True): + task_done = True + if "result" in task_info: + raw = task_info["result"] + if isinstance(raw, dict): + merged = {**t.meta, **raw} + merged.setdefault("status", "success") + t.result = merged + else: + t.result = { + **t.meta, + "result": raw, + "status": "success", + } + elif "exception" in task_info: + t.result = { + **t.meta, + "status": "failure", + "error_type": "RemoteException", + "message": str(task_info["exception"]), + } + dirty = True + except Exception as e: + logger.warning( + "Failed to query Globus task %s: %s", + t.globus_task_id, + e, + exc_info=True, + ) + + # --- already have a cached result --- + elif t.result is not None: + task_done = True + + if task_done: + done += 1 + if t.result is not None and t.result.get("status") == "failure": + failed += 1 + + if dirty: + self._save() pending = total - done if pending == total: @@ -259,7 +468,9 @@ def cancel_batch(self, batch_id: str) -> dict: cancelled = 0 already_done = 0 for t in batch.tasks: - if t.future.done(): + if t.future is None: + already_done += 1 + elif t.future.done(): already_done += 1 elif t.future.cancel(): cancelled += 1 @@ -284,13 +495,17 @@ def cleanup(self, max_age_hours: float = 24) -> int: with self._lock: for bid, batch in self._batches.items(): age_hours = (now - batch.submitted_at).total_seconds() / 3600 - if age_hours > max_age_hours and all( - t.future.done() for t in batch.tasks - ): + all_done = all( + (t.future is not None and t.future.done()) + or t.result is not None + for t in batch.tasks + ) + if age_hours > max_age_hours and all_done: to_remove.append(bid) for bid in to_remove: del self._batches[bid] if to_remove: logger.info("Cleaned up %d old batches", len(to_remove)) + self._save() return len(to_remove) From 1b760327b3e4c7e3a341392dd901057b36b6a4f5 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:42:03 -0500 Subject: [PATCH 010/119] Extend CGFastMCP: tracker kwargs, pre-submit hook, schema_fanout_tool - init_backend now accepts tracker_kwargs= and forwards it to JobTracker(...) in _ensure_backend. Callers can pass persist_file= so MCP servers recover job state across restarts. - set_pre_submit_hook(hook): hook receives each TaskSpec before backend.submit() and returns a (possibly mutated) one. Lets a server centralise transport concerns -- inline-structure embedding for local-submit-to-remote-worker, remote-path rewriting -- instead of repeating that logic in every tool body. Wired into the @tool, @ensemble_tool, and @schema_fanout_tool submit paths. - @schema_fanout_tool(worker=...): the decorated function is an expander (ensemble schema -> list of per-item args). The framework calls worker(item) on the backend for each item and gathers results. Preserves the ensemble schema as the agent-facing API (one tool call, server-side fanout), complementing @ensemble_tool which exposes list[Schema] for callers that want client-side enumeration. --- src/chemgraph/mcp/cg_fastmcp.py | 172 +++++++++++++++++++++++++++++++- 1 file changed, 168 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index a343fa3a..3a84c9f7 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -37,18 +37,33 @@ def __init__(self, **kwargs: Any) -> None: self._backend = None self._tracker = None self._backend_kwargs: Optional[dict[str, Any]] = None + self._tracker_kwargs: dict[str, Any] = {} + self._pre_submit_hook: Optional[Callable] = None # ── Backend lifecycle ─────────────────────────────────────────────── - def init_backend(self, **kwargs: Any) -> None: + def init_backend( + self, + *, + tracker_kwargs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: """Register backend configuration for lazy initialisation. The backend is not created until the first tool invocation, so the MCP server can start accepting connections immediately. - All keyword arguments are forwarded to - :func:`~chemgraph.execution.config.get_backend`. + + Parameters + ---------- + tracker_kwargs : dict, optional + Forwarded to :class:`~chemgraph.execution.job_tracker.JobTracker` + on first use. Use this to pass ``persist_file`` for cross-session + job state recovery. + **kwargs + Forwarded to :func:`~chemgraph.execution.config.get_backend`. """ self._backend_kwargs = kwargs + self._tracker_kwargs = tracker_kwargs or {} self._register_job_tools() logger.info("CGFastMCP backend configured (lazy init).") @@ -63,7 +78,7 @@ def _ensure_backend(self) -> None: from chemgraph.execution import JobTracker, get_backend self._backend = get_backend(**self._backend_kwargs) - self._tracker = JobTracker() + self._tracker = JobTracker(**self._tracker_kwargs) logger.info( "CGFastMCP backend initialised: %s", type(self._backend).__name__ ) @@ -78,8 +93,31 @@ def shutdown_backend(self) -> None: self._backend = None self._tracker = None self._backend_kwargs = None + self._tracker_kwargs = {} logger.info("CGFastMCP backend shut down.") + # ── Pre-submit transport hook ────────────────────────────────────── + + def set_pre_submit_hook(self, hook: Optional[Callable]) -> None: + """Register a hook that transforms each TaskSpec before submission. + + The hook receives the :class:`~chemgraph.execution.base.TaskSpec` + and must return one (possibly the same instance). Used for + transport concerns that should apply to every backend-submitted + tool on this server -- e.g. embedding a local structure file + into ``kwargs`` so a remote worker can materialise it, or + rewriting a local path to a pre-staged remote path. + + Pass ``None`` to clear the hook. + """ + self._pre_submit_hook = hook + + def _apply_pre_submit_hook(self, task): + """Run the registered pre-submit hook (no-op when unset).""" + if self._pre_submit_hook is None: + return task + return self._pre_submit_hook(task) + # ── Job tracking tools ───────────────────────────────────────────── def _register_job_tools(self) -> None: @@ -281,6 +319,7 @@ async def wrapper(params): kwargs={param.name: p}, **task_spec_kwargs, ) + task = self._apply_pre_submit_hook(task) fut = self._backend.submit(task) pending.append(({"index": i}, fut)) @@ -310,6 +349,130 @@ async def wrapper(params): return decorator + # ── Schema-driven fanout tool ────────────────────────────────────── + + def schema_fanout_tool( + self, + *, + worker: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a fan-out tool driven by a single *ensemble* schema. + + The decorated function is an **expander**: it receives the + ensemble schema and returns a list of per-item arguments. The + framework calls ``worker(item)`` on the backend for each item, + gathers the results, and returns a batch summary -- same shape + as :meth:`ensemble_tool`. + + Unlike :meth:`ensemble_tool` (whose tool signature is + ``list[Schema]``), this preserves the ensemble schema as the + agent-facing API, so the LLM makes a single tool call against + e.g. ``input_structure_directory`` and server-side expansion + produces the per-file jobs. + + Parameters + ---------- + worker : Callable + The per-item function executed on the backend. Must take + a single positional argument (the item produced by the + expander). + name, description, annotations + Passed through to :meth:`FastMCP.add_tool`. + num_nodes, processes_per_node, gpus_per_task, env, working_dir + Forwarded to each :class:`~chemgraph.execution.base.TaskSpec`. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + + # Worker is what actually runs on the backend, so it must be + # picklable from the MCP server's __main__ module. + self._fix_module_for_pickle(worker) + + worker_sig = inspect.signature(worker) + worker_params = list(worker_sig.parameters.values()) + if len(worker_params) != 1: + raise TypeError( + f"schema_fanout_tool worker must take exactly one " + f"parameter, got {len(worker_params)} on " + f"{worker.__qualname__}." + ) + worker_param_name = worker_params[0].name + + def decorator(expander: Callable) -> Callable: + sig = inspect.signature(expander) + params = list(sig.parameters.values()) + if len(params) != 1: + raise TypeError( + f"@schema_fanout_tool expander must take exactly one " + f"parameter (the ensemble schema), got {len(params)} " + f"on {expander.__qualname__}." + ) + param = params[0] + tool_name = name or expander.__name__ + + async def wrapper(**kwargs): + self._ensure_backend() + ensemble_params = kwargs[param.name] + items = expander(ensemble_params) + pending = [] + for i, item in enumerate(items): + task = TaskSpec( + task_id=f"{tool_name}_{i}", + task_type="python", + callable=worker, + kwargs={worker_param_name: item}, + **task_spec_kwargs, + ) + task = self._apply_pre_submit_hook(task) + fut = self._backend.submit(task) + pending.append(({"index": i}, fut)) + + return await submit_or_gather( + self._backend, + pending, + self._tracker, + tool_name, + ) + + wrapper.__name__ = tool_name + wrapper.__doc__ = expander.__doc__ + wrapper.__module__ = expander.__module__ + wrapper.__qualname__ = expander.__qualname__ + # Preserve the expander's signature so FastMCP advertises the + # ensemble schema to the LLM, not the worker's per-item one. + wrapper.__signature__ = sig + + self.add_tool(wrapper, **fastmcp_kwargs) + return expander + + return decorator + # ── Internal ──────────────────────────────────────────────────────── def _make_backend_wrapper( @@ -331,6 +494,7 @@ async def wrapper(**kwargs: Any) -> Any: kwargs=kwargs, **task_spec_kwargs, ) + task = self._apply_pre_submit_hook(task) fut = self._backend.submit(task) if self._backend.is_async_remote: From 04bcc8a214b4621823ef72b089805743e7bafd8f Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:49:14 -0500 Subject: [PATCH 011/119] Add remote_structure_directory schemas, GC executor recovery, XANES persistence - mace_input_schema_ensemble / graspa_input_schema_ensemble: new remote_structure_directory field for pre-staged HPC files (paired with the upcoming transfer_files tool). input_structure_directory now defaults to empty string so callers can pass either. - mace_input_schema/_ensemble model description spells out that 'mace_mp' is the calculator type, not a model name -- LLMs were confusing the two. - Nullable schema fields (driver, model, wall_time) typed as str|None / float|None for correct OpenAPI schema generation. - GlobusComputeBackend._ensure_executor re-creates the Executor when it has been shut down (e.g. after a remote task failure). Uses getattr() so we don't depend on the SDK's private _stopped attr existing. - check_endpoint_status logs exc_info on failure for easier debugging. - xanes_mcp_hpc: JobTracker(persist_file=~/.chemgraph/xanes_jobs.json) so XANES job state survives MCP server restarts. Instructions updated to tell the LLM to surface batch_ids to the user. --- pyproject.toml | 6 +++ .../execution/globus_compute_backend.py | 14 +++++++ src/chemgraph/mcp/xanes_mcp_hpc.py | 12 ++++-- src/chemgraph/schemas/graspa_schema.py | 11 +++++- src/chemgraph/schemas/mace_parsl_schema.py | 37 ++++++++++++++----- 5 files changed, 67 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 692b0388..c98d967d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,9 @@ ensemble_launcher = [ globus_compute = [ "globus-compute-sdk", ] +academy = [ + "academy-py", +] xanes = [ "mp-api; python_version >= '3.11'", "parsl" @@ -111,6 +114,9 @@ skip-magic-trailing-comma = false # Ensure Black-style formatting markers = [ "llm: marks tests as requiring LLM API access (run with --run-llm)", "globus_compute: marks tests requiring a live Globus Compute endpoint (run with --run-globus-compute)", + "parsl: marks tests requiring a live Parsl deployment (run with --run-parsl)", + "ensemble_launcher: marks tests requiring a live EnsembleLauncher deployment (run with --run-ensemble-launcher)", + "academy: marks tests requiring Academy agent infrastructure (run with --run-academy)", "asyncio: marks async tests", ] filterwarnings = [ diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index 2ec2bba1..f73bd5af 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -93,12 +93,23 @@ def initialize(self, system: str = "local", **kwargs: Any) -> None: # ── task submission ───────────────────────────────────────────────── + def _ensure_executor(self) -> None: + """Re-create the Executor if it was shut down (e.g. after a + task failure).""" + from globus_compute_sdk import Executor + + if self._executor is None or getattr(self._executor, "_stopped", False): + logger.info("Re-creating Globus Compute Executor") + self._executor = Executor(endpoint_id=self._endpoint_id) + def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._executor is None: raise RuntimeError( "GlobusComputeBackend is not initialized. Call initialize() first." ) + self._ensure_executor() + if task.task_type == "python": if task.callable is None: raise ValueError( @@ -142,6 +153,9 @@ def check_endpoint_status(self) -> dict: "status": status, } except Exception as e: + logger.warning( + "Endpoint status check failed: %s", e, exc_info=True, + ) return { "endpoint_id": self._endpoint_id, "status": "error", diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 4b0d219b..4abb94e0 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -30,7 +30,9 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() -tracker = JobTracker() + +_jobs_file = Path("~/.chemgraph/xanes_jobs.json").expanduser() +tracker = JobTracker(persist_file=_jobs_file) # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -54,8 +56,12 @@ - Keep responses compact -- full results are in the output directories. - When returning paths, use absolute paths. - Energies are in eV. - - When a tool returns status='submitted' with a batch_id, use - check_job_status to poll for progress before calling get_job_results. + - When a tool returns status='submitted' with a batch_id, call + get_job_results(batch_id) to retrieve results. If the job is + still pending, report the batch_id to the user so they can + check later. Job state is persisted across sessions -- the + user can call list_jobs or get_job_results in a future session + to retrieve results. """, ) register_job_tools(mcp, tracker, backend) diff --git a/src/chemgraph/schemas/graspa_schema.py b/src/chemgraph/schemas/graspa_schema.py index 9cd08231..996ec12b 100644 --- a/src/chemgraph/schemas/graspa_schema.py +++ b/src/chemgraph/schemas/graspa_schema.py @@ -46,7 +46,16 @@ class graspa_input_schema(BaseModel): class graspa_input_schema_ensemble(BaseModel): input_structures: Union[str, list[str]] = Field( - description="Path to a directory of CIF files OR a specific list of file paths." + default="", + description="Path to a directory of CIF files OR a specific list of file paths. Required unless remote_structure_directory is provided.", + ) + remote_structure_directory: str | None = Field( + default=None, + description=( + "Path to pre-staged CIF files on the remote HPC filesystem. " + "When provided, workers read structures directly from this path. " + "Use the transfer_files tool to stage files first." + ), ) output_result_file: str = Field( default="raspa.log", diff --git a/src/chemgraph/schemas/mace_parsl_schema.py b/src/chemgraph/schemas/mace_parsl_schema.py index e04ddba6..63dc8008 100644 --- a/src/chemgraph/schemas/mace_parsl_schema.py +++ b/src/chemgraph/schemas/mace_parsl_schema.py @@ -17,14 +17,18 @@ class mace_input_schema(BaseModel): default="output.json", description="Path to a JSON file where simulation results will be saved.", ) - driver: str = Field( + driver: str | None = Field( default=None, description="Specifies the type of simulation to run. Options: 'energy' for single-point energy calculations, 'opt' for geometry optimization, 'vib' for vibrational frequency analysis, and 'thermo' for thermochemical properties (including enthalpy, entropy, and Gibbs free energy).", ) model: str = Field( default="medium-mpa-0", - description="Path to the model. Default is medium-mpa-0." - "Options are 'small', 'medium', 'large', 'small-0b', 'medium-0b', 'small-0b2', 'medium-0b2','large-0b2', 'medium-0b3', 'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', 'mace-matpes-r2scan-0'", + description="MACE foundation model name (NOT the calculator type). " + "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " + "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " + "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " + "'mace-matpes-r2scan-0'. Default is 'medium-mpa-0'. " + "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( default="cpu", @@ -54,20 +58,35 @@ class mace_input_schema(BaseModel): class mace_input_schema_ensemble(BaseModel): input_structure_directory: str = Field( - description="Path to a folder of input structures containing the atomic structure for the simulations." + default="", + description="Path to a local folder of input structures. Required unless remote_structure_directory is provided.", + ) + remote_structure_directory: str | None = Field( + default=None, + description=( + "Path to pre-staged structure files on the remote HPC filesystem. " + "When provided, workers read structures directly from this path " + "instead of using inline structure embedding. Use the " + "transfer_files tool to stage files first, then pass the " + "remote directory here." + ), ) output_result_file: str = Field( default="output.json", description="Path to a JSON file where simulation results will be saved.", ) - driver: str = Field( + driver: str | None = Field( default=None, description="Specifies the type of simulation to run. Options: 'energy' for single-point energy calculations, 'opt' for geometry optimization, 'vib' for vibrational frequency analysis, and 'thermo' for thermochemical properties (including enthalpy, entropy, and Gibbs free energy).", ) model: str = Field( default="medium-mpa-0", - description="Path to the model. Default is medium-mpa-0." - "Options are 'small', 'medium', 'large', 'small-0b', 'medium-0b', 'small-0b2', 'medium-0b2','large-0b2', 'medium-0b3', 'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', 'mace-matpes-r2scan-0'", + description="MACE foundation model name (NOT the calculator type). " + "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " + "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " + "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " + "'mace-matpes-r2scan-0'. Default is 'medium-mpa-0'. " + "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( default="cpu", @@ -102,7 +121,7 @@ class mace_output_schema(BaseModel): output_result_file: str = Field( description="Path to a JSON file where simulation results is saved.", ) - model: str = Field( + model: str | None = Field( default=None, description="Path to the model. Default is medium-mpa-0." ) device: str = Field( @@ -143,7 +162,7 @@ class mace_output_schema(BaseModel): default="", description="Error captured during the simulation", ) - wall_time: float = Field( + wall_time: float | None = Field( default=None, description="Total wall time (in seconds) taken to complete the simulation.", ) From aaef2a9b1384e8f930a4b3eb165c1db6d34d2e89 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:55:27 -0500 Subject: [PATCH 012/119] Add Globus Transfer manager and MCP file-staging tools - execution/globus_transfer.py: GlobusTransferManager wraps the globus_sdk TransferClient with token caching, batched transfer_files / wait_for_transfer / check_transfer_status / list_remote_directory. Lazy globus_sdk import, lazy auth. - execution/config.get_transfer_manager(): builds a manager from [execution.globus_transfer] in config.toml with env var overrides (GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID, _DESTINATION_ENDPOINT_ID, _DESTINATION_BASE_PATH). Returns None when not configured so MCP servers can skip registration silently. - mcp/transfer_tools.register_transfer_tools(): registers transfer_files, check_transfer_status, list_remote_files on a FastMCP/CGFastMCP server. Uses mcp.add_tool() (not the backend-submitting @tool() decorator) because these are orchestration tools, not compute tasks -- they call the Globus Transfer API directly from the MCP server process. - get_backend() globus_compute endpoint_id fallback now treats empty-string endpoint_id as unset, matching the GLOBUS_COMPUTE_ ENDPOINT_ID env-var override behaviour. --- src/chemgraph/execution/config.py | 62 +++- src/chemgraph/execution/globus_transfer.py | 325 +++++++++++++++++++++ src/chemgraph/mcp/transfer_tools.py | 186 ++++++++++++ 3 files changed, 572 insertions(+), 1 deletion(-) create mode 100644 src/chemgraph/execution/globus_transfer.py create mode 100644 src/chemgraph/mcp/transfer_tools.py diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index fb4fac4b..dc650a26 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -125,7 +125,7 @@ def get_backend( merged_kwargs = {**backend_cfg, **kwargs} # Globus Compute: fall back to GLOBUS_COMPUTE_ENDPOINT_ID env var - if resolved_backend == "globus_compute" and "endpoint_id" not in merged_kwargs: + if resolved_backend == "globus_compute" and not merged_kwargs.get("endpoint_id"): env_id = os.getenv("GLOBUS_COMPUTE_ENDPOINT_ID") if env_id: merged_kwargs["endpoint_id"] = env_id @@ -183,3 +183,63 @@ def get_backend( backend.initialize(system=resolved_system, **merged_kwargs) return backend + + +def get_transfer_manager( + config_path: Optional[str] = None, + **kwargs: Any, +): + """Create a :class:`GlobusTransferManager` from config, or ``None``. + + Reads the ``[execution.globus_transfer]`` section from + ``config.toml``. Returns ``None`` when the required endpoint IDs + are not configured, so callers can skip transfer-tool registration. + + Environment variable overrides + ------------------------------ + ``GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID`` + ``GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID`` + ``GLOBUS_TRANSFER_DESTINATION_BASE_PATH`` + """ + cfg = _load_execution_config(config_path) + transfer_cfg = cfg.get("globus_transfer", {}) + merged = {**transfer_cfg, **kwargs} + + for key, env_var in ( + ("source_endpoint_id", "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID"), + ("destination_endpoint_id", "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID"), + ("destination_base_path", "GLOBUS_TRANSFER_DESTINATION_BASE_PATH"), + ): + if not merged.get(key): + env_val = os.getenv(env_var) + if env_val: + merged[key] = env_val + + required = ( + "source_endpoint_id", + "destination_endpoint_id", + "destination_base_path", + ) + if not all(merged.get(k) for k in required): + logger.debug( + "Globus Transfer not configured (missing %s). " + "Transfer tools will not be registered.", + [k for k in required if not merged.get(k)], + ) + return None + + from chemgraph.execution.globus_transfer import GlobusTransferManager + + manager = GlobusTransferManager( + source_endpoint_id=merged["source_endpoint_id"], + destination_endpoint_id=merged["destination_endpoint_id"], + destination_base_path=merged["destination_base_path"], + source_base_path=merged.get("source_base_path"), + client_id=merged.get("client_id"), + ) + logger.info( + "GlobusTransferManager created: %s -> %s", + merged["source_endpoint_id"], + merged["destination_endpoint_id"], + ) + return manager diff --git a/src/chemgraph/execution/globus_transfer.py b/src/chemgraph/execution/globus_transfer.py new file mode 100644 index 00000000..d8081ab3 --- /dev/null +++ b/src/chemgraph/execution/globus_transfer.py @@ -0,0 +1,325 @@ +"""Globus Transfer file-staging manager. + +Transfers files between a local Globus collection and a remote HPC +collection using the `Globus Transfer API +`_. This avoids encoding large +input files (e.g. atomic structures) inside Globus Compute function +payloads. + +**Prerequisites** + +1. Install ``globus_sdk`` (already a core dependency). +2. Have *Globus Connect Personal* running on the submitting machine + **or** use a managed Globus endpoint. +3. Configure endpoint IDs and base path in ``config.toml``:: + + [execution.globus_transfer] + source_endpoint_id = "" + destination_endpoint_id = "" + destination_base_path = "/eagle/projects/MyProject/staging" +""" + +from __future__ import annotations + +import logging +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# Globus Transfer API scope +TRANSFER_SCOPE = "urn:globus:auth:scope:transfer.api.globus.org:all" + +# Default Globus native-app client ID (Globus Tutorial client). +# Projects should register their own app at https://app.globus.org. +_DEFAULT_CLIENT_ID = "61338d24-54d5-408f-a10d-66c06b59f6d2" + + +@dataclass +class TransferResult: + """Metadata returned after submitting a Globus Transfer task.""" + + task_id: str + source_endpoint_id: str + destination_endpoint_id: str + file_mapping: dict[str, str] # local_path -> remote_path + remote_directory: str + submitted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + label: str = "" + + +class GlobusTransferManager: + """Manage file transfers between local and remote Globus collections. + + Parameters + ---------- + source_endpoint_id : str + UUID of the Globus collection on the submitting machine. + destination_endpoint_id : str + UUID of the Globus collection on the HPC system. + destination_base_path : str + Root directory on the destination where staged files are placed. + Each transfer batch creates a subdirectory underneath. + source_base_path : str, optional + If provided, local paths are resolved relative to this directory. + client_id : str, optional + Globus app client ID for OAuth. Defaults to the Globus Tutorial + client. + """ + + def __init__( + self, + source_endpoint_id: str, + destination_endpoint_id: str, + destination_base_path: str, + source_base_path: Optional[str] = None, + client_id: Optional[str] = None, + ) -> None: + self.source_endpoint_id = source_endpoint_id + self.destination_endpoint_id = destination_endpoint_id + self.destination_base_path = destination_base_path.rstrip("/") + self.source_base_path = source_base_path + self._client_id = client_id or _DEFAULT_CLIENT_ID + self._transfer_client = None + + # ── authentication ────────────────────────────────────────────────── + + def _get_transfer_client(self): + """Lazily create an authenticated ``TransferClient``.""" + if self._transfer_client is not None: + return self._transfer_client + + try: + import globus_sdk + except ImportError as exc: + raise ImportError( + "globus_sdk is required for Globus Transfer. " + "Install it with: pip install globus-sdk" + ) from exc + + client = globus_sdk.NativeAppAuthClient(self._client_id) + client.oauth2_start_flow( + requested_scopes=TRANSFER_SCOPE, + refresh_tokens=True, + ) + + # Try loading cached tokens first + token_file = ( + Path.home() / ".globus" / "chemgraph_transfer_tokens.json" + ) + tokens = self._load_tokens(token_file) + + if tokens is None: + # Interactive login required + authorize_url = client.oauth2_get_authorize_url() + logger.info( + "Globus Transfer authentication required.\n" + "Go to this URL and login:\n %s", + authorize_url, + ) + print( + "\nGlobus Transfer authentication required.\n" + f"Go to this URL and login:\n {authorize_url}\n" + ) + auth_code = input("Enter the authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(auth_code) + tokens = token_response.by_resource_server["transfer.api.globus.org"] + self._save_tokens(token_file, tokens) + else: + # Refresh if expired + if tokens.get("expires_at_seconds", 0) < time.time(): + try: + token_response = client.oauth2_refresh_tokens( + globus_sdk.RefreshTokenAuthorizer( + tokens["refresh_token"], client + ) + ) + tokens = token_response.by_resource_server[ + "transfer.api.globus.org" + ] + self._save_tokens(token_file, tokens) + except Exception: + logger.warning( + "Token refresh failed, falling back to existing token." + ) + + authorizer = globus_sdk.AccessTokenAuthorizer(tokens["access_token"]) + self._transfer_client = globus_sdk.TransferClient(authorizer=authorizer) + return self._transfer_client + + @staticmethod + def _load_tokens(path: Path) -> Optional[dict]: + if not path.is_file(): + return None + import json + + try: + with open(path) as f: + return json.load(f) + except (json.JSONDecodeError, KeyError): + return None + + @staticmethod + def _save_tokens(path: Path, tokens: dict) -> None: + import json + + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(dict(tokens), f, indent=2) + path.chmod(0o600) + + # ── transfers ─────────────────────────────────────────────────────── + + def transfer_files( + self, + local_paths: list[str], + remote_subdir: Optional[str] = None, + label: Optional[str] = None, + ) -> TransferResult: + """Submit a Globus Transfer task to stage files on the remote endpoint. + + Parameters + ---------- + local_paths : list[str] + Absolute paths to local files to transfer. + remote_subdir : str, optional + Subdirectory name under ``destination_base_path``. A UUID-based + name is generated if omitted. + label : str, optional + Human-readable label for the transfer task. + + Returns + ------- + TransferResult + Metadata including the Globus task ID and local-to-remote + path mapping. + """ + import globus_sdk + + tc = self._get_transfer_client() + + if remote_subdir is None: + remote_subdir = f"batch_{uuid.uuid4().hex[:12]}" + + remote_dir = f"{self.destination_base_path}/{remote_subdir}" + transfer_label = label or f"ChemGraph file staging ({remote_subdir})" + + tdata = globus_sdk.TransferData( + tc, + self.source_endpoint_id, + self.destination_endpoint_id, + label=transfer_label, + sync_level="checksum", + ) + + file_mapping: dict[str, str] = {} + for local_path in local_paths: + p = Path(local_path).resolve() + remote_path = f"{remote_dir}/{p.name}" + tdata.add_item(str(p), remote_path) + file_mapping[str(p)] = remote_path + + result = tc.submit_transfer(tdata) + task_id = result["task_id"] + + logger.info( + "Globus Transfer submitted: task_id=%s, %d files -> %s", + task_id, + len(local_paths), + remote_dir, + ) + + return TransferResult( + task_id=task_id, + source_endpoint_id=self.source_endpoint_id, + destination_endpoint_id=self.destination_endpoint_id, + file_mapping=file_mapping, + remote_directory=remote_dir, + label=transfer_label, + ) + + def check_transfer_status(self, task_id: str) -> dict[str, Any]: + """Check the status of a Globus Transfer task. + + Returns + ------- + dict + Keys: ``task_id``, ``status``, ``nice_status``, ``bytes_transferred``, + ``files``, ``files_transferred``. + """ + tc = self._get_transfer_client() + task = tc.get_task(task_id) + return { + "task_id": task_id, + "status": task["status"], + "nice_status": task.get("nice_status", ""), + "bytes_transferred": task.get("bytes_transferred", 0), + "files": task.get("files", 0), + "files_transferred": task.get("files_transferred", 0), + } + + def wait_for_transfer( + self, + task_id: str, + timeout: float = 300, + poll_interval: float = 5, + ) -> dict[str, Any]: + """Block until a transfer completes, fails, or times out. + + Parameters + ---------- + timeout : float + Maximum seconds to wait (default 300). + poll_interval : float + Seconds between status checks (default 5). + + Returns + ------- + dict + Final transfer status. + """ + deadline = time.time() + timeout + while time.time() < deadline: + status = self.check_transfer_status(task_id) + if status["status"] in ("SUCCEEDED", "FAILED"): + return status + time.sleep(poll_interval) + + status = self.check_transfer_status(task_id) + status["timed_out"] = True + return status + + def list_remote_directory(self, path: str) -> list[dict[str, Any]]: + """List files in a directory on the destination endpoint. + + Returns + ------- + list[dict] + Each dict has ``name``, ``type`` ("file" or "dir"), and ``size``. + """ + tc = self._get_transfer_client() + entries = [] + for entry in tc.operation_ls(self.destination_endpoint_id, path=path): + entries.append( + { + "name": entry["name"], + "type": entry["type"], + "size": entry.get("size", 0), + } + ) + return entries + + def get_remote_path( + self, + local_path: str, + remote_subdir: Optional[str] = None, + ) -> str: + """Compute the remote path for a local file.""" + filename = Path(local_path).name + if remote_subdir: + return f"{self.destination_base_path}/{remote_subdir}/{filename}" + return f"{self.destination_base_path}/{filename}" diff --git a/src/chemgraph/mcp/transfer_tools.py b/src/chemgraph/mcp/transfer_tools.py new file mode 100644 index 00000000..79ae2323 --- /dev/null +++ b/src/chemgraph/mcp/transfer_tools.py @@ -0,0 +1,186 @@ +"""Shared MCP tools for Globus Transfer file staging. + +Call :func:`register_transfer_tools` to add ``transfer_files``, +``check_transfer_status``, and ``list_remote_files`` to any +:class:`~mcp.server.fastmcp.FastMCP` (or +:class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`) server instance. + +These tools allow an LLM agent to stage input files on a remote HPC +filesystem *before* submitting compute jobs, avoiding the overhead of +encoding large files inside Globus Compute function payloads. + +Note +---- +Transfer tools are orchestration tools (they call the Globus Transfer +API directly from the MCP server process), not compute tools, so they +are registered via :meth:`FastMCP.add_tool` rather than CGFastMCP's +backend-submitting ``@tool()`` decorator. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from mcp.server.fastmcp import FastMCP + + from chemgraph.execution.globus_transfer import GlobusTransferManager + +logger = logging.getLogger(__name__) + + +def register_transfer_tools( + mcp: FastMCP, + transfer_manager: GlobusTransferManager, +) -> None: + """Register file-transfer MCP tools on *mcp*. + + Parameters + ---------- + mcp : FastMCP + The MCP server to register tools on. May be a plain ``FastMCP`` + or a :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`; ``add_tool`` + is inherited so the same registration works either way. + transfer_manager : GlobusTransferManager + The configured transfer manager instance. + """ + + def transfer_files( + source_paths: Union[str, list[str]], + extensions: Optional[list[str]] = None, + remote_subdir: Optional[str] = None, + wait: bool = True, + label: Optional[str] = None, + ) -> dict: + """Transfer files to the remote HPC endpoint via Globus Transfer. + + Parameters + ---------- + source_paths : str or list[str] + A directory path (all matching files transferred) or a list + of individual file paths. + extensions : list[str], optional + When *source_paths* is a directory, only transfer files with + these extensions (e.g. ``[".cif", ".xyz"]``). Ignored when + *source_paths* is a list. + remote_subdir : str, optional + Subdirectory name on the remote endpoint. Auto-generated if + omitted. + wait : bool + If True (default), block until the transfer completes. + label : str, optional + Human-readable label for the transfer task. + """ + if isinstance(source_paths, str): + src = Path(source_paths) + if src.is_dir(): + if extensions: + ext_set = { + e if e.startswith(".") else f".{e}" for e in extensions + } + files = sorted( + str(f) + for f in src.iterdir() + if f.is_file() and f.suffix.lower() in ext_set + ) + else: + files = sorted( + str(f) for f in src.iterdir() if f.is_file() + ) + if not files: + return { + "status": "error", + "message": f"No files found in {source_paths}" + + ( + f" with extensions {extensions}" + if extensions + else "" + ), + } + elif src.is_file(): + files = [str(src.resolve())] + else: + return { + "status": "error", + "message": f"Path not found: {source_paths}", + } + else: + files = [str(Path(p).resolve()) for p in source_paths] + + transfer_result = transfer_manager.transfer_files( + local_paths=files, + remote_subdir=remote_subdir, + label=label, + ) + + response = { + "task_id": transfer_result.task_id, + "remote_directory": transfer_result.remote_directory, + "file_count": len(files), + "file_mapping": transfer_result.file_mapping, + } + + if wait: + status = transfer_manager.wait_for_transfer(transfer_result.task_id) + response["status"] = ( + "completed" + if status["status"] == "SUCCEEDED" + else status["status"] + ) + response.update( + { + k: status[k] + for k in ("bytes_transferred", "files_transferred") + if k in status + } + ) + else: + response["status"] = "submitted" + + return response + + def check_transfer_status(task_id: str) -> dict: + """Check the status of a Globus Transfer task. + + Use to poll a non-blocking transfer submitted with ``wait=False``. + """ + return transfer_manager.check_transfer_status(task_id) + + def list_remote_files(remote_path: str) -> list[dict]: + """List files in a directory on the remote HPC endpoint. + + Useful to verify that files were staged correctly before + running ensemble calculations. + """ + return transfer_manager.list_remote_directory(remote_path) + + mcp.add_tool( + transfer_files, + name="transfer_files", + description=( + "Transfer local files to the remote HPC filesystem via " + "Globus Transfer. Use this to pre-stage structure files " + "before running ensemble calculations with " + "remote_structure_directory. Returns the remote directory " + "path and a mapping of local-to-remote file paths." + ), + ) + mcp.add_tool( + check_transfer_status, + name="check_transfer_status", + description=( + "Check the status of a Globus Transfer task. Use this to " + "poll a non-blocking transfer submitted with wait=False." + ), + ) + mcp.add_tool( + list_remote_files, + name="list_remote_files", + description=( + "List files in a directory on the remote HPC endpoint. " + "Useful to verify that files were staged correctly before " + "running ensemble calculations." + ), + ) From d2b3b2d1370c8b7eecc5daa1412ab8bb9c916394 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 13:06:35 -0500 Subject: [PATCH 013/119] Reintegrate MACE MCP transport, persistence, and Globus Transfer on CGFastMCP run_mace_single and run_mace_ensemble were collapsed to bare run_mace_core(params) calls in PR #127, dropping inline-structure embedding, remote-path support, JobTracker persistence, and the Globus Transfer registration that 51ba171 had built. This restores all of that on top of the new CGFastMCP framework. - Worker is now a separate function _mace_worker(job: dict) that handles two transport keys on the worker FS: remote_structure_file (use the path directly) and inline_structure (materialise an AtomsData dict to a temp XYZ). Embeds full_output back into the result for inline calls so callers do not need remote FS access. - Pre-submit hook _mace_transport_hook centralises the schema -> job-dict conversion, mace_mp -> medium-mpa-0 model normalisation, and inline embedding (when the input file exists on the submitting host). Hook rewrites task.callable from run_mace_single to _mace_worker so the LLM still sees a clean schema-shaped tool. - run_mace_ensemble switches to @schema_fanout_tool with a server-side expander, preserving the directory-driven UX (single LLM call instead of N). Local mode enumerates files via resolve_structure_files; remote mode submits a backend probe to ls remote_structure_directory and builds remote_structure_file per item. - extract_output_json registered via mcp.add_tool() (orchestration, no backend wrap). transfer_files/check_transfer_status/ list_remote_files registered conditionally when get_transfer_manager() finds [execution.globus_transfer] config. - __main__ now wires tracker_kwargs={persist_file: ~/.chemgraph/ mace_jobs.json} so MACE batches survive MCP server restarts. - Drop `from __future__ import annotations`: forward refs break FastMCP's signature introspection because the wrapper's __globals__ is cg_fastmcp's, not the tool module's. --- src/chemgraph/mcp/mace_mcp_hpc.py | 310 ++++++++++++++++++++++++++---- 1 file changed, 276 insertions(+), 34 deletions(-) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index 70496186..58750b46 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -1,80 +1,322 @@ """Backend-agnostic MACE MCP server. -Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP` so that tool -functions are plain computation — the framework handles backend -submission, future resolution, and async job tracking. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. -Nothing is initialised at import time so that worker subprocesses -(e.g. EnsembleLauncher) can safely re-import this module. +Transport (local-file embedding, pre-staged remote-path passthrough) +lives in a single pre-submit hook so the tool bodies stay simple. The +hook rewrites :class:`~chemgraph.execution.base.TaskSpec` instances +before submission to attach an inline structure when the input file +exists on the submitting host, leaving the path untouched when it +does not (assumed to be remote). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. """ +import logging +import os +from pathlib import Path + +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import ( + make_per_structure_output, + resolve_structure_files, +) from chemgraph.mcp.cg_fastmcp import CGFastMCP -from chemgraph.schemas.mace_parsl_schema import mace_input_schema +from chemgraph.mcp.transfer_tools import register_transfer_tools +from chemgraph.schemas.mace_parsl_schema import ( + mace_input_schema, + mace_input_schema_ensemble, +) from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core +logger = logging.getLogger(__name__) + +_JOBS_FILE = Path("~/.chemgraph/mace_jobs.json").expanduser() +_MACE_MP_ALIASES = {"mace_mp", "mace-mp", "MACE-MP", "mace_MP"} + mcp = CGFastMCP( name="ChemGraph MACE Tools", instructions=""" You expose tools for running MACE simulations and reading their results. The available tools are: - 1. run_mace_single: run a single MACE calculation using the specified - input schema. - 2. run_mace_ensemble: run MACE calculations over all structures in a - directory using the configured execution backend. + 1. run_mace_single: run a single MACE calculation. + 2. run_mace_ensemble: run MACE calculations over every structure in a + directory (local or pre-staged remote). 3. extract_output_json: load simulation results from a JSON file. - 4. check_job_status: check progress of a submitted HPC job batch. - 5. get_job_results: retrieve results from a completed job batch. - 6. list_jobs: list all tracked job batches. - 7. cancel_job: cancel pending tasks in a job batch. + 4. check_job_status / get_job_results / list_jobs / cancel_job: HPC + job batch management. Job state persists across sessions. + 5. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on the + remote HPC filesystem before running ensembles in remote mode. Guidelines: - Use each tool only when its input schema matches the user request. - - Do not guess numerical values; report tool errors exactly as they occur. - - Keep responses compact -- full results are written to the output files - defined in the schemas. + - Do not guess numerical values; report tool errors exactly as they + occur. + - Keep responses compact -- full results are written to the output + files defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. - - When a tool returns status='submitted' with a batch_id, use - check_job_status to poll for progress before calling get_job_results. + - When a tool returns status='submitted' with a batch_id, call + get_job_results(batch_id) to retrieve results. If still pending, + report the batch_id so the user can check later -- job state is + persisted across sessions. + - For the `model` field, pass a MACE foundation model name (e.g. + 'medium-mpa-0'). 'mace_mp' is the calculator type, not a model + name -- do not pass it. """, ) +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _mace_worker(job: dict) -> dict: + """Execute a single MACE simulation on a backend worker. + + Accepts a *job dict* (not the schema) so the pre-submit hook can + attach transport keys ``inline_structure`` / ``remote_structure_file`` + before submission. + """ + import json + import tempfile + + job = dict(job) + + # Pre-staged remote file: use the path directly on the worker FS. + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "output.json"), + ) + + # Inline structure: materialise on the worker's filesystem. + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, job.get("output_result_file", "output.json") + ) + + params = mace_input_schema(**job) + result = run_mace_core(params) + + # When inline, embed full output so the caller doesn't need to read + # a file on the remote filesystem to recover the results. + if inline is not None and isinstance(result, dict): + out_file = job.get("output_result_file", "") + if os.path.isfile(out_file): + with open(out_file) as fh: + result["full_output"] = json.load(fh) + + return result + + +# ── Pre-submit transport hook ────────────────────────────────────────── + + +def _embed_inline_if_local(job: dict) -> None: + """Mutate *job* in-place: attach inline_structure when the input + file is readable on the submitting host (and no other transport + key has already been set).""" + if job.get("remote_structure_file") or job.get("inline_structure"): + return + input_file = job.get("input_structure_file") + if not input_file or not os.path.isfile(input_file): + return # remote path -- worker will read it directly + + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(input_file) + job["inline_structure"] = atoms_to_atomsdata(atoms).model_dump() + + +def _normalize_model(job: dict) -> None: + """Map calculator-type aliases to a valid foundation model name.""" + if job.get("model") in _MACE_MP_ALIASES: + job["model"] = "medium-mpa-0" + + +def _mace_transport_hook(task: TaskSpec) -> TaskSpec: + """Route single-tool calls to the dict-based worker and embed + local structures on whichever path is taken.""" + if task.callable is run_mace_single: + params = task.kwargs.get("params") + if params is None: + return task + job = ( + params.model_dump() if hasattr(params, "model_dump") else dict(params) + ) + _normalize_model(job) + _embed_inline_if_local(job) + task.callable = _mace_worker + task.kwargs = {"job": job} + elif task.callable is _mace_worker: + job = dict(task.kwargs.get("job", {})) + _normalize_model(job) + _embed_inline_if_local(job) + task.kwargs = {"job": job} + return task + + +mcp.set_pre_submit_hook(_mace_transport_hook) + + +# ── Single-structure tool ────────────────────────────────────────────── + + @mcp.tool( name="run_mace_single", description="Run a single MACE calculation", ) -def run_mace_single(params: mace_input_schema): - """Run a single MACE calculation on the execution backend.""" - import sys +def run_mace_single(params: mace_input_schema) -> dict: + """Run a single MACE calculation on the configured backend. - old_stdout = sys.stdout - sys.stdout = sys.stderr - try: - return run_mace_core(params) - finally: - sys.stdout = old_stdout + The pre-submit hook rewrites this call to invoke ``_mace_worker`` + on the backend with a job dict that may carry an embedded inline + structure (when the input file exists locally) or a remote path + (when it does not). + """ + # Direct-call fallback path (no hook registered) -- normalises and + # delegates to the same worker. + job = params.model_dump() + _normalize_model(job) + return _mace_worker(job) + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) -@mcp.ensemble_tool( +def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: + """Server-side expansion of an ensemble request into per-file jobs. + + Local mode: enumerates ``input_structure_directory`` on this host. + Remote mode: submits a one-shot probe task to the backend to list + files under ``remote_structure_directory``, then builds per-file + jobs that the worker reads directly from the remote filesystem. + """ + shared = { + "output_result_file": params.output_result_file, + "driver": params.driver, + "model": params.model, + "device": params.device, + "temperature": params.temperature, + "pressure": params.pressure, + "fmax": params.fmax, + "steps": params.steps, + "optimizer": params.optimizer, + } + base_output = Path(params.output_result_file) + + if params.remote_structure_directory: + remote_dir = params.remote_structure_directory + mcp._ensure_backend() + probe = TaskSpec( + task_id="ls_remote_dir", + task_type="python", + callable=_ls_remote_files, + kwargs={"path": remote_dir}, + ) + fut = mcp._backend.submit(probe) + try: + file_names = fut.result(timeout=30) + except Exception as exc: + raise RuntimeError( + f"Could not list remote directory {remote_dir}: {exc}" + ) from exc + + jobs = [] + for fname in file_names: + per_output = make_per_structure_output(Path(fname), base_output) + job = {**shared} + job["remote_structure_file"] = f"{remote_dir}/{fname}" + job["output_result_file"] = str(per_output) + jobs.append(job) + return jobs + + if not params.input_structure_directory: + raise ValueError( + "Either input_structure_directory or remote_structure_directory " + "must be provided." + ) + + structure_files, _ = resolve_structure_files(params.input_structure_directory) + return [ + { + **shared, + "input_structure_file": str(f), + "output_result_file": str(make_per_structure_output(f, base_output)), + } + for f in structure_files + ] + + +@mcp.schema_fanout_tool( name="run_mace_ensemble", - description="Run an ensemble of MACE calculations for multiple inputs.", + description=( + "Run MACE calculations over every structure in a directory. " + "Local mode uses input_structure_directory; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_mace_worker, ) -def _run_mace_worker(params: mace_input_schema): - return run_mace_core(params) +def run_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: + return _expand_mace_ensemble(params) + + +# ── Orchestration tools (no backend involvement) ─────────────────────── mcp.add_tool( extract_output_json, name="extract_output_json", - description="Load output from a JSON file.", + description="Load simulation results from an output JSON file.", ) +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on MACE MCP server.") + + if __name__ == "__main__": from chemgraph.mcp.server_utils import run_mcp_server - mcp.init_backend() + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) try: run_mcp_server(mcp, default_port=9004) From b7ac17c716874b7a01348b25a5bad9dda47fd25d Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 13:58:39 -0500 Subject: [PATCH 014/119] Add academy module: distributed multi-agent screening via Academy Wraps the Academy distributed agent framework with ChemGraph LLM agents for federated HPC screening workflows. Decoupled from the existing pipeline -- no chemgraph.cli / chemgraph.agent / chemgraph.eval references; only the lazily imported chemgraph.agent.llm_agent.ChemGraph. - ChemGraphAgent: Academy Agent wrapping a single ChemGraph instance, exposes run_query / get_info actions. - ScreeningAgent: iterates a molecule batch, writes per-result JSONs for fault-tolerant aggregation. Failed-molecule records now store str(exc) so the actual exception message survives. - CoordinatorAgent: polls a results dir, optionally analyses results via an LLM, suggests follow-up molecules. - AcademyConfig + build_manager: bridge config.toml to Academy Manager / Exchange / Launcher (local, Redis, Parsl, Globus Compute). - RateLimiter: stdlib async token-bucket for shared per-provider LLM quotas across agents. Lazy imports in __init__.py let the package load without the optional academy-py dependency; ChemGraphAgent / ScreeningAgent / CoordinatorAgent raise ModuleNotFoundError on access if academy-py is missing, while AcademyConfig and RateLimiter remain usable. pyproject's academy optional-dep + pytest marker are already in HEAD (commit 04bcc8a). tests/test_academy.py and scripts/academy_example/ remain untracked and will land in follow-ups. --- src/chemgraph/academy/__init__.py | 45 +++++++ src/chemgraph/academy/agent.py | 123 ++++++++++++++++++ src/chemgraph/academy/config.py | 175 +++++++++++++++++++++++++ src/chemgraph/academy/coordinator.py | 179 ++++++++++++++++++++++++++ src/chemgraph/academy/rate_limiter.py | 135 +++++++++++++++++++ src/chemgraph/academy/screening.py | 151 ++++++++++++++++++++++ 6 files changed, 808 insertions(+) create mode 100644 src/chemgraph/academy/__init__.py create mode 100644 src/chemgraph/academy/agent.py create mode 100644 src/chemgraph/academy/config.py create mode 100644 src/chemgraph/academy/coordinator.py create mode 100644 src/chemgraph/academy/rate_limiter.py create mode 100644 src/chemgraph/academy/screening.py diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py new file mode 100644 index 00000000..90e5bf12 --- /dev/null +++ b/src/chemgraph/academy/__init__.py @@ -0,0 +1,45 @@ +"""Academy Agents integration for ChemGraph. + +Provides agent classes and utilities for deploying ChemGraph workflows +across federated HPC infrastructure using the Academy framework. + +Requires the ``academy`` optional extra:: + + pip install chemgraphagent[academy] + +Modules that depend on ``academy-py`` (agent, screening, coordinator) +use lazy imports so that the rate limiter and config utilities remain +usable without the optional dependency. +""" + +from __future__ import annotations + +from chemgraph.academy.config import AcademyConfig, build_manager +from chemgraph.academy.rate_limiter import RateLimiter + + +def __getattr__(name: str): # noqa: N807 + """Lazy-import Academy-dependent classes.""" + if name == "ChemGraphAgent": + from chemgraph.academy.agent import ChemGraphAgent + + return ChemGraphAgent + if name == "ScreeningAgent": + from chemgraph.academy.screening import ScreeningAgent + + return ScreeningAgent + if name == "CoordinatorAgent": + from chemgraph.academy.coordinator import CoordinatorAgent + + return CoordinatorAgent + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "ChemGraphAgent", + "AcademyConfig", + "build_manager", + "RateLimiter", + "ScreeningAgent", + "CoordinatorAgent", +] diff --git a/src/chemgraph/academy/agent.py b/src/chemgraph/academy/agent.py new file mode 100644 index 00000000..1ec04b3e --- /dev/null +++ b/src/chemgraph/academy/agent.py @@ -0,0 +1,123 @@ +"""Base Academy Agent wrapping a ChemGraph instance. + +Each ``ChemGraphAgent`` holds one ``ChemGraph`` object and exposes its +``run()`` method as an Academy ``@action`` so it can be invoked remotely +by peer agents, coordinators, or the Manager user handle. +""" + +from __future__ import annotations + +import logging +import os +import uuid +from typing import Any, Optional + +from academy.agent import Agent, action + +from chemgraph.agent.llm_agent import ChemGraph + +logger = logging.getLogger(__name__) + + +class ChemGraphAgent(Agent): + """Academy Agent wrapping a single :class:`ChemGraph` instance. + + Parameters + ---------- + model_name : str + LLM model to use (e.g. ``"gpt-4o"``, ``"claude-sonnet-4"``). + workflow_type : str + ChemGraph workflow (e.g. ``"single_agent"``, ``"multi_agent"``). + log_dir : str or None + Base directory for agent logs. A per-agent subdirectory is + created automatically. + rate_limiter : RateLimiter or None + Shared rate limiter for LLM API calls. + chemgraph_kwargs : dict + Extra keyword arguments forwarded to the :class:`ChemGraph` + constructor (e.g. ``base_url``, ``api_key``, ``recursion_limit``). + """ + + def __init__( + self, + model_name: str = "gpt-4o-mini", + workflow_type: str = "single_agent", + log_dir: Optional[str] = None, + rate_limiter: Any = None, + **chemgraph_kwargs: Any, + ) -> None: + super().__init__() + self._model_name = model_name + self._workflow_type = workflow_type + self._log_dir = log_dir + self._rate_limiter = rate_limiter + self._chemgraph_kwargs = chemgraph_kwargs + self._cg: Optional[ChemGraph] = None + self._agent_uuid = uuid.uuid4().hex[:8] + + async def agent_on_startup(self) -> None: + """Initialise the ChemGraph instance on the remote worker.""" + agent_log_dir = self._log_dir + if agent_log_dir: + agent_log_dir = os.path.join(agent_log_dir, self._agent_uuid) + os.makedirs(agent_log_dir, exist_ok=True) + + self._cg = ChemGraph( + model_name=self._model_name, + workflow_type=self._workflow_type, + log_dir=agent_log_dir, + enable_memory=False, + **self._chemgraph_kwargs, + ) + logger.info( + "ChemGraphAgent %s started: model=%s workflow=%s", + self._agent_uuid, + self._model_name, + self._workflow_type, + ) + + async def agent_on_shutdown(self) -> None: + """Clean up resources.""" + logger.info("ChemGraphAgent %s shutting down", self._agent_uuid) + self._cg = None + + @action + async def run_query( + self, + query: str, + *, + config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Execute a ChemGraph query and return the result. + + Parameters + ---------- + query : str + The natural-language chemistry query. + config : dict, optional + LangGraph config (thread_id, etc.). + + Returns + ------- + dict + The workflow result (serialised state or last message, + depending on the ChemGraph ``return_option``). + """ + if self._cg is None: + raise RuntimeError("Agent not initialised (call agent_on_startup first)") + + if self._rate_limiter is not None: + await self._rate_limiter.acquire(self._model_name) + + thread_cfg = config or {"configurable": {"thread_id": uuid.uuid4().hex[:8]}} + result = await self._cg.run(query=query, config=thread_cfg) + return result + + @action + async def get_info(self) -> dict[str, str]: + """Return metadata about this agent instance.""" + return { + "agent_uuid": self._agent_uuid, + "model_name": self._model_name, + "workflow_type": self._workflow_type, + } diff --git a/src/chemgraph/academy/config.py b/src/chemgraph/academy/config.py new file mode 100644 index 00000000..5f7a98b3 --- /dev/null +++ b/src/chemgraph/academy/config.py @@ -0,0 +1,175 @@ +"""Bridge between ChemGraph config.toml and Academy Manager/Exchange/Launcher. + +Reads the ``[academy]`` section from ``config.toml`` and builds the +corresponding Academy objects. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + +import toml + +logger = logging.getLogger(__name__) + +# Exchange and launcher types supported by this bridge. +ExchangeType = Literal["local", "redis", "hybrid"] +LauncherType = Literal["thread", "process", "parsl", "globus_compute"] + + +@dataclass +class AcademyConfig: + """Parsed ``[academy]`` configuration section. + + Attributes + ---------- + exchange : ExchangeType + Message exchange backend (default ``"local"``). + launcher : LauncherType + Agent deployment mechanism (default ``"thread"``). + num_agents : int + Number of worker agents to spawn (default ``1``). + redis_hostname : str + Redis host when ``exchange="redis"`` (default ``"localhost"``). + redis_port : int + Redis port (default ``6379``). + parsl_system : str + HPC system name for Parsl config (default ``"local"``). + globus_endpoint_id : str + Globus Compute endpoint UUID. + max_concurrency : int + Max concurrent LLM calls per provider (default ``50``). + log_dir : str or None + Base log directory for agent output. + extra : dict + Any additional keys from the config section. + """ + + exchange: ExchangeType = "local" + launcher: LauncherType = "thread" + num_agents: int = 1 + redis_hostname: str = "localhost" + redis_port: int = 6379 + parsl_system: str = "local" + globus_endpoint_id: str = "" + max_concurrency: int = 50 + log_dir: Optional[str] = None + extra: dict = field(default_factory=dict) + + +def load_academy_config(config_path: str = "config.toml") -> AcademyConfig: + """Load the ``[academy]`` section from a TOML config file. + + Missing keys are filled with defaults. Unknown keys are stored + in ``extra``. + """ + try: + data = toml.load(config_path) + except FileNotFoundError: + logger.warning("Config file %s not found, using defaults", config_path) + return AcademyConfig() + + section = data.get("academy", {}) + + known_keys = {f.name for f in AcademyConfig.__dataclass_fields__.values()} + known = {k: v for k, v in section.items() if k in known_keys} + extra = {k: v for k, v in section.items() if k not in known_keys} + + return AcademyConfig(**known, extra=extra) + + +def _build_exchange_factory(cfg: AcademyConfig) -> Any: + """Create the Academy ExchangeFactory matching the config.""" + if cfg.exchange == "local": + from academy.exchange import LocalExchangeFactory + + return LocalExchangeFactory() + + if cfg.exchange == "redis": + from academy.exchange import RedisExchangeFactory + + return RedisExchangeFactory( + hostname=cfg.redis_hostname, + port=cfg.redis_port, + ) + + if cfg.exchange == "hybrid": + from academy.exchange import HybridExchangeFactory + + return HybridExchangeFactory() + + raise ValueError(f"Unsupported exchange type: {cfg.exchange}") + + +def _build_executor(cfg: AcademyConfig) -> Any: + """Create the executor matching the configured launcher type.""" + if cfg.launcher == "thread": + from concurrent.futures import ThreadPoolExecutor + + return ThreadPoolExecutor(max_workers=cfg.num_agents) + + if cfg.launcher == "process": + from concurrent.futures import ProcessPoolExecutor + + return ProcessPoolExecutor(max_workers=cfg.num_agents) + + if cfg.launcher == "parsl": + try: + from academy.executor import ParslExecutor + except ImportError as exc: + raise ImportError( + "Parsl launcher requires: pip install chemgraphagent[academy,parsl]" + ) from exc + return ParslExecutor() + + if cfg.launcher == "globus_compute": + try: + from academy.executor import GlobusComputeExecutor + except ImportError as exc: + raise ImportError( + "Globus Compute launcher requires: " + "pip install chemgraphagent[academy,globus_compute]" + ) from exc + return GlobusComputeExecutor(endpoint_id=cfg.globus_endpoint_id) + + raise ValueError(f"Unsupported launcher type: {cfg.launcher}") + + +async def build_manager( + cfg: AcademyConfig | None = None, + config_path: str = "config.toml", +) -> Any: + """Build an Academy Manager from ChemGraph config. + + Returns an async context manager. Usage:: + + async with await build_manager() as manager: + handle = await manager.launch(ScreeningAgent, ...) + result = await handle.screen_molecule("CCO", "optimize") + + Parameters + ---------- + cfg : AcademyConfig, optional + Pre-loaded config. If ``None``, loads from *config_path*. + config_path : str + Path to config.toml (used only when *cfg* is ``None``). + + Returns + ------- + Manager + An Academy Manager ready for ``async with``. + """ + from academy.manager import Manager + + if cfg is None: + cfg = load_academy_config(config_path) + + factory = _build_exchange_factory(cfg) + executor = _build_executor(cfg) + + return await Manager.from_exchange_factory( + factory=factory, + executors=executor, + ) diff --git a/src/chemgraph/academy/coordinator.py b/src/chemgraph/academy/coordinator.py new file mode 100644 index 00000000..12f9fc76 --- /dev/null +++ b/src/chemgraph/academy/coordinator.py @@ -0,0 +1,179 @@ +"""Coordinator agent for multi-wave screening campaigns. + +The coordinator manages a fleet of :class:`ScreeningAgent` instances, +collects results, and optionally uses a ChemGraph LLM workflow to +analyse the collected data and spawn follow-up screening waves. +""" + +from __future__ import annotations + +import asyncio +import glob +import json +import logging +import os +import time +from typing import Any, Optional + +from academy.agent import Agent, action, timer +from academy.handle import Handle + +logger = logging.getLogger(__name__) + + +class CoordinatorAgent(Agent): + """Collects screening results and orchestrates follow-up waves. + + Parameters + ---------- + results_dir : str + Directory where :class:`ScreeningAgent` instances write their + per-molecule JSON result files. + worker_handles : list[Handle] or None + Handles to active screening agents (for progress polling). + analysis_model : str + LLM model for analysing aggregated results. + analysis_workflow : str + ChemGraph workflow type for the analysis step. + analysis_kwargs : dict + Extra kwargs for the analysis ChemGraph instance. + """ + + def __init__( + self, + results_dir: str, + worker_handles: list[Handle] | None = None, + analysis_model: str = "gpt-4o", + analysis_workflow: str = "single_agent", + **analysis_kwargs: Any, + ) -> None: + super().__init__() + self._results_dir = results_dir + self._worker_handles = worker_handles or [] + self._analysis_model = analysis_model + self._analysis_workflow = analysis_workflow + self._analysis_kwargs = analysis_kwargs + self._collected: list[dict[str, Any]] = [] + self._analysis_result: Optional[dict[str, Any]] = None + + async def agent_on_startup(self) -> None: + os.makedirs(self._results_dir, exist_ok=True) + logger.info( + "CoordinatorAgent started: watching %s, %d workers", + self._results_dir, + len(self._worker_handles), + ) + + # ------------------------------------------------------------------ + # Progress monitoring + # ------------------------------------------------------------------ + + @action + async def poll_progress(self) -> dict[str, Any]: + """Query all workers for their screening progress.""" + progress = [] + for handle in self._worker_handles: + try: + p = await handle.get_progress() + progress.append(p) + except Exception as exc: + progress.append({"error": str(exc)}) + total = sum(p.get("total", 0) for p in progress if "error" not in p) + completed = sum(p.get("completed", 0) for p in progress if "error" not in p) + failed = sum(p.get("failed", 0) for p in progress if "error" not in p) + return { + "workers": len(progress), + "total": total, + "completed": completed, + "failed": failed, + "per_worker": progress, + } + + # ------------------------------------------------------------------ + # Result collection + # ------------------------------------------------------------------ + + @action + async def collect_results(self) -> list[dict[str, Any]]: + """Read all result JSON files from the shared results directory.""" + pattern = os.path.join(self._results_dir, "*.json") + files = sorted(glob.glob(pattern)) + results = [] + for path in files: + try: + with open(path) as f: + results.append(json.load(f)) + except (json.JSONDecodeError, OSError): + logger.warning("Skipping corrupt result file: %s", path) + self._collected = results + logger.info("Collected %d results from %s", len(results), self._results_dir) + return results + + # ------------------------------------------------------------------ + # LLM-powered analysis + # ------------------------------------------------------------------ + + @action + async def analyse(self, query: Optional[str] = None) -> dict[str, Any]: + """Use a ChemGraph agent to analyse collected results. + + Parameters + ---------- + query : str, optional + Custom analysis query. Defaults to a standard prompt + asking the LLM to rank candidates. + """ + from chemgraph.agent.llm_agent import ChemGraph + + if not self._collected: + await self.collect_results() + + successes = [r for r in self._collected if r.get("status") == "success"] + if not successes: + return {"error": "No successful results to analyse"} + + summary = json.dumps(successes, default=str, indent=2) + if query is None: + query = ( + "You are analysing computational chemistry screening results. " + f"Here are {len(successes)} results:\n\n{summary}\n\n" + "Identify the top candidates based on energy, stability, " + "or other relevant properties. Rank them and explain why." + ) + + cg = ChemGraph( + model_name=self._analysis_model, + workflow_type=self._analysis_workflow, + enable_memory=False, + **self._analysis_kwargs, + ) + self._analysis_result = await cg.run(query=query) + return self._analysis_result + + @action + async def get_analysis(self) -> dict[str, Any] | None: + """Return the most recent analysis result.""" + return self._analysis_result + + # ------------------------------------------------------------------ + # Wave dispatch + # ------------------------------------------------------------------ + + @action + async def suggest_followup_molecules( + self, + top_n: int = 10, + ) -> list[str]: + """Extract top candidate SMILES from analysis for a follow-up wave. + + Returns a list of SMILES strings identified as promising by + the analysis step. Falls back to returning the top-N by + lowest energy if no analysis is available. + """ + if not self._collected: + await self.collect_results() + + successes = [r for r in self._collected if r.get("status") == "success"] + # Simple heuristic: return the SMILES of completed molecules. + # A real implementation would parse energies from results. + return [r["smiles"] for r in successes[:top_n]] diff --git a/src/chemgraph/academy/rate_limiter.py b/src/chemgraph/academy/rate_limiter.py new file mode 100644 index 00000000..9c521f55 --- /dev/null +++ b/src/chemgraph/academy/rate_limiter.py @@ -0,0 +1,135 @@ +"""Token-bucket rate limiter for LLM API calls. + +Academy is LLM-agnostic, so rate limiting must be handled at the +ChemGraph layer. This module provides a shared :class:`RateLimiter` +that agents ``await`` before each LLM call to stay within per-provider +API quotas. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class _ProviderBucket: + """Token bucket state for a single LLM provider.""" + + rpm: float + tokens: float = 0.0 + last_refill: float = field(default_factory=time.monotonic) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + def __post_init__(self) -> None: + # Start with a full bucket. + self.tokens = self.rpm + + +class RateLimiter: + """Async token-bucket rate limiter keyed by LLM provider. + + Parameters + ---------- + default_rpm : int + Default requests-per-minute for providers not explicitly + configured (default ``60``). + provider_rpm : dict[str, int] or None + Per-provider overrides. Keys are provider prefixes or model + names (e.g. ``"openai"``, ``"anthropic"``, ``"gpt-4o"``). + + Usage + ----- + :: + + limiter = RateLimiter(default_rpm=60, provider_rpm={"openai": 500}) + await limiter.acquire("gpt-4o") # blocks if bucket empty + """ + + # Map model-name prefixes to canonical provider keys so that + # ``acquire("gpt-4o")`` matches a rule set for ``"openai"``. + _PREFIX_MAP: dict[str, str] = { + "gpt-": "openai", + "o1": "openai", + "o3": "openai", + "o4": "openai", + "argo:": "argo", + "claude-": "anthropic", + "gemini-": "google", + "groq:": "groq", + "llama": "alcf", + } + + def __init__( + self, + default_rpm: int = 60, + provider_rpm: dict[str, int] | None = None, + ) -> None: + self._default_rpm = default_rpm + self._provider_rpm: dict[str, int] = provider_rpm or {} + self._buckets: dict[str, _ProviderBucket] = {} + + def _resolve_provider(self, model_name: str) -> str: + """Map a model name to a canonical provider key.""" + # Direct match first. + if model_name in self._provider_rpm: + return model_name + + # Prefix match. + lower = model_name.lower() + for prefix, provider in self._PREFIX_MAP.items(): + if lower.startswith(prefix): + return provider + + return model_name + + def _get_bucket(self, provider: str) -> _ProviderBucket: + """Get or create the bucket for *provider*.""" + if provider not in self._buckets: + rpm = self._provider_rpm.get(provider, self._default_rpm) + self._buckets[provider] = _ProviderBucket(rpm=rpm) + return self._buckets[provider] + + async def acquire(self, model_name: str) -> None: + """Wait until a request token is available for *model_name*. + + Refills the token bucket based on elapsed time, then consumes + one token. If the bucket is empty, sleeps until a token + becomes available. + """ + provider = self._resolve_provider(model_name) + bucket = self._get_bucket(provider) + + async with bucket.lock: + now = time.monotonic() + elapsed = now - bucket.last_refill + # Refill at rpm / 60 tokens per second. + refill = elapsed * (bucket.rpm / 60.0) + bucket.tokens = min(bucket.rpm, bucket.tokens + refill) + bucket.last_refill = now + + if bucket.tokens >= 1.0: + bucket.tokens -= 1.0 + return + + # Need to wait for a token. + deficit = 1.0 - bucket.tokens + wait_seconds = deficit / (bucket.rpm / 60.0) + logger.debug( + "Rate limit: waiting %.1fs for provider %s (rpm=%d)", + wait_seconds, + provider, + bucket.rpm, + ) + + # Sleep outside the lock so other providers aren't blocked. + await asyncio.sleep(wait_seconds) + + # Consume after waking. + async with bucket.lock: + bucket.tokens = 0.0 + bucket.last_refill = time.monotonic() diff --git a/src/chemgraph/academy/screening.py b/src/chemgraph/academy/screening.py new file mode 100644 index 00000000..09891642 --- /dev/null +++ b/src/chemgraph/academy/screening.py @@ -0,0 +1,151 @@ +"""Screening agent for batch molecule processing. + +Wraps :class:`ChemGraphAgent` with a ``@loop`` that iterates over an +assigned list of molecules and publishes results via the Academy +exchange. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import time +from typing import Any, Optional + +from academy.agent import Agent, action, loop + +from chemgraph.academy.agent import ChemGraphAgent + +logger = logging.getLogger(__name__) + + +class ScreeningAgent(ChemGraphAgent): + """Agent that screens a batch of molecules using a ChemGraph workflow. + + Parameters + ---------- + molecules : list[str] + SMILES strings to screen. + query_template : str + Query template with ``{smiles}`` placeholder, e.g. + ``"Optimize the geometry of {smiles} and compute its energy."``. + results_dir : str or None + Directory to write per-molecule JSON result files for + downstream aggregation. If ``None``, results are only + returned via the exchange. + model_name, workflow_type, log_dir, rate_limiter, **chemgraph_kwargs + Forwarded to :class:`ChemGraphAgent`. + """ + + def __init__( + self, + molecules: list[str], + query_template: str, + results_dir: Optional[str] = None, + model_name: str = "gpt-4o-mini", + workflow_type: str = "single_agent", + log_dir: Optional[str] = None, + rate_limiter: Any = None, + **chemgraph_kwargs: Any, + ) -> None: + super().__init__( + model_name=model_name, + workflow_type=workflow_type, + log_dir=log_dir, + rate_limiter=rate_limiter, + **chemgraph_kwargs, + ) + self._molecules = molecules + self._query_template = query_template + self._results_dir = results_dir + self.results: list[dict[str, Any]] = [] + self.completed: int = 0 + self.failed: int = 0 + + async def agent_on_startup(self) -> None: + await super().agent_on_startup() + if self._results_dir: + os.makedirs(self._results_dir, exist_ok=True) + logger.info( + "ScreeningAgent %s: %d molecules to process", + self._agent_uuid, + len(self._molecules), + ) + + @action + async def get_progress(self) -> dict[str, Any]: + """Return screening progress.""" + return { + "agent_uuid": self._agent_uuid, + "total": len(self._molecules), + "completed": self.completed, + "failed": self.failed, + } + + @loop + async def screening_loop(self, shutdown: asyncio.Event) -> None: + """Iterate over assigned molecules and run queries.""" + for smiles in self._molecules: + if shutdown.is_set(): + logger.info( + "ScreeningAgent %s: shutdown requested, stopping", + self._agent_uuid, + ) + break + + query = self._query_template.format(smiles=smiles) + t0 = time.monotonic() + try: + result = await self.run_query(query) + elapsed = time.monotonic() - t0 + record = { + "smiles": smiles, + "status": "success", + "result": result, + "elapsed_seconds": round(elapsed, 2), + "agent_uuid": self._agent_uuid, + } + self.completed += 1 + except Exception as exc: + elapsed = time.monotonic() - t0 + logger.exception( + "ScreeningAgent %s: failed on %s", + self._agent_uuid, + smiles, + ) + record = { + "smiles": smiles, + "status": "error", + "error": str(exc), + "elapsed_seconds": round(elapsed, 2), + "agent_uuid": self._agent_uuid, + } + self.failed += 1 + + self.results.append(record) + + # Write individual result file for aggregation. + if self._results_dir: + safe_name = smiles.replace("/", "_").replace("\\", "_")[:50] + path = os.path.join( + self._results_dir, + f"{self._agent_uuid}_{safe_name}.json", + ) + with open(path, "w") as f: + json.dump(record, f, default=str) + + logger.info( + "ScreeningAgent %s: finished (%d ok, %d failed)", + self._agent_uuid, + self.completed, + self.failed, + ) + # Signal that this agent is done. + self.agent_shutdown() + + @action + async def get_results(self) -> list[dict[str, Any]]: + """Return all collected results so far.""" + return self.results From b3cc242f9b57a5a87b17b16db2bec27df4203978 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 15:04:12 -0500 Subject: [PATCH 015/119] Fix bugs in HPC execution layer - globus_transfer.py: disambiguate same-basename inputs with a numeric suffix so two files that share a name (e.g. /a/in.cif and /b/in.cif) don't silently overwrite each other on the remote collection. - job_tracker.py: promote the "no Globus task_id within timeout" message to a warning at submit time, and emit a per-task warning at reload time for batches restored without a task_id (those tasks cannot be queried via the Globus Compute API and would otherwise be silently orphaned across server restarts). - globus_compute_backend.py: catch "executor stopped" exceptions in submit(), rebuild the Executor, and retry once. The previous _ensure_executor relied on the SDK's private _stopped attribute, which fails silently if the SDK exposes the shutdown state differently. - cg_fastmcp.py: wrap _apply_pre_submit_hook in try/except and re-raise hook failures as a ValueError naming the hook and task_id so they surface as a structured tool error instead of an opaque traceback. --- .../execution/globus_compute_backend.py | 38 ++++++++++++++++++- src/chemgraph/execution/globus_transfer.py | 17 ++++++++- src/chemgraph/execution/job_tracker.py | 32 +++++++++++++--- src/chemgraph/mcp/cg_fastmcp.py | 24 +++++++++++- 4 files changed, 100 insertions(+), 11 deletions(-) diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index f73bd5af..6c810e46 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -102,6 +102,16 @@ def _ensure_executor(self) -> None: logger.info("Re-creating Globus Compute Executor") self._executor = Executor(endpoint_id=self._endpoint_id) + @staticmethod + def _looks_like_stopped_executor(exc: BaseException) -> bool: + """Heuristic: did a submit fail because the Executor is shut down? + + The SDK does not expose a stable exception type for this state; + we match on common substrings observed in practice. + """ + msg = str(exc).lower() + return "shut down" in msg or "stopped" in msg or "closed" in msg + def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._executor is None: raise RuntimeError( @@ -118,7 +128,19 @@ def submit(self, task: TaskSpec) -> Future: # Executor.submit() returns a ComputeFuture (a # concurrent.futures.Future subclass), fully compatible # with asyncio.wrap_future() used by gather_futures(). - return self._executor.submit(task.callable, *task.args, **task.kwargs) + try: + return self._executor.submit(task.callable, *task.args, **task.kwargs) + except Exception as exc: + if not self._looks_like_stopped_executor(exc): + raise + logger.warning( + "Submit raised %s -- rebuilding Globus Compute Executor " + "and retrying once.", + type(exc).__name__, + ) + self._executor = None + self._ensure_executor() + return self._executor.submit(task.callable, *task.args, **task.kwargs) elif task.task_type == "shell": if task.command is None: @@ -128,7 +150,19 @@ def submit(self, task: TaskSpec) -> Future: from globus_compute_sdk import ShellFunction shell_fn = ShellFunction(task.command) - return self._executor.submit(shell_fn) + try: + return self._executor.submit(shell_fn) + except Exception as exc: + if not self._looks_like_stopped_executor(exc): + raise + logger.warning( + "Submit raised %s -- rebuilding Globus Compute Executor " + "and retrying once.", + type(exc).__name__, + ) + self._executor = None + self._ensure_executor() + return self._executor.submit(shell_fn) else: raise ValueError( diff --git a/src/chemgraph/execution/globus_transfer.py b/src/chemgraph/execution/globus_transfer.py index d8081ab3..3355a8d1 100644 --- a/src/chemgraph/execution/globus_transfer.py +++ b/src/chemgraph/execution/globus_transfer.py @@ -216,10 +216,25 @@ def transfer_files( sync_level="checksum", ) + # Disambiguate same-basename inputs (e.g. /a/in.cif and /b/in.cif) + # by suffixing duplicates with _1, _2, ... Without this the + # second add_item silently overwrites the first on the + # destination collection. file_mapping: dict[str, str] = {} + used_names: dict[str, int] = {} for local_path in local_paths: p = Path(local_path).resolve() - remote_path = f"{remote_dir}/{p.name}" + base = p.name + count = used_names.get(base, 0) + if count == 0: + remote_name = base + else: + stem, dot, suffix = base.partition(".") + remote_name = ( + f"{stem}_{count}.{suffix}" if dot else f"{stem}_{count}" + ) + used_names[base] = count + 1 + remote_path = f"{remote_dir}/{remote_name}" tdata.add_item(str(p), remote_path) file_mapping[str(p)] = remote_path diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py index 23f6c837..4efa41b0 100644 --- a/src/chemgraph/execution/job_tracker.py +++ b/src/chemgraph/execution/job_tracker.py @@ -136,21 +136,28 @@ def _load(self) -> None: logger.warning("Could not load job tracker state: %s", exc) return + orphaned: list[tuple[str, str]] = [] # (batch_id, task_id) with self._lock: for bid, info in data.items(): if bid in self._batches: continue # don't overwrite live batches - tasks = [ - TrackedTask( + tasks = [] + for t in info.get("tasks", []): + tracked = TrackedTask( task_id=t["task_id"], meta=t.get("meta", {}), future=None, globus_task_id=t.get("globus_task_id"), result=t.get("result"), ) - for t in info.get("tasks", []) - ] + # Tasks loaded from disk with no globus_task_id and + # no cached result are orphaned -- get_status cannot + # query Globus for them (see line ~320). + if tracked.globus_task_id is None and tracked.result is None: + orphaned.append((bid, tracked.task_id)) + tasks.append(tracked) + self._batches[bid] = TrackedBatch( batch_id=bid, tool_name=info["tool_name"], @@ -161,6 +168,13 @@ def _load(self) -> None: logger.info( "Loaded %d batches from %s", len(data), self._persist_file ) + if orphaned: + logger.warning( + "%d task(s) reloaded without a Globus task_id -- their " + "results cannot be recovered. Examples: %s", + len(orphaned), + ", ".join(f"{b}/{t}" for b, t in orphaned[:5]), + ) # ── registration ─────────────────────────────────────────────────── @@ -241,8 +255,14 @@ def _wait_for_globus_task_ids( time.sleep(0.25) if pending: - logger.debug( - "%d tasks did not receive a Globus task_id within %.1fs", + # Promoted from debug -> warning: tasks without a task_id + # at this point will be lost across a server restart, so the + # user should see this immediately rather than only in the + # post-mortem orphan warning at reload time. + logger.warning( + "%d task(s) did not receive a Globus task_id within %.1fs; " + "they will be unrecoverable if the server restarts before " + "the next get_status call", len(pending), timeout, ) diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index 3a84c9f7..155dd76d 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -113,10 +113,30 @@ def set_pre_submit_hook(self, hook: Optional[Callable]) -> None: self._pre_submit_hook = hook def _apply_pre_submit_hook(self, task): - """Run the registered pre-submit hook (no-op when unset).""" + """Run the registered pre-submit hook (no-op when unset). + + Hook exceptions are wrapped in a ``ValueError`` naming the hook + and the offending task_id, so they surface to the agent as a + structured error instead of an opaque traceback. + """ if self._pre_submit_hook is None: return task - return self._pre_submit_hook(task) + try: + return self._pre_submit_hook(task) + except Exception as exc: + hook_name = getattr( + self._pre_submit_hook, "__name__", repr(self._pre_submit_hook) + ) + task_id = getattr(task, "task_id", "") + logger.warning( + "Pre-submit hook %s failed for task %s", + hook_name, + task_id, + exc_info=True, + ) + raise ValueError( + f"Pre-submit hook '{hook_name}' failed for task '{task_id}': {exc}" + ) from exc # ── Job tracking tools ───────────────────────────────────────────── From 27fff5dd4c83c73b2030751b1a1f3a091d4781db Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 15:04:24 -0500 Subject: [PATCH 016/119] Migrate XANES and gRASPA MCP servers to CGFastMCP Both servers now mirror the mace_mcp_hpc.py pattern: - CGFastMCP with lazy backend initialisation via init_backend(); the worker subprocesses re-importing the module no longer instantiate a backend at import time. - Job-management tools (check_job_status, get_job_results, list_jobs, cancel_job, check_endpoint_status) are auto-registered by CGFastMCP._register_job_tools; the external register_job_tools call is dropped. - __main__ wires init_backend(tracker_kwargs={"persist_file": ...}) and pairs run_mcp_server with shutdown_backend in finally. This also closes a real bug in graspa_mcp_hpc.py, which was instantiating JobTracker() with no persist_file and silently losing job state across restarts despite the server's instructions promising persistence. - Globus Transfer tools (transfer_files, check_transfer_status, list_remote_files) are registered on both servers when the transfer manager is configured, matching the existing MACE behaviour. - gRASPA expander now supports remote_structure_directory the same way MACE does: a one-shot probe task lists CIFs on the remote endpoint and the worker reads them directly from the staged path. - Ensemble flows use the schema_fanout_tool decorator; per-job structure metadata is propagated through the worker output (since the framework meta is only the index). Legacy *_mcp_parsl.py modules now raise a DeprecationWarning at import pointing to the *_hpc.py replacement; they remain functional because scripts/mcp_xanes_example/ still imports xanes_mcp_parsl. --- src/chemgraph/mcp/graspa_mcp_hpc.py | 280 +++++++++++++++++--------- src/chemgraph/mcp/graspa_mcp_parsl.py | 11 + src/chemgraph/mcp/mace_mcp_parsl.py | 11 + src/chemgraph/mcp/xanes_mcp_hpc.py | 250 +++++++++++++---------- src/chemgraph/mcp/xanes_mcp_parsl.py | 11 + 5 files changed, 363 insertions(+), 200 deletions(-) diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index 87eeb231..be7737a6 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -1,141 +1,231 @@ """Backend-agnostic gRASPA MCP server. -Replaces ``graspa_mcp_parsl.py`` by using the :mod:`chemgraph.execution` -abstraction layer. The execution backend (Parsl, EnsembleLauncher, -local) is selected at startup via ``config.toml`` or the -``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +The ensemble expander emits one job per ``(structure, condition)`` pair +and supports both local input directories and pre-staged remote +directories (mirrors the MACE server's local/remote modes). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. """ import logging +import os from pathlib import Path -from mcp.server.fastmcp import FastMCP - -from chemgraph.execution import TaskSpec, get_backend -from chemgraph.execution.job_tracker import JobTracker +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.config import get_transfer_manager from chemgraph.execution.utils import ( make_per_structure_output, resolve_structure_files, - submit_or_gather, - write_results_jsonl, ) -from chemgraph.mcp.job_tools import register_job_tools -from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble logger = logging.getLogger(__name__) -# ── Initialise execution backend ──────────────────────────────────────── -backend = get_backend() -tracker = JobTracker() +_JOBS_FILE = Path("~/.chemgraph/graspa_jobs.json").expanduser() -# ── MCP server ────────────────────────────────────────────────────────── -mcp = FastMCP( +mcp = CGFastMCP( name="ChemGraph Graspa Tools", instructions=""" - You expose tools for running graspa simulations and reading their results. - The available tools are: - 1. run_graspa_ensemble: run graspa calculations over all structures in a - directory using the configured execution backend. - 2. check_job_status: check progress of a submitted HPC job batch. - 3. get_job_results: retrieve results from a completed job batch. - 4. list_jobs: list all tracked job batches. - 5. cancel_job: cancel pending tasks in a job batch. + You expose tools for running gRASPA simulations and reading + their results. The available tools are: + 1. run_graspa_ensemble: run gRASPA calculations over every + structure in a directory at one or more (T, P) conditions. + Local mode uses input_structures; remote mode uses + remote_structure_directory (pre-stage files first with + transfer_files). + 2. check_job_status / get_job_results / list_jobs / cancel_job: + HPC job batch management. Job state persists across sessions. + 3. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on + the remote HPC filesystem before running ensembles in remote + mode. Guidelines: - - Use each tool only when its input schema matches the user request. - - Do not guess numerical values; report tool errors exactly as they occur. - - Keep responses compact -- full results are written to the output files - defined in the schemas. + - Use each tool only when its input schema matches the user + request. + - Do not guess numerical values; report tool errors exactly as + they occur. + - Keep responses compact -- full results are written to the + output files defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. - When a tool returns status='submitted' with a batch_id, use - check_job_status to poll for progress before calling get_job_results. + check_job_status to poll for progress before calling + get_job_results. Job state is persisted across sessions. """, ) -register_job_tools(mcp, tracker, backend) -def _run_graspa_single(job: dict) -> dict: - """Execute a single gRASPA simulation (runs on the worker).""" +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _graspa_worker(job: dict) -> dict: + """Execute a single gRASPA simulation on a backend worker.""" from chemgraph.schemas.graspa_schema import graspa_input_schema from chemgraph.tools.graspa_tools import run_graspa_core - params = graspa_input_schema(**job) if isinstance(job, dict) else job - return run_graspa_core(params) - + job = dict(job) + structure = job.pop("_structure_name", None) + temperature = job.get("temperature") + pressure = job.get("pressure") + + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "raspa.log"), + ) -@mcp.tool( - name="run_graspa_ensemble", - description="Run an ensemble of graspa calculations for multiple input files.", -) -async def run_graspa_ensemble( - params: graspa_input_schema_ensemble, -): - """Run an ensemble of gRASPA calculations over all structure files - using the configured execution backend. - - Parameters - ---------- - params : graspa_input_schema_ensemble - Input parameters for the ensemble of gRASPA calculations. - """ - structure_files, output_dir = resolve_structure_files( - params.input_structures, - extensions={".cif"}, + params = graspa_input_schema(**job) + result = run_graspa_core(params) + + if isinstance(result, dict): + merged = { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + **result, + } + merged.setdefault("status", "success") + return merged + return { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + "result": result, + "status": "success", + } + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) ) - # Base output file name - base_output = Path(params.output_result_file).resolve() - pending_tasks = [] +def _expand_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: + """Server-side expansion of an ensemble request into per-job dicts. + Local mode: enumerates ``input_structures`` on this host. + Remote mode: submits a one-shot probe task to the backend to list + files under ``remote_structure_directory``, then builds per-file + jobs that the worker reads directly from the remote filesystem. + """ + base_output = Path(params.output_result_file) + + if params.remote_structure_directory: + remote_dir = params.remote_structure_directory + mcp._ensure_backend() + probe = TaskSpec( + task_id="ls_remote_dir", + task_type="python", + callable=_ls_remote_files, + kwargs={"path": remote_dir}, + ) + fut = mcp._backend.submit(probe) + try: + file_names = fut.result(timeout=30) + except Exception as exc: + raise RuntimeError( + f"Could not list remote directory {remote_dir}: {exc}" + ) from exc + + # Filter to CIF files (gRASPA expects CIFs). + file_names = [f for f in file_names if f.lower().endswith(".cif")] + if not file_names: + raise ValueError( + f"No CIF files found under remote directory {remote_dir}." + ) + + jobs = [] + for fname in file_names: + mof_name = Path(fname).stem + for condition in params.conditions: + per_output = make_per_structure_output(Path(fname), base_output) + jobs.append( + { + "_structure_name": mof_name, + "remote_structure_file": f"{remote_dir}/{fname}", + "output_result_file": str(per_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } + ) + return jobs + + if not params.input_structures: + raise ValueError( + "Either input_structures or remote_structure_directory " + "must be provided." + ) + + structure_files, _ = resolve_structure_files( + params.input_structures, extensions={".cif"} + ) + jobs = [] for struct_path in structure_files: mof_name = struct_path.stem for condition in params.conditions: - per_struct_output = make_per_structure_output(struct_path, base_output) - job = { - "input_structure_file": str(struct_path), - "output_result_file": str(per_struct_output), - "temperature": condition.temperature, - "pressure": condition.pressure, - "adsorbate": params.adsorbate, - "n_cycles": params.n_cycles, - } - - task = TaskSpec( - task_id=f"graspa_{mof_name}_{condition.temperature}K_{condition.pressure}Pa", - task_type="python", - callable=_run_graspa_single, - kwargs={"job": job}, + per_output = make_per_structure_output(struct_path, base_output) + jobs.append( + { + "_structure_name": mof_name, + "input_structure_file": str(struct_path), + "output_result_file": str(per_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } ) - fut = backend.submit(task) + return jobs - task_meta = { - "structure": mof_name, - "temperature": condition.temperature, - "pressure": condition.pressure, - } - pending_tasks.append((task_meta, fut)) - result = await submit_or_gather( - backend, pending_tasks, tracker, "run_graspa_ensemble", - ) +@mcp.schema_fanout_tool( + name="run_graspa_ensemble", + description=( + "Run gRASPA calculations over every structure in a directory at " + "one or more (temperature, pressure) conditions. Local mode " + "uses input_structures; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_graspa_worker, +) +def run_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: + return _expand_graspa_ensemble(params) - if result["status"] == "completed": - summary_log_path = output_dir / "simulation_results.jsonl" - success_count, total_count = write_results_jsonl( - result["results"], summary_log_path, - ) - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." - ) - # Async remote: return submission confirmation - return result +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on gRASPA MCP server.") if __name__ == "__main__": - run_mcp_server(mcp, default_port=9001) + from chemgraph.mcp.server_utils import run_mcp_server + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9001) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/graspa_mcp_parsl.py b/src/chemgraph/mcp/graspa_mcp_parsl.py index c063b5a5..0bdc21ec 100644 --- a/src/chemgraph/mcp/graspa_mcp_parsl.py +++ b/src/chemgraph/mcp/graspa_mcp_parsl.py @@ -1,8 +1,19 @@ import asyncio import json import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.graspa_mcp_parsl is deprecated; use " + "chemgraph.mcp.graspa_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP import parsl diff --git a/src/chemgraph/mcp/mace_mcp_parsl.py b/src/chemgraph/mcp/mace_mcp_parsl.py index 7f293e18..7f610b12 100644 --- a/src/chemgraph/mcp/mace_mcp_parsl.py +++ b/src/chemgraph/mcp/mace_mcp_parsl.py @@ -1,6 +1,17 @@ import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.mace_mcp_parsl is deprecated; use " + "chemgraph.mcp.mace_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP from parsl.config import Config from parsl.executors import HighThroughputExecutor diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 4abb94e0..8583ae65 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -1,25 +1,29 @@ """Backend-agnostic XANES/FDMNES MCP server. -Replaces ``xanes_mcp_parsl.py`` by using the :mod:`chemgraph.execution` -abstraction layer. The execution backend (Parsl, EnsembleLauncher, -local) is selected at startup via ``config.toml`` or the -``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +The ensemble expander runs server-side and prepares per-structure +FDMNES input files in ``runs_dir``; the worker (which runs on the +backend) executes FDMNES via subprocess and extracts convergence data. +This assumes the server and worker share a filesystem (true for any +Globus Compute endpoint on the same HPC where the MCP server runs; +Globus Transfer staging is a separate concern). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. """ import logging +import subprocess from pathlib import Path -from mcp.server.fastmcp import FastMCP - -from chemgraph.execution import TaskSpec, get_backend -from chemgraph.execution.job_tracker import JobTracker -from chemgraph.execution.utils import ( - resolve_structure_files, - submit_or_gather, - write_results_jsonl, -) -from chemgraph.mcp.job_tools import register_job_tools -from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import resolve_structure_files +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.xanes_schema import ( mp_query_schema, xanes_input_schema, @@ -28,14 +32,9 @@ logger = logging.getLogger(__name__) -# ── Initialise execution backend ──────────────────────────────────────── -backend = get_backend() +_JOBS_FILE = Path("~/.chemgraph/xanes_jobs.json").expanduser() -_jobs_file = Path("~/.chemgraph/xanes_jobs.json").expanduser() -tracker = JobTracker(persist_file=_jobs_file) - -# ── MCP server ────────────────────────────────────────────────────────── -mcp = FastMCP( +mcp = CGFastMCP( name="ChemGraph XANES Tools", instructions=""" You expose tools for running XANES/FDMNES simulations. @@ -45,10 +44,11 @@ using the configured execution backend. 3. fetch_mp_structures: fetch optimized structures from Materials Project. 4. plot_xanes: generate normalized XANES plots for completed calculations. - 5. check_job_status: check progress of a submitted HPC job batch. - 6. get_job_results: retrieve results from a completed job batch. - 7. list_jobs: list all tracked job batches. - 8. cancel_job: cancel pending tasks in a job batch. + 5. check_job_status / get_job_results / list_jobs / cancel_job: HPC + job batch management. Job state persists across sessions. + 6. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on the + remote HPC filesystem before running ensembles. Guidelines: - Use each tool only when its input schema matches the user request. @@ -64,7 +64,20 @@ to retrieve results. """, ) -register_job_tools(mcp, tracker, backend) + + +# ── Single-structure tool ────────────────────────────────────────────── + + +def _xanes_single_worker(params: xanes_input_schema) -> dict: + """Run a single FDMNES calculation on a backend worker.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + result = run_xanes_core(params) + if isinstance(result, dict): + result.setdefault("status", "success") + return result + return {"status": "success", "result": result} @mcp.tool( @@ -72,18 +85,63 @@ description="Run a single XANES/FDMNES calculation for one input structure.", ) def run_xanes_single(params: xanes_input_schema): - """Run a single FDMNES calculation using the core engine.""" - from chemgraph.tools.xanes_tools import run_xanes_core + """Run a single FDMNES calculation using the core engine. + + The CGFastMCP wrapper submits this call to the configured backend; + the body is the direct-call fallback when no backend is active. + """ + return _xanes_single_worker(params) - return run_xanes_core(params) +# ── Ensemble fanout ──────────────────────────────────────────────────── -def _xanes_post_fn(meta: dict, _result) -> dict: - """Post-process a completed FDMNES task: extract convergence data.""" + +def _xanes_ensemble_worker(item: dict) -> dict: + """Execute one prepared FDMNES run on the backend. + + The expander has already written ``input_fdmnes.txt`` (or the + equivalent) into ``item['run_dir']``; this worker runs the binary + via subprocess and then extracts convergence data. + """ from chemgraph.tools.xanes_tools import extract_conv + run_dir = item["run_dir"] + fdmnes_exe = item["fdmnes_exe"] + meta = { + "structure": item.get("structure"), + "run_dir": run_dir, + "z_absorber": item.get("z_absorber"), + } + + stdout_path = Path(run_dir) / "fdmnes_stdout.txt" + stderr_path = Path(run_dir) / "fdmnes_stderr.txt" try: - conv_data = extract_conv(meta["run_dir"]) + with open(stdout_path, "w") as out, open(stderr_path, "w") as err: + proc = subprocess.run( + [fdmnes_exe], + cwd=run_dir, + stdout=out, + stderr=err, + check=False, + ) + if proc.returncode != 0: + return { + **meta, + "status": "failure", + "error_type": "FDMNESExitCode", + "message": f"FDMNES exited with code {proc.returncode}", + "returncode": proc.returncode, + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"FDMNES launch failed: {e}", + } + + try: + conv_data = extract_conv(run_dir) return { **meta, "status": "success", @@ -98,24 +156,9 @@ def _xanes_post_fn(meta: dict, _result) -> dict: } -@mcp.tool( - name="run_xanes_ensemble", - description="Run an ensemble of XANES/FDMNES calculations using the configured backend.", -) -async def run_xanes_ensemble(params: xanes_input_schema_ensemble): - """Run ensemble XANES calculations over all structure files. - - For each structure file: - 1. Reads the structure via ASE. - 2. Creates FDMNES input files in a per-structure subdirectory. - 3. Submits a shell task to run FDMNES. - 4. Gathers results and writes a JSONL summary log. - - Parameters - ---------- - params : xanes_input_schema_ensemble - Input parameters for the ensemble calculation. - """ +def _expand_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: + """Server-side expansion: prepare per-structure run dirs and return + one item per structure for the worker to execute.""" from ase.io import read as ase_read from chemgraph.tools.xanes_tools import write_fdmnes_input @@ -125,19 +168,14 @@ async def run_xanes_ensemble(params: xanes_input_schema_ensemble): extensions={".cif", ".xyz", ".poscar"}, ) - # Create a batch runs directory runs_dir = output_dir / "fdmnes_batch_runs" runs_dir.mkdir(parents=True, exist_ok=True) - fdmnes_exe = params.fdmnes_exe - - pending_tasks = [] - + items: list[dict] = [] for i, struct_path in enumerate(structure_files): run_dir = runs_dir / f"run_{i}" run_dir.mkdir(parents=True, exist_ok=True) - # Read structure and write FDMNES inputs atoms = ase_read(str(struct_path)) z_abs = ( params.z_absorber @@ -153,48 +191,34 @@ async def run_xanes_ensemble(params: xanes_input_schema_ensemble): magnetism=params.magnetism, ) - # Submit shell task - task = TaskSpec( - task_id=f"xanes_{struct_path.stem}_{i}", - task_type="shell", - command=f'cd "{run_dir}" && "{fdmnes_exe}"', - working_dir=str(run_dir), - stdout=str(run_dir / "fdmnes_stdout.txt"), - stderr=str(run_dir / "fdmnes_stderr.txt"), + items.append( + { + "structure": struct_path.name, + "run_dir": str(run_dir), + "z_absorber": z_abs, + "fdmnes_exe": params.fdmnes_exe, + } ) - fut = backend.submit(task) - task_meta = { - "structure": struct_path.name, - "run_dir": str(run_dir), - "z_absorber": z_abs, - } - pending_tasks.append((task_meta, fut)) + return items - result = await submit_or_gather( - backend, pending_tasks, tracker, "run_xanes_ensemble", - post_fn=_xanes_post_fn, - ) - if result["status"] == "completed": - summary_log_path = output_dir / "xanes_results.jsonl" - success_count, total_count = write_results_jsonl( - result["results"], summary_log_path, - ) - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." - ) +@mcp.schema_fanout_tool( + name="run_xanes_ensemble", + description=( + "Run FDMNES/XANES calculations over every structure in an input " + "directory (or list of files). Each structure is prepared " + "server-side and submitted to the configured execution backend." + ), + worker=_xanes_ensemble_worker, +) +def run_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: + return _expand_xanes_ensemble(params) - # Async remote: return submission confirmation - return result + +# ── Orchestration tools (no backend involvement) ─────────────────────── -@mcp.tool( - name="fetch_mp_structures", - description="Fetch optimized structures from Materials Project.", -) def fetch_mp_structures(params: mp_query_schema): """Fetch structures from Materials Project and save as CIF files and pickle database.""" from chemgraph.tools.xanes_tools import ( @@ -214,19 +238,8 @@ def fetch_mp_structures(params: mp_query_schema): } -@mcp.tool( - name="plot_xanes", - description="Generate normalized XANES plots for completed FDMNES calculations.", -) def plot_xanes(runs_dir: str): - """Generate XANES plots for all completed runs in a directory. - - Parameters - ---------- - runs_dir : str - Path to the ``fdmnes_batch_runs`` directory containing ``run_*`` - subdirectories with FDMNES outputs. - """ + """Generate XANES plots for all completed runs in a directory.""" from chemgraph.tools.xanes_tools import ( _get_data_dir, plot_xanes_results, @@ -247,5 +260,32 @@ def plot_xanes(runs_dir: str): } +mcp.add_tool( + fetch_mp_structures, + name="fetch_mp_structures", + description="Fetch optimized structures from Materials Project.", +) +mcp.add_tool( + plot_xanes, + name="plot_xanes", + description="Generate normalized XANES plots for completed FDMNES calculations.", +) + + +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on XANES MCP server.") + + if __name__ == "__main__": - run_mcp_server(mcp, default_port=9007) + from chemgraph.mcp.server_utils import run_mcp_server + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9007) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/xanes_mcp_parsl.py b/src/chemgraph/mcp/xanes_mcp_parsl.py index 0c26bd1b..8725effe 100644 --- a/src/chemgraph/mcp/xanes_mcp_parsl.py +++ b/src/chemgraph/mcp/xanes_mcp_parsl.py @@ -1,8 +1,19 @@ import asyncio import json import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.xanes_mcp_parsl is deprecated; use " + "chemgraph.mcp.xanes_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP import parsl From d28274c2cfc6758873d0831abae2c7b7f1f43117 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 4 May 2026 12:01:52 -0500 Subject: [PATCH 017/119] Add pluggable execution backend for Parsl, EnsembleLauncher, and Globus Compute Introduce a unified execution module with an abstract ExecutionBackend interface and TaskSpec model, supporting four backends: local (ProcessPoolExecutor), Parsl, EnsembleLauncher, and Globus Compute. Includes config factory with resolution order (args > env > config.toml), HPC configs loader, comprehensive tests, and pytest --run-globus-compute option for live endpoint tests. --- pyproject.toml | 7 + src/chemgraph/execution/__init__.py | 33 + src/chemgraph/execution/base.py | 144 +++ src/chemgraph/execution/config.py | 163 +++ .../execution/ensemble_launcher_backend.py | 199 ++++ .../execution/globus_compute_backend.py | 131 +++ src/chemgraph/execution/local_backend.py | 119 ++ src/chemgraph/execution/parsl_backend.py | 122 ++ src/chemgraph/execution/utils.py | 175 +++ src/chemgraph/hpc_configs/__init__.py | 1 + src/chemgraph/hpc_configs/loader.py | 65 ++ src/chemgraph/hpc_configs/local_parsl.py | 60 + src/chemgraph/mcp/graspa_mcp_hpc.py | 124 ++ src/chemgraph/mcp/mace_mcp_hpc.py | 178 +++ src/chemgraph/mcp/xanes_mcp_hpc.py | 227 ++++ tests/conftest.py | 22 +- tests/test_execution.py | 1017 +++++++++++++++++ 17 files changed, 2781 insertions(+), 6 deletions(-) create mode 100644 src/chemgraph/execution/__init__.py create mode 100644 src/chemgraph/execution/base.py create mode 100644 src/chemgraph/execution/config.py create mode 100644 src/chemgraph/execution/ensemble_launcher_backend.py create mode 100644 src/chemgraph/execution/globus_compute_backend.py create mode 100644 src/chemgraph/execution/local_backend.py create mode 100644 src/chemgraph/execution/parsl_backend.py create mode 100644 src/chemgraph/execution/utils.py create mode 100644 src/chemgraph/hpc_configs/__init__.py create mode 100644 src/chemgraph/hpc_configs/loader.py create mode 100644 src/chemgraph/hpc_configs/local_parsl.py create mode 100644 src/chemgraph/mcp/graspa_mcp_hpc.py create mode 100644 src/chemgraph/mcp/mace_mcp_hpc.py create mode 100644 src/chemgraph/mcp/xanes_mcp_hpc.py create mode 100644 tests/test_execution.py diff --git a/pyproject.toml b/pyproject.toml index d61f3918..28896e0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,12 @@ ui = [ parsl = [ "parsl", ] +ensemble_launcher = [ + "ensemble-launcher", +] +globus_compute = [ + "globus-compute-sdk", +] xanes = [ "mp-api; python_version >= '3.11'", "parsl" @@ -108,6 +114,7 @@ skip-magic-trailing-comma = false # Ensure Black-style formatting testpaths = ["tests"] markers = [ "llm: marks tests as requiring LLM API access (run with --run-llm)", + "globus_compute: marks tests requiring a live Globus Compute endpoint (run with --run-globus-compute)", "asyncio: marks async tests", ] filterwarnings = [ diff --git a/src/chemgraph/execution/__init__.py b/src/chemgraph/execution/__init__.py new file mode 100644 index 00000000..0fd6709b --- /dev/null +++ b/src/chemgraph/execution/__init__.py @@ -0,0 +1,33 @@ +"""Pluggable execution backends for ChemGraph HPC workloads. + +This package provides a backend-agnostic interface for submitting +computational tasks to different workflow managers (Parsl, +EnsembleLauncher, Globus Compute, local process pool). + +Quick start +----------- +>>> from chemgraph.execution import get_backend, TaskSpec +>>> backend = get_backend() # reads config.toml / env vars +>>> future = backend.submit(TaskSpec( +... task_id="test-1", +... task_type="python", +... callable=my_function, +... kwargs={"param": 42}, +... )) +>>> result = future.result() +>>> backend.shutdown() + +See Also +-------- +:mod:`chemgraph.execution.base` -- abstract classes +:mod:`chemgraph.execution.config` -- factory function +""" + +from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.config import get_backend + +__all__ = [ + "ExecutionBackend", + "TaskSpec", + "get_backend", +] diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py new file mode 100644 index 00000000..e7dc338b --- /dev/null +++ b/src/chemgraph/execution/base.py @@ -0,0 +1,144 @@ +"""Abstract base classes for execution backends. + +This module defines the ``ExecutionBackend`` protocol and the ``TaskSpec`` +data model that all workflow managers (Parsl, EnsembleLauncher, local +process pool, etc.) must implement. Downstream code (MCP servers, tools) +only depends on these abstractions -- never on a concrete backend. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from concurrent.futures import Future +from typing import Any, Callable, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field + +logger = logging.getLogger(__name__) + + +class TaskSpec(BaseModel): + """Specification for a single unit of work to submit to a backend. + + Supports two execution modes: + + * **python** -- run a Python callable (``callable(*args, **kwargs)``) + * **shell** -- run a shell command string + + Resource hints (``num_nodes``, ``processes_per_node``, ``gpus_per_task``) + are advisory; backends may ignore hints they do not support. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + task_id: str = Field( + description="Unique identifier for this task within the batch." + ) + task_type: Literal["python", "shell"] = Field( + default="python", + description="Execution mode: 'python' for a callable, 'shell' for a command.", + ) + + # ── Python task fields ────────────────────────────────────────────── + callable: Optional[Callable[..., Any]] = Field( + default=None, + description="Python callable to execute (required when task_type='python').", + ) + args: tuple = Field( + default=(), + description="Positional arguments forwarded to the callable.", + ) + kwargs: dict = Field( + default_factory=dict, + description="Keyword arguments forwarded to the callable.", + ) + + # ── Shell task fields ─────────────────────────────────────────────── + command: Optional[str] = Field( + default=None, + description="Shell command to execute (required when task_type='shell').", + ) + working_dir: Optional[str] = Field( + default=None, + description="Working directory for the shell command.", + ) + stdout: Optional[str] = Field( + default=None, + description="Path to capture stdout (shell tasks).", + ) + stderr: Optional[str] = Field( + default=None, + description="Path to capture stderr (shell tasks).", + ) + + # ── Resource hints ────────────────────────────────────────────────── + num_nodes: int = Field( + default=1, + description="Number of compute nodes requested.", + ) + processes_per_node: int = Field( + default=1, + description="Number of processes (ranks) per node.", + ) + gpus_per_task: int = Field( + default=0, + description="Number of GPUs requested per task.", + ) + + +class ExecutionBackend(ABC): + """Abstract interface that every workflow-manager adapter must implement. + + Lifecycle + --------- + 1. ``initialize(system, **kwargs)`` -- start the backend + 2. ``submit(task)`` / ``submit_batch(tasks)`` -- dispatch work + 3. ``shutdown()`` -- release resources + + The class also supports the context-manager protocol (``with`` statement). + """ + + def __init__(self) -> None: + self._initialized: bool = False + + @abstractmethod + def initialize(self, system: str = "local", **kwargs: Any) -> None: + """Prepare the backend for accepting work. + + Parameters + ---------- + system : str + Target HPC system name (e.g. ``"polaris"``, ``"aurora"``, + ``"local"``). Backends may use this to load system-specific + configurations. + **kwargs + Backend-specific options (worker_init, run_dir, etc.). + """ + + @abstractmethod + def submit(self, task: TaskSpec) -> Future: + """Submit a single task and return a ``concurrent.futures.Future``. + + The future resolves to whatever the callable/command returns. + """ + + def submit_batch(self, tasks: list[TaskSpec]) -> list[Future]: + """Submit multiple tasks, returning futures in submission order. + + The default implementation simply loops over ``submit()``. + Backends may override this for optimized batch submission. + """ + return [self.submit(t) for t in tasks] + + @abstractmethod + def shutdown(self) -> None: + """Release all resources held by the backend.""" + + # ── Context-manager protocol ──────────────────────────────────────── + + def __enter__(self) -> ExecutionBackend: + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001 + self.shutdown() diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py new file mode 100644 index 00000000..71d3de90 --- /dev/null +++ b/src/chemgraph/execution/config.py @@ -0,0 +1,163 @@ +"""Execution backend configuration and factory. + +Reads the ``[execution]`` section from ``config.toml`` (or env-var +overrides) and returns an initialised :class:`ExecutionBackend` instance. + +Environment variables +--------------------- +``CHEMGRAPH_EXECUTION_BACKEND`` + Override the backend name (``"parsl"``, ``"ensemble_launcher"``, + ``"globus_compute"``, ``"local"``). +``COMPUTE_SYSTEM`` + Override the target HPC system (``"polaris"``, ``"aurora"``, + ``"local"``). +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from chemgraph.execution.base import ExecutionBackend + +logger = logging.getLogger(__name__) + +# Supported backend names (keep in sync with the ``elif`` chain below) +SUPPORTED_BACKENDS = ("parsl", "ensemble_launcher", "globus_compute", "local") + + +def _load_execution_config(config_path: Optional[str] = None) -> dict[str, Any]: + """Read the ``[execution]`` table from ``config.toml``. + + Returns an empty dict if the section is missing or the file is not + found, so callers always get sensible defaults. + """ + if config_path is None: + # Walk upward from CWD to find config.toml (same heuristic the + # rest of ChemGraph uses). + candidate = Path.cwd() / "config.toml" + if candidate.is_file(): + config_path = str(candidate) + else: + # Try the repo root (two levels up from this file). + repo_root = Path(__file__).resolve().parents[3] + candidate = repo_root / "config.toml" + if candidate.is_file(): + config_path = str(candidate) + + if config_path is None: + return {} + + try: + import toml + + full_config = toml.load(config_path) + return full_config.get("execution", {}) + except Exception as exc: # noqa: BLE001 + logger.warning("Could not read [execution] from %s: %s", config_path, exc) + return {} + + +def get_backend( + config_path: Optional[str] = None, + backend_name: Optional[str] = None, + system: Optional[str] = None, + **kwargs: Any, +) -> ExecutionBackend: + """Create and initialise an :class:`ExecutionBackend`. + + Resolution order for ``backend_name``: + + 1. Explicit ``backend_name`` argument + 2. ``CHEMGRAPH_EXECUTION_BACKEND`` environment variable + 3. ``config.toml`` ``[execution] backend`` key + 4. ``"local"`` (safe fallback) + + Resolution order for ``system``: + + 1. Explicit ``system`` argument + 2. ``COMPUTE_SYSTEM`` environment variable + 3. ``config.toml`` ``[execution] system`` key + 4. ``"local"`` + + Parameters + ---------- + config_path : str, optional + Path to ``config.toml``. Auto-detected when omitted. + backend_name : str, optional + Force a specific backend. + system : str, optional + Target HPC system name. + **kwargs + Extra keyword arguments forwarded to + :meth:`ExecutionBackend.initialize`. + + Returns + ------- + ExecutionBackend + A ready-to-use backend instance. + """ + cfg = _load_execution_config(config_path) + + # -- resolve backend name ------------------------------------------------- + resolved_backend = ( + backend_name + or os.getenv("CHEMGRAPH_EXECUTION_BACKEND") + or cfg.get("backend", "local") + ) + resolved_backend = resolved_backend.lower().strip() + + if resolved_backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"Unknown execution backend '{resolved_backend}'. " + f"Supported: {', '.join(SUPPORTED_BACKENDS)}" + ) + + # -- resolve system ------------------------------------------------------- + resolved_system = ( + system or os.getenv("COMPUTE_SYSTEM") or cfg.get("system", "local") + ) + + # -- merge backend-specific config ---------------------------------------- + backend_cfg = cfg.get(resolved_backend, {}) + merged_kwargs = {**backend_cfg, **kwargs} + + # -- instantiate ---------------------------------------------------------- + logger.info( + "Creating execution backend '%s' for system '%s'", + resolved_backend, + resolved_system, + ) + + if resolved_backend == "parsl": + from chemgraph.execution.parsl_backend import ParslBackend + + backend = ParslBackend() + + elif resolved_backend == "ensemble_launcher": + from chemgraph.execution.ensemble_launcher_backend import ( + EnsembleLauncherBackend, + ) + + backend = EnsembleLauncherBackend() + + elif resolved_backend == "globus_compute": + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + + elif resolved_backend == "local": + from chemgraph.execution.local_backend import LocalBackend + + backend = LocalBackend() + + else: + # Should be unreachable thanks to the validation above. + raise ValueError(f"Unsupported backend: {resolved_backend}") + + backend.initialize(system=resolved_system, **merged_kwargs) + return backend diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py new file mode 100644 index 00000000..23462f5b --- /dev/null +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -0,0 +1,199 @@ +"""EnsembleLauncher execution backend. + +Wraps `EnsembleLauncher `_ +to conform to the :class:`ExecutionBackend` interface. Uses the +cluster-mode API (``EnsembleLauncher.start()`` + ``ClusterClient``) so +that tasks can be submitted dynamically. + +EnsembleLauncher must be installed separately +(``pip install chemgraphagent[ensemble_launcher]``). +""" + +from __future__ import annotations + +import logging +import os +import socket +import time +import uuid +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class EnsembleLauncherBackend(ExecutionBackend): + """Execution backend that delegates work to EnsembleLauncher. + + The backend starts an EnsembleLauncher orchestrator in cluster mode + and submits tasks through a :class:`ClusterClient`. + + Configuration + ------------- + The following ``kwargs`` are accepted by :meth:`initialize`: + + ``comm_name`` : str + Communication backend (``"zmq"``, ``"async_zmq"``, ``"multiprocessing"``). + Default: ``"async_zmq"``. + ``task_executor_name`` : str + Task executor (``"multiprocessing"``, ``"mpi"``, + ``"async_processpool"``). Default: ``"async_processpool"``. + ``nlevels`` : int + Hierarchy depth. Default: ``0`` (single-node). + ``max_workers`` : int + Number of CPUs to expose. Default: ``os.cpu_count()``. + ``checkpoint_dir`` : str + Directory for orchestrator checkpoint files. Auto-generated + when omitted. + ``nodes`` : list[str] + List of compute node hostnames. Defaults to ``[hostname]``. + ``startup_delay`` : float + Seconds to wait after ``el.start()`` for the orchestrator to be + ready. Default: ``2.0``. + """ + + def __init__(self) -> None: + super().__init__() + self._el = None + self._client = None + self._checkpoint_dir: str | None = None + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + try: + from ensemble_launcher import EnsembleLauncher + from ensemble_launcher.config import LauncherConfig, SystemConfig + from ensemble_launcher.orchestrator import ClusterClient + except ImportError as exc: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) from exc + + # -- extract parameters ------------------------------------------------ + comm_name = kwargs.get("comm_name", "async_zmq") + task_executor = kwargs.get("task_executor_name", "async_processpool") + nlevels = kwargs.get("nlevels", 0) + ncpus = kwargs.get("max_workers", os.cpu_count() or 4) + checkpoint_dir = kwargs.get( + "checkpoint_dir", + os.path.join(os.getcwd(), f".el_ckpt_{uuid.uuid4().hex[:8]}"), + ) + nodes = kwargs.get("nodes", [socket.gethostname()]) + startup_delay = kwargs.get("startup_delay", 2.0) + + self._checkpoint_dir = checkpoint_dir + + # -- configure --------------------------------------------------------- + system_config = SystemConfig( + name=system, + ncpus=ncpus, + cpus=list(range(ncpus)), + ) + + launcher_config = LauncherConfig( + task_executor_name=task_executor, + comm_name=comm_name, + nlevels=nlevels, + cluster=True, + checkpoint_dir=checkpoint_dir, + ) + + # -- start orchestrator ------------------------------------------------ + self._el = EnsembleLauncher( + ensemble_file={}, + system_config=system_config, + launcher_config=launcher_config, + Nodes=nodes, + ) + self._el.start() + time.sleep(startup_delay) + + # -- connect client ---------------------------------------------------- + self._client = ClusterClient(checkpoint_dir=checkpoint_dir) + self._client.start() + + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized (system='%s', " + "comm='%s', executor='%s', nodes=%s)", + system, + comm_name, + task_executor, + nodes, + ) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._client is None: + raise RuntimeError( + "EnsembleLauncherBackend is not initialized. " + "Call initialize() first." + ) + + from ensemble_launcher.ensemble import Task as ELTask + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + el_task = ELTask( + task_id=task.task_id, + nnodes=task.num_nodes, + ppn=task.processes_per_node, + executable=task.callable, + args=task.args or (), + kwargs=task.kwargs or {}, + ) + return self._client.submit(el_task) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + el_task = ELTask( + task_id=task.task_id, + nnodes=task.num_nodes, + ppn=task.processes_per_node, + cmd_template=task.command, + ) + return self._client.submit(el_task) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + self._initialized = False + client_ok = True + if self._client is not None: + try: + self._client.teardown() + self._client = None + except Exception: + client_ok = False + logger.warning( + "Error tearing down EnsembleLauncher client.", exc_info=True + ) + + el_ok = True + if self._el is not None: + try: + self._el.stop() + self._el = None + except Exception: + el_ok = False + logger.warning( + "Error stopping EnsembleLauncher orchestrator.", exc_info=True + ) + + if client_ok and el_ok: + logger.info("EnsembleLauncherBackend shut down.") + else: + logger.warning( + "EnsembleLauncherBackend partially shut down. " + "Call shutdown() again to retry failed teardown." + ) diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py new file mode 100644 index 00000000..0c2a9634 --- /dev/null +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -0,0 +1,131 @@ +"""Globus Compute execution backend. + +Wraps the `Globus Compute SDK `_ +to conform to the :class:`ExecutionBackend` interface. Python tasks are +dispatched via :meth:`Executor.submit` and shell tasks via +:class:`ShellFunction`. + +Unlike the Parsl and EnsembleLauncher backends, Globus Compute does **not** +require an active PBS/Slurm allocation at submit time. A persistent +Globus Compute *endpoint* daemon running on the HPC login node +automatically provisions and manages batch jobs as tasks arrive. + +**Prerequisites** + +1. Install the SDK: ``pip install chemgraphagent[globus_compute]`` +2. On the HPC system, configure and start an endpoint:: + + globus-compute-endpoint configure chemgraph-polaris + globus-compute-endpoint start chemgraph-polaris + # -> prints the endpoint UUID + +3. Set ``endpoint_id`` in ``config.toml`` or pass it to + :func:`~chemgraph.execution.config.get_backend`. +""" + +from __future__ import annotations + +import logging +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class GlobusComputeBackend(ExecutionBackend): + """Execution backend that delegates work to Globus Compute. + + Configuration + ------------- + The following ``kwargs`` are accepted by :meth:`initialize`: + + ``endpoint_id`` : str **required** + UUID of the Globus Compute endpoint to submit tasks to. + ``amqp_port`` : int, optional + Port for the AMQP result-streaming connection. Defaults to the + SDK default (5671). Set to ``443`` if outbound 5671 is blocked. + """ + + def __init__(self) -> None: + super().__init__() + self._executor = None + + # ── lifecycle ──────────────────────────────────────────────────────── + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + try: + from globus_compute_sdk import Executor + except ImportError as exc: + raise ImportError( + "globus-compute-sdk is required for the GlobusComputeBackend. " + "Install it with: pip install chemgraphagent[globus_compute]" + ) from exc + + endpoint_id = kwargs.get("endpoint_id") + if not endpoint_id: + raise ValueError( + "GlobusComputeBackend requires an 'endpoint_id'. " + "Set it in config.toml under [execution.globus_compute] " + "or pass it directly to get_backend()." + ) + + executor_kwargs: dict[str, Any] = {"endpoint_id": endpoint_id} + + amqp_port = kwargs.get("amqp_port") + if amqp_port is not None: + executor_kwargs["amqp_port"] = int(amqp_port) + + self._executor = Executor(**executor_kwargs) + self._initialized = True + logger.info( + "GlobusComputeBackend initialized (system='%s', endpoint='%s')", + system, + endpoint_id, + ) + + # ── task submission ───────────────────────────────────────────────── + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._executor is None: + raise RuntimeError( + "GlobusComputeBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + # Executor.submit() returns a ComputeFuture (a + # concurrent.futures.Future subclass), fully compatible + # with asyncio.wrap_future() used by gather_futures(). + return self._executor.submit(task.callable, *task.args, **task.kwargs) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + from globus_compute_sdk import ShellFunction + + shell_fn = ShellFunction(task.command) + return self._executor.submit(shell_fn) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + # ── teardown ──────────────────────────────────────────────────────── + + def shutdown(self) -> None: + if self._executor is not None: + try: + self._executor.shutdown() + logger.info("GlobusComputeBackend shut down.") + except Exception: + logger.warning("Error during Globus Compute shutdown.", exc_info=True) + self._executor = None + self._initialized = False diff --git a/src/chemgraph/execution/local_backend.py b/src/chemgraph/execution/local_backend.py new file mode 100644 index 00000000..c6a66abe --- /dev/null +++ b/src/chemgraph/execution/local_backend.py @@ -0,0 +1,119 @@ +"""Local execution backend using ``concurrent.futures.ProcessPoolExecutor``. + +Ideal for development, testing, and single-node runs where no HPC +workflow manager is needed. Requires zero external dependencies beyond +the Python standard library. +""" + +from __future__ import annotations + +import logging +import subprocess +from concurrent.futures import Future, ProcessPoolExecutor +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + +# Default number of worker processes (can be overridden via config). +_DEFAULT_MAX_WORKERS = 4 + + +def _run_shell_task( + command: str, + working_dir: str | None, + stdout_path: str | None, + stderr_path: str | None, +) -> int: + """Execute a shell command in a child process. + + Returns the process exit code. stdout/stderr are captured to + files when paths are provided. + """ + import contextlib + + with ( + open(stdout_path, "w") if stdout_path else contextlib.nullcontext() as stdout_fh, + open(stderr_path, "w") if stderr_path else contextlib.nullcontext() as stderr_fh, + ): + result = subprocess.run( + command, + shell=True, + cwd=working_dir, + stdout=stdout_fh, + stderr=stderr_fh, + check=True, + ) + return result.returncode + + +def _run_python_task( + fn: Any, # Callable -- typed as Any for pickling + args: tuple, + kwargs: dict, +) -> Any: + """Execute a Python callable in a child process.""" + return fn(*args, **kwargs) + + +class LocalBackend(ExecutionBackend): + """Execution backend backed by :class:`ProcessPoolExecutor`. + + Configuration + ------------- + ``max_workers`` : int + Maximum number of concurrent worker processes (default: 4). + """ + + def __init__(self) -> None: + super().__init__() + self._pool: ProcessPoolExecutor | None = None + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + max_workers = kwargs.get("max_workers", _DEFAULT_MAX_WORKERS) + self._pool = ProcessPoolExecutor(max_workers=max_workers) + self._initialized = True + logger.info( + "LocalBackend initialized with %d workers", max_workers + ) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._pool is None: + raise RuntimeError( + "LocalBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + return self._pool.submit( + _run_python_task, task.callable, task.args, task.kwargs + ) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + return self._pool.submit( + _run_shell_task, + task.command, + task.working_dir, + task.stdout, + task.stderr, + ) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + if self._pool is not None: + logger.info("Shutting down LocalBackend process pool.") + self._pool.shutdown(wait=True) + self._pool = None + self._initialized = False diff --git a/src/chemgraph/execution/parsl_backend.py b/src/chemgraph/execution/parsl_backend.py new file mode 100644 index 00000000..f2e4fe37 --- /dev/null +++ b/src/chemgraph/execution/parsl_backend.py @@ -0,0 +1,122 @@ +"""Parsl execution backend. + +Wraps `Parsl `_ to conform to the +:class:`ExecutionBackend` interface. Python tasks are dispatched via +``@python_app`` and shell tasks via ``@bash_app``. + +Parsl must be installed separately (``pip install chemgraphagent[parsl]``). +""" + +from __future__ import annotations + +import logging +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class ParslBackend(ExecutionBackend): + """Execution backend that delegates work to Parsl. + + Configuration + ------------- + The ``system`` argument passed to :meth:`initialize` is forwarded to + :func:`chemgraph.hpc_configs.loader.load_parsl_config` which returns + the appropriate ``parsl.config.Config``. + + Extra ``kwargs`` are forwarded to the config loader (e.g. + ``worker_init``). + """ + + def __init__(self) -> None: + super().__init__() + self._python_app = None + self._bash_app = None + + def initialize(self, system: str = "polaris", **kwargs: Any) -> None: + try: + import parsl + from parsl import bash_app, python_app + except ImportError as exc: + raise ImportError( + "Parsl is required for the ParslBackend. " + "Install it with: pip install chemgraphagent[parsl]" + ) from exc + + from chemgraph.hpc_configs.loader import load_parsl_config + + run_dir = kwargs.pop("run_dir", None) + worker_init = kwargs.pop("worker_init", None) + + # Build kwargs for the config loader + loader_kwargs: dict[str, Any] = {} + if run_dir is not None: + loader_kwargs["run_dir"] = run_dir + if worker_init is not None: + loader_kwargs["worker_init"] = worker_init + + config = load_parsl_config(system, **loader_kwargs) + parsl.load(config) + + # Create generic app wrappers ------------------------------------------ + # These are created once and reused for all submitted tasks. + + @python_app + def _generic_python_app(fn, args, kwargs): + """Execute an arbitrary callable on a Parsl worker.""" + return fn(*args, **kwargs) + + @bash_app + def _generic_bash_app(command, stdout=None, stderr=None): + """Execute a shell command string on a Parsl worker.""" + return command + + self._python_app = _generic_python_app + self._bash_app = _generic_bash_app + + self._initialized = True + logger.info("ParslBackend initialized for system '%s'", system) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized: + raise RuntimeError( + "ParslBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + return self._python_app(task.callable, task.args, task.kwargs) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + bash_kwargs: dict[str, Any] = {"command": task.command} + if task.stdout: + bash_kwargs["stdout"] = task.stdout + if task.stderr: + bash_kwargs["stderr"] = task.stderr + return self._bash_app(**bash_kwargs) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + if self._initialized: + try: + import parsl + + parsl.clear() + logger.info("ParslBackend shut down.") + except Exception: + logger.warning("Error during Parsl shutdown.", exc_info=True) + self._initialized = False diff --git a/src/chemgraph/execution/utils.py b/src/chemgraph/execution/utils.py new file mode 100644 index 00000000..70759a71 --- /dev/null +++ b/src/chemgraph/execution/utils.py @@ -0,0 +1,175 @@ +"""Shared utilities for ensemble execution in MCP servers. + +Consolidates patterns that were previously duplicated across +``graspa_mcp_parsl.py``, ``xanes_mcp_parsl.py``, and +``mace_mcp_parsl.py``: + +* Structure file resolution from directory or file list +* Async future gathering with error handling +* JSONL result writing +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from concurrent.futures import Future +from pathlib import Path +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +def resolve_structure_files( + input_source: str | list[str], + extensions: set[str] | None = None, +) -> tuple[list[Path], Path]: + """Resolve a directory path or file list into a list of structure files. + + Parameters + ---------- + input_source : str or list[str] + Either a directory path (all matching files will be collected) + or an explicit list of file paths. + extensions : set[str], optional + File extensions to include when scanning a directory (e.g. + ``{".cif", ".xyz"}``). If *None*, all files are included. + + Returns + ------- + structure_files : list[Path] + Sorted list of resolved file paths. + output_dir : Path + The parent directory (useful for placing output files). + + Raises + ------ + ValueError + If no files are found or if listed files do not exist. + """ + structure_files: list[Path] = [] + output_dir: Path = Path.cwd() + + if isinstance(input_source, list): + structure_files = [Path(p) for p in input_source] + missing = [p for p in structure_files if not p.exists()] + if missing: + raise ValueError(f"The following input files are missing: {missing}") + if structure_files: + output_dir = structure_files[0].parent + else: + input_dir = Path(input_source) + if not input_dir.is_dir(): + raise ValueError(f"'{input_dir}' is not a valid directory.") + + if extensions: + structure_files = sorted( + p for p in input_dir.iterdir() if p.suffix in extensions + ) + else: + structure_files = sorted(p for p in input_dir.iterdir() if p.is_file()) + + output_dir = input_dir + + if not structure_files: + raise ValueError("No structure files found to simulate.") + + return structure_files, output_dir + + +async def gather_futures( + pending: list[tuple[dict, Future]], + post_fn: Optional[Callable[[dict, Any], dict]] = None, +) -> list[dict]: + """Await a list of ``(metadata, future)`` pairs concurrently. + + Each future is converted to an asyncio-awaitable via + :func:`asyncio.wrap_future` and gathered concurrently. + + Parameters + ---------- + pending : list[tuple[dict, Future]] + Each element is ``(task_metadata_dict, concurrent_futures_Future)``. + post_fn : callable, optional + If provided, called as ``post_fn(metadata, result)`` after a + successful future resolution. Must return a ``dict`` to include + in the results list. When *None*, the raw result is merged with + metadata. + + Returns + ------- + list[dict] + One result dict per task (successful or failed). + """ + + async def _wait(meta: dict, fut: Future) -> dict: + try: + result = await asyncio.wrap_future(fut) + if post_fn is not None: + return post_fn(meta, result) + # Default: merge metadata with result (if result is a dict) + if isinstance(result, dict): + merged = {**meta, **result} + merged.setdefault("status", "success") + return merged + return {**meta, "result": result, "status": "success"} + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": str(e), + } + + return list( + await asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) + ) + + +def write_results_jsonl( + results: list[dict], + output_path: Path, + append: bool = True, +) -> tuple[int, int]: + """Write results to a JSONL file and return (success_count, total_count). + + Parameters + ---------- + results : list[dict] + Each dict should contain a ``"status"`` key. + output_path : Path + Path to the JSONL file. + append : bool + If *True* (default), append to an existing file. + + Returns + ------- + success_count : int + total_count : int + """ + mode = "a" if append else "w" + success_count = 0 + + with open(output_path, mode, encoding="utf-8") as f: + for res in results: + if res.get("status") == "success": + success_count += 1 + f.write(json.dumps(res) + "\n") + + return success_count, len(results) + + +def make_per_structure_output( + struct_path: Path, + base_output: Path, +) -> Path: + """Generate a per-structure output filename. + + Given ``struct_path = "/data/MOF-5.cif"`` and + ``base_output = "/results/output.json"``, returns + ``"/results/MOF-5_output.json"``. + """ + base_suffix = base_output.suffix or ".json" + base_stem = base_output.stem + return base_output.with_name(f"{struct_path.stem}_{base_stem}{base_suffix}") diff --git a/src/chemgraph/hpc_configs/__init__.py b/src/chemgraph/hpc_configs/__init__.py new file mode 100644 index 00000000..32d8bc92 --- /dev/null +++ b/src/chemgraph/hpc_configs/__init__.py @@ -0,0 +1 @@ +"""HPC configuration factories for workflow managers.""" diff --git a/src/chemgraph/hpc_configs/loader.py b/src/chemgraph/hpc_configs/loader.py new file mode 100644 index 00000000..4a25d5e7 --- /dev/null +++ b/src/chemgraph/hpc_configs/loader.py @@ -0,0 +1,65 @@ +"""Unified loader for HPC-specific Parsl configurations. + +This consolidates the ``load_parsl_config()`` function that was +previously duplicated across ``graspa_mcp_parsl.py`` and +``xanes_mcp_parsl.py``. +""" + +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + + +def load_parsl_config(system_name: str, run_dir: str | None = None, **kwargs): + """Dynamically import and return a Parsl ``Config`` for the given HPC system. + + Parameters + ---------- + system_name : str + Target system name. Supported: ``"local"``, ``"polaris"``, + ``"aurora"``. + run_dir : str, optional + Parsl run directory. Defaults to the current working directory. + **kwargs + Extra keyword arguments forwarded to the system-specific + config factory (e.g. ``worker_init``, ``max_workers``). + + Returns + ------- + parsl.config.Config + A ready-to-use Parsl configuration object. + + Raises + ------ + ValueError + If *system_name* is not recognised. + """ + system_name = system_name.lower().strip() + if run_dir is None: + run_dir = os.getcwd() + + logger.info("Loading Parsl config for system: %s", system_name) + + if system_name == "local": + from chemgraph.hpc_configs.local_parsl import get_local_config + + return get_local_config(run_dir=run_dir, **kwargs) + + elif system_name == "polaris": + from chemgraph.hpc_configs.polaris_parsl import get_polaris_config + + return get_polaris_config(run_dir=run_dir, **kwargs) + + elif system_name == "aurora": + from chemgraph.hpc_configs.aurora_parsl import get_aurora_config + + return get_aurora_config(run_dir=run_dir, **kwargs) + + else: + raise ValueError( + f"Unknown HPC system: '{system_name}'. " + f"Supported systems: local, polaris, aurora" + ) diff --git a/src/chemgraph/hpc_configs/local_parsl.py b/src/chemgraph/hpc_configs/local_parsl.py new file mode 100644 index 00000000..b4c05f01 --- /dev/null +++ b/src/chemgraph/hpc_configs/local_parsl.py @@ -0,0 +1,60 @@ +"""Local Parsl configuration for development and single-node runs. + +Uses ``HighThroughputExecutor`` with a ``LocalProvider`` (no MPI +launcher, no PBS/Slurm dependency). Suitable for laptops, CI runners, +and single-node workstations where the Parsl backend is desired but no +HPC scheduler is available. +""" + +from __future__ import annotations + +import logging +import os + +from parsl.config import Config +from parsl.executors import HighThroughputExecutor +from parsl.providers import LocalProvider + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_WORKERS = 4 + + +def get_local_config( + run_dir: str | None = None, + max_workers: int = _DEFAULT_MAX_WORKERS, + worker_init: str = "export TMPDIR=/tmp", +) -> Config: + """Generate a Parsl configuration for local execution. + + Parameters + ---------- + run_dir : str, optional + Parsl run directory. Defaults to the current working directory. + max_workers : int, optional + Maximum number of concurrent workers. Default: 4. + worker_init : str, optional + Shell commands executed on each worker before tasks. + """ + if run_dir is None: + run_dir = os.getcwd() + + logger.info("Creating local Parsl config with %d workers", max_workers) + + config = Config( + executors=[ + HighThroughputExecutor( + label="local_htex", + max_workers_per_node=max_workers, + provider=LocalProvider( + init_blocks=1, + min_blocks=0, + max_blocks=1, + worker_init=worker_init, + ), + ), + ], + run_dir=run_dir, + ) + + return config diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py new file mode 100644 index 00000000..9ee276bc --- /dev/null +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -0,0 +1,124 @@ +"""Backend-agnostic gRASPA MCP server. + +Replaces ``graspa_mcp_parsl.py`` by using the :mod:`chemgraph.execution` +abstraction layer. The execution backend (Parsl, EnsembleLauncher, +local) is selected at startup via ``config.toml`` or the +``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +""" + +import logging +from pathlib import Path + +from mcp.server.fastmcp import FastMCP + +from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, + write_results_jsonl, +) +from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble + +logger = logging.getLogger(__name__) + +# ── Initialise execution backend ──────────────────────────────────────── +backend = get_backend() + +# ── MCP server ────────────────────────────────────────────────────────── +mcp = FastMCP( + name="ChemGraph Graspa Tools", + instructions=""" + You expose tools for running graspa simulations and reading their results. + The available tools are: + 1. run_graspa_ensemble: run graspa calculations over all structures in a + directory using the configured execution backend. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are written to the output files + defined in the schemas. + - When returning paths, use absolute paths. + - Energies are in eV and wall times are in seconds. + """, +) + + +def _run_graspa_single(job: dict) -> dict: + """Execute a single gRASPA simulation (runs on the worker).""" + from chemgraph.schemas.graspa_schema import graspa_input_schema + from chemgraph.tools.graspa_tools import run_graspa_core + + params = graspa_input_schema(**job) if isinstance(job, dict) else job + return run_graspa_core(params) + + +@mcp.tool( + name="run_graspa_ensemble", + description="Run an ensemble of graspa calculations for multiple input files.", +) +async def run_graspa_ensemble( + params: graspa_input_schema_ensemble, +): + """Run an ensemble of gRASPA calculations over all structure files + using the configured execution backend. + + Parameters + ---------- + params : graspa_input_schema_ensemble + Input parameters for the ensemble of gRASPA calculations. + """ + structure_files, output_dir = resolve_structure_files( + params.input_structures, + extensions={".cif"}, + ) + + # Base output file name + base_output = Path(params.output_result_file).resolve() + + pending_tasks = [] + + for struct_path in structure_files: + mof_name = struct_path.stem + for condition in params.conditions: + per_struct_output = make_per_structure_output(struct_path, base_output) + job = { + "input_structure_file": str(struct_path), + "output_result_file": str(per_struct_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } + + task = TaskSpec( + task_id=f"graspa_{mof_name}_{condition.temperature}K_{condition.pressure}Pa", + task_type="python", + callable=_run_graspa_single, + kwargs={"job": job}, + ) + fut = backend.submit(task) + + task_meta = { + "structure": mof_name, + "temperature": condition.temperature, + "pressure": condition.pressure, + } + pending_tasks.append((task_meta, fut)) + + results = await gather_futures(pending_tasks) + + summary_log_path = output_dir / "simulation_results.jsonl" + success_count, total_count = write_results_jsonl(results, summary_log_path) + + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) + + +if __name__ == "__main__": + run_mcp_server(mcp, default_port=9001) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py new file mode 100644 index 00000000..eba86858 --- /dev/null +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -0,0 +1,178 @@ +"""Backend-agnostic MACE MCP server. + +Replaces ``mace_mcp_parsl.py`` by using the :mod:`chemgraph.execution` +abstraction layer. The execution backend (Parsl, EnsembleLauncher, +local) is selected at startup via ``config.toml`` or the +``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. + +Key improvements over the original: +- No hardcoded Polaris config or user-specific conda paths. +- Ensemble tool is now async (non-blocking event loop). +- Uses shared utilities for structure resolution and result gathering. +""" + +import json +import logging +from pathlib import Path + +from mcp.server.fastmcp import FastMCP + +from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, +) +from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.tools.parsl_tools import ( + mace_input_schema, + mace_input_schema_ensemble, + run_mace_core, +) + +logger = logging.getLogger(__name__) + +# ── Initialise execution backend ──────────────────────────────────────── +backend = get_backend() + +# ── MCP server ────────────────────────────────────────────────────────── +mcp = FastMCP( + name="ChemGraph MACE Tools", + instructions=""" + You expose tools for running MACE simulations and reading their results. + The available tools are: + 1. run_mace_single: run a single MACE calculation using the specified + input schema. + 2. run_mace_ensemble: run MACE calculations over all structures in a + directory using the configured execution backend. + 3. extract_output_json: load simulation results from a JSON file. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are written to the output files + defined in the schemas. + - When returning paths, use absolute paths. + - Energies are in eV and wall times are in seconds. + """, +) + + +def _run_mace_single(job: dict) -> dict: + """Execute a single MACE simulation (runs on the worker).""" + from chemgraph.tools.parsl_tools import mace_input_schema, run_mace_core + + params = mace_input_schema(**job) if isinstance(job, dict) else job + return run_mace_core(params) + + +@mcp.tool( + name="run_mace_single", + description="Run a single MACE calculation", +) +def run_mace_single(params: mace_input_schema): + return run_mace_core(params) + + +def _mace_post_fn(meta: dict, result) -> dict: + """Post-process a completed MACE task.""" + status = result.get("status", "unknown") if isinstance(result, dict) else "success" + energy = result.get("single_point_energy") if isinstance(result, dict) else None + return { + "structure": meta["structure"], + "output_result_file": meta["output_result_file"], + "status": status, + "single_point_energy": energy, + "raw_result": result, + } + + +@mcp.tool( + name="run_mace_ensemble", + description="Run an ensemble of MACE calculations", +) +async def run_mace_ensemble(params: mace_input_schema_ensemble): + """Run an ensemble of MACE calculations over all structure files in a + directory using the configured execution backend. + + Parameters + ---------- + params : mace_input_schema_ensemble + Input parameters for the ensemble of MACE calculations. + + Returns + ------- + dict + Summary of all jobs with minimal per-job results. + """ + structure_files, _output_dir = resolve_structure_files( + params.input_structure_directory, + ) + + # Base output file name used as a pattern for per-structure outputs + base_output = Path(params.output_result_file) + + pending_tasks = [] + for struct_path in structure_files: + per_struct_output = make_per_structure_output(struct_path, base_output) + + job = { + "input_structure_file": str(struct_path), + "output_result_file": str(per_struct_output), + "driver": params.driver, + "model": params.model, + "device": params.device, + "temperature": params.temperature, + "pressure": params.pressure, + "fmax": params.fmax, + "steps": params.steps, + "optimizer": params.optimizer, + } + + task = TaskSpec( + task_id=f"mace_{struct_path.stem}", + task_type="python", + callable=_run_mace_single, + kwargs={"job": job}, + ) + fut = backend.submit(task) + + task_meta = { + "structure": struct_path.name, + "output_result_file": str(per_struct_output), + } + pending_tasks.append((task_meta, fut)) + + results = await gather_futures(pending_tasks, post_fn=_mace_post_fn) + + return { + "status": "success", + "n_structures": len(structure_files), + "results": results, + } + + +@mcp.tool( + name="extract_output_json", + description="Load output from a JSON file.", +) +def extract_output_json(json_file: str) -> dict: + """Load simulation results from a JSON file produced by run_ase. + + Parameters + ---------- + json_file : str + Path to the JSON file containing ASE simulation results. + + Returns + ------- + dict + Parsed results from the JSON file. + """ + with open(json_file, "r") as f: + data = json.load(f) + return data + + +if __name__ == "__main__": + run_mcp_server(mcp, default_port=9004) diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py new file mode 100644 index 00000000..3ed81fa7 --- /dev/null +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -0,0 +1,227 @@ +"""Backend-agnostic XANES/FDMNES MCP server. + +Replaces ``xanes_mcp_parsl.py`` by using the :mod:`chemgraph.execution` +abstraction layer. The execution backend (Parsl, EnsembleLauncher, +local) is selected at startup via ``config.toml`` or the +``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +""" + +import logging +from pathlib import Path + +from mcp.server.fastmcp import FastMCP + +from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.utils import ( + gather_futures, + resolve_structure_files, + write_results_jsonl, +) +from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.schemas.xanes_schema import ( + mp_query_schema, + xanes_input_schema, + xanes_input_schema_ensemble, +) + +logger = logging.getLogger(__name__) + +# ── Initialise execution backend ──────────────────────────────────────── +backend = get_backend() + +# ── MCP server ────────────────────────────────────────────────────────── +mcp = FastMCP( + name="ChemGraph XANES Tools", + instructions=""" + You expose tools for running XANES/FDMNES simulations. + The available tools are: + 1. run_xanes_single: run a single FDMNES calculation for one structure. + 2. run_xanes_ensemble: run FDMNES calculations over multiple structures + using the configured execution backend. + 3. fetch_mp_structures: fetch optimized structures from Materials Project. + 4. plot_xanes: generate normalized XANES plots for completed calculations. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are in the output directories. + - When returning paths, use absolute paths. + - Energies are in eV. + """, +) + + +@mcp.tool( + name="run_xanes_single", + description="Run a single XANES/FDMNES calculation for one input structure.", +) +def run_xanes_single(params: xanes_input_schema): + """Run a single FDMNES calculation using the core engine.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + return run_xanes_core(params) + + +def _xanes_post_fn(meta: dict, _result) -> dict: + """Post-process a completed FDMNES task: extract convergence data.""" + from chemgraph.tools.xanes_tools import extract_conv + + try: + conv_data = extract_conv(meta["run_dir"]) + return { + **meta, + "status": "success", + "n_conv_files": len(conv_data), + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"Post-processing failed: {e}", + } + + +@mcp.tool( + name="run_xanes_ensemble", + description="Run an ensemble of XANES/FDMNES calculations using the configured backend.", +) +async def run_xanes_ensemble(params: xanes_input_schema_ensemble): + """Run ensemble XANES calculations over all structure files. + + For each structure file: + 1. Reads the structure via ASE. + 2. Creates FDMNES input files in a per-structure subdirectory. + 3. Submits a shell task to run FDMNES. + 4. Gathers results and writes a JSONL summary log. + + Parameters + ---------- + params : xanes_input_schema_ensemble + Input parameters for the ensemble calculation. + """ + from ase.io import read as ase_read + + from chemgraph.tools.xanes_tools import write_fdmnes_input + + structure_files, output_dir = resolve_structure_files( + params.input_structures, + extensions={".cif", ".xyz", ".poscar"}, + ) + + # Create a batch runs directory + runs_dir = output_dir / "fdmnes_batch_runs" + runs_dir.mkdir(parents=True, exist_ok=True) + + fdmnes_exe = params.fdmnes_exe + + pending_tasks = [] + + for i, struct_path in enumerate(structure_files): + run_dir = runs_dir / f"run_{i}" + run_dir.mkdir(parents=True, exist_ok=True) + + # Read structure and write FDMNES inputs + atoms = ase_read(str(struct_path)) + z_abs = ( + params.z_absorber + if params.z_absorber is not None + else int(max(atoms.get_atomic_numbers())) + ) + + write_fdmnes_input( + ase_atoms=atoms, + z_absorber=z_abs, + input_file_dir=run_dir, + radius=params.radius, + magnetism=params.magnetism, + ) + + # Submit shell task + task = TaskSpec( + task_id=f"xanes_{struct_path.stem}_{i}", + task_type="shell", + command=f'cd "{run_dir}" && "{fdmnes_exe}"', + working_dir=str(run_dir), + stdout=str(run_dir / "fdmnes_stdout.txt"), + stderr=str(run_dir / "fdmnes_stderr.txt"), + ) + fut = backend.submit(task) + + task_meta = { + "structure": struct_path.name, + "run_dir": str(run_dir), + "z_absorber": z_abs, + } + pending_tasks.append((task_meta, fut)) + + results = await gather_futures(pending_tasks, post_fn=_xanes_post_fn) + + summary_log_path = output_dir / "xanes_results.jsonl" + success_count, total_count = write_results_jsonl(results, summary_log_path) + + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) + + +@mcp.tool( + name="fetch_mp_structures", + description="Fetch optimized structures from Materials Project.", +) +def fetch_mp_structures(params: mp_query_schema): + """Fetch structures from Materials Project and save as CIF files and pickle database.""" + from chemgraph.tools.xanes_tools import ( + _get_data_dir, + fetch_materials_project_data, + ) + + data_dir = _get_data_dir() + result = fetch_materials_project_data(params, data_dir) + return { + "status": "success", + "n_structures": result["n_structures"], + "chemsys": params.chemsys, + "output_dir": str(data_dir), + "structure_files": result["structure_files"], + "pickle_file": result["pickle_file"], + } + + +@mcp.tool( + name="plot_xanes", + description="Generate normalized XANES plots for completed FDMNES calculations.", +) +def plot_xanes(runs_dir: str): + """Generate XANES plots for all completed runs in a directory. + + Parameters + ---------- + runs_dir : str + Path to the ``fdmnes_batch_runs`` directory containing ``run_*`` + subdirectories with FDMNES outputs. + """ + from chemgraph.tools.xanes_tools import ( + _get_data_dir, + plot_xanes_results, + ) + + runs_path = Path(runs_dir) + if not runs_path.is_dir(): + raise ValueError(f"'{runs_dir}' is not a valid directory.") + + data_dir = _get_data_dir() + result = plot_xanes_results(data_dir, runs_path) + return { + "status": "success", + "n_plots": result["n_plots"], + "n_failed": result["n_failed"], + "plot_files": result["plot_files"], + "failed": result["failed"], + } + + +if __name__ == "__main__": + run_mcp_server(mcp, default_port=9007) diff --git a/tests/conftest.py b/tests/conftest.py index 083d138e..0de3313d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,12 +27,22 @@ def pytest_addoption(parser): parser.addoption( "--run-llm", action="store_true", default=False, help="run tests that call LLM APIs" ) + parser.addoption( + "--run-globus-compute", action="store_true", default=False, + help="run tests that require a live Globus Compute endpoint" + ) def pytest_collection_modifyitems(config, items): - if config.getoption("--run-llm"): - # --run-llm given in cli: do not skip llm tests - return - skip_llm = pytest.mark.skip(reason="need --run-llm option to run") + skip_llm = None + if not config.getoption("--run-llm"): + skip_llm = pytest.mark.skip(reason="need --run-llm option to run") + + skip_globus = None + if not config.getoption("--run-globus-compute"): + skip_globus = pytest.mark.skip(reason="need --run-globus-compute option to run") + for item in items: - if "llm" in item.keywords: - item.add_marker(skip_llm) \ No newline at end of file + if skip_llm and "llm" in item.keywords: + item.add_marker(skip_llm) + if skip_globus and "globus_compute" in item.keywords: + item.add_marker(skip_globus) \ No newline at end of file diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 00000000..5f1617bc --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,1017 @@ +"""Tests for the chemgraph.execution abstraction layer. + +Tests cover: +- TaskSpec validation +- LocalBackend: python and shell tasks +- GlobusComputeBackend: python and shell tasks (mocked SDK) +- Backend factory (get_backend) +- Shared utilities: resolve_structure_files, gather_futures, write_results_jsonl +""" + +import asyncio +import json +import os +import sys +import tempfile +from concurrent.futures import Future +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.local_backend import LocalBackend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, + write_results_jsonl, +) + + +# ── TaskSpec tests ────────────────────────────────────────────────────── + + +class TestTaskSpec: + def test_python_task_minimal(self): + spec = TaskSpec(task_id="t1", task_type="python", callable=abs, args=(42,)) + assert spec.task_id == "t1" + assert spec.task_type == "python" + assert spec.callable is abs + assert spec.args == (42,) + + def test_shell_task_minimal(self): + spec = TaskSpec(task_id="t2", task_type="shell", command="echo hello") + assert spec.task_type == "shell" + assert spec.command == "echo hello" + + def test_defaults(self): + spec = TaskSpec(task_id="t3") + assert spec.task_type == "python" + assert spec.callable is None + assert spec.args == () + assert spec.kwargs == {} + assert spec.num_nodes == 1 + assert spec.processes_per_node == 1 + assert spec.gpus_per_task == 0 + + +# ── LocalBackend tests ────────────────────────────────────────────────── + + +def _square(x): + return x * x + + +def _add(a, b): + return a + b + + +def _failing_fn(): + raise ValueError("intentional test error") + + +class TestLocalBackend: + def test_python_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=2) + try: + task = TaskSpec( + task_id="sq", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + assert fut.result(timeout=10) == 49 + finally: + backend.shutdown() + + def test_python_task_kwargs(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=2) + try: + task = TaskSpec( + task_id="add", + task_type="python", + callable=_add, + kwargs={"a": 3, "b": 5}, + ) + assert backend.submit(task).result(timeout=10) == 8 + finally: + backend.shutdown() + + def test_shell_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as f: + stdout_path = f.name + + task = TaskSpec( + task_id="echo", + task_type="shell", + command="echo hello_world", + stdout=stdout_path, + ) + fut = backend.submit(task) + fut.result(timeout=10) + + with open(stdout_path) as f: + assert "hello_world" in f.read() + finally: + backend.shutdown() + os.unlink(stdout_path) + + def test_submit_batch(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=4) + try: + tasks = [ + TaskSpec( + task_id=f"sq_{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(5) + ] + futures = backend.submit_batch(tasks) + assert len(futures) == 5 + results = [f.result(timeout=10) for f in futures] + assert results == [0, 1, 4, 9, 16] + finally: + backend.shutdown() + + def test_failing_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec( + task_id="fail", + task_type="python", + callable=_failing_fn, + ) + fut = backend.submit(task) + with pytest.raises(ValueError, match="intentional test error"): + fut.result(timeout=10) + finally: + backend.shutdown() + + def test_context_manager(self): + with LocalBackend() as backend: + backend.initialize(system="local", max_workers=1) + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + assert backend.submit(task).result(timeout=10) == 9 + + def test_not_initialized_raises(self): + backend = LocalBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_python_task_missing_callable(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + finally: + backend.shutdown() + + def test_shell_task_missing_command(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + finally: + backend.shutdown() + + +# ── GlobusComputeBackend tests ────────────────────────────────────────── + + +def _make_mock_gc_modules(): + """Create mock globus_compute_sdk module and its classes.""" + mock_sdk = MagicMock() + + # Mock Executor: instances track submit calls and return Futures + mock_executor_instance = MagicMock() + mock_future = Future() + mock_future.set_result(42) + mock_executor_instance.submit.return_value = mock_future + mock_sdk.Executor.return_value = mock_executor_instance + + # Mock ShellFunction + mock_shell_fn_instance = MagicMock() + mock_sdk.ShellFunction.return_value = mock_shell_fn_instance + + return mock_sdk, mock_executor_instance + + +class TestGlobusComputeBackend: + def _patch_and_import(self, mock_sdk): + """Patch globus_compute_sdk into sys.modules and import the backend.""" + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + # Force re-import to pick up the mock + import importlib + + import chemgraph.execution.globus_compute_backend as gc_mod + + importlib.reload(gc_mod) + return gc_mod.GlobusComputeBackend + + def test_initialize_success(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(system="polaris", endpoint_id="test-uuid-1234") + + assert backend._initialized is True + mock_sdk.Executor.assert_called_once_with(endpoint_id="test-uuid-1234") + + def test_initialize_with_amqp_port(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize( + system="polaris", + endpoint_id="test-uuid", + amqp_port=443, + ) + + mock_sdk.Executor.assert_called_once_with( + endpoint_id="test-uuid", amqp_port=443 + ) + + def test_initialize_missing_endpoint_id(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ValueError, match="endpoint_id"): + backend.initialize(system="polaris") + + def test_initialize_empty_endpoint_id(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ValueError, match="endpoint_id"): + backend.initialize(system="polaris", endpoint_id="") + + def test_initialize_import_error(self): + """Verify helpful error when globus-compute-sdk is not installed.""" + with patch.dict(sys.modules, {"globus_compute_sdk": None}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ImportError, match="globus-compute-sdk"): + backend.initialize(endpoint_id="test-uuid") + + def test_submit_python_task(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="py1", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + + assert isinstance(fut, Future) + mock_executor.submit.assert_called_once_with(_square, 7) + + def test_submit_python_task_with_kwargs(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="py2", + task_type="python", + callable=_add, + args=(3,), + kwargs={"b": 5}, + ) + backend.submit(task) + + mock_executor.submit.assert_called_once_with(_add, 3, b=5) + + def test_submit_shell_task(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="sh1", + task_type="shell", + command="echo hello", + ) + backend.submit(task) + + # ShellFunction should be constructed with the command + mock_sdk.ShellFunction.assert_called_once_with("echo hello") + # And then submitted via the executor + shell_fn_instance = mock_sdk.ShellFunction.return_value + mock_executor.submit.assert_called_once_with(shell_fn_instance) + + def test_submit_not_initialized(self): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_submit_python_missing_callable(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + + def test_submit_shell_missing_command(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + + def test_shutdown(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + assert backend._initialized is True + + backend.shutdown() + + assert backend._initialized is False + assert backend._executor is None + mock_executor.shutdown.assert_called_once() + + def test_shutdown_idempotent(self): + """Calling shutdown() when not initialized should not raise.""" + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.shutdown() # should be a no-op + assert backend._initialized is False + + def test_context_manager(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + with GlobusComputeBackend() as backend: + backend.initialize(endpoint_id="test-uuid") + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + backend.submit(task) + + # After exiting context, shutdown should have been called + mock_executor.shutdown.assert_called_once() + + +class TestGetBackendGlobusCompute: + def test_factory_creates_globus_compute_backend(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend( + backend_name="globus_compute", + endpoint_id="factory-test-uuid", + ) + try: + assert isinstance(backend, GlobusComputeBackend) + assert backend._initialized is True + finally: + backend.shutdown() + + def test_factory_via_env_var(self): + mock_sdk, _ = _make_mock_gc_modules() + with ( + patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}), + patch.dict( + os.environ, + {"CHEMGRAPH_EXECUTION_BACKEND": "globus_compute"}, + ), + ): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend(endpoint_id="env-test-uuid") + try: + assert isinstance(backend, GlobusComputeBackend) + finally: + backend.shutdown() + + +# ── Factory tests ─────────────────────────────────────────────────────── + + +class TestGetBackend: + def test_local_backend_via_env(self): + with patch.dict(os.environ, {"CHEMGRAPH_EXECUTION_BACKEND": "local"}): + from chemgraph.execution.config import get_backend + + backend = get_backend() + try: + assert isinstance(backend, LocalBackend) + finally: + backend.shutdown() + + def test_explicit_backend_name(self): + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="local", max_workers=2) + try: + assert isinstance(backend, LocalBackend) + finally: + backend.shutdown() + + def test_unsupported_backend_raises(self): + from chemgraph.execution.config import get_backend + + with pytest.raises(ValueError, match="Unknown execution backend"): + get_backend(backend_name="nonexistent") + + +# ── Utility tests ─────────────────────────────────────────────────────── + + +class TestResolveStructureFiles: + def test_from_directory(self, tmp_path): + for name in ["a.cif", "b.cif", "c.txt"]: + (tmp_path / name).write_text("dummy") + + files, out_dir = resolve_structure_files(str(tmp_path), extensions={".cif"}) + assert len(files) == 2 + assert out_dir == tmp_path + assert all(f.suffix == ".cif" for f in files) + + def test_from_file_list(self, tmp_path): + paths = [] + for name in ["x.xyz", "y.xyz"]: + p = tmp_path / name + p.write_text("dummy") + paths.append(str(p)) + + files, out_dir = resolve_structure_files(paths) + assert len(files) == 2 + assert out_dir == tmp_path + + def test_missing_file_raises(self, tmp_path): + with pytest.raises(ValueError, match="missing"): + resolve_structure_files([str(tmp_path / "noexist.cif")]) + + def test_empty_dir_raises(self, tmp_path): + with pytest.raises(ValueError, match="No structure files"): + resolve_structure_files(str(tmp_path), extensions={".cif"}) + + def test_invalid_dir_raises(self): + with pytest.raises(ValueError, match="not a valid directory"): + resolve_structure_files("/nonexistent/path") + + +class TestMakePerStructureOutput: + def test_basic(self): + result = make_per_structure_output( + Path("/data/MOF-5.cif"), + Path("/results/output.json"), + ) + assert result == Path("/results/MOF-5_output.json") + + def test_no_suffix(self): + result = make_per_structure_output( + Path("/data/struct.xyz"), + Path("/results/result"), + ) + assert result == Path("/results/struct_result.json") + + +class TestGatherFutures: + @pytest.mark.asyncio + async def test_successful_futures(self): + loop = asyncio.get_event_loop() + + def _make_resolved(val): + f = Future() + f.set_result(val) + return f + + pending = [ + ({"name": "a"}, _make_resolved({"status": "success", "energy": -1.0})), + ({"name": "b"}, _make_resolved({"status": "success", "energy": -2.0})), + ] + results = await gather_futures(pending) + assert len(results) == 2 + assert results[0]["name"] == "a" + assert results[0]["energy"] == -1.0 + + @pytest.mark.asyncio + async def test_failed_future(self): + f = Future() + f.set_exception(RuntimeError("boom")) + + pending = [({"name": "fail"}, f)] + results = await gather_futures(pending) + assert results[0]["status"] == "failure" + assert results[0]["error_type"] == "RuntimeError" + assert "boom" in results[0]["message"] + + @pytest.mark.asyncio + async def test_with_post_fn(self): + f = Future() + f.set_result(42) + + def post(meta, result): + return {**meta, "doubled": result * 2, "status": "success"} + + results = await gather_futures([({"id": "x"}, f)], post_fn=post) + assert results[0]["doubled"] == 84 + + +class TestWriteResultsJsonl: + def test_write_and_count(self, tmp_path): + results = [ + {"status": "success", "value": 1}, + {"status": "failure", "error": "bad"}, + {"status": "success", "value": 2}, + ] + path = tmp_path / "results.jsonl" + success, total = write_results_jsonl(results, path) + assert success == 2 + assert total == 3 + + lines = path.read_text().strip().split("\n") + assert len(lines) == 3 + assert json.loads(lines[0])["value"] == 1 + + def test_append_mode(self, tmp_path): + path = tmp_path / "results.jsonl" + write_results_jsonl([{"status": "success"}], path) + write_results_jsonl([{"status": "success"}], path, append=True) + + lines = path.read_text().strip().split("\n") + assert len(lines) == 2 + + +# ── Layer 2: GlobusComputeBackend unit-test gap coverage ──────────────── + + +class TestGlobusComputeBackendGaps: + """Additional mocked tests covering gaps in the original test suite.""" + + def test_submit_unsupported_task_type(self): + """The else branch in submit() should raise for unknown task_type.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="bad_type", + task_type="python", + callable=_square, + args=(1,), + ) + # Bypass Pydantic validation to force an invalid task_type + object.__setattr__(task, "task_type", "mpi") + + with pytest.raises(ValueError, match="unsupported task_type"): + backend.submit(task) + + def test_submit_batch_delegates(self): + """submit_batch (inherited from base) should call submit() N times.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + tasks = [ + TaskSpec( + task_id=f"t{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(3) + ] + futures = backend.submit_batch(tasks) + + assert len(futures) == 3 + assert mock_executor.submit.call_count == 3 + + def test_amqp_port_string_coercion(self): + """amqp_port from config.toml arrives as a string; must be coerced to int.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid", amqp_port="443") + + mock_sdk.Executor.assert_called_once_with( + endpoint_id="test-uuid", amqp_port=443 + ) + + def test_shutdown_executor_raises(self): + """If executor.shutdown() raises, the error is swallowed and state resets.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + mock_executor.shutdown.side_effect = RuntimeError("connection lost") + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + # Should NOT raise + backend.shutdown() + + assert backend._initialized is False + assert backend._executor is None + + +class TestGetBackendGlobusComputeGaps: + """Additional factory tests for config merging and TOML-driven creation.""" + + def test_factory_kwargs_override_config(self, tmp_path): + """Explicit kwargs should override values from config.toml.""" + config_file = tmp_path / "config.toml" + config_file.write_text( + "[execution]\n" + 'backend = "globus_compute"\n\n' + "[execution.globus_compute]\n" + 'endpoint_id = "config-uuid"\n' + "amqp_port = 5671\n" + ) + + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend( + config_path=str(config_file), + endpoint_id="kwarg-uuid", + ) + try: + assert isinstance(backend, GlobusComputeBackend) + # kwarg-uuid should win over config-uuid; amqp_port from config + mock_sdk.Executor.assert_called_once_with( + endpoint_id="kwarg-uuid", + amqp_port=5671, + ) + finally: + backend.shutdown() + + def test_factory_config_toml_driven(self, tmp_path): + """get_backend() with only a config.toml path should work end-to-end.""" + config_file = tmp_path / "config.toml" + config_file.write_text( + "[execution]\n" + 'backend = "globus_compute"\n\n' + "[execution.globus_compute]\n" + 'endpoint_id = "toml-uuid"\n' + ) + + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend(config_path=str(config_file)) + try: + assert isinstance(backend, GlobusComputeBackend) + assert backend._initialized is True + mock_sdk.Executor.assert_called_once_with(endpoint_id="toml-uuid") + finally: + backend.shutdown() + + +# ── Layer 3: Globus Compute integration tests (real endpoint) ─────────── + + +@pytest.fixture +def globus_backend(): + """Provide an initialized GlobusComputeBackend connected to a real endpoint. + + Skips the test if GLOBUS_COMPUTE_ENDPOINT_ID is not set or the SDK is + not installed. + """ + endpoint_id = os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID") + if not endpoint_id: + pytest.skip("GLOBUS_COMPUTE_ENDPOINT_ID env var not set") + + try: + from chemgraph.execution.config import get_backend + except ImportError: + pytest.skip("chemgraph.execution not available") + + try: + backend = get_backend( + backend_name="globus_compute", endpoint_id=endpoint_id + ) + except ImportError: + pytest.skip("globus-compute-sdk not installed") + + yield backend + backend.shutdown() + + +def _gc_double(x): + """Trivial function for Globus Compute integration tests.""" + return x * 2 + + +def _gc_square(x): + """Square function for Globus Compute integration tests.""" + return x * x + + +def _gc_identity(x): + """Identity function for Globus Compute integration tests.""" + return x + + +@pytest.mark.globus_compute +class TestGlobusComputeIntegration: + """Integration tests that submit work to a real Globus Compute endpoint. + + These are skipped by default. Run with:: + + GLOBUS_COMPUTE_ENDPOINT_ID= pytest --run-globus-compute -k Integration + """ + + def test_python_task_roundtrip(self, globus_backend): + """Submit a trivial Python callable and verify the result.""" + task = TaskSpec( + task_id="roundtrip", + task_type="python", + callable=_gc_double, + args=(21,), + ) + fut = globus_backend.submit(task) + result = fut.result(timeout=120) + assert result == 42 + + def test_shell_task_roundtrip(self, globus_backend): + """Submit a shell command and verify the output.""" + task = TaskSpec( + task_id="shell_rt", + task_type="shell", + command="echo hello_globus", + ) + fut = globus_backend.submit(task) + result = fut.result(timeout=120) + # ShellFunction returns a ShellResult; stdout should contain the string + assert "hello_globus" in str(result) + + def test_batch_submission(self, globus_backend): + """Submit a batch of tasks and verify all results.""" + tasks = [ + TaskSpec( + task_id=f"batch_{i}", + task_type="python", + callable=_gc_square, + args=(i,), + ) + for i in range(5) + ] + futures = globus_backend.submit_batch(tasks) + results = [f.result(timeout=120) for f in futures] + assert results == [0, 1, 4, 9, 16] + + @pytest.mark.asyncio + async def test_gather_futures_with_real_endpoint(self, globus_backend): + """Verify gather_futures works with real ComputeFuture objects.""" + tasks = [ + TaskSpec( + task_id=f"gf_{i}", + task_type="python", + callable=_gc_identity, + args=(i,), + ) + for i in range(3) + ] + futs = globus_backend.submit_batch(tasks) + pending = [({"index": i}, f) for i, f in enumerate(futs)] + + results = await gather_futures(pending) + assert len(results) == 3 + assert all("index" in r for r in results) + + +# ── Layer 4: Edge-case and error-handling tests ───────────────────────── + + +class TestGlobusComputeEdgeCases: + """Mocked tests for error paths and edge conditions.""" + + def test_submit_after_shutdown(self): + """Submitting after shutdown() should raise RuntimeError.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + backend.shutdown() + + task = TaskSpec(task_id="late", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_double_initialize(self): + """Calling initialize() twice should succeed and create a new executor.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="uuid-1") + backend.initialize(endpoint_id="uuid-2") + + assert backend._initialized is True + assert mock_sdk.Executor.call_count == 2 + backend.shutdown() + + def test_context_manager_with_exception(self): + """shutdown() must be called even when the body raises.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + with pytest.raises(ValueError, match="intentional"): + with GlobusComputeBackend() as backend: + backend.initialize(endpoint_id="test-uuid") + raise ValueError("intentional") + + mock_executor.shutdown.assert_called_once() + + def test_executor_submit_raises_propagates(self): + """Errors from executor.submit() should propagate to the caller.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + mock_executor.submit.side_effect = RuntimeError("endpoint unavailable") + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="err", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="endpoint unavailable"): + backend.submit(task) + + def test_submit_with_resource_hints(self): + """Resource hints are advisory and should not break submission.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="hints", + task_type="python", + callable=_square, + args=(5,), + num_nodes=4, + processes_per_node=32, + gpus_per_task=4, + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + # Resource hints should NOT be passed to executor.submit + mock_executor.submit.assert_called_once_with(_square, 5) + + def test_failed_future_result(self): + """A future that resolves to an exception should be retrievable.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + failed_future = Future() + failed_future.set_exception(RuntimeError("task exploded")) + mock_executor.submit.return_value = failed_future + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="fail", callable=_square, args=(1,)) + fut = backend.submit(task) + + with pytest.raises(RuntimeError, match="task exploded"): + fut.result(timeout=5) From db2a41cc0f924593528273a7817a068a287b8907 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 4 May 2026 12:02:08 -0500 Subject: [PATCH 018/119] Fix unreachable code in aurora_parsl and EnsembleLauncher shutdown state Remove dead num_nodes=1 after raise in aurora_parsl.py and fix misleading error message. Set _initialized=False at start of EnsembleLauncherBackend.shutdown() to prevent submitting to a partially torn-down backend. --- src/chemgraph/hpc_configs/aurora_parsl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/hpc_configs/aurora_parsl.py b/src/chemgraph/hpc_configs/aurora_parsl.py index 61793aaf..2f7ac354 100644 --- a/src/chemgraph/hpc_configs/aurora_parsl.py +++ b/src/chemgraph/hpc_configs/aurora_parsl.py @@ -34,9 +34,9 @@ def get_aurora_config( node_list = f.readlines() num_nodes = len(node_list) else: - # Fallback for testing/local runs without PBS - raise ValueError("Warning: PBS_NODEFILE not found. Defaulting to 1 node.") - num_nodes = 1 + raise ValueError( + "PBS_NODEFILE not found. Cannot determine node count for Aurora." + ) config = Config( executors=[ From 39f28a1b86a1ab97f5d2879df3d6146172873fe6 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 14 May 2026 13:42:27 -0500 Subject: [PATCH 019/119] Update Globus config --- src/chemgraph/execution/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index 71d3de90..80b3c458 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -124,6 +124,12 @@ def get_backend( backend_cfg = cfg.get(resolved_backend, {}) merged_kwargs = {**backend_cfg, **kwargs} + # Globus Compute: fall back to GLOBUS_COMPUTE_ENDPOINT_ID env var + if resolved_backend == "globus_compute" and "endpoint_id" not in merged_kwargs: + env_id = os.getenv("GLOBUS_COMPUTE_ENDPOINT_ID") + if env_id: + merged_kwargs["endpoint_id"] = env_id + # -- instantiate ---------------------------------------------------------- logger.info( "Creating execution backend '%s' for system '%s'", From b13fc530f2d0d455cc6071584913fa8107c1608b Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 14 May 2026 13:43:00 -0500 Subject: [PATCH 020/119] Add inline structure for file transferring between local and globus remote --- src/chemgraph/mcp/mace_mcp_hpc.py | 88 +++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 5 deletions(-) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index eba86858..dc966c49 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -11,8 +11,10 @@ - Uses shared utilities for structure resolution and result gathering. """ +import asyncio import json import logging +import os from pathlib import Path from mcp.server.fastmcp import FastMCP @@ -27,7 +29,6 @@ from chemgraph.tools.parsl_tools import ( mace_input_schema, mace_input_schema_ensemble, - run_mace_core, ) logger = logging.getLogger(__name__) @@ -59,19 +60,85 @@ def _run_mace_single(job: dict) -> dict: - """Execute a single MACE simulation (runs on the worker).""" + """Execute a single MACE simulation (runs on the worker). + + When the ``job`` dict contains an ``inline_structure`` key (with + ``numbers``, ``positions``, and optional ``cell``/``pbc``), the + structure is materialised as a temporary XYZ file on the worker + filesystem before running MACE. This allows local-agent / + remote-worker workflows where the original file only exists on the + submitting machine. + """ + import os + import tempfile + from chemgraph.tools.parsl_tools import mace_input_schema, run_mace_core + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, job.get("output_result_file", "output.json") + ) + params = mace_input_schema(**job) if isinstance(job, dict) else job - return run_mace_core(params) + result = run_mace_core(params) + + # Embed full output JSON when running with inline structure so the + # caller does not need to read a file on the remote filesystem. + if inline is not None: + out_file = job.get("output_result_file", "") + if os.path.isfile(out_file): + import json as _json + + with open(out_file, "r") as fh: + result["full_output"] = _json.load(fh) + + return result @mcp.tool( name="run_mace_single", description="Run a single MACE calculation", ) -def run_mace_single(params: mace_input_schema): - return run_mace_core(params) +async def run_mace_single(params: mace_input_schema): + """Run a single MACE calculation using the configured execution backend.""" + job = params.model_dump() + + # Read the local structure file and embed it so the job is + # self-contained and can run on any worker (local or remote). + input_file = job.get("input_structure_file") + if input_file and os.path.isfile(input_file): + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(input_file) + atomsdata = atoms_to_atomsdata(atoms) + job["inline_structure"] = atomsdata.model_dump() + + task = TaskSpec( + task_id="mace_single", + task_type="python", + callable=_run_mace_single, + kwargs={"job": job}, + ) + fut = backend.submit(task) + return await asyncio.wrap_future(fut) def _mace_post_fn(meta: dict, result) -> dict: @@ -129,6 +196,17 @@ async def run_mace_ensemble(params: mace_input_schema_ensemble): "optimizer": params.optimizer, } + # Embed structure data so the job works on remote workers that + # cannot access the local filesystem. + if struct_path.is_file(): + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(struct_path)) + atomsdata = atoms_to_atomsdata(atoms) + job["inline_structure"] = atomsdata.model_dump() + task = TaskSpec( task_id=f"mace_{struct_path.stem}", task_type="python", From b605ec059df347dcc95619d968309e2cd7d67e5e Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 14 May 2026 18:09:04 -0500 Subject: [PATCH 021/119] Add async job tracking for Globus Compute MCP tools When backend=globus_compute, MCP tools now return immediately after submitting jobs to the remote HPC endpoint instead of blocking until completion. A new JobTracker tracks submitted futures across tool calls, and new MCP tools (check_job_status, get_job_results, list_jobs, cancel_job) let the LLM agent poll for progress and retrieve results. Non-Globus backends (local, Parsl, EnsembleLauncher) are unchanged and continue to block until results are ready. Key changes: - Add is_async_remote property to ExecutionBackend (True for Globus) - Add check_endpoint_status() health check to GlobusComputeBackend - Add JobTracker with batch registration, status, results, cleanup - Add submit_or_gather() utility that branches on backend type - Add optional timeout parameter to gather_futures() - Add register_job_tools() to wire job tools into any MCP server - Integrate tracker into MACE, XANES, and gRASPA MCP servers --- src/chemgraph/execution/__init__.py | 2 + src/chemgraph/execution/base.py | 8 + .../execution/globus_compute_backend.py | 30 ++ src/chemgraph/execution/job_tracker.py | 296 +++++++++++++ src/chemgraph/execution/utils.py | 75 +++- src/chemgraph/mcp/graspa_mcp_hpc.py | 37 +- src/chemgraph/mcp/job_tools.py | 107 +++++ src/chemgraph/mcp/mace_mcp_hpc.py | 39 +- src/chemgraph/mcp/xanes_mcp_hpc.py | 36 +- tests/test_job_tracker.py | 394 ++++++++++++++++++ 10 files changed, 994 insertions(+), 30 deletions(-) create mode 100644 src/chemgraph/execution/job_tracker.py create mode 100644 src/chemgraph/mcp/job_tools.py create mode 100644 tests/test_job_tracker.py diff --git a/src/chemgraph/execution/__init__.py b/src/chemgraph/execution/__init__.py index 0fd6709b..bd6d0ccf 100644 --- a/src/chemgraph/execution/__init__.py +++ b/src/chemgraph/execution/__init__.py @@ -25,9 +25,11 @@ from chemgraph.execution.base import ExecutionBackend, TaskSpec from chemgraph.execution.config import get_backend +from chemgraph.execution.job_tracker import JobTracker __all__ = [ "ExecutionBackend", + "JobTracker", "TaskSpec", "get_backend", ] diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py index e7dc338b..ccfb4f2d 100644 --- a/src/chemgraph/execution/base.py +++ b/src/chemgraph/execution/base.py @@ -102,6 +102,14 @@ class ExecutionBackend(ABC): def __init__(self) -> None: self._initialized: bool = False + @property + def is_async_remote(self) -> bool: + """Whether this backend submits to a remote queue where jobs may + take minutes to hours. When ``True``, MCP tools should return + immediately after submission and provide separate status/result + retrieval tools instead of blocking until completion.""" + return False + @abstractmethod def initialize(self, system: str = "local", **kwargs: Any) -> None: """Prepare the backend for accepting work. diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index 0c2a9634..2ec2bba1 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -51,6 +51,11 @@ class GlobusComputeBackend(ExecutionBackend): def __init__(self) -> None: super().__init__() self._executor = None + self._endpoint_id: str | None = None + + @property + def is_async_remote(self) -> bool: + return True # ── lifecycle ──────────────────────────────────────────────────────── @@ -77,6 +82,7 @@ def initialize(self, system: str = "local", **kwargs: Any) -> None: if amqp_port is not None: executor_kwargs["amqp_port"] = int(amqp_port) + self._endpoint_id = endpoint_id self._executor = Executor(**executor_kwargs) self._initialized = True logger.info( @@ -118,6 +124,30 @@ def submit(self, task: TaskSpec) -> Future: f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." ) + # ── health check ──────────────────────────────────────────────────── + + def check_endpoint_status(self) -> dict: + """Check the status of the configured Globus Compute endpoint. + + Returns a dict with ``endpoint_id`` and ``status`` fields. + Useful as a pre-flight check before submitting tasks. + """ + try: + from globus_compute_sdk import Client + + client = Client() + status = client.get_endpoint_status(self._endpoint_id) + return { + "endpoint_id": self._endpoint_id, + "status": status, + } + except Exception as e: + return { + "endpoint_id": self._endpoint_id, + "status": "error", + "error": str(e), + } + # ── teardown ──────────────────────────────────────────────────────── def shutdown(self) -> None: diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py new file mode 100644 index 00000000..87b473c0 --- /dev/null +++ b/src/chemgraph/execution/job_tracker.py @@ -0,0 +1,296 @@ +"""In-memory job tracker for async remote execution backends. + +Tracks ``concurrent.futures.Future`` objects returned by +:meth:`ExecutionBackend.submit` so that MCP tools can return +immediately after submission and provide separate status / result +retrieval endpoints. + +Each MCP server process creates its own ``JobTracker`` instance +(mirroring the existing ``backend = get_backend()`` pattern). +""" + +from __future__ import annotations + +import logging +import threading +import uuid +from concurrent.futures import Future +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class TrackedTask: + """A single task within a tracked batch.""" + + task_id: str + meta: dict + future: Future + result: Optional[dict] = None + + +@dataclass +class TrackedBatch: + """A group of tasks submitted together.""" + + batch_id: str + tool_name: str + submitted_at: datetime + tasks: list[TrackedTask] = field(default_factory=list) + post_fn: Optional[Callable[[dict, Any], dict]] = None + + +class JobTracker: + """Track submitted job batches and their futures. + + Thread-safe: all public methods acquire an internal lock. + """ + + def __init__(self) -> None: + self._batches: dict[str, TrackedBatch] = {} + self._lock = threading.Lock() + + # ── registration ─────────────────────────────────────────────────── + + def register_batch( + self, + tool_name: str, + pending_tasks: list[tuple[dict, Future]], + post_fn: Optional[Callable[[dict, Any], dict]] = None, + ) -> str: + """Register a batch of submitted tasks and return a batch ID. + + Parameters + ---------- + tool_name : str + Name of the MCP tool that submitted the batch. + pending_tasks : list[tuple[dict, Future]] + Each element is ``(metadata_dict, future)``. + post_fn : callable, optional + Post-processing function applied when collecting results. + Called as ``post_fn(metadata, raw_result) -> dict``. + + Returns + ------- + str + A UUID batch identifier. + """ + batch_id = uuid.uuid4().hex[:12] + tracked = [ + TrackedTask( + task_id=meta.get("task_id", meta.get("structure", f"task_{i}")), + meta=meta, + future=fut, + ) + for i, (meta, fut) in enumerate(pending_tasks) + ] + batch = TrackedBatch( + batch_id=batch_id, + tool_name=tool_name, + submitted_at=datetime.now(timezone.utc), + tasks=tracked, + post_fn=post_fn, + ) + with self._lock: + self._batches[batch_id] = batch + + logger.info( + "Registered batch '%s' (%s) with %d tasks", + batch_id, + tool_name, + len(tracked), + ) + return batch_id + + # ── status ───────────────────────────────────────────────────────── + + def get_status(self, batch_id: str) -> dict: + """Return the current status of a batch. + + Returns + ------- + dict + Keys: ``batch_id``, ``tool_name``, ``submitted_at``, + ``status``, ``total_tasks``, ``completed_tasks``, + ``failed_tasks``, ``pending_tasks``, ``progress_pct``. + """ + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + total = len(batch.tasks) + done = 0 + failed = 0 + + for t in batch.tasks: + if t.future.done(): + done += 1 + # Cache the result on first check + if t.result is None: + try: + raw = t.future.result(timeout=0) + if batch.post_fn is not None: + t.result = batch.post_fn(t.meta, raw) + elif isinstance(raw, dict): + merged = {**t.meta, **raw} + merged.setdefault("status", "success") + t.result = merged + else: + t.result = { + **t.meta, + "result": raw, + "status": "success", + } + except Exception as e: + t.result = { + **t.meta, + "status": "failure", + "error_type": type(e).__name__, + "message": str(e), + } + if t.result.get("status") == "failure": + failed += 1 + + pending = total - done + if pending == total: + status = "pending" + elif pending > 0: + status = "running" + elif failed == total: + status = "failed" + elif failed > 0: + status = "partial" + else: + status = "completed" + + return { + "batch_id": batch_id, + "tool_name": batch.tool_name, + "submitted_at": batch.submitted_at.isoformat(), + "status": status, + "total_tasks": total, + "completed_tasks": done - failed, + "failed_tasks": failed, + "pending_tasks": pending, + "progress_pct": round(done / total * 100, 1) if total else 0.0, + } + + # ── results ──────────────────────────────────────────────────────── + + def get_results( + self, batch_id: str, include_partial: bool = False + ) -> dict: + """Collect results from a batch. + + Parameters + ---------- + batch_id : str + The batch identifier. + include_partial : bool + If ``True``, return results for completed tasks even if some + are still pending. If ``False`` (default) and the batch is + not fully resolved, return a status message instead. + + Returns + ------- + dict + Contains ``status``, ``results`` list, and summary counts. + """ + status_info = self.get_status(batch_id) + if "error" in status_info: + return status_info + + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + if not include_partial and status_info["pending_tasks"] > 0: + return { + **status_info, + "message": ( + f"{status_info['pending_tasks']} of " + f"{status_info['total_tasks']} tasks still pending. " + f"Call check_job_status('{batch_id}') to monitor, " + f"or use include_partial=True to get partial results." + ), + } + + results = [] + for t in batch.tasks: + if t.result is not None: + results.append(t.result) + + return { + **status_info, + "results": results, + } + + # ── listing ──────────────────────────────────────────────────────── + + def list_batches(self) -> list[dict]: + """Return a summary of all tracked batches.""" + with self._lock: + batch_ids = list(self._batches.keys()) + + summaries = [] + for bid in batch_ids: + summaries.append(self.get_status(bid)) + return summaries + + # ── cancellation ─────────────────────────────────────────────────── + + def cancel_batch(self, batch_id: str) -> dict: + """Attempt to cancel pending tasks in a batch. + + Returns a dict with the number of successfully cancelled tasks. + Note: ``Future.cancel()`` only succeeds if the task has not yet + started executing. + """ + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + cancelled = 0 + already_done = 0 + for t in batch.tasks: + if t.future.done(): + already_done += 1 + elif t.future.cancel(): + cancelled += 1 + + return { + "batch_id": batch_id, + "cancelled": cancelled, + "already_done": already_done, + "could_not_cancel": len(batch.tasks) - cancelled - already_done, + } + + # ── cleanup ──────────────────────────────────────────────────────── + + def cleanup(self, max_age_hours: float = 24) -> int: + """Remove completed batches older than *max_age_hours*. + + Returns the number of batches removed. + """ + now = datetime.now(timezone.utc) + to_remove: list[str] = [] + + with self._lock: + for bid, batch in self._batches.items(): + age_hours = (now - batch.submitted_at).total_seconds() / 3600 + if age_hours > max_age_hours and all( + t.future.done() for t in batch.tasks + ): + to_remove.append(bid) + for bid in to_remove: + del self._batches[bid] + + if to_remove: + logger.info("Cleaned up %d old batches", len(to_remove)) + return len(to_remove) diff --git a/src/chemgraph/execution/utils.py b/src/chemgraph/execution/utils.py index 70759a71..ba941fd6 100644 --- a/src/chemgraph/execution/utils.py +++ b/src/chemgraph/execution/utils.py @@ -16,7 +16,11 @@ import logging from concurrent.futures import Future from pathlib import Path -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + from chemgraph.execution.base import ExecutionBackend + from chemgraph.execution.job_tracker import JobTracker logger = logging.getLogger(__name__) @@ -81,6 +85,7 @@ def resolve_structure_files( async def gather_futures( pending: list[tuple[dict, Future]], post_fn: Optional[Callable[[dict, Any], dict]] = None, + timeout: Optional[float] = None, ) -> list[dict]: """Await a list of ``(metadata, future)`` pairs concurrently. @@ -96,11 +101,20 @@ async def gather_futures( successful future resolution. Must return a ``dict`` to include in the results list. When *None*, the raw result is merged with metadata. + timeout : float, optional + Maximum seconds to wait for all futures to resolve. If the + timeout expires, an :class:`asyncio.TimeoutError` is raised. + When *None* (default), wait indefinitely. Returns ------- list[dict] One result dict per task (successful or failed). + + Raises + ------ + asyncio.TimeoutError + If *timeout* is set and exceeded before all futures complete. """ async def _wait(meta: dict, fut: Future) -> dict: @@ -122,9 +136,62 @@ async def _wait(meta: dict, fut: Future) -> dict: "message": str(e), } - return list( - await asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) - ) + coro = asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) + if timeout is not None: + return list(await asyncio.wait_for(coro, timeout=timeout)) + return list(await coro) + + +async def submit_or_gather( + backend: ExecutionBackend, + pending: list[tuple[dict, Future]], + tracker: JobTracker, + tool_name: str, + post_fn: Optional[Callable[[dict, Any], dict]] = None, +) -> dict: + """Gather results or register for async tracking, depending on the backend. + + When ``backend.is_async_remote`` is ``True``, the pending futures are + registered with the *tracker* and a submission confirmation is + returned immediately (non-blocking). Otherwise, results are gathered + synchronously via :func:`gather_futures`. + + Parameters + ---------- + backend : ExecutionBackend + The active execution backend. + pending : list[tuple[dict, Future]] + Each element is ``(metadata_dict, future)``. + tracker : JobTracker + The job tracker instance to register batches with. + tool_name : str + Name of the MCP tool submitting the batch. + post_fn : callable, optional + Post-processing function for results. + + Returns + ------- + dict + Either ``{"status": "submitted", "batch_id": ..., ...}`` for + async backends, or ``{"status": "completed", "results": ...}`` + for synchronous backends. + """ + if backend.is_async_remote: + batch_id = tracker.register_batch(tool_name, pending, post_fn=post_fn) + return { + "status": "submitted", + "batch_id": batch_id, + "n_tasks": len(pending), + "message": ( + f"Submitted {len(pending)} task(s) to remote HPC endpoint. " + f"Use check_job_status(batch_id='{batch_id}') to monitor " + f"progress, and get_job_results(batch_id='{batch_id}') to " + f"retrieve results once complete." + ), + } + + results = await gather_futures(pending, post_fn=post_fn) + return {"status": "completed", "results": results} def write_results_jsonl( diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index 9ee276bc..87eeb231 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -12,12 +12,14 @@ from mcp.server.fastmcp import FastMCP from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.job_tracker import JobTracker from chemgraph.execution.utils import ( - gather_futures, make_per_structure_output, resolve_structure_files, + submit_or_gather, write_results_jsonl, ) +from chemgraph.mcp.job_tools import register_job_tools from chemgraph.mcp.server_utils import run_mcp_server from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble @@ -25,6 +27,7 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() +tracker = JobTracker() # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -34,6 +37,10 @@ The available tools are: 1. run_graspa_ensemble: run graspa calculations over all structures in a directory using the configured execution backend. + 2. check_job_status: check progress of a submitted HPC job batch. + 3. get_job_results: retrieve results from a completed job batch. + 4. list_jobs: list all tracked job batches. + 5. cancel_job: cancel pending tasks in a job batch. Guidelines: - Use each tool only when its input schema matches the user request. @@ -42,8 +49,11 @@ defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling get_job_results. """, ) +register_job_tools(mcp, tracker, backend) def _run_graspa_single(job: dict) -> dict: @@ -108,17 +118,24 @@ async def run_graspa_ensemble( } pending_tasks.append((task_meta, fut)) - results = await gather_futures(pending_tasks) - - summary_log_path = output_dir / "simulation_results.jsonl" - success_count, total_count = write_results_jsonl(results, summary_log_path) - - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." + result = await submit_or_gather( + backend, pending_tasks, tracker, "run_graspa_ensemble", ) + if result["status"] == "completed": + summary_log_path = output_dir / "simulation_results.jsonl" + success_count, total_count = write_results_jsonl( + result["results"], summary_log_path, + ) + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) + + # Async remote: return submission confirmation + return result + if __name__ == "__main__": run_mcp_server(mcp, default_port=9001) diff --git a/src/chemgraph/mcp/job_tools.py b/src/chemgraph/mcp/job_tools.py new file mode 100644 index 00000000..6974aef1 --- /dev/null +++ b/src/chemgraph/mcp/job_tools.py @@ -0,0 +1,107 @@ +"""Shared MCP tools for job status tracking and result retrieval. + +Call :func:`register_job_tools` to add ``check_job_status``, +``get_job_results``, ``list_jobs``, ``cancel_job``, and (optionally) +``check_endpoint_status`` to any :class:`~mcp.server.fastmcp.FastMCP` +server instance. + +These tools are only useful when the execution backend is async-remote +(e.g. Globus Compute), but are registered unconditionally so the LLM +agent always has a consistent tool surface. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from mcp.server.fastmcp import FastMCP + + from chemgraph.execution.base import ExecutionBackend + from chemgraph.execution.job_tracker import JobTracker + + +def register_job_tools( + mcp: FastMCP, + tracker: JobTracker, + backend: ExecutionBackend, +) -> None: + """Register job-management MCP tools on *mcp*. + + Parameters + ---------- + mcp : FastMCP + The MCP server to register tools on. + tracker : JobTracker + The job tracker for this server process. + backend : ExecutionBackend + The active execution backend (used for endpoint health checks). + """ + + @mcp.tool( + name="check_job_status", + description=( + "Check the status of a previously submitted HPC job batch. " + "Returns progress information including how many tasks are " + "complete, failed, or still pending. Use this to poll " + "long-running remote compute jobs." + ), + ) + def check_job_status(batch_id: str) -> dict: + """Check the status of a submitted job batch.""" + return tracker.get_status(batch_id) + + @mcp.tool( + name="get_job_results", + description=( + "Retrieve results from a completed (or partially completed) " + "HPC job batch. By default, returns results only when all " + "tasks are done. Set include_partial=True to get results " + "for tasks that have finished so far." + ), + ) + def get_job_results( + batch_id: str, + include_partial: bool = False, + ) -> dict: + """Retrieve results from a job batch.""" + return tracker.get_results(batch_id, include_partial=include_partial) + + @mcp.tool( + name="list_jobs", + description=( + "List all tracked job batches with their current status. " + "Shows batch IDs, tool names, submission times, and progress." + ), + ) + def list_jobs() -> list[dict]: + """List all tracked job batches.""" + batches = tracker.list_batches() + if not batches: + return [{"message": "No job batches tracked."}] + return batches + + @mcp.tool( + name="cancel_job", + description=( + "Cancel pending tasks in a job batch. Only tasks that have " + "not yet started executing can be cancelled." + ), + ) + def cancel_job(batch_id: str) -> dict: + """Cancel pending tasks in a job batch.""" + return tracker.cancel_batch(batch_id) + + if backend.is_async_remote and hasattr(backend, "check_endpoint_status"): + + @mcp.tool( + name="check_endpoint_status", + description=( + "Check whether the remote HPC compute endpoint is " + "reachable and accepting tasks. Use this as a pre-flight " + "check before submitting jobs." + ), + ) + def check_endpoint_status() -> dict: + """Check the remote compute endpoint status.""" + return backend.check_endpoint_status() diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index dc966c49..a664a1e7 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -20,11 +20,13 @@ from mcp.server.fastmcp import FastMCP from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.job_tracker import JobTracker from chemgraph.execution.utils import ( - gather_futures, make_per_structure_output, resolve_structure_files, + submit_or_gather, ) +from chemgraph.mcp.job_tools import register_job_tools from chemgraph.mcp.server_utils import run_mcp_server from chemgraph.tools.parsl_tools import ( mace_input_schema, @@ -35,6 +37,7 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() +tracker = JobTracker() # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -47,6 +50,10 @@ 2. run_mace_ensemble: run MACE calculations over all structures in a directory using the configured execution backend. 3. extract_output_json: load simulation results from a JSON file. + 4. check_job_status: check progress of a submitted HPC job batch. + 5. get_job_results: retrieve results from a completed job batch. + 6. list_jobs: list all tracked job batches. + 7. cancel_job: cancel pending tasks in a job batch. Guidelines: - Use each tool only when its input schema matches the user request. @@ -55,8 +62,11 @@ defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling get_job_results. """, ) +register_job_tools(mcp, tracker, backend) def _run_mace_single(job: dict) -> dict: @@ -138,6 +148,13 @@ async def run_mace_single(params: mace_input_schema): kwargs={"job": job}, ) fut = backend.submit(task) + + if backend.is_async_remote: + task_meta = {"task_id": "mace_single"} + return await submit_or_gather( + backend, [(task_meta, fut)], tracker, "run_mace_single" + ) + return await asyncio.wrap_future(fut) @@ -221,13 +238,21 @@ async def run_mace_ensemble(params: mace_input_schema_ensemble): } pending_tasks.append((task_meta, fut)) - results = await gather_futures(pending_tasks, post_fn=_mace_post_fn) + result = await submit_or_gather( + backend, pending_tasks, tracker, "run_mace_ensemble", + post_fn=_mace_post_fn, + ) - return { - "status": "success", - "n_structures": len(structure_files), - "results": results, - } + if result["status"] == "completed": + return { + "status": "success", + "n_structures": len(structure_files), + "results": result["results"], + } + + # Async remote: return submission confirmation + result["n_structures"] = len(structure_files) + return result @mcp.tool( diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 3ed81fa7..4b0d219b 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -12,11 +12,13 @@ from mcp.server.fastmcp import FastMCP from chemgraph.execution import TaskSpec, get_backend +from chemgraph.execution.job_tracker import JobTracker from chemgraph.execution.utils import ( - gather_futures, resolve_structure_files, + submit_or_gather, write_results_jsonl, ) +from chemgraph.mcp.job_tools import register_job_tools from chemgraph.mcp.server_utils import run_mcp_server from chemgraph.schemas.xanes_schema import ( mp_query_schema, @@ -28,6 +30,7 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() +tracker = JobTracker() # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -40,6 +43,10 @@ using the configured execution backend. 3. fetch_mp_structures: fetch optimized structures from Materials Project. 4. plot_xanes: generate normalized XANES plots for completed calculations. + 5. check_job_status: check progress of a submitted HPC job batch. + 6. get_job_results: retrieve results from a completed job batch. + 7. list_jobs: list all tracked job batches. + 8. cancel_job: cancel pending tasks in a job batch. Guidelines: - Use each tool only when its input schema matches the user request. @@ -47,8 +54,11 @@ - Keep responses compact -- full results are in the output directories. - When returning paths, use absolute paths. - Energies are in eV. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling get_job_results. """, ) +register_job_tools(mcp, tracker, backend) @mcp.tool( @@ -155,16 +165,24 @@ async def run_xanes_ensemble(params: xanes_input_schema_ensemble): } pending_tasks.append((task_meta, fut)) - results = await gather_futures(pending_tasks, post_fn=_xanes_post_fn) + result = await submit_or_gather( + backend, pending_tasks, tracker, "run_xanes_ensemble", + post_fn=_xanes_post_fn, + ) - summary_log_path = output_dir / "xanes_results.jsonl" - success_count, total_count = write_results_jsonl(results, summary_log_path) + if result["status"] == "completed": + summary_log_path = output_dir / "xanes_results.jsonl" + success_count, total_count = write_results_jsonl( + result["results"], summary_log_path, + ) + return ( + f"Ensemble execution completed. Ran {total_count} tasks " + f"({success_count} successful). " + f"Detailed results appended to '{summary_log_path}'." + ) - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." - ) + # Async remote: return submission confirmation + return result @mcp.tool( diff --git a/tests/test_job_tracker.py b/tests/test_job_tracker.py new file mode 100644 index 00000000..cee3d081 --- /dev/null +++ b/tests/test_job_tracker.py @@ -0,0 +1,394 @@ +"""Tests for the JobTracker and submit_or_gather utilities.""" + +import asyncio +from concurrent.futures import Future +from unittest.mock import MagicMock + +import pytest + +from chemgraph.execution.job_tracker import JobTracker +from chemgraph.execution.utils import gather_futures, submit_or_gather + + +# ── Helpers ──────────────────────────────────────────────────────────── + + +def _make_done_future(result): + """Create a Future that is already resolved with *result*.""" + fut = Future() + fut.set_result(result) + return fut + + +def _make_failed_future(exc): + """Create a Future that is already resolved with an exception.""" + fut = Future() + fut.set_exception(exc) + return fut + + +def _make_pending_future(): + """Create a Future that is not yet resolved.""" + return Future() + + +# ── JobTracker.register_batch ────────────────────────────────────────── + + +class TestRegisterBatch: + def test_returns_batch_id(self): + tracker = JobTracker() + fut = _make_pending_future() + batch_id = tracker.register_batch( + "test_tool", [({"key": "val"}, fut)] + ) + assert isinstance(batch_id, str) + assert len(batch_id) == 12 + + def test_stores_tasks(self): + tracker = JobTracker() + futs = [_make_pending_future() for _ in range(3)] + pending = [({"idx": i}, f) for i, f in enumerate(futs)] + batch_id = tracker.register_batch("test_tool", pending) + + status = tracker.get_status(batch_id) + assert status["total_tasks"] == 3 + + def test_multiple_batches_unique_ids(self): + tracker = JobTracker() + ids = set() + for _ in range(10): + bid = tracker.register_batch( + "tool", [({"x": 1}, _make_pending_future())] + ) + ids.add(bid) + assert len(ids) == 10 + + +# ── JobTracker.get_status ────────────────────────────────────────────── + + +class TestGetStatus: + def test_all_pending(self): + tracker = JobTracker() + pending = [({"i": i}, _make_pending_future()) for i in range(3)] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "pending" + assert status["total_tasks"] == 3 + assert status["completed_tasks"] == 0 + assert status["pending_tasks"] == 3 + assert status["progress_pct"] == 0.0 + + def test_all_completed(self): + tracker = JobTracker() + pending = [ + ({"i": i}, _make_done_future({"val": i})) for i in range(3) + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + assert status["completed_tasks"] == 3 + assert status["failed_tasks"] == 0 + assert status["pending_tasks"] == 0 + assert status["progress_pct"] == 100.0 + + def test_partial_done(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 0})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "running" + assert status["completed_tasks"] == 1 + assert status["pending_tasks"] == 1 + assert status["progress_pct"] == 50.0 + + def test_all_failed(self): + tracker = JobTracker() + pending = [ + ({"i": i}, _make_failed_future(ValueError(f"err_{i}"))) + for i in range(2) + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "failed" + assert status["failed_tasks"] == 2 + + def test_mixed_success_and_failure(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 0})), + ({"i": 1}, _make_failed_future(RuntimeError("boom"))), + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "partial" + assert status["completed_tasks"] == 1 + assert status["failed_tasks"] == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + status = tracker.get_status("nonexistent") + assert "error" in status + + def test_with_post_fn(self): + def post_fn(meta, result): + return {"custom": True, "status": "success", **meta} + + tracker = JobTracker() + pending = [({"i": 0}, _make_done_future({"raw": 1}))] + batch_id = tracker.register_batch("tool", pending, post_fn=post_fn) + + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + + +# ── JobTracker.get_results ───────────────────────────────────────────── + + +class TestGetResults: + def test_returns_results_when_complete(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_done_future({"val": 20})), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id) + assert "results" in result + assert len(result["results"]) == 2 + + def test_blocks_when_pending_and_partial_false(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id, include_partial=False) + assert "results" not in result + assert "message" in result + assert "still pending" in result["message"] + + def test_returns_partial_when_requested(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id, include_partial=True) + assert "results" in result + assert len(result["results"]) == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + result = tracker.get_results("nonexistent") + assert "error" in result + + +# ── JobTracker.list_batches ──────────────────────────────────────────── + + +class TestListBatches: + def test_empty(self): + tracker = JobTracker() + assert tracker.list_batches() == [] + + def test_multiple_batches(self): + tracker = JobTracker() + tracker.register_batch("tool_a", [({"x": 1}, _make_pending_future())]) + tracker.register_batch("tool_b", [({"x": 2}, _make_done_future(42))]) + + batches = tracker.list_batches() + assert len(batches) == 2 + tool_names = {b["tool_name"] for b in batches} + assert tool_names == {"tool_a", "tool_b"} + + +# ── JobTracker.cancel_batch ──────────────────────────────────────────── + + +class TestCancelBatch: + def test_cancel_pending(self): + tracker = JobTracker() + fut = _make_pending_future() + batch_id = tracker.register_batch("tool", [({"i": 0}, fut)]) + + result = tracker.cancel_batch(batch_id) + # Future.cancel() may or may not succeed depending on state, + # but the call should not raise + assert "batch_id" in result + + def test_cancel_already_done(self): + tracker = JobTracker() + fut = _make_done_future({"val": 1}) + batch_id = tracker.register_batch("tool", [({"i": 0}, fut)]) + + result = tracker.cancel_batch(batch_id) + assert result["already_done"] == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + result = tracker.cancel_batch("nonexistent") + assert "error" in result + + +# ── JobTracker.cleanup ───────────────────────────────────────────────── + + +class TestCleanup: + def test_removes_old_completed(self): + tracker = JobTracker() + batch_id = tracker.register_batch( + "tool", [({"i": 0}, _make_done_future(1))] + ) + + # Force the submitted_at to be old + batch = tracker._batches[batch_id] + from datetime import timedelta + + batch.submitted_at -= timedelta(hours=25) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 1 + assert tracker.list_batches() == [] + + def test_keeps_recent(self): + tracker = JobTracker() + tracker.register_batch("tool", [({"i": 0}, _make_done_future(1))]) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 0 + assert len(tracker.list_batches()) == 1 + + def test_keeps_pending(self): + tracker = JobTracker() + batch_id = tracker.register_batch( + "tool", [({"i": 0}, _make_pending_future())] + ) + + batch = tracker._batches[batch_id] + from datetime import timedelta + + batch.submitted_at -= timedelta(hours=25) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 0 + + +# ── gather_futures with timeout ──────────────────────────────────────── + + +class TestGatherFuturesTimeout: + def test_completes_within_timeout(self): + pending = [ + ({"i": 0}, _make_done_future({"val": 1})), + ({"i": 1}, _make_done_future({"val": 2})), + ] + results = asyncio.get_event_loop().run_until_complete( + gather_futures(pending, timeout=5.0) + ) + assert len(results) == 2 + + def test_timeout_raises(self): + pending = [({"i": 0}, _make_pending_future())] + with pytest.raises(asyncio.TimeoutError): + asyncio.get_event_loop().run_until_complete( + gather_futures(pending, timeout=0.1) + ) + + def test_no_timeout_default(self): + pending = [({"i": 0}, _make_done_future(42))] + results = asyncio.get_event_loop().run_until_complete( + gather_futures(pending) + ) + assert len(results) == 1 + + +# ── submit_or_gather ─────────────────────────────────────────────────── + + +class TestSubmitOrGather: + def test_sync_backend_returns_completed(self): + backend = MagicMock() + backend.is_async_remote = False + + tracker = JobTracker() + pending = [({"i": 0}, _make_done_future({"val": 10}))] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + assert result["status"] == "completed" + assert "results" in result + assert len(result["results"]) == 1 + + def test_async_backend_returns_submitted(self): + backend = MagicMock() + backend.is_async_remote = True + + tracker = JobTracker() + pending = [({"i": 0}, _make_pending_future())] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + assert result["status"] == "submitted" + assert "batch_id" in result + assert result["n_tasks"] == 1 + assert "check_job_status" in result["message"] + + def test_async_backend_batch_trackable(self): + backend = MagicMock() + backend.is_async_remote = True + + tracker = JobTracker() + fut = _make_done_future({"val": 99}) + pending = [({"i": 0}, fut)] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + batch_id = result["batch_id"] + + # Verify the batch is tracked and status works + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + + # Verify results can be retrieved + results = tracker.get_results(batch_id) + assert "results" in results + assert len(results["results"]) == 1 + + def test_async_backend_with_post_fn(self): + backend = MagicMock() + backend.is_async_remote = True + + def post_fn(meta, result): + return {"processed": True, "status": "success"} + + tracker = JobTracker() + fut = _make_done_future({"raw": 1}) + pending = [({"i": 0}, fut)] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather( + backend, pending, tracker, "test_tool", post_fn=post_fn, + ) + ) + batch_id = result["batch_id"] + + results = tracker.get_results(batch_id) + assert results["results"][0]["processed"] is True From 35a2d6526994f51e9fb55586ca312026e0920707 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Thu, 21 May 2026 16:43:22 -0500 Subject: [PATCH 022/119] Modified the EL backend implemenations, and added a EL backend test --- src/chemgraph/execution/base.py | 3 +- src/chemgraph/execution/config.py | 8 + .../execution/ensemble_launcher_backend.py | 149 ++++++++---- tests/test_execution.py | 221 +++++++++++++++++- 4 files changed, 325 insertions(+), 56 deletions(-) diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py index ccfb4f2d..c182b4cf 100644 --- a/src/chemgraph/execution/base.py +++ b/src/chemgraph/execution/base.py @@ -11,7 +11,7 @@ import logging from abc import ABC, abstractmethod from concurrent.futures import Future -from typing import Any, Callable, Literal, Optional +from typing import Any, Callable, Dict, Literal, Optional from pydantic import BaseModel, ConfigDict, Field @@ -85,6 +85,7 @@ class TaskSpec(BaseModel): default=0, description="Number of GPUs requested per task.", ) + env: Dict[str, str] = Field(default_factory=dict) class ExecutionBackend(ABC): diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index 80b3c458..60921f92 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -144,10 +144,18 @@ def get_backend( elif resolved_backend == "ensemble_launcher": from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, EnsembleLauncherBackend, + get_launcher_config, ) backend = EnsembleLauncherBackend() + assert system in SYSTEM_CONFIG_REGISTRY, ( + f"Unknown system: only know {SYSTEM_CONFIG_REGISTRY.keys()}" + ) + merged_kwargs = {} + merged_kwargs["system_config"] = SYSTEM_CONFIG_REGISTRY[system] + merged_kwargs["launcher_config"] = get_launcher_config(**backend_cfg) elif resolved_backend == "globus_compute": from chemgraph.execution.globus_compute_backend import ( diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 23462f5b..f78d41b1 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -13,17 +13,86 @@ import logging import os -import socket import time import uuid from concurrent.futures import Future -from typing import Any +from typing import List, Literal, Optional, Union from chemgraph.execution.base import ExecutionBackend, TaskSpec +try: + from ensemble_launcher import EnsembleLauncher + from ensemble_launcher.config import ( + LauncherConfig, + MPIConfig, + PolicyConfig, + SystemConfig, + ) + from ensemble_launcher.helper_functions import get_nodes + from ensemble_launcher.orchestrator import ClusterClient +except ImportError as exc: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) from exc + logger = logging.getLogger(__name__) +def get_local_system_config(): + system_config = SystemConfig( + name="local", + ncpus=os.cpu_count(), + cpus=list(range(os.cpu_count())), + ) + return system_config + + +def get_polaris_system_config(): + system_config = SystemConfig( + name="polaris", + ncpus=32, + cpus=list(range(32)), + ngpus=4, + gpus=list(range(4)), + ) + return system_config + + +def get_aurora_system_config(): + system_config = SystemConfig( + name="aurora", + ncpus=102, + cpus=list(range(1, 52)) + list(range(53, 104)), + ngpus=12, + gpus=list(range(12)), + ) + return system_config + + +def get_launcher_config( + task_executor_name: Union[str, List] = "async_processpool", + child_executor_policy: str = "fixed_leafs_children_policy", + policy_config: Optional[PolicyConfig] = None, + checkpoint_dir=f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}", + mpi_flavour: Literal["test", "mpich"] = "test", +): + if policy_config is None: + policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) + return LauncherConfig( + child_executor_name="async_mpi", + task_executor_name=task_executor_name, + return_stdout=True, + worker_logs=True, + master_logs=True, + children_scheduler_policy=child_executor_policy, + policy_config=policy_config, + cluster=True, + checkpoint_dir=checkpoint_dir, + mpi_config=MPIConfig(flavor=mpi_flavour), + ) + + class EnsembleLauncherBackend(ExecutionBackend): """Execution backend that delegates work to EnsembleLauncher. @@ -60,75 +129,42 @@ def __init__(self) -> None: self._client = None self._checkpoint_dir: str | None = None - def initialize(self, system: str = "local", **kwargs: Any) -> None: - try: - from ensemble_launcher import EnsembleLauncher - from ensemble_launcher.config import LauncherConfig, SystemConfig - from ensemble_launcher.orchestrator import ClusterClient - except ImportError as exc: - raise ImportError( - "EnsembleLauncher is required for the EnsembleLauncherBackend. " - "Install it with: pip install ensemble-launcher" - ) from exc - - # -- extract parameters ------------------------------------------------ - comm_name = kwargs.get("comm_name", "async_zmq") - task_executor = kwargs.get("task_executor_name", "async_processpool") - nlevels = kwargs.get("nlevels", 0) - ncpus = kwargs.get("max_workers", os.cpu_count() or 4) - checkpoint_dir = kwargs.get( - "checkpoint_dir", - os.path.join(os.getcwd(), f".el_ckpt_{uuid.uuid4().hex[:8]}"), - ) - nodes = kwargs.get("nodes", [socket.gethostname()]) - startup_delay = kwargs.get("startup_delay", 2.0) - - self._checkpoint_dir = checkpoint_dir - - # -- configure --------------------------------------------------------- - system_config = SystemConfig( - name=system, - ncpus=ncpus, - cpus=list(range(ncpus)), - ) - - launcher_config = LauncherConfig( - task_executor_name=task_executor, - comm_name=comm_name, - nlevels=nlevels, - cluster=True, - checkpoint_dir=checkpoint_dir, - ) + def initialize( + self, + system: str, + system_config: SystemConfig, + launcher_config: LauncherConfig, + startup_delay: float = 1.0, + ) -> None: + os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) # -- start orchestrator ------------------------------------------------ self._el = EnsembleLauncher( ensemble_file={}, system_config=system_config, launcher_config=launcher_config, - Nodes=nodes, ) self._el.start() time.sleep(startup_delay) # -- connect client ---------------------------------------------------- - self._client = ClusterClient(checkpoint_dir=checkpoint_dir) + self._client = ClusterClient(checkpoint_dir=launcher_config.checkpoint_dir) self._client.start() self._initialized = True logger.info( "EnsembleLauncherBackend initialized (system='%s', " "comm='%s', executor='%s', nodes=%s)", - system, - comm_name, - task_executor, - nodes, + system_config.name, + launcher_config.comm_name, + launcher_config.task_executor_name, + len(self._el.nodes), ) def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._client is None: raise RuntimeError( - "EnsembleLauncherBackend is not initialized. " - "Call initialize() first." + "EnsembleLauncherBackend is not initialized. Call initialize() first." ) from ensemble_launcher.ensemble import Task as ELTask @@ -145,6 +181,7 @@ def submit(self, task: TaskSpec) -> Future: executable=task.callable, args=task.args or (), kwargs=task.kwargs or {}, + env=task.env, ) return self._client.submit(el_task) @@ -157,7 +194,8 @@ def submit(self, task: TaskSpec) -> Future: task_id=task.task_id, nnodes=task.num_nodes, ppn=task.processes_per_node, - cmd_template=task.command, + executable=task.command, + env=task.env, ) return self._client.submit(el_task) @@ -197,3 +235,14 @@ def shutdown(self) -> None: "EnsembleLauncherBackend partially shut down. " "Call shutdown() again to retry failed teardown." ) + + +SYSTEM_CONFIG_REGISTRY = { + "local": get_local_system_config(), + "aurora": get_aurora_system_config(), + "polaris": get_polaris_system_config(), +} + +if __name__ == "__main__": + el_backend = EnsembleLauncherBackend() + el_backend.initialize() diff --git a/tests/test_execution.py b/tests/test_execution.py index 5f1617bc..c662547c 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -19,7 +19,7 @@ import pytest -from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.base import TaskSpec from chemgraph.execution.local_backend import LocalBackend from chemgraph.execution.utils import ( gather_futures, @@ -28,7 +28,6 @@ write_results_jsonl, ) - # ── TaskSpec tests ────────────────────────────────────────────────────── @@ -199,6 +198,220 @@ def test_shell_task_missing_command(self): backend.shutdown() +# ── EnsembleLauncherBackend tests ────────────────────────────────────────── + + +class TestELBackend: + @classmethod + def setup_class(cls): + project_root = str(Path(__file__).resolve().parent.parent) + existing = os.environ.get("PYTHONPATH", "") + os.environ["PYTHONPATH"] = ( + f"{project_root}:{existing}" if existing else project_root + ) + + def test_python_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="sq", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + assert fut.result(timeout=10) == 49 + finally: + backend.shutdown() + + def test_python_task_kwargs(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="add", + task_type="python", + callable=_add, + kwargs={"a": 3, "b": 5}, + ) + assert backend.submit(task).result(timeout=10) == 8 + finally: + backend.shutdown() + + def test_shell_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="echo", + task_type="shell", + command="echo hello_world", + ) + fut = backend.submit(task) + result = fut.result(timeout=10) + assert result is not None + finally: + backend.shutdown() + + def test_submit_batch(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + tasks = [ + TaskSpec( + task_id=f"sq_{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(5) + ] + futures = backend.submit_batch(tasks) + assert len(futures) == 5 + results = [f.result(timeout=10) for f in futures] + assert results == [0, 1, 4, 9, 16] + finally: + backend.shutdown() + + def test_failing_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="fail", + task_type="python", + callable=_failing_fn, + ) + fut = backend.submit(task) + with pytest.raises(Exception, match="intentional test error"): + fut.result(timeout=10) + finally: + backend.shutdown() + + def test_context_manager(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + with EnsembleLauncherBackend() as backend: + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + assert backend.submit(task).result(timeout=10) == 9 + + def test_not_initialized_raises(self): + from chemgraph.execution.ensemble_launcher_backend import ( + EnsembleLauncherBackend, + ) + + backend = EnsembleLauncherBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_python_task_missing_callable(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + finally: + backend.shutdown() + + def test_shell_task_missing_command(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + finally: + backend.shutdown() + + # ── GlobusComputeBackend tests ────────────────────────────────────────── @@ -807,9 +1020,7 @@ def globus_backend(): pytest.skip("chemgraph.execution not available") try: - backend = get_backend( - backend_name="globus_compute", endpoint_id=endpoint_id - ) + backend = get_backend(backend_name="globus_compute", endpoint_id=endpoint_id) except ImportError: pytest.skip("globus-compute-sdk not installed") From 212a054e53f941ba3f2842bd056f851a48a963e0 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Fri, 22 May 2026 12:46:29 -0500 Subject: [PATCH 023/119] Add CGFastMCP backend framework, EL client-only mode, and pickle fix - Add CGFastMCP: FastMCP subclass with integrated execution backend, lazy init, built-in job tools, @tool() and @ensemble_tool() decorators - Refactor EnsembleLauncherBackend with client-only mode (shared orchestrator via checkpoint_dir) and managed mode - Update get_backend() to route client_only vs managed EL initialization - Rewrite mace_mcp_hpc.py to use CGFastMCP decorators - Clean up parsl_tools.py: remove dead code, use stdlib logging - Fix __main__ pickle issue via _fix_module_for_pickle + sys.modules alias - Add client-only mode demo cell to notebook 3 Co-Authored-By: Claude Opus 4.6 --- notebooks/3_Demo_using_MCP.ipynb | 303 ++++++++++------ src/chemgraph/execution/config.py | 20 +- .../execution/ensemble_launcher_backend.py | 164 +++++---- src/chemgraph/mcp/cg_fastmcp.py | 339 ++++++++++++++++++ src/chemgraph/mcp/mace_mcp_hpc.py | 261 ++------------ src/chemgraph/tools/parsl_tools.py | 24 +- 6 files changed, 694 insertions(+), 417 deletions(-) create mode 100644 src/chemgraph/mcp/cg_fastmcp.py diff --git a/notebooks/3_Demo_using_MCP.ipynb b/notebooks/3_Demo_using_MCP.ipynb index ce37b46d..caf11cb0 100644 --- a/notebooks/3_Demo_using_MCP.ipynb +++ b/notebooks/3_Demo_using_MCP.ipynb @@ -2,190 +2,269 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "3b97dfba-13c9-49a4-bdce-efd5900dcafa", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tpham2/work/projects/ChemGraph/env/chemgraph_env/lib/python3.10/site-packages/google/api_core/_python_version_support.py:266: FutureWarning: You are using a Python version (3.10.19) which Google will stop supporting in new releases of google.api_core once it reaches its end of life (2026-10-04). Please upgrade to the latest Python version, or at least Python 3.11, to continue receiving updates for google.api_core past that date.\n", - " warnings.warn(message, FutureWarning)\n", - "WARNING:root:fairchem is not installed. .\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "2026-01-22 11:50:08,686 - chemgraph.models.openai - INFO - OpenAI API key not found in environment variables.\n" + "Done creating client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "INFO:chemgraph.models.openai:OpenAI API key not found in environment variables.\n" - ] - }, - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Please enter your OpenAI API key: ········\n" + "2026-05-22 12:34:00,370 - chemgraph.graphs.single_agent - INFO - Constructing single agent graph\n", + "2026-05-22 12:34:00,372 - chemgraph.graphs.single_agent - INFO - Graph construction completed\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2026-01-22 11:50:10,594 - chemgraph.models.openai - INFO - Loading OpenAI model: gpt-4o-mini\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:chemgraph.models.openai:Loading OpenAI model: gpt-4o-mini\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2026-01-22 11:50:10,710 - chemgraph.models.openai - INFO - Requested model: gpt-4o-mini\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:chemgraph.models.openai:Requested model: gpt-4o-mini\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2026-01-22 11:50:10,711 - chemgraph.models.openai - INFO - OpenAI model loaded successfully\n" + "Done getting tools\n", + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "Run a mace calculations with the same file, use energy for driver and small model. a cif file are located at /Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "Tool Calls:\n", + " run_mace_single (chatcmpl-tool-a42c48d32a55e54d)\n", + " Call ID: chatcmpl-tool-a42c48d32a55e54d\n", + " Args:\n", + " params: {'input_structure_file': '/Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif', 'driver': 'energy', 'model': 'small'}\n", + "=================================\u001b[1m Tool Message \u001b[0m=================================\n", + "Name: run_mace_single\n", + "\n", + "{\n", + " \"status\": \"success\",\n", + " \"message\": \"Simulation completed. Results saved to /Users/hari/projects/ChemGraph/notebooks/output.json\",\n", + " \"single_point_energy\": -295.75144320599975,\n", + " \"unit\": \"eV\"\n", + "}\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "\n", + "The MACE single‑point energy calculation completed successfully.\n", + "\n", + "**Result**\n", + "- **Energy:** -295.75144320599975 eV \n", + "- **Output file:** `/Users/hari/projects/ChemGraph/notebooks/output.json`\n", + "\n", + "If you need any other properties (e.g., forces, charge distribution) or would like to run additional calculations (geometry optimization, vibrational analysis, etc.), just let me know!\n", + "Done\n" ] - }, + } + ], + "source": [ + "import subprocess, time, os\n", + "from langchain_mcp_adapters.client import MultiServerMCPClient\n", + "from chemgraph.agent.llm_agent import ChemGraph\n", + "\n", + "prompt_single = \"Run a mace calculations with the same file, use energy for driver and small model. a cif file are located at /Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif\"\n", + "\n", + "os.environ[\"ALCF_ACCESS_TOKEN\"]=\" None: super().__init__() - self._el = None - self._client = None - self._checkpoint_dir: str | None = None + self._orchestrator: Optional[EnsembleLauncher] = None + self._client: Optional[ClusterClient] = None def initialize( self, - system: str, - system_config: SystemConfig, - launcher_config: LauncherConfig, - startup_delay: float = 1.0, + system: str = "local", + *, + client_only: bool = False, + checkpoint_dir: Optional[str] = None, + node_id: str = "global", + system_config: Optional[SystemConfig] = None, + launcher_config: Optional[LauncherConfig] = None, + startup_delay: float = 10.0, + **kwargs, ) -> None: + """Prepare the backend for accepting work. + + Parameters + ---------- + client_only : bool + When ``True``, connect to a running orchestrator via + *checkpoint_dir* — no orchestrator is started. + checkpoint_dir : str + Path to the orchestrator's checkpoint directory. Required + when *client_only* is ``True``. + node_id : str + Orchestrator node to connect to (default ``"global"``). + system_config, launcher_config + Required for **managed** mode (``client_only=False``). + The backend starts its own orchestrator with these. + startup_delay : float + Seconds to wait for the orchestrator to become ready + (managed mode only). + """ + if client_only: + # -- client-only mode ---------------------------------------------- + if checkpoint_dir is None: + raise ValueError( + "client_only=True requires a checkpoint_dir pointing " + "to a running orchestrator." + ) + self._client = ClusterClient( + checkpoint_dir=checkpoint_dir, node_id=node_id + ) + self._client.start() + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized in client-only mode " + "(checkpoint_dir='%s', node_id='%s')", + checkpoint_dir, + node_id, + ) + else: + # -- managed mode: start orchestrator first ------------------------ + if system_config is None or launcher_config is None: + raise ValueError( + "Managed mode requires system_config and launcher_config " + "(or set client_only=True with a checkpoint_dir)." + ) + os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) + self._orchestrator = EnsembleLauncher( + ensemble_file={}, + system_config=system_config, + launcher_config=launcher_config, + ) + self._orchestrator.start() + time.sleep(startup_delay) - os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) - # -- start orchestrator ------------------------------------------------ - self._el = EnsembleLauncher( - ensemble_file={}, - system_config=system_config, - launcher_config=launcher_config, - ) - self._el.start() - time.sleep(startup_delay) - - # -- connect client ---------------------------------------------------- - self._client = ClusterClient(checkpoint_dir=launcher_config.checkpoint_dir) - self._client.start() - - self._initialized = True - logger.info( - "EnsembleLauncherBackend initialized (system='%s', " - "comm='%s', executor='%s', nodes=%s)", - system_config.name, - launcher_config.comm_name, - launcher_config.task_executor_name, - len(self._el.nodes), - ) + self._client = ClusterClient( + checkpoint_dir=launcher_config.checkpoint_dir, + node_id=node_id, + ) + self._client.start() + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized in managed mode " + "(system='%s', comm='%s', executor='%s', nodes=%s)", + system_config.name, + launcher_config.comm_name, + launcher_config.task_executor_name, + len(self._orchestrator.nodes), + ) def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._client is None: @@ -217,18 +252,19 @@ def shutdown(self) -> None: "Error tearing down EnsembleLauncher client.", exc_info=True ) - el_ok = True - if self._el is not None: + orchestrator_ok = True + if self._orchestrator is not None: try: - self._el.stop() - self._el = None + self._orchestrator.stop() + self._orchestrator = None except Exception: - el_ok = False + orchestrator_ok = False logger.warning( - "Error stopping EnsembleLauncher orchestrator.", exc_info=True + "Error stopping EnsembleLauncher orchestrator.", + exc_info=True, ) - if client_ok and el_ok: + if client_ok and orchestrator_ok: logger.info("EnsembleLauncherBackend shut down.") else: logger.warning( diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py new file mode 100644 index 00000000..e25b23f9 --- /dev/null +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -0,0 +1,339 @@ +"""Backend-aware FastMCP subclass for ChemGraph. + +:class:`CGFastMCP` extends :class:`FastMCP` with an execution backend. +Tools registered via :meth:`tool` are automatically submitted to the +backend as :class:`~chemgraph.execution.base.TaskSpec` instances — +the tool author writes a plain function and the framework handles +submission, future resolution, and async job tracking. + +Tools that do **not** need the backend (e.g. JSON loaders, plotting +utilities) should be registered with :meth:`add_tool` (inherited from +FastMCP) which bypasses the backend wrapper entirely. +""" + +import asyncio +import functools +import inspect +import logging +from typing import Any, Callable, Dict, Optional + +from mcp.server.fastmcp import FastMCP +from mcp.types import ToolAnnotations + +logger = logging.getLogger(__name__) + + +class CGFastMCP(FastMCP): + """FastMCP with an integrated execution backend. + + Parameters + ---------- + **kwargs + Forwarded to :class:`FastMCP` (``name``, ``instructions``, etc.). + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._backend = None + self._tracker = None + self._backend_kwargs: Optional[dict[str, Any]] = None + + # ── Backend lifecycle ─────────────────────────────────────────────── + + def init_backend(self, **kwargs: Any) -> None: + """Register backend configuration for lazy initialisation. + + The backend is not created until the first tool invocation, + so the MCP server can start accepting connections immediately. + All keyword arguments are forwarded to + :func:`~chemgraph.execution.config.get_backend`. + """ + self._backend_kwargs = kwargs + self._register_job_tools() + logger.info("CGFastMCP backend configured (lazy init).") + + def _ensure_backend(self) -> None: + """Create the backend on first use.""" + if self._backend is not None: + return + if self._backend_kwargs is None: + raise RuntimeError( + "Backend not configured. Call init_backend() first." + ) + from chemgraph.execution import JobTracker, get_backend + + self._backend = get_backend(**self._backend_kwargs) + self._tracker = JobTracker() + logger.info( + "CGFastMCP backend initialised: %s", type(self._backend).__name__ + ) + + def shutdown_backend(self) -> None: + """Shut down the execution backend and release resources.""" + if self._backend is not None: + try: + self._backend.shutdown() + except Exception: + logger.warning("Error during backend shutdown.", exc_info=True) + self._backend = None + self._tracker = None + self._backend_kwargs = None + logger.info("CGFastMCP backend shut down.") + + # ── Job tracking tools ───────────────────────────────────────────── + + def _register_job_tools(self) -> None: + """Register job-management tools (status, results, cancel).""" + + @self.add_tool + def check_job_status(batch_id: str) -> dict: + """Check the status of a submitted job batch.""" + self._ensure_backend() + return self._tracker.get_status(batch_id) + + @self.add_tool + def get_job_results( + batch_id: str, include_partial: bool = False + ) -> dict: + """Retrieve results from a completed job batch.""" + self._ensure_backend() + return self._tracker.get_results( + batch_id, include_partial=include_partial + ) + + @self.add_tool + def list_jobs() -> list[dict]: + """List all tracked job batches.""" + self._ensure_backend() + batches = self._tracker.list_batches() + if not batches: + return [{"message": "No job batches tracked."}] + return batches + + @self.add_tool + def cancel_job(batch_id: str) -> dict: + """Cancel pending tasks in a job batch.""" + self._ensure_backend() + return self._tracker.cancel_batch(batch_id) + + @self.add_tool + def check_endpoint_status() -> dict: + """Check whether the remote compute endpoint is reachable.""" + self._ensure_backend() + if hasattr(self._backend, "check_endpoint_status"): + return self._backend.check_endpoint_status() + return {"status": "not_applicable", + "message": "This backend does not support endpoint status checks."} + + # ── Internal helpers ────────────────────────────────────────────── + + @staticmethod + def _fix_module_for_pickle(fn: Callable) -> None: + """Ensure *fn* is picklable when the MCP server runs as ``__main__``.""" + if fn.__module__ == "__main__": + import sys + + spec = getattr(sys.modules.get("__main__"), "__spec__", None) + if spec and spec.name: + fn.__module__ = spec.name + if spec.name not in sys.modules: + sys.modules[spec.name] = sys.modules["__main__"] + + # ── Tool registration ─────────────────────────────────────────────── + + def tool( + self, + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + structured_output: Optional[bool] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a tool that runs on the execution backend. + + Same calling convention as :meth:`FastMCP.tool` — **parens are + required** (``@mcp.tool()``, not ``@mcp.tool``). + + The additional parameters (``num_nodes``, ``processes_per_node``, + ``gpus_per_task``, ``env``, ``working_dir``) are forwarded to the + :class:`~chemgraph.execution.base.TaskSpec` that wraps the + decorated function when it is invoked. + + Parameters + ---------- + name, title, description, annotations, structured_output + Passed through to :meth:`FastMCP.add_tool`. + num_nodes : int + Number of compute nodes (default ``1``). + processes_per_node : int + Processes per node (default ``1``). + gpus_per_task : int + GPUs per task (default ``0``). + env : dict, optional + Extra environment variables for the worker. + working_dir : str, optional + Working directory for the task. + """ + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if title is not None: + fastmcp_kwargs["title"] = title + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + if structured_output is not None: + fastmcp_kwargs["structured_output"] = structured_output + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + def decorator(fn: Callable) -> Callable: + wrapper = self._make_backend_wrapper(fn, task_spec_kwargs) + self.add_tool(wrapper, **fastmcp_kwargs) + return fn + + return decorator + + # ── Ensemble tool registration ───────────────────────────────────── + + def ensemble_tool( + self, + name: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a fan-out tool that submits ``list[params]`` to the backend. + + Decorates ``fn(params: Schema) -> result``. The MCP tool schema + becomes ``list[Schema]`` — the LLM provides a list of jobs and + the framework submits each as a + :class:`~chemgraph.execution.base.TaskSpec`, then gathers results + via :func:`~chemgraph.execution.utils.submit_or_gather`. + + Parameters + ---------- + name, description, annotations + Passed through to :meth:`FastMCP.add_tool`. + num_nodes, processes_per_node, gpus_per_task, env, working_dir + Forwarded to :class:`~chemgraph.execution.base.TaskSpec`. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + + def decorator(fn: Callable) -> Callable: + self._fix_module_for_pickle(fn) + sig = inspect.signature(fn) + param = list(sig.parameters.values())[0] + param_type = param.annotation + + async def wrapper(params): + self._ensure_backend() + pending = [] + for i, p in enumerate(params): + task = TaskSpec( + task_id=f"{fn.__name__}_{i}", + task_type="python", + callable=fn, + kwargs={param.name: p}, + **task_spec_kwargs, + ) + fut = self._backend.submit(task) + pending.append(({"index": i}, fut)) + + return await submit_or_gather( + self._backend, + pending, + self._tracker, + name or fn.__name__, + ) + + wrapper.__name__ = name or fn.__name__ + wrapper.__doc__ = fn.__doc__ + wrapper.__module__ = fn.__module__ + wrapper.__qualname__ = fn.__qualname__ + + new_param = inspect.Parameter( + "params", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=list[param_type], + ) + wrapper.__signature__ = inspect.Signature( + parameters=[new_param] + ) + + self.add_tool(wrapper, **fastmcp_kwargs) + return fn + + return decorator + + # ── Internal ──────────────────────────────────────────────────────── + + def _make_backend_wrapper( + self, fn: Callable, task_spec_kwargs: dict[str, Any] + ) -> Callable: + """Build an async wrapper that submits *fn* to the backend.""" + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + self._fix_module_for_pickle(fn) + + @functools.wraps(fn) + async def wrapper(**kwargs: Any) -> Any: + self._ensure_backend() + task = TaskSpec( + task_id=fn.__name__, + task_type="python", + callable=fn, + kwargs=kwargs, + **task_spec_kwargs, + ) + fut = self._backend.submit(task) + + if self._backend.is_async_remote: + return await submit_or_gather( + self._backend, + [({"task_id": fn.__name__}, fut)], + self._tracker, + fn.__name__, + ) + + return await asyncio.wrap_future(fut) + + return wrapper diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index a664a1e7..70496186 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -1,46 +1,18 @@ """Backend-agnostic MACE MCP server. -Replaces ``mace_mcp_parsl.py`` by using the :mod:`chemgraph.execution` -abstraction layer. The execution backend (Parsl, EnsembleLauncher, -local) is selected at startup via ``config.toml`` or the -``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP` so that tool +functions are plain computation — the framework handles backend +submission, future resolution, and async job tracking. -Key improvements over the original: -- No hardcoded Polaris config or user-specific conda paths. -- Ensemble tool is now async (non-blocking event loop). -- Uses shared utilities for structure resolution and result gathering. +Nothing is initialised at import time so that worker subprocesses +(e.g. EnsembleLauncher) can safely re-import this module. """ -import asyncio -import json -import logging -import os -from pathlib import Path +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.schemas.mace_parsl_schema import mace_input_schema +from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core -from mcp.server.fastmcp import FastMCP - -from chemgraph.execution import TaskSpec, get_backend -from chemgraph.execution.job_tracker import JobTracker -from chemgraph.execution.utils import ( - make_per_structure_output, - resolve_structure_files, - submit_or_gather, -) -from chemgraph.mcp.job_tools import register_job_tools -from chemgraph.mcp.server_utils import run_mcp_server -from chemgraph.tools.parsl_tools import ( - mace_input_schema, - mace_input_schema_ensemble, -) - -logger = logging.getLogger(__name__) - -# ── Initialise execution backend ──────────────────────────────────────── -backend = get_backend() -tracker = JobTracker() - -# ── MCP server ────────────────────────────────────────────────────────── -mcp = FastMCP( +mcp = CGFastMCP( name="ChemGraph MACE Tools", instructions=""" You expose tools for running MACE simulations and reading their results. @@ -66,216 +38,45 @@ check_job_status to poll for progress before calling get_job_results. """, ) -register_job_tools(mcp, tracker, backend) - - -def _run_mace_single(job: dict) -> dict: - """Execute a single MACE simulation (runs on the worker). - - When the ``job`` dict contains an ``inline_structure`` key (with - ``numbers``, ``positions``, and optional ``cell``/``pbc``), the - structure is materialised as a temporary XYZ file on the worker - filesystem before running MACE. This allows local-agent / - remote-worker workflows where the original file only exists on the - submitting machine. - """ - import os - import tempfile - - from chemgraph.tools.parsl_tools import mace_input_schema, run_mace_core - - inline = job.pop("inline_structure", None) - if inline is not None: - from ase import Atoms - from ase.io import write as ase_write - - atoms = Atoms( - numbers=inline["numbers"], - positions=inline["positions"], - cell=inline.get("cell"), - pbc=inline.get("pbc"), - ) - tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") - xyz_path = os.path.join(tmpdir, "structure.xyz") - ase_write(xyz_path, atoms) - job["input_structure_file"] = xyz_path - - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - tmpdir, job.get("output_result_file", "output.json") - ) - - params = mace_input_schema(**job) if isinstance(job, dict) else job - result = run_mace_core(params) - - # Embed full output JSON when running with inline structure so the - # caller does not need to read a file on the remote filesystem. - if inline is not None: - out_file = job.get("output_result_file", "") - if os.path.isfile(out_file): - import json as _json - - with open(out_file, "r") as fh: - result["full_output"] = _json.load(fh) - - return result @mcp.tool( name="run_mace_single", description="Run a single MACE calculation", ) -async def run_mace_single(params: mace_input_schema): - """Run a single MACE calculation using the configured execution backend.""" - job = params.model_dump() - - # Read the local structure file and embed it so the job is - # self-contained and can run on any worker (local or remote). - input_file = job.get("input_structure_file") - if input_file and os.path.isfile(input_file): - from ase.io import read as ase_read - - from chemgraph.tools.ase_core import atoms_to_atomsdata - - atoms = ase_read(input_file) - atomsdata = atoms_to_atomsdata(atoms) - job["inline_structure"] = atomsdata.model_dump() - - task = TaskSpec( - task_id="mace_single", - task_type="python", - callable=_run_mace_single, - kwargs={"job": job}, - ) - fut = backend.submit(task) +def run_mace_single(params: mace_input_schema): + """Run a single MACE calculation on the execution backend.""" + import sys - if backend.is_async_remote: - task_meta = {"task_id": "mace_single"} - return await submit_or_gather( - backend, [(task_meta, fut)], tracker, "run_mace_single" - ) + old_stdout = sys.stdout + sys.stdout = sys.stderr + try: + return run_mace_core(params) + finally: + sys.stdout = old_stdout - return await asyncio.wrap_future(fut) - -def _mace_post_fn(meta: dict, result) -> dict: - """Post-process a completed MACE task.""" - status = result.get("status", "unknown") if isinstance(result, dict) else "success" - energy = result.get("single_point_energy") if isinstance(result, dict) else None - return { - "structure": meta["structure"], - "output_result_file": meta["output_result_file"], - "status": status, - "single_point_energy": energy, - "raw_result": result, - } - - -@mcp.tool( +@mcp.ensemble_tool( name="run_mace_ensemble", - description="Run an ensemble of MACE calculations", + description="Run an ensemble of MACE calculations for multiple inputs.", ) -async def run_mace_ensemble(params: mace_input_schema_ensemble): - """Run an ensemble of MACE calculations over all structure files in a - directory using the configured execution backend. - - Parameters - ---------- - params : mace_input_schema_ensemble - Input parameters for the ensemble of MACE calculations. - - Returns - ------- - dict - Summary of all jobs with minimal per-job results. - """ - structure_files, _output_dir = resolve_structure_files( - params.input_structure_directory, - ) - - # Base output file name used as a pattern for per-structure outputs - base_output = Path(params.output_result_file) - - pending_tasks = [] - for struct_path in structure_files: - per_struct_output = make_per_structure_output(struct_path, base_output) - - job = { - "input_structure_file": str(struct_path), - "output_result_file": str(per_struct_output), - "driver": params.driver, - "model": params.model, - "device": params.device, - "temperature": params.temperature, - "pressure": params.pressure, - "fmax": params.fmax, - "steps": params.steps, - "optimizer": params.optimizer, - } - - # Embed structure data so the job works on remote workers that - # cannot access the local filesystem. - if struct_path.is_file(): - from ase.io import read as ase_read +def _run_mace_worker(params: mace_input_schema): + return run_mace_core(params) - from chemgraph.tools.ase_core import atoms_to_atomsdata - atoms = ase_read(str(struct_path)) - atomsdata = atoms_to_atomsdata(atoms) - job["inline_structure"] = atomsdata.model_dump() - - task = TaskSpec( - task_id=f"mace_{struct_path.stem}", - task_type="python", - callable=_run_mace_single, - kwargs={"job": job}, - ) - fut = backend.submit(task) - - task_meta = { - "structure": struct_path.name, - "output_result_file": str(per_struct_output), - } - pending_tasks.append((task_meta, fut)) - - result = await submit_or_gather( - backend, pending_tasks, tracker, "run_mace_ensemble", - post_fn=_mace_post_fn, - ) - - if result["status"] == "completed": - return { - "status": "success", - "n_structures": len(structure_files), - "results": result["results"], - } - - # Async remote: return submission confirmation - result["n_structures"] = len(structure_files) - return result - - -@mcp.tool( +mcp.add_tool( + extract_output_json, name="extract_output_json", description="Load output from a JSON file.", ) -def extract_output_json(json_file: str) -> dict: - """Load simulation results from a JSON file produced by run_ase. - Parameters - ---------- - json_file : str - Path to the JSON file containing ASE simulation results. - Returns - ------- - dict - Parsed results from the JSON file. - """ - with open(json_file, "r") as f: - data = json.load(f) - return data +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + mcp.init_backend() -if __name__ == "__main__": - run_mcp_server(mcp, default_port=9004) + try: + run_mcp_server(mcp, default_port=9004) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/tools/parsl_tools.py b/src/chemgraph/tools/parsl_tools.py index 908ac29c..f2f10fce 100644 --- a/src/chemgraph/tools/parsl_tools.py +++ b/src/chemgraph/tools/parsl_tools.py @@ -6,23 +6,25 @@ from __future__ import annotations -from chemgraph.tools.ase_core import run_ase_core +import logging + from chemgraph.schemas.ase_input import ASEInputSchema from chemgraph.schemas.mace_parsl_schema import ( mace_input_schema, - mace_input_schema_ensemble, mace_output_schema, ) +from chemgraph.tools.ase_core import run_ase_core # Re-export schemas so existing ``from chemgraph.tools.parsl_tools import …`` # statements continue to work. __all__ = [ "mace_input_schema", - "mace_input_schema_ensemble", "mace_output_schema", "run_mace_core", + "extract_output_json", ] +logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Core execution — delegates to the unified implementation @@ -75,5 +77,17 @@ def run_mace_core(params: mace_input_schema) -> dict: dict Simulation result payload. """ - ase_params = _mace_input_to_ase_input(params) - return run_ase_core(ase_params) + try: + ase_params = _mace_input_to_ase_input(params) + return run_ase_core(ase_params) + except Exception as e: + print(f"Running ase failed with error:{e}") + return None + + +def extract_output_json(json_file: str) -> dict: + """Load simulation results from a JSON file produced by run_ase.""" + import json + + with open(json_file, "r") as f: + return json.load(f) From 6422b61d524869ad79f9c56f12ca6b17498bc38c Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:26:51 -0500 Subject: [PATCH 024/119] Fix PR #127 blockers: silent failure, decorator IndexError, hard EL import - parsl_tools.run_mace_core: stop swallowing exceptions and returning None. run_ase_core already returns a structured failure dict on simulation errors, and programmer errors should propagate. - cg_fastmcp.ensemble_tool: raise TypeError with a clear message when the decorated function does not have exactly one parameter, instead of crashing with IndexError at decoration time. - ensemble_launcher_backend: soft-import ensemble_launcher and defer the failure to construction / call time. SYSTEM_CONFIG_REGISTRY is now a lazy view backed by builder functions so the module loads cleanly without EL installed, restoring the deferred-error behaviour callers of chemgraph.execution.config expected. --- .../execution/ensemble_launcher_backend.py | 70 ++++++++++++++----- src/chemgraph/mcp/cg_fastmcp.py | 9 ++- src/chemgraph/tools/parsl_tools.py | 8 +-- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index e3863d38..56210d97 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -30,16 +30,31 @@ ) from ensemble_launcher.helper_functions import get_nodes from ensemble_launcher.orchestrator import ClusterClient -except ImportError as exc: - raise ImportError( - "EnsembleLauncher is required for the EnsembleLauncherBackend. " - "Install it with: pip install ensemble-launcher" - ) from exc + + _ENSEMBLE_LAUNCHER_AVAILABLE = True +except ImportError: + EnsembleLauncher = None + LauncherConfig = None + MPIConfig = None + PolicyConfig = None + SystemConfig = None + get_nodes = None + ClusterClient = None + _ENSEMBLE_LAUNCHER_AVAILABLE = False logger = logging.getLogger(__name__) +def _require_ensemble_launcher() -> None: + if not _ENSEMBLE_LAUNCHER_AVAILABLE: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) + + def get_local_system_config(): + _require_ensemble_launcher() system_config = SystemConfig( name="local", ncpus=os.cpu_count(), @@ -49,6 +64,7 @@ def get_local_system_config(): def get_polaris_system_config(): + _require_ensemble_launcher() system_config = SystemConfig( name="polaris", ncpus=32, @@ -60,6 +76,7 @@ def get_polaris_system_config(): def get_aurora_system_config(): + _require_ensemble_launcher() system_config = SystemConfig( name="aurora", ncpus=102, @@ -73,10 +90,11 @@ def get_aurora_system_config(): def get_launcher_config( task_executor_name: Union[str, List] = "async_processpool", child_executor_policy: str = "fixed_leafs_children_policy", - policy_config: Optional[PolicyConfig] = None, + policy_config=None, checkpoint_dir=f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}", mpi_flavour: Literal["test", "mpich"] = "test", ): + _require_ensemble_launcher() if policy_config is None: policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) return LauncherConfig( @@ -112,9 +130,10 @@ class EnsembleLauncherBackend(ExecutionBackend): """ def __init__(self) -> None: + _require_ensemble_launcher() super().__init__() - self._orchestrator: Optional[EnsembleLauncher] = None - self._client: Optional[ClusterClient] = None + self._orchestrator = None + self._client = None def initialize( self, @@ -123,8 +142,8 @@ def initialize( client_only: bool = False, checkpoint_dir: Optional[str] = None, node_id: str = "global", - system_config: Optional[SystemConfig] = None, - launcher_config: Optional[LauncherConfig] = None, + system_config=None, + launcher_config=None, startup_delay: float = 10.0, **kwargs, ) -> None: @@ -273,12 +292,29 @@ def shutdown(self) -> None: ) -SYSTEM_CONFIG_REGISTRY = { - "local": get_local_system_config(), - "aurora": get_aurora_system_config(), - "polaris": get_polaris_system_config(), +_SYSTEM_CONFIG_BUILDERS = { + "local": get_local_system_config, + "aurora": get_aurora_system_config, + "polaris": get_polaris_system_config, } -if __name__ == "__main__": - el_backend = EnsembleLauncherBackend() - el_backend.initialize() + +class _LazyRegistry: + """Built-on-first-access mapping of system name -> SystemConfig. + + Avoids importing ``ensemble_launcher`` at module load time. + """ + + def __contains__(self, key: str) -> bool: + return key in _SYSTEM_CONFIG_BUILDERS + + def __getitem__(self, key: str): + if key not in _SYSTEM_CONFIG_BUILDERS: + raise KeyError(key) + return _SYSTEM_CONFIG_BUILDERS[key]() + + def keys(self): + return _SYSTEM_CONFIG_BUILDERS.keys() + + +SYSTEM_CONFIG_REGISTRY = _LazyRegistry() diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index e25b23f9..a343fa3a 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -260,7 +260,14 @@ def ensemble_tool( def decorator(fn: Callable) -> Callable: self._fix_module_for_pickle(fn) sig = inspect.signature(fn) - param = list(sig.parameters.values())[0] + params = list(sig.parameters.values()) + if len(params) != 1: + raise TypeError( + f"@ensemble_tool expects a function with exactly one " + f"parameter (the per-item schema), got {len(params)} " + f"on {fn.__qualname__}." + ) + param = params[0] param_type = param.annotation async def wrapper(params): diff --git a/src/chemgraph/tools/parsl_tools.py b/src/chemgraph/tools/parsl_tools.py index f2f10fce..86e43a7f 100644 --- a/src/chemgraph/tools/parsl_tools.py +++ b/src/chemgraph/tools/parsl_tools.py @@ -77,12 +77,8 @@ def run_mace_core(params: mace_input_schema) -> dict: dict Simulation result payload. """ - try: - ase_params = _mace_input_to_ase_input(params) - return run_ase_core(ase_params) - except Exception as e: - print(f"Running ase failed with error:{e}") - return None + ase_params = _mace_input_to_ase_input(params) + return run_ase_core(ase_params) def extract_output_json(json_file: str) -> dict: From a8a3a87ba507461c2cba8ce5d9e275c8c68602f9 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:33:52 -0500 Subject: [PATCH 025/119] Add JobTracker persistence and Globus task UUID round-trip - persist_file parameter: when set, batch metadata and Globus Compute task UUIDs are written to JSON after registration and after results are cached, and loaded on init. Allows MCP servers to recover job state across restarts. - TrackedTask.globus_task_id and TrackedTask.future are both optional; loaded-from-disk batches have no in-memory Future and are queried via the Globus Compute Client directly in get_status. - Lazy Globus Compute Client with a separate gc_lock for thread safety. - _wait_for_globus_task_ids polls each ComputeFuture briefly after submission to capture the Globus task_id assigned asynchronously by the Executor background thread. - cancel_batch / cleanup_old_batches handle the no-future case. --- src/chemgraph/execution/job_tracker.py | 237 +++++++++++++++++++++++-- 1 file changed, 226 insertions(+), 11 deletions(-) diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py index 87b473c0..23f6c837 100644 --- a/src/chemgraph/execution/job_tracker.py +++ b/src/chemgraph/execution/job_tracker.py @@ -7,16 +7,23 @@ Each MCP server process creates its own ``JobTracker`` instance (mirroring the existing ``backend = get_backend()`` pattern). + +When a *persist_file* is provided, batch metadata and Globus Compute +task UUIDs are written to a JSON file so that a future session can +reload them and query Globus Compute directly for results. """ from __future__ import annotations +import json import logging import threading +import time import uuid from concurrent.futures import Future from dataclasses import dataclass, field from datetime import datetime, timezone +from pathlib import Path from typing import Any, Callable, Optional logger = logging.getLogger(__name__) @@ -28,7 +35,8 @@ class TrackedTask: task_id: str meta: dict - future: Future + future: Optional[Future] = None + globus_task_id: Optional[str] = None result: Optional[dict] = None @@ -47,11 +55,112 @@ class JobTracker: """Track submitted job batches and their futures. Thread-safe: all public methods acquire an internal lock. + + Parameters + ---------- + persist_file : Path or str, optional + Path to a JSON file for persisting batch metadata across + sessions. When set, batches are saved after registration and + after results are cached. On init, existing batches are loaded. """ - def __init__(self) -> None: + def __init__(self, persist_file: Optional[Path | str] = None) -> None: self._batches: dict[str, TrackedBatch] = {} self._lock = threading.Lock() + self._gc_lock = threading.Lock() + self._persist_file = Path(persist_file) if persist_file else None + self._gc_client = None # lazily initialised Globus Compute Client + + if self._persist_file is not None: + self._load() + + # ── Globus Compute client (lazy) ────────────────────────────────── + + def _get_gc_client(self): + """Return a Globus Compute ``Client`` (created once, reused).""" + if self._gc_client is not None: + return self._gc_client + with self._gc_lock: + if self._gc_client is None: + try: + from globus_compute_sdk import Client + + self._gc_client = Client() + except Exception: + logger.warning( + "Could not create Globus Compute Client", + exc_info=True, + ) + return None + return self._gc_client + + # ── persistence ─────────────────────────────────────────────────── + + def _save(self) -> None: + """Write current batch metadata to *persist_file*.""" + if self._persist_file is None: + return + + data: dict[str, Any] = {} + with self._lock: + for bid, batch in self._batches.items(): + data[bid] = { + "tool_name": batch.tool_name, + "submitted_at": batch.submitted_at.isoformat(), + "tasks": [ + { + "task_id": t.task_id, + "meta": t.meta, + "globus_task_id": t.globus_task_id, + "result": t.result, + } + for t in batch.tasks + ], + } + + self._persist_file.parent.mkdir(parents=True, exist_ok=True) + tmp = self._persist_file.with_suffix(".tmp") + with open(tmp, "w") as f: + json.dump(data, f, indent=2) + tmp.replace(self._persist_file) + + def _load(self) -> None: + """Load batch metadata from *persist_file* (if it exists).""" + if self._persist_file is None or not self._persist_file.is_file(): + return + + try: + with open(self._persist_file) as f: + data = json.load(f) + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Could not load job tracker state: %s", exc) + return + + with self._lock: + for bid, info in data.items(): + if bid in self._batches: + continue # don't overwrite live batches + + tasks = [ + TrackedTask( + task_id=t["task_id"], + meta=t.get("meta", {}), + future=None, + globus_task_id=t.get("globus_task_id"), + result=t.get("result"), + ) + for t in info.get("tasks", []) + ] + self._batches[bid] = TrackedBatch( + batch_id=bid, + tool_name=info["tool_name"], + submitted_at=datetime.fromisoformat(info["submitted_at"]), + tasks=tasks, + ) + + logger.info( + "Loaded %d batches from %s", len(data), self._persist_file + ) # ── registration ─────────────────────────────────────────────────── @@ -103,13 +212,61 @@ def register_batch( tool_name, len(tracked), ) + + # Wait briefly for the Executor background thread to set task_ids + # on the ComputeFutures. Typically takes ~1-2 s; we cap at 3 s + # so the MCP tool response isn't delayed excessively. + self._wait_for_globus_task_ids(tracked, timeout=3.0) + self._save() return batch_id + def _wait_for_globus_task_ids( + self, tasks: list[TrackedTask], timeout: float = 3.0 + ) -> None: + """Wait up to *timeout* seconds for Globus ``task_id`` to appear + on each ComputeFuture, then store them for persistence.""" + deadline = time.monotonic() + timeout + pending = [t for t in tasks if t.future is not None and t.globus_task_id is None] + + while pending and time.monotonic() < deadline: + still_pending = [] + for t in pending: + gc_id = getattr(t.future, "task_id", None) + if gc_id is not None: + t.globus_task_id = str(gc_id) + else: + still_pending.append(t) + pending = still_pending + if pending: + time.sleep(0.25) + + if pending: + logger.debug( + "%d tasks did not receive a Globus task_id within %.1fs", + len(pending), + timeout, + ) + + def _try_capture_globus_task_ids(self, tasks: list[TrackedTask]) -> bool: + """Non-blocking: extract ``task_id`` from any ComputeFuture that + has one available. Returns True if any new IDs were captured.""" + captured = False + for t in tasks: + if t.globus_task_id is None and t.future is not None: + gc_id = getattr(t.future, "task_id", None) + if gc_id is not None: + t.globus_task_id = str(gc_id) + captured = True + return captured + # ── status ───────────────────────────────────────────────────────── def get_status(self, batch_id: str) -> dict: """Return the current status of a batch. + For tasks loaded from disk (no in-memory ``Future``), queries + Globus Compute directly if a ``globus_task_id`` is available. + Returns ------- dict @@ -125,11 +282,16 @@ def get_status(self, batch_id: str) -> dict: total = len(batch.tasks) done = 0 failed = 0 + # Lazily capture Globus Compute task UUIDs (set asynchronously + # by the Executor background thread after submission). + dirty = self._try_capture_globus_task_ids(batch.tasks) for t in batch.tasks: - if t.future.done(): - done += 1 - # Cache the result on first check + task_done = False + + # --- live future path --- + if t.future is not None and t.future.done(): + task_done = True if t.result is None: try: raw = t.future.result(timeout=0) @@ -152,8 +314,55 @@ def get_status(self, batch_id: str) -> dict: "error_type": type(e).__name__, "message": str(e), } - if t.result.get("status") == "failure": - failed += 1 + dirty = True + + # --- loaded-from-disk path (no future, use Globus client) --- + elif t.future is None and t.result is None and t.globus_task_id: + gc = self._get_gc_client() + if gc is not None: + try: + task_info = gc.get_task(t.globus_task_id) + if not task_info.get("pending", True): + task_done = True + if "result" in task_info: + raw = task_info["result"] + if isinstance(raw, dict): + merged = {**t.meta, **raw} + merged.setdefault("status", "success") + t.result = merged + else: + t.result = { + **t.meta, + "result": raw, + "status": "success", + } + elif "exception" in task_info: + t.result = { + **t.meta, + "status": "failure", + "error_type": "RemoteException", + "message": str(task_info["exception"]), + } + dirty = True + except Exception as e: + logger.warning( + "Failed to query Globus task %s: %s", + t.globus_task_id, + e, + exc_info=True, + ) + + # --- already have a cached result --- + elif t.result is not None: + task_done = True + + if task_done: + done += 1 + if t.result is not None and t.result.get("status") == "failure": + failed += 1 + + if dirty: + self._save() pending = total - done if pending == total: @@ -259,7 +468,9 @@ def cancel_batch(self, batch_id: str) -> dict: cancelled = 0 already_done = 0 for t in batch.tasks: - if t.future.done(): + if t.future is None: + already_done += 1 + elif t.future.done(): already_done += 1 elif t.future.cancel(): cancelled += 1 @@ -284,13 +495,17 @@ def cleanup(self, max_age_hours: float = 24) -> int: with self._lock: for bid, batch in self._batches.items(): age_hours = (now - batch.submitted_at).total_seconds() / 3600 - if age_hours > max_age_hours and all( - t.future.done() for t in batch.tasks - ): + all_done = all( + (t.future is not None and t.future.done()) + or t.result is not None + for t in batch.tasks + ) + if age_hours > max_age_hours and all_done: to_remove.append(bid) for bid in to_remove: del self._batches[bid] if to_remove: logger.info("Cleaned up %d old batches", len(to_remove)) + self._save() return len(to_remove) From 78c9c33790e68b421bfe9fb892742409bbc8a680 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:42:03 -0500 Subject: [PATCH 026/119] Extend CGFastMCP: tracker kwargs, pre-submit hook, schema_fanout_tool - init_backend now accepts tracker_kwargs= and forwards it to JobTracker(...) in _ensure_backend. Callers can pass persist_file= so MCP servers recover job state across restarts. - set_pre_submit_hook(hook): hook receives each TaskSpec before backend.submit() and returns a (possibly mutated) one. Lets a server centralise transport concerns -- inline-structure embedding for local-submit-to-remote-worker, remote-path rewriting -- instead of repeating that logic in every tool body. Wired into the @tool, @ensemble_tool, and @schema_fanout_tool submit paths. - @schema_fanout_tool(worker=...): the decorated function is an expander (ensemble schema -> list of per-item args). The framework calls worker(item) on the backend for each item and gathers results. Preserves the ensemble schema as the agent-facing API (one tool call, server-side fanout), complementing @ensemble_tool which exposes list[Schema] for callers that want client-side enumeration. --- src/chemgraph/mcp/cg_fastmcp.py | 172 +++++++++++++++++++++++++++++++- 1 file changed, 168 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index a343fa3a..3a84c9f7 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -37,18 +37,33 @@ def __init__(self, **kwargs: Any) -> None: self._backend = None self._tracker = None self._backend_kwargs: Optional[dict[str, Any]] = None + self._tracker_kwargs: dict[str, Any] = {} + self._pre_submit_hook: Optional[Callable] = None # ── Backend lifecycle ─────────────────────────────────────────────── - def init_backend(self, **kwargs: Any) -> None: + def init_backend( + self, + *, + tracker_kwargs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: """Register backend configuration for lazy initialisation. The backend is not created until the first tool invocation, so the MCP server can start accepting connections immediately. - All keyword arguments are forwarded to - :func:`~chemgraph.execution.config.get_backend`. + + Parameters + ---------- + tracker_kwargs : dict, optional + Forwarded to :class:`~chemgraph.execution.job_tracker.JobTracker` + on first use. Use this to pass ``persist_file`` for cross-session + job state recovery. + **kwargs + Forwarded to :func:`~chemgraph.execution.config.get_backend`. """ self._backend_kwargs = kwargs + self._tracker_kwargs = tracker_kwargs or {} self._register_job_tools() logger.info("CGFastMCP backend configured (lazy init).") @@ -63,7 +78,7 @@ def _ensure_backend(self) -> None: from chemgraph.execution import JobTracker, get_backend self._backend = get_backend(**self._backend_kwargs) - self._tracker = JobTracker() + self._tracker = JobTracker(**self._tracker_kwargs) logger.info( "CGFastMCP backend initialised: %s", type(self._backend).__name__ ) @@ -78,8 +93,31 @@ def shutdown_backend(self) -> None: self._backend = None self._tracker = None self._backend_kwargs = None + self._tracker_kwargs = {} logger.info("CGFastMCP backend shut down.") + # ── Pre-submit transport hook ────────────────────────────────────── + + def set_pre_submit_hook(self, hook: Optional[Callable]) -> None: + """Register a hook that transforms each TaskSpec before submission. + + The hook receives the :class:`~chemgraph.execution.base.TaskSpec` + and must return one (possibly the same instance). Used for + transport concerns that should apply to every backend-submitted + tool on this server -- e.g. embedding a local structure file + into ``kwargs`` so a remote worker can materialise it, or + rewriting a local path to a pre-staged remote path. + + Pass ``None`` to clear the hook. + """ + self._pre_submit_hook = hook + + def _apply_pre_submit_hook(self, task): + """Run the registered pre-submit hook (no-op when unset).""" + if self._pre_submit_hook is None: + return task + return self._pre_submit_hook(task) + # ── Job tracking tools ───────────────────────────────────────────── def _register_job_tools(self) -> None: @@ -281,6 +319,7 @@ async def wrapper(params): kwargs={param.name: p}, **task_spec_kwargs, ) + task = self._apply_pre_submit_hook(task) fut = self._backend.submit(task) pending.append(({"index": i}, fut)) @@ -310,6 +349,130 @@ async def wrapper(params): return decorator + # ── Schema-driven fanout tool ────────────────────────────────────── + + def schema_fanout_tool( + self, + *, + worker: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a fan-out tool driven by a single *ensemble* schema. + + The decorated function is an **expander**: it receives the + ensemble schema and returns a list of per-item arguments. The + framework calls ``worker(item)`` on the backend for each item, + gathers the results, and returns a batch summary -- same shape + as :meth:`ensemble_tool`. + + Unlike :meth:`ensemble_tool` (whose tool signature is + ``list[Schema]``), this preserves the ensemble schema as the + agent-facing API, so the LLM makes a single tool call against + e.g. ``input_structure_directory`` and server-side expansion + produces the per-file jobs. + + Parameters + ---------- + worker : Callable + The per-item function executed on the backend. Must take + a single positional argument (the item produced by the + expander). + name, description, annotations + Passed through to :meth:`FastMCP.add_tool`. + num_nodes, processes_per_node, gpus_per_task, env, working_dir + Forwarded to each :class:`~chemgraph.execution.base.TaskSpec`. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + + # Worker is what actually runs on the backend, so it must be + # picklable from the MCP server's __main__ module. + self._fix_module_for_pickle(worker) + + worker_sig = inspect.signature(worker) + worker_params = list(worker_sig.parameters.values()) + if len(worker_params) != 1: + raise TypeError( + f"schema_fanout_tool worker must take exactly one " + f"parameter, got {len(worker_params)} on " + f"{worker.__qualname__}." + ) + worker_param_name = worker_params[0].name + + def decorator(expander: Callable) -> Callable: + sig = inspect.signature(expander) + params = list(sig.parameters.values()) + if len(params) != 1: + raise TypeError( + f"@schema_fanout_tool expander must take exactly one " + f"parameter (the ensemble schema), got {len(params)} " + f"on {expander.__qualname__}." + ) + param = params[0] + tool_name = name or expander.__name__ + + async def wrapper(**kwargs): + self._ensure_backend() + ensemble_params = kwargs[param.name] + items = expander(ensemble_params) + pending = [] + for i, item in enumerate(items): + task = TaskSpec( + task_id=f"{tool_name}_{i}", + task_type="python", + callable=worker, + kwargs={worker_param_name: item}, + **task_spec_kwargs, + ) + task = self._apply_pre_submit_hook(task) + fut = self._backend.submit(task) + pending.append(({"index": i}, fut)) + + return await submit_or_gather( + self._backend, + pending, + self._tracker, + tool_name, + ) + + wrapper.__name__ = tool_name + wrapper.__doc__ = expander.__doc__ + wrapper.__module__ = expander.__module__ + wrapper.__qualname__ = expander.__qualname__ + # Preserve the expander's signature so FastMCP advertises the + # ensemble schema to the LLM, not the worker's per-item one. + wrapper.__signature__ = sig + + self.add_tool(wrapper, **fastmcp_kwargs) + return expander + + return decorator + # ── Internal ──────────────────────────────────────────────────────── def _make_backend_wrapper( @@ -331,6 +494,7 @@ async def wrapper(**kwargs: Any) -> Any: kwargs=kwargs, **task_spec_kwargs, ) + task = self._apply_pre_submit_hook(task) fut = self._backend.submit(task) if self._backend.is_async_remote: From 2d6a28341dcbfc4789c54b5b5ab7d553c0e0691c Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:49:14 -0500 Subject: [PATCH 027/119] Add remote_structure_directory schemas, GC executor recovery, XANES persistence - mace_input_schema_ensemble / graspa_input_schema_ensemble: new remote_structure_directory field for pre-staged HPC files (paired with the upcoming transfer_files tool). input_structure_directory now defaults to empty string so callers can pass either. - mace_input_schema/_ensemble model description spells out that 'mace_mp' is the calculator type, not a model name -- LLMs were confusing the two. - Nullable schema fields (driver, model, wall_time) typed as str|None / float|None for correct OpenAPI schema generation. - GlobusComputeBackend._ensure_executor re-creates the Executor when it has been shut down (e.g. after a remote task failure). Uses getattr() so we don't depend on the SDK's private _stopped attr existing. - check_endpoint_status logs exc_info on failure for easier debugging. - xanes_mcp_hpc: JobTracker(persist_file=~/.chemgraph/xanes_jobs.json) so XANES job state survives MCP server restarts. Instructions updated to tell the LLM to surface batch_ids to the user. --- pyproject.toml | 6 +++ .../execution/globus_compute_backend.py | 14 +++++++ src/chemgraph/mcp/xanes_mcp_hpc.py | 12 ++++-- src/chemgraph/schemas/graspa_schema.py | 11 +++++- src/chemgraph/schemas/mace_parsl_schema.py | 37 ++++++++++++++----- 5 files changed, 67 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 28896e0f..428edf96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,9 @@ ensemble_launcher = [ globus_compute = [ "globus-compute-sdk", ] +academy = [ + "academy-py", +] xanes = [ "mp-api; python_version >= '3.11'", "parsl" @@ -115,6 +118,9 @@ testpaths = ["tests"] markers = [ "llm: marks tests as requiring LLM API access (run with --run-llm)", "globus_compute: marks tests requiring a live Globus Compute endpoint (run with --run-globus-compute)", + "parsl: marks tests requiring a live Parsl deployment (run with --run-parsl)", + "ensemble_launcher: marks tests requiring a live EnsembleLauncher deployment (run with --run-ensemble-launcher)", + "academy: marks tests requiring Academy agent infrastructure (run with --run-academy)", "asyncio: marks async tests", ] filterwarnings = [ diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index 2ec2bba1..f73bd5af 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -93,12 +93,23 @@ def initialize(self, system: str = "local", **kwargs: Any) -> None: # ── task submission ───────────────────────────────────────────────── + def _ensure_executor(self) -> None: + """Re-create the Executor if it was shut down (e.g. after a + task failure).""" + from globus_compute_sdk import Executor + + if self._executor is None or getattr(self._executor, "_stopped", False): + logger.info("Re-creating Globus Compute Executor") + self._executor = Executor(endpoint_id=self._endpoint_id) + def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._executor is None: raise RuntimeError( "GlobusComputeBackend is not initialized. Call initialize() first." ) + self._ensure_executor() + if task.task_type == "python": if task.callable is None: raise ValueError( @@ -142,6 +153,9 @@ def check_endpoint_status(self) -> dict: "status": status, } except Exception as e: + logger.warning( + "Endpoint status check failed: %s", e, exc_info=True, + ) return { "endpoint_id": self._endpoint_id, "status": "error", diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 4b0d219b..4abb94e0 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -30,7 +30,9 @@ # ── Initialise execution backend ──────────────────────────────────────── backend = get_backend() -tracker = JobTracker() + +_jobs_file = Path("~/.chemgraph/xanes_jobs.json").expanduser() +tracker = JobTracker(persist_file=_jobs_file) # ── MCP server ────────────────────────────────────────────────────────── mcp = FastMCP( @@ -54,8 +56,12 @@ - Keep responses compact -- full results are in the output directories. - When returning paths, use absolute paths. - Energies are in eV. - - When a tool returns status='submitted' with a batch_id, use - check_job_status to poll for progress before calling get_job_results. + - When a tool returns status='submitted' with a batch_id, call + get_job_results(batch_id) to retrieve results. If the job is + still pending, report the batch_id to the user so they can + check later. Job state is persisted across sessions -- the + user can call list_jobs or get_job_results in a future session + to retrieve results. """, ) register_job_tools(mcp, tracker, backend) diff --git a/src/chemgraph/schemas/graspa_schema.py b/src/chemgraph/schemas/graspa_schema.py index 9cd08231..996ec12b 100644 --- a/src/chemgraph/schemas/graspa_schema.py +++ b/src/chemgraph/schemas/graspa_schema.py @@ -46,7 +46,16 @@ class graspa_input_schema(BaseModel): class graspa_input_schema_ensemble(BaseModel): input_structures: Union[str, list[str]] = Field( - description="Path to a directory of CIF files OR a specific list of file paths." + default="", + description="Path to a directory of CIF files OR a specific list of file paths. Required unless remote_structure_directory is provided.", + ) + remote_structure_directory: str | None = Field( + default=None, + description=( + "Path to pre-staged CIF files on the remote HPC filesystem. " + "When provided, workers read structures directly from this path. " + "Use the transfer_files tool to stage files first." + ), ) output_result_file: str = Field( default="raspa.log", diff --git a/src/chemgraph/schemas/mace_parsl_schema.py b/src/chemgraph/schemas/mace_parsl_schema.py index e04ddba6..63dc8008 100644 --- a/src/chemgraph/schemas/mace_parsl_schema.py +++ b/src/chemgraph/schemas/mace_parsl_schema.py @@ -17,14 +17,18 @@ class mace_input_schema(BaseModel): default="output.json", description="Path to a JSON file where simulation results will be saved.", ) - driver: str = Field( + driver: str | None = Field( default=None, description="Specifies the type of simulation to run. Options: 'energy' for single-point energy calculations, 'opt' for geometry optimization, 'vib' for vibrational frequency analysis, and 'thermo' for thermochemical properties (including enthalpy, entropy, and Gibbs free energy).", ) model: str = Field( default="medium-mpa-0", - description="Path to the model. Default is medium-mpa-0." - "Options are 'small', 'medium', 'large', 'small-0b', 'medium-0b', 'small-0b2', 'medium-0b2','large-0b2', 'medium-0b3', 'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', 'mace-matpes-r2scan-0'", + description="MACE foundation model name (NOT the calculator type). " + "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " + "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " + "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " + "'mace-matpes-r2scan-0'. Default is 'medium-mpa-0'. " + "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( default="cpu", @@ -54,20 +58,35 @@ class mace_input_schema(BaseModel): class mace_input_schema_ensemble(BaseModel): input_structure_directory: str = Field( - description="Path to a folder of input structures containing the atomic structure for the simulations." + default="", + description="Path to a local folder of input structures. Required unless remote_structure_directory is provided.", + ) + remote_structure_directory: str | None = Field( + default=None, + description=( + "Path to pre-staged structure files on the remote HPC filesystem. " + "When provided, workers read structures directly from this path " + "instead of using inline structure embedding. Use the " + "transfer_files tool to stage files first, then pass the " + "remote directory here." + ), ) output_result_file: str = Field( default="output.json", description="Path to a JSON file where simulation results will be saved.", ) - driver: str = Field( + driver: str | None = Field( default=None, description="Specifies the type of simulation to run. Options: 'energy' for single-point energy calculations, 'opt' for geometry optimization, 'vib' for vibrational frequency analysis, and 'thermo' for thermochemical properties (including enthalpy, entropy, and Gibbs free energy).", ) model: str = Field( default="medium-mpa-0", - description="Path to the model. Default is medium-mpa-0." - "Options are 'small', 'medium', 'large', 'small-0b', 'medium-0b', 'small-0b2', 'medium-0b2','large-0b2', 'medium-0b3', 'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', 'mace-matpes-r2scan-0'", + description="MACE foundation model name (NOT the calculator type). " + "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " + "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " + "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " + "'mace-matpes-r2scan-0'. Default is 'medium-mpa-0'. " + "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( default="cpu", @@ -102,7 +121,7 @@ class mace_output_schema(BaseModel): output_result_file: str = Field( description="Path to a JSON file where simulation results is saved.", ) - model: str = Field( + model: str | None = Field( default=None, description="Path to the model. Default is medium-mpa-0." ) device: str = Field( @@ -143,7 +162,7 @@ class mace_output_schema(BaseModel): default="", description="Error captured during the simulation", ) - wall_time: float = Field( + wall_time: float | None = Field( default=None, description="Total wall time (in seconds) taken to complete the simulation.", ) From f1863d365bf6cb15bddf7db7b2f04afe9982fd26 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 12:55:27 -0500 Subject: [PATCH 028/119] Add Globus Transfer manager and MCP file-staging tools - execution/globus_transfer.py: GlobusTransferManager wraps the globus_sdk TransferClient with token caching, batched transfer_files / wait_for_transfer / check_transfer_status / list_remote_directory. Lazy globus_sdk import, lazy auth. - execution/config.get_transfer_manager(): builds a manager from [execution.globus_transfer] in config.toml with env var overrides (GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID, _DESTINATION_ENDPOINT_ID, _DESTINATION_BASE_PATH). Returns None when not configured so MCP servers can skip registration silently. - mcp/transfer_tools.register_transfer_tools(): registers transfer_files, check_transfer_status, list_remote_files on a FastMCP/CGFastMCP server. Uses mcp.add_tool() (not the backend-submitting @tool() decorator) because these are orchestration tools, not compute tasks -- they call the Globus Transfer API directly from the MCP server process. - get_backend() globus_compute endpoint_id fallback now treats empty-string endpoint_id as unset, matching the GLOBUS_COMPUTE_ ENDPOINT_ID env-var override behaviour. --- src/chemgraph/execution/config.py | 62 +++- src/chemgraph/execution/globus_transfer.py | 325 +++++++++++++++++++++ src/chemgraph/mcp/transfer_tools.py | 186 ++++++++++++ 3 files changed, 572 insertions(+), 1 deletion(-) create mode 100644 src/chemgraph/execution/globus_transfer.py create mode 100644 src/chemgraph/mcp/transfer_tools.py diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index fb4fac4b..dc650a26 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -125,7 +125,7 @@ def get_backend( merged_kwargs = {**backend_cfg, **kwargs} # Globus Compute: fall back to GLOBUS_COMPUTE_ENDPOINT_ID env var - if resolved_backend == "globus_compute" and "endpoint_id" not in merged_kwargs: + if resolved_backend == "globus_compute" and not merged_kwargs.get("endpoint_id"): env_id = os.getenv("GLOBUS_COMPUTE_ENDPOINT_ID") if env_id: merged_kwargs["endpoint_id"] = env_id @@ -183,3 +183,63 @@ def get_backend( backend.initialize(system=resolved_system, **merged_kwargs) return backend + + +def get_transfer_manager( + config_path: Optional[str] = None, + **kwargs: Any, +): + """Create a :class:`GlobusTransferManager` from config, or ``None``. + + Reads the ``[execution.globus_transfer]`` section from + ``config.toml``. Returns ``None`` when the required endpoint IDs + are not configured, so callers can skip transfer-tool registration. + + Environment variable overrides + ------------------------------ + ``GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID`` + ``GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID`` + ``GLOBUS_TRANSFER_DESTINATION_BASE_PATH`` + """ + cfg = _load_execution_config(config_path) + transfer_cfg = cfg.get("globus_transfer", {}) + merged = {**transfer_cfg, **kwargs} + + for key, env_var in ( + ("source_endpoint_id", "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID"), + ("destination_endpoint_id", "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID"), + ("destination_base_path", "GLOBUS_TRANSFER_DESTINATION_BASE_PATH"), + ): + if not merged.get(key): + env_val = os.getenv(env_var) + if env_val: + merged[key] = env_val + + required = ( + "source_endpoint_id", + "destination_endpoint_id", + "destination_base_path", + ) + if not all(merged.get(k) for k in required): + logger.debug( + "Globus Transfer not configured (missing %s). " + "Transfer tools will not be registered.", + [k for k in required if not merged.get(k)], + ) + return None + + from chemgraph.execution.globus_transfer import GlobusTransferManager + + manager = GlobusTransferManager( + source_endpoint_id=merged["source_endpoint_id"], + destination_endpoint_id=merged["destination_endpoint_id"], + destination_base_path=merged["destination_base_path"], + source_base_path=merged.get("source_base_path"), + client_id=merged.get("client_id"), + ) + logger.info( + "GlobusTransferManager created: %s -> %s", + merged["source_endpoint_id"], + merged["destination_endpoint_id"], + ) + return manager diff --git a/src/chemgraph/execution/globus_transfer.py b/src/chemgraph/execution/globus_transfer.py new file mode 100644 index 00000000..d8081ab3 --- /dev/null +++ b/src/chemgraph/execution/globus_transfer.py @@ -0,0 +1,325 @@ +"""Globus Transfer file-staging manager. + +Transfers files between a local Globus collection and a remote HPC +collection using the `Globus Transfer API +`_. This avoids encoding large +input files (e.g. atomic structures) inside Globus Compute function +payloads. + +**Prerequisites** + +1. Install ``globus_sdk`` (already a core dependency). +2. Have *Globus Connect Personal* running on the submitting machine + **or** use a managed Globus endpoint. +3. Configure endpoint IDs and base path in ``config.toml``:: + + [execution.globus_transfer] + source_endpoint_id = "" + destination_endpoint_id = "" + destination_base_path = "/eagle/projects/MyProject/staging" +""" + +from __future__ import annotations + +import logging +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# Globus Transfer API scope +TRANSFER_SCOPE = "urn:globus:auth:scope:transfer.api.globus.org:all" + +# Default Globus native-app client ID (Globus Tutorial client). +# Projects should register their own app at https://app.globus.org. +_DEFAULT_CLIENT_ID = "61338d24-54d5-408f-a10d-66c06b59f6d2" + + +@dataclass +class TransferResult: + """Metadata returned after submitting a Globus Transfer task.""" + + task_id: str + source_endpoint_id: str + destination_endpoint_id: str + file_mapping: dict[str, str] # local_path -> remote_path + remote_directory: str + submitted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + label: str = "" + + +class GlobusTransferManager: + """Manage file transfers between local and remote Globus collections. + + Parameters + ---------- + source_endpoint_id : str + UUID of the Globus collection on the submitting machine. + destination_endpoint_id : str + UUID of the Globus collection on the HPC system. + destination_base_path : str + Root directory on the destination where staged files are placed. + Each transfer batch creates a subdirectory underneath. + source_base_path : str, optional + If provided, local paths are resolved relative to this directory. + client_id : str, optional + Globus app client ID for OAuth. Defaults to the Globus Tutorial + client. + """ + + def __init__( + self, + source_endpoint_id: str, + destination_endpoint_id: str, + destination_base_path: str, + source_base_path: Optional[str] = None, + client_id: Optional[str] = None, + ) -> None: + self.source_endpoint_id = source_endpoint_id + self.destination_endpoint_id = destination_endpoint_id + self.destination_base_path = destination_base_path.rstrip("/") + self.source_base_path = source_base_path + self._client_id = client_id or _DEFAULT_CLIENT_ID + self._transfer_client = None + + # ── authentication ────────────────────────────────────────────────── + + def _get_transfer_client(self): + """Lazily create an authenticated ``TransferClient``.""" + if self._transfer_client is not None: + return self._transfer_client + + try: + import globus_sdk + except ImportError as exc: + raise ImportError( + "globus_sdk is required for Globus Transfer. " + "Install it with: pip install globus-sdk" + ) from exc + + client = globus_sdk.NativeAppAuthClient(self._client_id) + client.oauth2_start_flow( + requested_scopes=TRANSFER_SCOPE, + refresh_tokens=True, + ) + + # Try loading cached tokens first + token_file = ( + Path.home() / ".globus" / "chemgraph_transfer_tokens.json" + ) + tokens = self._load_tokens(token_file) + + if tokens is None: + # Interactive login required + authorize_url = client.oauth2_get_authorize_url() + logger.info( + "Globus Transfer authentication required.\n" + "Go to this URL and login:\n %s", + authorize_url, + ) + print( + "\nGlobus Transfer authentication required.\n" + f"Go to this URL and login:\n {authorize_url}\n" + ) + auth_code = input("Enter the authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(auth_code) + tokens = token_response.by_resource_server["transfer.api.globus.org"] + self._save_tokens(token_file, tokens) + else: + # Refresh if expired + if tokens.get("expires_at_seconds", 0) < time.time(): + try: + token_response = client.oauth2_refresh_tokens( + globus_sdk.RefreshTokenAuthorizer( + tokens["refresh_token"], client + ) + ) + tokens = token_response.by_resource_server[ + "transfer.api.globus.org" + ] + self._save_tokens(token_file, tokens) + except Exception: + logger.warning( + "Token refresh failed, falling back to existing token." + ) + + authorizer = globus_sdk.AccessTokenAuthorizer(tokens["access_token"]) + self._transfer_client = globus_sdk.TransferClient(authorizer=authorizer) + return self._transfer_client + + @staticmethod + def _load_tokens(path: Path) -> Optional[dict]: + if not path.is_file(): + return None + import json + + try: + with open(path) as f: + return json.load(f) + except (json.JSONDecodeError, KeyError): + return None + + @staticmethod + def _save_tokens(path: Path, tokens: dict) -> None: + import json + + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(dict(tokens), f, indent=2) + path.chmod(0o600) + + # ── transfers ─────────────────────────────────────────────────────── + + def transfer_files( + self, + local_paths: list[str], + remote_subdir: Optional[str] = None, + label: Optional[str] = None, + ) -> TransferResult: + """Submit a Globus Transfer task to stage files on the remote endpoint. + + Parameters + ---------- + local_paths : list[str] + Absolute paths to local files to transfer. + remote_subdir : str, optional + Subdirectory name under ``destination_base_path``. A UUID-based + name is generated if omitted. + label : str, optional + Human-readable label for the transfer task. + + Returns + ------- + TransferResult + Metadata including the Globus task ID and local-to-remote + path mapping. + """ + import globus_sdk + + tc = self._get_transfer_client() + + if remote_subdir is None: + remote_subdir = f"batch_{uuid.uuid4().hex[:12]}" + + remote_dir = f"{self.destination_base_path}/{remote_subdir}" + transfer_label = label or f"ChemGraph file staging ({remote_subdir})" + + tdata = globus_sdk.TransferData( + tc, + self.source_endpoint_id, + self.destination_endpoint_id, + label=transfer_label, + sync_level="checksum", + ) + + file_mapping: dict[str, str] = {} + for local_path in local_paths: + p = Path(local_path).resolve() + remote_path = f"{remote_dir}/{p.name}" + tdata.add_item(str(p), remote_path) + file_mapping[str(p)] = remote_path + + result = tc.submit_transfer(tdata) + task_id = result["task_id"] + + logger.info( + "Globus Transfer submitted: task_id=%s, %d files -> %s", + task_id, + len(local_paths), + remote_dir, + ) + + return TransferResult( + task_id=task_id, + source_endpoint_id=self.source_endpoint_id, + destination_endpoint_id=self.destination_endpoint_id, + file_mapping=file_mapping, + remote_directory=remote_dir, + label=transfer_label, + ) + + def check_transfer_status(self, task_id: str) -> dict[str, Any]: + """Check the status of a Globus Transfer task. + + Returns + ------- + dict + Keys: ``task_id``, ``status``, ``nice_status``, ``bytes_transferred``, + ``files``, ``files_transferred``. + """ + tc = self._get_transfer_client() + task = tc.get_task(task_id) + return { + "task_id": task_id, + "status": task["status"], + "nice_status": task.get("nice_status", ""), + "bytes_transferred": task.get("bytes_transferred", 0), + "files": task.get("files", 0), + "files_transferred": task.get("files_transferred", 0), + } + + def wait_for_transfer( + self, + task_id: str, + timeout: float = 300, + poll_interval: float = 5, + ) -> dict[str, Any]: + """Block until a transfer completes, fails, or times out. + + Parameters + ---------- + timeout : float + Maximum seconds to wait (default 300). + poll_interval : float + Seconds between status checks (default 5). + + Returns + ------- + dict + Final transfer status. + """ + deadline = time.time() + timeout + while time.time() < deadline: + status = self.check_transfer_status(task_id) + if status["status"] in ("SUCCEEDED", "FAILED"): + return status + time.sleep(poll_interval) + + status = self.check_transfer_status(task_id) + status["timed_out"] = True + return status + + def list_remote_directory(self, path: str) -> list[dict[str, Any]]: + """List files in a directory on the destination endpoint. + + Returns + ------- + list[dict] + Each dict has ``name``, ``type`` ("file" or "dir"), and ``size``. + """ + tc = self._get_transfer_client() + entries = [] + for entry in tc.operation_ls(self.destination_endpoint_id, path=path): + entries.append( + { + "name": entry["name"], + "type": entry["type"], + "size": entry.get("size", 0), + } + ) + return entries + + def get_remote_path( + self, + local_path: str, + remote_subdir: Optional[str] = None, + ) -> str: + """Compute the remote path for a local file.""" + filename = Path(local_path).name + if remote_subdir: + return f"{self.destination_base_path}/{remote_subdir}/{filename}" + return f"{self.destination_base_path}/{filename}" diff --git a/src/chemgraph/mcp/transfer_tools.py b/src/chemgraph/mcp/transfer_tools.py new file mode 100644 index 00000000..79ae2323 --- /dev/null +++ b/src/chemgraph/mcp/transfer_tools.py @@ -0,0 +1,186 @@ +"""Shared MCP tools for Globus Transfer file staging. + +Call :func:`register_transfer_tools` to add ``transfer_files``, +``check_transfer_status``, and ``list_remote_files`` to any +:class:`~mcp.server.fastmcp.FastMCP` (or +:class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`) server instance. + +These tools allow an LLM agent to stage input files on a remote HPC +filesystem *before* submitting compute jobs, avoiding the overhead of +encoding large files inside Globus Compute function payloads. + +Note +---- +Transfer tools are orchestration tools (they call the Globus Transfer +API directly from the MCP server process), not compute tools, so they +are registered via :meth:`FastMCP.add_tool` rather than CGFastMCP's +backend-submitting ``@tool()`` decorator. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from mcp.server.fastmcp import FastMCP + + from chemgraph.execution.globus_transfer import GlobusTransferManager + +logger = logging.getLogger(__name__) + + +def register_transfer_tools( + mcp: FastMCP, + transfer_manager: GlobusTransferManager, +) -> None: + """Register file-transfer MCP tools on *mcp*. + + Parameters + ---------- + mcp : FastMCP + The MCP server to register tools on. May be a plain ``FastMCP`` + or a :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`; ``add_tool`` + is inherited so the same registration works either way. + transfer_manager : GlobusTransferManager + The configured transfer manager instance. + """ + + def transfer_files( + source_paths: Union[str, list[str]], + extensions: Optional[list[str]] = None, + remote_subdir: Optional[str] = None, + wait: bool = True, + label: Optional[str] = None, + ) -> dict: + """Transfer files to the remote HPC endpoint via Globus Transfer. + + Parameters + ---------- + source_paths : str or list[str] + A directory path (all matching files transferred) or a list + of individual file paths. + extensions : list[str], optional + When *source_paths* is a directory, only transfer files with + these extensions (e.g. ``[".cif", ".xyz"]``). Ignored when + *source_paths* is a list. + remote_subdir : str, optional + Subdirectory name on the remote endpoint. Auto-generated if + omitted. + wait : bool + If True (default), block until the transfer completes. + label : str, optional + Human-readable label for the transfer task. + """ + if isinstance(source_paths, str): + src = Path(source_paths) + if src.is_dir(): + if extensions: + ext_set = { + e if e.startswith(".") else f".{e}" for e in extensions + } + files = sorted( + str(f) + for f in src.iterdir() + if f.is_file() and f.suffix.lower() in ext_set + ) + else: + files = sorted( + str(f) for f in src.iterdir() if f.is_file() + ) + if not files: + return { + "status": "error", + "message": f"No files found in {source_paths}" + + ( + f" with extensions {extensions}" + if extensions + else "" + ), + } + elif src.is_file(): + files = [str(src.resolve())] + else: + return { + "status": "error", + "message": f"Path not found: {source_paths}", + } + else: + files = [str(Path(p).resolve()) for p in source_paths] + + transfer_result = transfer_manager.transfer_files( + local_paths=files, + remote_subdir=remote_subdir, + label=label, + ) + + response = { + "task_id": transfer_result.task_id, + "remote_directory": transfer_result.remote_directory, + "file_count": len(files), + "file_mapping": transfer_result.file_mapping, + } + + if wait: + status = transfer_manager.wait_for_transfer(transfer_result.task_id) + response["status"] = ( + "completed" + if status["status"] == "SUCCEEDED" + else status["status"] + ) + response.update( + { + k: status[k] + for k in ("bytes_transferred", "files_transferred") + if k in status + } + ) + else: + response["status"] = "submitted" + + return response + + def check_transfer_status(task_id: str) -> dict: + """Check the status of a Globus Transfer task. + + Use to poll a non-blocking transfer submitted with ``wait=False``. + """ + return transfer_manager.check_transfer_status(task_id) + + def list_remote_files(remote_path: str) -> list[dict]: + """List files in a directory on the remote HPC endpoint. + + Useful to verify that files were staged correctly before + running ensemble calculations. + """ + return transfer_manager.list_remote_directory(remote_path) + + mcp.add_tool( + transfer_files, + name="transfer_files", + description=( + "Transfer local files to the remote HPC filesystem via " + "Globus Transfer. Use this to pre-stage structure files " + "before running ensemble calculations with " + "remote_structure_directory. Returns the remote directory " + "path and a mapping of local-to-remote file paths." + ), + ) + mcp.add_tool( + check_transfer_status, + name="check_transfer_status", + description=( + "Check the status of a Globus Transfer task. Use this to " + "poll a non-blocking transfer submitted with wait=False." + ), + ) + mcp.add_tool( + list_remote_files, + name="list_remote_files", + description=( + "List files in a directory on the remote HPC endpoint. " + "Useful to verify that files were staged correctly before " + "running ensemble calculations." + ), + ) From 890acaae8679d7a2e6faaa8546bda01c26e741bf Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 13:06:35 -0500 Subject: [PATCH 029/119] Reintegrate MACE MCP transport, persistence, and Globus Transfer on CGFastMCP run_mace_single and run_mace_ensemble were collapsed to bare run_mace_core(params) calls in PR #127, dropping inline-structure embedding, remote-path support, JobTracker persistence, and the Globus Transfer registration that 51ba171 had built. This restores all of that on top of the new CGFastMCP framework. - Worker is now a separate function _mace_worker(job: dict) that handles two transport keys on the worker FS: remote_structure_file (use the path directly) and inline_structure (materialise an AtomsData dict to a temp XYZ). Embeds full_output back into the result for inline calls so callers do not need remote FS access. - Pre-submit hook _mace_transport_hook centralises the schema -> job-dict conversion, mace_mp -> medium-mpa-0 model normalisation, and inline embedding (when the input file exists on the submitting host). Hook rewrites task.callable from run_mace_single to _mace_worker so the LLM still sees a clean schema-shaped tool. - run_mace_ensemble switches to @schema_fanout_tool with a server-side expander, preserving the directory-driven UX (single LLM call instead of N). Local mode enumerates files via resolve_structure_files; remote mode submits a backend probe to ls remote_structure_directory and builds remote_structure_file per item. - extract_output_json registered via mcp.add_tool() (orchestration, no backend wrap). transfer_files/check_transfer_status/ list_remote_files registered conditionally when get_transfer_manager() finds [execution.globus_transfer] config. - __main__ now wires tracker_kwargs={persist_file: ~/.chemgraph/ mace_jobs.json} so MACE batches survive MCP server restarts. - Drop `from __future__ import annotations`: forward refs break FastMCP's signature introspection because the wrapper's __globals__ is cg_fastmcp's, not the tool module's. --- src/chemgraph/mcp/mace_mcp_hpc.py | 310 ++++++++++++++++++++++++++---- 1 file changed, 276 insertions(+), 34 deletions(-) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index 70496186..58750b46 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -1,80 +1,322 @@ """Backend-agnostic MACE MCP server. -Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP` so that tool -functions are plain computation — the framework handles backend -submission, future resolution, and async job tracking. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. -Nothing is initialised at import time so that worker subprocesses -(e.g. EnsembleLauncher) can safely re-import this module. +Transport (local-file embedding, pre-staged remote-path passthrough) +lives in a single pre-submit hook so the tool bodies stay simple. The +hook rewrites :class:`~chemgraph.execution.base.TaskSpec` instances +before submission to attach an inline structure when the input file +exists on the submitting host, leaving the path untouched when it +does not (assumed to be remote). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. """ +import logging +import os +from pathlib import Path + +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import ( + make_per_structure_output, + resolve_structure_files, +) from chemgraph.mcp.cg_fastmcp import CGFastMCP -from chemgraph.schemas.mace_parsl_schema import mace_input_schema +from chemgraph.mcp.transfer_tools import register_transfer_tools +from chemgraph.schemas.mace_parsl_schema import ( + mace_input_schema, + mace_input_schema_ensemble, +) from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core +logger = logging.getLogger(__name__) + +_JOBS_FILE = Path("~/.chemgraph/mace_jobs.json").expanduser() +_MACE_MP_ALIASES = {"mace_mp", "mace-mp", "MACE-MP", "mace_MP"} + mcp = CGFastMCP( name="ChemGraph MACE Tools", instructions=""" You expose tools for running MACE simulations and reading their results. The available tools are: - 1. run_mace_single: run a single MACE calculation using the specified - input schema. - 2. run_mace_ensemble: run MACE calculations over all structures in a - directory using the configured execution backend. + 1. run_mace_single: run a single MACE calculation. + 2. run_mace_ensemble: run MACE calculations over every structure in a + directory (local or pre-staged remote). 3. extract_output_json: load simulation results from a JSON file. - 4. check_job_status: check progress of a submitted HPC job batch. - 5. get_job_results: retrieve results from a completed job batch. - 6. list_jobs: list all tracked job batches. - 7. cancel_job: cancel pending tasks in a job batch. + 4. check_job_status / get_job_results / list_jobs / cancel_job: HPC + job batch management. Job state persists across sessions. + 5. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on the + remote HPC filesystem before running ensembles in remote mode. Guidelines: - Use each tool only when its input schema matches the user request. - - Do not guess numerical values; report tool errors exactly as they occur. - - Keep responses compact -- full results are written to the output files - defined in the schemas. + - Do not guess numerical values; report tool errors exactly as they + occur. + - Keep responses compact -- full results are written to the output + files defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. - - When a tool returns status='submitted' with a batch_id, use - check_job_status to poll for progress before calling get_job_results. + - When a tool returns status='submitted' with a batch_id, call + get_job_results(batch_id) to retrieve results. If still pending, + report the batch_id so the user can check later -- job state is + persisted across sessions. + - For the `model` field, pass a MACE foundation model name (e.g. + 'medium-mpa-0'). 'mace_mp' is the calculator type, not a model + name -- do not pass it. """, ) +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _mace_worker(job: dict) -> dict: + """Execute a single MACE simulation on a backend worker. + + Accepts a *job dict* (not the schema) so the pre-submit hook can + attach transport keys ``inline_structure`` / ``remote_structure_file`` + before submission. + """ + import json + import tempfile + + job = dict(job) + + # Pre-staged remote file: use the path directly on the worker FS. + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "output.json"), + ) + + # Inline structure: materialise on the worker's filesystem. + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, job.get("output_result_file", "output.json") + ) + + params = mace_input_schema(**job) + result = run_mace_core(params) + + # When inline, embed full output so the caller doesn't need to read + # a file on the remote filesystem to recover the results. + if inline is not None and isinstance(result, dict): + out_file = job.get("output_result_file", "") + if os.path.isfile(out_file): + with open(out_file) as fh: + result["full_output"] = json.load(fh) + + return result + + +# ── Pre-submit transport hook ────────────────────────────────────────── + + +def _embed_inline_if_local(job: dict) -> None: + """Mutate *job* in-place: attach inline_structure when the input + file is readable on the submitting host (and no other transport + key has already been set).""" + if job.get("remote_structure_file") or job.get("inline_structure"): + return + input_file = job.get("input_structure_file") + if not input_file or not os.path.isfile(input_file): + return # remote path -- worker will read it directly + + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(input_file) + job["inline_structure"] = atoms_to_atomsdata(atoms).model_dump() + + +def _normalize_model(job: dict) -> None: + """Map calculator-type aliases to a valid foundation model name.""" + if job.get("model") in _MACE_MP_ALIASES: + job["model"] = "medium-mpa-0" + + +def _mace_transport_hook(task: TaskSpec) -> TaskSpec: + """Route single-tool calls to the dict-based worker and embed + local structures on whichever path is taken.""" + if task.callable is run_mace_single: + params = task.kwargs.get("params") + if params is None: + return task + job = ( + params.model_dump() if hasattr(params, "model_dump") else dict(params) + ) + _normalize_model(job) + _embed_inline_if_local(job) + task.callable = _mace_worker + task.kwargs = {"job": job} + elif task.callable is _mace_worker: + job = dict(task.kwargs.get("job", {})) + _normalize_model(job) + _embed_inline_if_local(job) + task.kwargs = {"job": job} + return task + + +mcp.set_pre_submit_hook(_mace_transport_hook) + + +# ── Single-structure tool ────────────────────────────────────────────── + + @mcp.tool( name="run_mace_single", description="Run a single MACE calculation", ) -def run_mace_single(params: mace_input_schema): - """Run a single MACE calculation on the execution backend.""" - import sys +def run_mace_single(params: mace_input_schema) -> dict: + """Run a single MACE calculation on the configured backend. - old_stdout = sys.stdout - sys.stdout = sys.stderr - try: - return run_mace_core(params) - finally: - sys.stdout = old_stdout + The pre-submit hook rewrites this call to invoke ``_mace_worker`` + on the backend with a job dict that may carry an embedded inline + structure (when the input file exists locally) or a remote path + (when it does not). + """ + # Direct-call fallback path (no hook registered) -- normalises and + # delegates to the same worker. + job = params.model_dump() + _normalize_model(job) + return _mace_worker(job) + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) -@mcp.ensemble_tool( +def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: + """Server-side expansion of an ensemble request into per-file jobs. + + Local mode: enumerates ``input_structure_directory`` on this host. + Remote mode: submits a one-shot probe task to the backend to list + files under ``remote_structure_directory``, then builds per-file + jobs that the worker reads directly from the remote filesystem. + """ + shared = { + "output_result_file": params.output_result_file, + "driver": params.driver, + "model": params.model, + "device": params.device, + "temperature": params.temperature, + "pressure": params.pressure, + "fmax": params.fmax, + "steps": params.steps, + "optimizer": params.optimizer, + } + base_output = Path(params.output_result_file) + + if params.remote_structure_directory: + remote_dir = params.remote_structure_directory + mcp._ensure_backend() + probe = TaskSpec( + task_id="ls_remote_dir", + task_type="python", + callable=_ls_remote_files, + kwargs={"path": remote_dir}, + ) + fut = mcp._backend.submit(probe) + try: + file_names = fut.result(timeout=30) + except Exception as exc: + raise RuntimeError( + f"Could not list remote directory {remote_dir}: {exc}" + ) from exc + + jobs = [] + for fname in file_names: + per_output = make_per_structure_output(Path(fname), base_output) + job = {**shared} + job["remote_structure_file"] = f"{remote_dir}/{fname}" + job["output_result_file"] = str(per_output) + jobs.append(job) + return jobs + + if not params.input_structure_directory: + raise ValueError( + "Either input_structure_directory or remote_structure_directory " + "must be provided." + ) + + structure_files, _ = resolve_structure_files(params.input_structure_directory) + return [ + { + **shared, + "input_structure_file": str(f), + "output_result_file": str(make_per_structure_output(f, base_output)), + } + for f in structure_files + ] + + +@mcp.schema_fanout_tool( name="run_mace_ensemble", - description="Run an ensemble of MACE calculations for multiple inputs.", + description=( + "Run MACE calculations over every structure in a directory. " + "Local mode uses input_structure_directory; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_mace_worker, ) -def _run_mace_worker(params: mace_input_schema): - return run_mace_core(params) +def run_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: + return _expand_mace_ensemble(params) + + +# ── Orchestration tools (no backend involvement) ─────────────────────── mcp.add_tool( extract_output_json, name="extract_output_json", - description="Load output from a JSON file.", + description="Load simulation results from an output JSON file.", ) +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on MACE MCP server.") + + if __name__ == "__main__": from chemgraph.mcp.server_utils import run_mcp_server - mcp.init_backend() + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) try: run_mcp_server(mcp, default_port=9004) From ee3d727fc783a6c296d8ab561d0c6901b25198c8 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 13:58:39 -0500 Subject: [PATCH 030/119] Add academy module: distributed multi-agent screening via Academy Wraps the Academy distributed agent framework with ChemGraph LLM agents for federated HPC screening workflows. Decoupled from the existing pipeline -- no chemgraph.cli / chemgraph.agent / chemgraph.eval references; only the lazily imported chemgraph.agent.llm_agent.ChemGraph. - ChemGraphAgent: Academy Agent wrapping a single ChemGraph instance, exposes run_query / get_info actions. - ScreeningAgent: iterates a molecule batch, writes per-result JSONs for fault-tolerant aggregation. Failed-molecule records now store str(exc) so the actual exception message survives. - CoordinatorAgent: polls a results dir, optionally analyses results via an LLM, suggests follow-up molecules. - AcademyConfig + build_manager: bridge config.toml to Academy Manager / Exchange / Launcher (local, Redis, Parsl, Globus Compute). - RateLimiter: stdlib async token-bucket for shared per-provider LLM quotas across agents. Lazy imports in __init__.py let the package load without the optional academy-py dependency; ChemGraphAgent / ScreeningAgent / CoordinatorAgent raise ModuleNotFoundError on access if academy-py is missing, while AcademyConfig and RateLimiter remain usable. pyproject's academy optional-dep + pytest marker are already in HEAD (commit 04bcc8a). tests/test_academy.py and scripts/academy_example/ remain untracked and will land in follow-ups. --- src/chemgraph/academy/__init__.py | 45 +++++++ src/chemgraph/academy/agent.py | 123 ++++++++++++++++++ src/chemgraph/academy/config.py | 175 +++++++++++++++++++++++++ src/chemgraph/academy/coordinator.py | 179 ++++++++++++++++++++++++++ src/chemgraph/academy/rate_limiter.py | 135 +++++++++++++++++++ src/chemgraph/academy/screening.py | 151 ++++++++++++++++++++++ 6 files changed, 808 insertions(+) create mode 100644 src/chemgraph/academy/__init__.py create mode 100644 src/chemgraph/academy/agent.py create mode 100644 src/chemgraph/academy/config.py create mode 100644 src/chemgraph/academy/coordinator.py create mode 100644 src/chemgraph/academy/rate_limiter.py create mode 100644 src/chemgraph/academy/screening.py diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py new file mode 100644 index 00000000..90e5bf12 --- /dev/null +++ b/src/chemgraph/academy/__init__.py @@ -0,0 +1,45 @@ +"""Academy Agents integration for ChemGraph. + +Provides agent classes and utilities for deploying ChemGraph workflows +across federated HPC infrastructure using the Academy framework. + +Requires the ``academy`` optional extra:: + + pip install chemgraphagent[academy] + +Modules that depend on ``academy-py`` (agent, screening, coordinator) +use lazy imports so that the rate limiter and config utilities remain +usable without the optional dependency. +""" + +from __future__ import annotations + +from chemgraph.academy.config import AcademyConfig, build_manager +from chemgraph.academy.rate_limiter import RateLimiter + + +def __getattr__(name: str): # noqa: N807 + """Lazy-import Academy-dependent classes.""" + if name == "ChemGraphAgent": + from chemgraph.academy.agent import ChemGraphAgent + + return ChemGraphAgent + if name == "ScreeningAgent": + from chemgraph.academy.screening import ScreeningAgent + + return ScreeningAgent + if name == "CoordinatorAgent": + from chemgraph.academy.coordinator import CoordinatorAgent + + return CoordinatorAgent + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "ChemGraphAgent", + "AcademyConfig", + "build_manager", + "RateLimiter", + "ScreeningAgent", + "CoordinatorAgent", +] diff --git a/src/chemgraph/academy/agent.py b/src/chemgraph/academy/agent.py new file mode 100644 index 00000000..1ec04b3e --- /dev/null +++ b/src/chemgraph/academy/agent.py @@ -0,0 +1,123 @@ +"""Base Academy Agent wrapping a ChemGraph instance. + +Each ``ChemGraphAgent`` holds one ``ChemGraph`` object and exposes its +``run()`` method as an Academy ``@action`` so it can be invoked remotely +by peer agents, coordinators, or the Manager user handle. +""" + +from __future__ import annotations + +import logging +import os +import uuid +from typing import Any, Optional + +from academy.agent import Agent, action + +from chemgraph.agent.llm_agent import ChemGraph + +logger = logging.getLogger(__name__) + + +class ChemGraphAgent(Agent): + """Academy Agent wrapping a single :class:`ChemGraph` instance. + + Parameters + ---------- + model_name : str + LLM model to use (e.g. ``"gpt-4o"``, ``"claude-sonnet-4"``). + workflow_type : str + ChemGraph workflow (e.g. ``"single_agent"``, ``"multi_agent"``). + log_dir : str or None + Base directory for agent logs. A per-agent subdirectory is + created automatically. + rate_limiter : RateLimiter or None + Shared rate limiter for LLM API calls. + chemgraph_kwargs : dict + Extra keyword arguments forwarded to the :class:`ChemGraph` + constructor (e.g. ``base_url``, ``api_key``, ``recursion_limit``). + """ + + def __init__( + self, + model_name: str = "gpt-4o-mini", + workflow_type: str = "single_agent", + log_dir: Optional[str] = None, + rate_limiter: Any = None, + **chemgraph_kwargs: Any, + ) -> None: + super().__init__() + self._model_name = model_name + self._workflow_type = workflow_type + self._log_dir = log_dir + self._rate_limiter = rate_limiter + self._chemgraph_kwargs = chemgraph_kwargs + self._cg: Optional[ChemGraph] = None + self._agent_uuid = uuid.uuid4().hex[:8] + + async def agent_on_startup(self) -> None: + """Initialise the ChemGraph instance on the remote worker.""" + agent_log_dir = self._log_dir + if agent_log_dir: + agent_log_dir = os.path.join(agent_log_dir, self._agent_uuid) + os.makedirs(agent_log_dir, exist_ok=True) + + self._cg = ChemGraph( + model_name=self._model_name, + workflow_type=self._workflow_type, + log_dir=agent_log_dir, + enable_memory=False, + **self._chemgraph_kwargs, + ) + logger.info( + "ChemGraphAgent %s started: model=%s workflow=%s", + self._agent_uuid, + self._model_name, + self._workflow_type, + ) + + async def agent_on_shutdown(self) -> None: + """Clean up resources.""" + logger.info("ChemGraphAgent %s shutting down", self._agent_uuid) + self._cg = None + + @action + async def run_query( + self, + query: str, + *, + config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Execute a ChemGraph query and return the result. + + Parameters + ---------- + query : str + The natural-language chemistry query. + config : dict, optional + LangGraph config (thread_id, etc.). + + Returns + ------- + dict + The workflow result (serialised state or last message, + depending on the ChemGraph ``return_option``). + """ + if self._cg is None: + raise RuntimeError("Agent not initialised (call agent_on_startup first)") + + if self._rate_limiter is not None: + await self._rate_limiter.acquire(self._model_name) + + thread_cfg = config or {"configurable": {"thread_id": uuid.uuid4().hex[:8]}} + result = await self._cg.run(query=query, config=thread_cfg) + return result + + @action + async def get_info(self) -> dict[str, str]: + """Return metadata about this agent instance.""" + return { + "agent_uuid": self._agent_uuid, + "model_name": self._model_name, + "workflow_type": self._workflow_type, + } diff --git a/src/chemgraph/academy/config.py b/src/chemgraph/academy/config.py new file mode 100644 index 00000000..5f7a98b3 --- /dev/null +++ b/src/chemgraph/academy/config.py @@ -0,0 +1,175 @@ +"""Bridge between ChemGraph config.toml and Academy Manager/Exchange/Launcher. + +Reads the ``[academy]`` section from ``config.toml`` and builds the +corresponding Academy objects. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + +import toml + +logger = logging.getLogger(__name__) + +# Exchange and launcher types supported by this bridge. +ExchangeType = Literal["local", "redis", "hybrid"] +LauncherType = Literal["thread", "process", "parsl", "globus_compute"] + + +@dataclass +class AcademyConfig: + """Parsed ``[academy]`` configuration section. + + Attributes + ---------- + exchange : ExchangeType + Message exchange backend (default ``"local"``). + launcher : LauncherType + Agent deployment mechanism (default ``"thread"``). + num_agents : int + Number of worker agents to spawn (default ``1``). + redis_hostname : str + Redis host when ``exchange="redis"`` (default ``"localhost"``). + redis_port : int + Redis port (default ``6379``). + parsl_system : str + HPC system name for Parsl config (default ``"local"``). + globus_endpoint_id : str + Globus Compute endpoint UUID. + max_concurrency : int + Max concurrent LLM calls per provider (default ``50``). + log_dir : str or None + Base log directory for agent output. + extra : dict + Any additional keys from the config section. + """ + + exchange: ExchangeType = "local" + launcher: LauncherType = "thread" + num_agents: int = 1 + redis_hostname: str = "localhost" + redis_port: int = 6379 + parsl_system: str = "local" + globus_endpoint_id: str = "" + max_concurrency: int = 50 + log_dir: Optional[str] = None + extra: dict = field(default_factory=dict) + + +def load_academy_config(config_path: str = "config.toml") -> AcademyConfig: + """Load the ``[academy]`` section from a TOML config file. + + Missing keys are filled with defaults. Unknown keys are stored + in ``extra``. + """ + try: + data = toml.load(config_path) + except FileNotFoundError: + logger.warning("Config file %s not found, using defaults", config_path) + return AcademyConfig() + + section = data.get("academy", {}) + + known_keys = {f.name for f in AcademyConfig.__dataclass_fields__.values()} + known = {k: v for k, v in section.items() if k in known_keys} + extra = {k: v for k, v in section.items() if k not in known_keys} + + return AcademyConfig(**known, extra=extra) + + +def _build_exchange_factory(cfg: AcademyConfig) -> Any: + """Create the Academy ExchangeFactory matching the config.""" + if cfg.exchange == "local": + from academy.exchange import LocalExchangeFactory + + return LocalExchangeFactory() + + if cfg.exchange == "redis": + from academy.exchange import RedisExchangeFactory + + return RedisExchangeFactory( + hostname=cfg.redis_hostname, + port=cfg.redis_port, + ) + + if cfg.exchange == "hybrid": + from academy.exchange import HybridExchangeFactory + + return HybridExchangeFactory() + + raise ValueError(f"Unsupported exchange type: {cfg.exchange}") + + +def _build_executor(cfg: AcademyConfig) -> Any: + """Create the executor matching the configured launcher type.""" + if cfg.launcher == "thread": + from concurrent.futures import ThreadPoolExecutor + + return ThreadPoolExecutor(max_workers=cfg.num_agents) + + if cfg.launcher == "process": + from concurrent.futures import ProcessPoolExecutor + + return ProcessPoolExecutor(max_workers=cfg.num_agents) + + if cfg.launcher == "parsl": + try: + from academy.executor import ParslExecutor + except ImportError as exc: + raise ImportError( + "Parsl launcher requires: pip install chemgraphagent[academy,parsl]" + ) from exc + return ParslExecutor() + + if cfg.launcher == "globus_compute": + try: + from academy.executor import GlobusComputeExecutor + except ImportError as exc: + raise ImportError( + "Globus Compute launcher requires: " + "pip install chemgraphagent[academy,globus_compute]" + ) from exc + return GlobusComputeExecutor(endpoint_id=cfg.globus_endpoint_id) + + raise ValueError(f"Unsupported launcher type: {cfg.launcher}") + + +async def build_manager( + cfg: AcademyConfig | None = None, + config_path: str = "config.toml", +) -> Any: + """Build an Academy Manager from ChemGraph config. + + Returns an async context manager. Usage:: + + async with await build_manager() as manager: + handle = await manager.launch(ScreeningAgent, ...) + result = await handle.screen_molecule("CCO", "optimize") + + Parameters + ---------- + cfg : AcademyConfig, optional + Pre-loaded config. If ``None``, loads from *config_path*. + config_path : str + Path to config.toml (used only when *cfg* is ``None``). + + Returns + ------- + Manager + An Academy Manager ready for ``async with``. + """ + from academy.manager import Manager + + if cfg is None: + cfg = load_academy_config(config_path) + + factory = _build_exchange_factory(cfg) + executor = _build_executor(cfg) + + return await Manager.from_exchange_factory( + factory=factory, + executors=executor, + ) diff --git a/src/chemgraph/academy/coordinator.py b/src/chemgraph/academy/coordinator.py new file mode 100644 index 00000000..12f9fc76 --- /dev/null +++ b/src/chemgraph/academy/coordinator.py @@ -0,0 +1,179 @@ +"""Coordinator agent for multi-wave screening campaigns. + +The coordinator manages a fleet of :class:`ScreeningAgent` instances, +collects results, and optionally uses a ChemGraph LLM workflow to +analyse the collected data and spawn follow-up screening waves. +""" + +from __future__ import annotations + +import asyncio +import glob +import json +import logging +import os +import time +from typing import Any, Optional + +from academy.agent import Agent, action, timer +from academy.handle import Handle + +logger = logging.getLogger(__name__) + + +class CoordinatorAgent(Agent): + """Collects screening results and orchestrates follow-up waves. + + Parameters + ---------- + results_dir : str + Directory where :class:`ScreeningAgent` instances write their + per-molecule JSON result files. + worker_handles : list[Handle] or None + Handles to active screening agents (for progress polling). + analysis_model : str + LLM model for analysing aggregated results. + analysis_workflow : str + ChemGraph workflow type for the analysis step. + analysis_kwargs : dict + Extra kwargs for the analysis ChemGraph instance. + """ + + def __init__( + self, + results_dir: str, + worker_handles: list[Handle] | None = None, + analysis_model: str = "gpt-4o", + analysis_workflow: str = "single_agent", + **analysis_kwargs: Any, + ) -> None: + super().__init__() + self._results_dir = results_dir + self._worker_handles = worker_handles or [] + self._analysis_model = analysis_model + self._analysis_workflow = analysis_workflow + self._analysis_kwargs = analysis_kwargs + self._collected: list[dict[str, Any]] = [] + self._analysis_result: Optional[dict[str, Any]] = None + + async def agent_on_startup(self) -> None: + os.makedirs(self._results_dir, exist_ok=True) + logger.info( + "CoordinatorAgent started: watching %s, %d workers", + self._results_dir, + len(self._worker_handles), + ) + + # ------------------------------------------------------------------ + # Progress monitoring + # ------------------------------------------------------------------ + + @action + async def poll_progress(self) -> dict[str, Any]: + """Query all workers for their screening progress.""" + progress = [] + for handle in self._worker_handles: + try: + p = await handle.get_progress() + progress.append(p) + except Exception as exc: + progress.append({"error": str(exc)}) + total = sum(p.get("total", 0) for p in progress if "error" not in p) + completed = sum(p.get("completed", 0) for p in progress if "error" not in p) + failed = sum(p.get("failed", 0) for p in progress if "error" not in p) + return { + "workers": len(progress), + "total": total, + "completed": completed, + "failed": failed, + "per_worker": progress, + } + + # ------------------------------------------------------------------ + # Result collection + # ------------------------------------------------------------------ + + @action + async def collect_results(self) -> list[dict[str, Any]]: + """Read all result JSON files from the shared results directory.""" + pattern = os.path.join(self._results_dir, "*.json") + files = sorted(glob.glob(pattern)) + results = [] + for path in files: + try: + with open(path) as f: + results.append(json.load(f)) + except (json.JSONDecodeError, OSError): + logger.warning("Skipping corrupt result file: %s", path) + self._collected = results + logger.info("Collected %d results from %s", len(results), self._results_dir) + return results + + # ------------------------------------------------------------------ + # LLM-powered analysis + # ------------------------------------------------------------------ + + @action + async def analyse(self, query: Optional[str] = None) -> dict[str, Any]: + """Use a ChemGraph agent to analyse collected results. + + Parameters + ---------- + query : str, optional + Custom analysis query. Defaults to a standard prompt + asking the LLM to rank candidates. + """ + from chemgraph.agent.llm_agent import ChemGraph + + if not self._collected: + await self.collect_results() + + successes = [r for r in self._collected if r.get("status") == "success"] + if not successes: + return {"error": "No successful results to analyse"} + + summary = json.dumps(successes, default=str, indent=2) + if query is None: + query = ( + "You are analysing computational chemistry screening results. " + f"Here are {len(successes)} results:\n\n{summary}\n\n" + "Identify the top candidates based on energy, stability, " + "or other relevant properties. Rank them and explain why." + ) + + cg = ChemGraph( + model_name=self._analysis_model, + workflow_type=self._analysis_workflow, + enable_memory=False, + **self._analysis_kwargs, + ) + self._analysis_result = await cg.run(query=query) + return self._analysis_result + + @action + async def get_analysis(self) -> dict[str, Any] | None: + """Return the most recent analysis result.""" + return self._analysis_result + + # ------------------------------------------------------------------ + # Wave dispatch + # ------------------------------------------------------------------ + + @action + async def suggest_followup_molecules( + self, + top_n: int = 10, + ) -> list[str]: + """Extract top candidate SMILES from analysis for a follow-up wave. + + Returns a list of SMILES strings identified as promising by + the analysis step. Falls back to returning the top-N by + lowest energy if no analysis is available. + """ + if not self._collected: + await self.collect_results() + + successes = [r for r in self._collected if r.get("status") == "success"] + # Simple heuristic: return the SMILES of completed molecules. + # A real implementation would parse energies from results. + return [r["smiles"] for r in successes[:top_n]] diff --git a/src/chemgraph/academy/rate_limiter.py b/src/chemgraph/academy/rate_limiter.py new file mode 100644 index 00000000..9c521f55 --- /dev/null +++ b/src/chemgraph/academy/rate_limiter.py @@ -0,0 +1,135 @@ +"""Token-bucket rate limiter for LLM API calls. + +Academy is LLM-agnostic, so rate limiting must be handled at the +ChemGraph layer. This module provides a shared :class:`RateLimiter` +that agents ``await`` before each LLM call to stay within per-provider +API quotas. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class _ProviderBucket: + """Token bucket state for a single LLM provider.""" + + rpm: float + tokens: float = 0.0 + last_refill: float = field(default_factory=time.monotonic) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + def __post_init__(self) -> None: + # Start with a full bucket. + self.tokens = self.rpm + + +class RateLimiter: + """Async token-bucket rate limiter keyed by LLM provider. + + Parameters + ---------- + default_rpm : int + Default requests-per-minute for providers not explicitly + configured (default ``60``). + provider_rpm : dict[str, int] or None + Per-provider overrides. Keys are provider prefixes or model + names (e.g. ``"openai"``, ``"anthropic"``, ``"gpt-4o"``). + + Usage + ----- + :: + + limiter = RateLimiter(default_rpm=60, provider_rpm={"openai": 500}) + await limiter.acquire("gpt-4o") # blocks if bucket empty + """ + + # Map model-name prefixes to canonical provider keys so that + # ``acquire("gpt-4o")`` matches a rule set for ``"openai"``. + _PREFIX_MAP: dict[str, str] = { + "gpt-": "openai", + "o1": "openai", + "o3": "openai", + "o4": "openai", + "argo:": "argo", + "claude-": "anthropic", + "gemini-": "google", + "groq:": "groq", + "llama": "alcf", + } + + def __init__( + self, + default_rpm: int = 60, + provider_rpm: dict[str, int] | None = None, + ) -> None: + self._default_rpm = default_rpm + self._provider_rpm: dict[str, int] = provider_rpm or {} + self._buckets: dict[str, _ProviderBucket] = {} + + def _resolve_provider(self, model_name: str) -> str: + """Map a model name to a canonical provider key.""" + # Direct match first. + if model_name in self._provider_rpm: + return model_name + + # Prefix match. + lower = model_name.lower() + for prefix, provider in self._PREFIX_MAP.items(): + if lower.startswith(prefix): + return provider + + return model_name + + def _get_bucket(self, provider: str) -> _ProviderBucket: + """Get or create the bucket for *provider*.""" + if provider not in self._buckets: + rpm = self._provider_rpm.get(provider, self._default_rpm) + self._buckets[provider] = _ProviderBucket(rpm=rpm) + return self._buckets[provider] + + async def acquire(self, model_name: str) -> None: + """Wait until a request token is available for *model_name*. + + Refills the token bucket based on elapsed time, then consumes + one token. If the bucket is empty, sleeps until a token + becomes available. + """ + provider = self._resolve_provider(model_name) + bucket = self._get_bucket(provider) + + async with bucket.lock: + now = time.monotonic() + elapsed = now - bucket.last_refill + # Refill at rpm / 60 tokens per second. + refill = elapsed * (bucket.rpm / 60.0) + bucket.tokens = min(bucket.rpm, bucket.tokens + refill) + bucket.last_refill = now + + if bucket.tokens >= 1.0: + bucket.tokens -= 1.0 + return + + # Need to wait for a token. + deficit = 1.0 - bucket.tokens + wait_seconds = deficit / (bucket.rpm / 60.0) + logger.debug( + "Rate limit: waiting %.1fs for provider %s (rpm=%d)", + wait_seconds, + provider, + bucket.rpm, + ) + + # Sleep outside the lock so other providers aren't blocked. + await asyncio.sleep(wait_seconds) + + # Consume after waking. + async with bucket.lock: + bucket.tokens = 0.0 + bucket.last_refill = time.monotonic() diff --git a/src/chemgraph/academy/screening.py b/src/chemgraph/academy/screening.py new file mode 100644 index 00000000..09891642 --- /dev/null +++ b/src/chemgraph/academy/screening.py @@ -0,0 +1,151 @@ +"""Screening agent for batch molecule processing. + +Wraps :class:`ChemGraphAgent` with a ``@loop`` that iterates over an +assigned list of molecules and publishes results via the Academy +exchange. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import time +from typing import Any, Optional + +from academy.agent import Agent, action, loop + +from chemgraph.academy.agent import ChemGraphAgent + +logger = logging.getLogger(__name__) + + +class ScreeningAgent(ChemGraphAgent): + """Agent that screens a batch of molecules using a ChemGraph workflow. + + Parameters + ---------- + molecules : list[str] + SMILES strings to screen. + query_template : str + Query template with ``{smiles}`` placeholder, e.g. + ``"Optimize the geometry of {smiles} and compute its energy."``. + results_dir : str or None + Directory to write per-molecule JSON result files for + downstream aggregation. If ``None``, results are only + returned via the exchange. + model_name, workflow_type, log_dir, rate_limiter, **chemgraph_kwargs + Forwarded to :class:`ChemGraphAgent`. + """ + + def __init__( + self, + molecules: list[str], + query_template: str, + results_dir: Optional[str] = None, + model_name: str = "gpt-4o-mini", + workflow_type: str = "single_agent", + log_dir: Optional[str] = None, + rate_limiter: Any = None, + **chemgraph_kwargs: Any, + ) -> None: + super().__init__( + model_name=model_name, + workflow_type=workflow_type, + log_dir=log_dir, + rate_limiter=rate_limiter, + **chemgraph_kwargs, + ) + self._molecules = molecules + self._query_template = query_template + self._results_dir = results_dir + self.results: list[dict[str, Any]] = [] + self.completed: int = 0 + self.failed: int = 0 + + async def agent_on_startup(self) -> None: + await super().agent_on_startup() + if self._results_dir: + os.makedirs(self._results_dir, exist_ok=True) + logger.info( + "ScreeningAgent %s: %d molecules to process", + self._agent_uuid, + len(self._molecules), + ) + + @action + async def get_progress(self) -> dict[str, Any]: + """Return screening progress.""" + return { + "agent_uuid": self._agent_uuid, + "total": len(self._molecules), + "completed": self.completed, + "failed": self.failed, + } + + @loop + async def screening_loop(self, shutdown: asyncio.Event) -> None: + """Iterate over assigned molecules and run queries.""" + for smiles in self._molecules: + if shutdown.is_set(): + logger.info( + "ScreeningAgent %s: shutdown requested, stopping", + self._agent_uuid, + ) + break + + query = self._query_template.format(smiles=smiles) + t0 = time.monotonic() + try: + result = await self.run_query(query) + elapsed = time.monotonic() - t0 + record = { + "smiles": smiles, + "status": "success", + "result": result, + "elapsed_seconds": round(elapsed, 2), + "agent_uuid": self._agent_uuid, + } + self.completed += 1 + except Exception as exc: + elapsed = time.monotonic() - t0 + logger.exception( + "ScreeningAgent %s: failed on %s", + self._agent_uuid, + smiles, + ) + record = { + "smiles": smiles, + "status": "error", + "error": str(exc), + "elapsed_seconds": round(elapsed, 2), + "agent_uuid": self._agent_uuid, + } + self.failed += 1 + + self.results.append(record) + + # Write individual result file for aggregation. + if self._results_dir: + safe_name = smiles.replace("/", "_").replace("\\", "_")[:50] + path = os.path.join( + self._results_dir, + f"{self._agent_uuid}_{safe_name}.json", + ) + with open(path, "w") as f: + json.dump(record, f, default=str) + + logger.info( + "ScreeningAgent %s: finished (%d ok, %d failed)", + self._agent_uuid, + self.completed, + self.failed, + ) + # Signal that this agent is done. + self.agent_shutdown() + + @action + async def get_results(self) -> list[dict[str, Any]]: + """Return all collected results so far.""" + return self.results From e3ac8b00000f4285fef93f1097eed3d2277de02a Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 15:04:12 -0500 Subject: [PATCH 031/119] Fix bugs in HPC execution layer - globus_transfer.py: disambiguate same-basename inputs with a numeric suffix so two files that share a name (e.g. /a/in.cif and /b/in.cif) don't silently overwrite each other on the remote collection. - job_tracker.py: promote the "no Globus task_id within timeout" message to a warning at submit time, and emit a per-task warning at reload time for batches restored without a task_id (those tasks cannot be queried via the Globus Compute API and would otherwise be silently orphaned across server restarts). - globus_compute_backend.py: catch "executor stopped" exceptions in submit(), rebuild the Executor, and retry once. The previous _ensure_executor relied on the SDK's private _stopped attribute, which fails silently if the SDK exposes the shutdown state differently. - cg_fastmcp.py: wrap _apply_pre_submit_hook in try/except and re-raise hook failures as a ValueError naming the hook and task_id so they surface as a structured tool error instead of an opaque traceback. --- .../execution/globus_compute_backend.py | 38 ++++++++++++++++++- src/chemgraph/execution/globus_transfer.py | 17 ++++++++- src/chemgraph/execution/job_tracker.py | 32 +++++++++++++--- src/chemgraph/mcp/cg_fastmcp.py | 24 +++++++++++- 4 files changed, 100 insertions(+), 11 deletions(-) diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index f73bd5af..6c810e46 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -102,6 +102,16 @@ def _ensure_executor(self) -> None: logger.info("Re-creating Globus Compute Executor") self._executor = Executor(endpoint_id=self._endpoint_id) + @staticmethod + def _looks_like_stopped_executor(exc: BaseException) -> bool: + """Heuristic: did a submit fail because the Executor is shut down? + + The SDK does not expose a stable exception type for this state; + we match on common substrings observed in practice. + """ + msg = str(exc).lower() + return "shut down" in msg or "stopped" in msg or "closed" in msg + def submit(self, task: TaskSpec) -> Future: if not self._initialized or self._executor is None: raise RuntimeError( @@ -118,7 +128,19 @@ def submit(self, task: TaskSpec) -> Future: # Executor.submit() returns a ComputeFuture (a # concurrent.futures.Future subclass), fully compatible # with asyncio.wrap_future() used by gather_futures(). - return self._executor.submit(task.callable, *task.args, **task.kwargs) + try: + return self._executor.submit(task.callable, *task.args, **task.kwargs) + except Exception as exc: + if not self._looks_like_stopped_executor(exc): + raise + logger.warning( + "Submit raised %s -- rebuilding Globus Compute Executor " + "and retrying once.", + type(exc).__name__, + ) + self._executor = None + self._ensure_executor() + return self._executor.submit(task.callable, *task.args, **task.kwargs) elif task.task_type == "shell": if task.command is None: @@ -128,7 +150,19 @@ def submit(self, task: TaskSpec) -> Future: from globus_compute_sdk import ShellFunction shell_fn = ShellFunction(task.command) - return self._executor.submit(shell_fn) + try: + return self._executor.submit(shell_fn) + except Exception as exc: + if not self._looks_like_stopped_executor(exc): + raise + logger.warning( + "Submit raised %s -- rebuilding Globus Compute Executor " + "and retrying once.", + type(exc).__name__, + ) + self._executor = None + self._ensure_executor() + return self._executor.submit(shell_fn) else: raise ValueError( diff --git a/src/chemgraph/execution/globus_transfer.py b/src/chemgraph/execution/globus_transfer.py index d8081ab3..3355a8d1 100644 --- a/src/chemgraph/execution/globus_transfer.py +++ b/src/chemgraph/execution/globus_transfer.py @@ -216,10 +216,25 @@ def transfer_files( sync_level="checksum", ) + # Disambiguate same-basename inputs (e.g. /a/in.cif and /b/in.cif) + # by suffixing duplicates with _1, _2, ... Without this the + # second add_item silently overwrites the first on the + # destination collection. file_mapping: dict[str, str] = {} + used_names: dict[str, int] = {} for local_path in local_paths: p = Path(local_path).resolve() - remote_path = f"{remote_dir}/{p.name}" + base = p.name + count = used_names.get(base, 0) + if count == 0: + remote_name = base + else: + stem, dot, suffix = base.partition(".") + remote_name = ( + f"{stem}_{count}.{suffix}" if dot else f"{stem}_{count}" + ) + used_names[base] = count + 1 + remote_path = f"{remote_dir}/{remote_name}" tdata.add_item(str(p), remote_path) file_mapping[str(p)] = remote_path diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py index 23f6c837..4efa41b0 100644 --- a/src/chemgraph/execution/job_tracker.py +++ b/src/chemgraph/execution/job_tracker.py @@ -136,21 +136,28 @@ def _load(self) -> None: logger.warning("Could not load job tracker state: %s", exc) return + orphaned: list[tuple[str, str]] = [] # (batch_id, task_id) with self._lock: for bid, info in data.items(): if bid in self._batches: continue # don't overwrite live batches - tasks = [ - TrackedTask( + tasks = [] + for t in info.get("tasks", []): + tracked = TrackedTask( task_id=t["task_id"], meta=t.get("meta", {}), future=None, globus_task_id=t.get("globus_task_id"), result=t.get("result"), ) - for t in info.get("tasks", []) - ] + # Tasks loaded from disk with no globus_task_id and + # no cached result are orphaned -- get_status cannot + # query Globus for them (see line ~320). + if tracked.globus_task_id is None and tracked.result is None: + orphaned.append((bid, tracked.task_id)) + tasks.append(tracked) + self._batches[bid] = TrackedBatch( batch_id=bid, tool_name=info["tool_name"], @@ -161,6 +168,13 @@ def _load(self) -> None: logger.info( "Loaded %d batches from %s", len(data), self._persist_file ) + if orphaned: + logger.warning( + "%d task(s) reloaded without a Globus task_id -- their " + "results cannot be recovered. Examples: %s", + len(orphaned), + ", ".join(f"{b}/{t}" for b, t in orphaned[:5]), + ) # ── registration ─────────────────────────────────────────────────── @@ -241,8 +255,14 @@ def _wait_for_globus_task_ids( time.sleep(0.25) if pending: - logger.debug( - "%d tasks did not receive a Globus task_id within %.1fs", + # Promoted from debug -> warning: tasks without a task_id + # at this point will be lost across a server restart, so the + # user should see this immediately rather than only in the + # post-mortem orphan warning at reload time. + logger.warning( + "%d task(s) did not receive a Globus task_id within %.1fs; " + "they will be unrecoverable if the server restarts before " + "the next get_status call", len(pending), timeout, ) diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index 3a84c9f7..155dd76d 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -113,10 +113,30 @@ def set_pre_submit_hook(self, hook: Optional[Callable]) -> None: self._pre_submit_hook = hook def _apply_pre_submit_hook(self, task): - """Run the registered pre-submit hook (no-op when unset).""" + """Run the registered pre-submit hook (no-op when unset). + + Hook exceptions are wrapped in a ``ValueError`` naming the hook + and the offending task_id, so they surface to the agent as a + structured error instead of an opaque traceback. + """ if self._pre_submit_hook is None: return task - return self._pre_submit_hook(task) + try: + return self._pre_submit_hook(task) + except Exception as exc: + hook_name = getattr( + self._pre_submit_hook, "__name__", repr(self._pre_submit_hook) + ) + task_id = getattr(task, "task_id", "") + logger.warning( + "Pre-submit hook %s failed for task %s", + hook_name, + task_id, + exc_info=True, + ) + raise ValueError( + f"Pre-submit hook '{hook_name}' failed for task '{task_id}': {exc}" + ) from exc # ── Job tracking tools ───────────────────────────────────────────── From 2d25f89444275b783b9a9b9f2db9d05d6b407746 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Mon, 1 Jun 2026 15:04:24 -0500 Subject: [PATCH 032/119] Migrate XANES and gRASPA MCP servers to CGFastMCP Both servers now mirror the mace_mcp_hpc.py pattern: - CGFastMCP with lazy backend initialisation via init_backend(); the worker subprocesses re-importing the module no longer instantiate a backend at import time. - Job-management tools (check_job_status, get_job_results, list_jobs, cancel_job, check_endpoint_status) are auto-registered by CGFastMCP._register_job_tools; the external register_job_tools call is dropped. - __main__ wires init_backend(tracker_kwargs={"persist_file": ...}) and pairs run_mcp_server with shutdown_backend in finally. This also closes a real bug in graspa_mcp_hpc.py, which was instantiating JobTracker() with no persist_file and silently losing job state across restarts despite the server's instructions promising persistence. - Globus Transfer tools (transfer_files, check_transfer_status, list_remote_files) are registered on both servers when the transfer manager is configured, matching the existing MACE behaviour. - gRASPA expander now supports remote_structure_directory the same way MACE does: a one-shot probe task lists CIFs on the remote endpoint and the worker reads them directly from the staged path. - Ensemble flows use the schema_fanout_tool decorator; per-job structure metadata is propagated through the worker output (since the framework meta is only the index). Legacy *_mcp_parsl.py modules now raise a DeprecationWarning at import pointing to the *_hpc.py replacement; they remain functional because scripts/mcp_xanes_example/ still imports xanes_mcp_parsl. --- src/chemgraph/mcp/graspa_mcp_hpc.py | 280 +++++++++++++++++--------- src/chemgraph/mcp/graspa_mcp_parsl.py | 11 + src/chemgraph/mcp/mace_mcp_parsl.py | 11 + src/chemgraph/mcp/xanes_mcp_hpc.py | 250 +++++++++++++---------- src/chemgraph/mcp/xanes_mcp_parsl.py | 11 + 5 files changed, 363 insertions(+), 200 deletions(-) diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index 87eeb231..be7737a6 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -1,141 +1,231 @@ """Backend-agnostic gRASPA MCP server. -Replaces ``graspa_mcp_parsl.py`` by using the :mod:`chemgraph.execution` -abstraction layer. The execution backend (Parsl, EnsembleLauncher, -local) is selected at startup via ``config.toml`` or the -``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +The ensemble expander emits one job per ``(structure, condition)`` pair +and supports both local input directories and pre-staged remote +directories (mirrors the MACE server's local/remote modes). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. """ import logging +import os from pathlib import Path -from mcp.server.fastmcp import FastMCP - -from chemgraph.execution import TaskSpec, get_backend -from chemgraph.execution.job_tracker import JobTracker +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.config import get_transfer_manager from chemgraph.execution.utils import ( make_per_structure_output, resolve_structure_files, - submit_or_gather, - write_results_jsonl, ) -from chemgraph.mcp.job_tools import register_job_tools -from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble logger = logging.getLogger(__name__) -# ── Initialise execution backend ──────────────────────────────────────── -backend = get_backend() -tracker = JobTracker() +_JOBS_FILE = Path("~/.chemgraph/graspa_jobs.json").expanduser() -# ── MCP server ────────────────────────────────────────────────────────── -mcp = FastMCP( +mcp = CGFastMCP( name="ChemGraph Graspa Tools", instructions=""" - You expose tools for running graspa simulations and reading their results. - The available tools are: - 1. run_graspa_ensemble: run graspa calculations over all structures in a - directory using the configured execution backend. - 2. check_job_status: check progress of a submitted HPC job batch. - 3. get_job_results: retrieve results from a completed job batch. - 4. list_jobs: list all tracked job batches. - 5. cancel_job: cancel pending tasks in a job batch. + You expose tools for running gRASPA simulations and reading + their results. The available tools are: + 1. run_graspa_ensemble: run gRASPA calculations over every + structure in a directory at one or more (T, P) conditions. + Local mode uses input_structures; remote mode uses + remote_structure_directory (pre-stage files first with + transfer_files). + 2. check_job_status / get_job_results / list_jobs / cancel_job: + HPC job batch management. Job state persists across sessions. + 3. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on + the remote HPC filesystem before running ensembles in remote + mode. Guidelines: - - Use each tool only when its input schema matches the user request. - - Do not guess numerical values; report tool errors exactly as they occur. - - Keep responses compact -- full results are written to the output files - defined in the schemas. + - Use each tool only when its input schema matches the user + request. + - Do not guess numerical values; report tool errors exactly as + they occur. + - Keep responses compact -- full results are written to the + output files defined in the schemas. - When returning paths, use absolute paths. - Energies are in eV and wall times are in seconds. - When a tool returns status='submitted' with a batch_id, use - check_job_status to poll for progress before calling get_job_results. + check_job_status to poll for progress before calling + get_job_results. Job state is persisted across sessions. """, ) -register_job_tools(mcp, tracker, backend) -def _run_graspa_single(job: dict) -> dict: - """Execute a single gRASPA simulation (runs on the worker).""" +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _graspa_worker(job: dict) -> dict: + """Execute a single gRASPA simulation on a backend worker.""" from chemgraph.schemas.graspa_schema import graspa_input_schema from chemgraph.tools.graspa_tools import run_graspa_core - params = graspa_input_schema(**job) if isinstance(job, dict) else job - return run_graspa_core(params) - + job = dict(job) + structure = job.pop("_structure_name", None) + temperature = job.get("temperature") + pressure = job.get("pressure") + + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "raspa.log"), + ) -@mcp.tool( - name="run_graspa_ensemble", - description="Run an ensemble of graspa calculations for multiple input files.", -) -async def run_graspa_ensemble( - params: graspa_input_schema_ensemble, -): - """Run an ensemble of gRASPA calculations over all structure files - using the configured execution backend. - - Parameters - ---------- - params : graspa_input_schema_ensemble - Input parameters for the ensemble of gRASPA calculations. - """ - structure_files, output_dir = resolve_structure_files( - params.input_structures, - extensions={".cif"}, + params = graspa_input_schema(**job) + result = run_graspa_core(params) + + if isinstance(result, dict): + merged = { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + **result, + } + merged.setdefault("status", "success") + return merged + return { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + "result": result, + "status": "success", + } + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) ) - # Base output file name - base_output = Path(params.output_result_file).resolve() - pending_tasks = [] +def _expand_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: + """Server-side expansion of an ensemble request into per-job dicts. + Local mode: enumerates ``input_structures`` on this host. + Remote mode: submits a one-shot probe task to the backend to list + files under ``remote_structure_directory``, then builds per-file + jobs that the worker reads directly from the remote filesystem. + """ + base_output = Path(params.output_result_file) + + if params.remote_structure_directory: + remote_dir = params.remote_structure_directory + mcp._ensure_backend() + probe = TaskSpec( + task_id="ls_remote_dir", + task_type="python", + callable=_ls_remote_files, + kwargs={"path": remote_dir}, + ) + fut = mcp._backend.submit(probe) + try: + file_names = fut.result(timeout=30) + except Exception as exc: + raise RuntimeError( + f"Could not list remote directory {remote_dir}: {exc}" + ) from exc + + # Filter to CIF files (gRASPA expects CIFs). + file_names = [f for f in file_names if f.lower().endswith(".cif")] + if not file_names: + raise ValueError( + f"No CIF files found under remote directory {remote_dir}." + ) + + jobs = [] + for fname in file_names: + mof_name = Path(fname).stem + for condition in params.conditions: + per_output = make_per_structure_output(Path(fname), base_output) + jobs.append( + { + "_structure_name": mof_name, + "remote_structure_file": f"{remote_dir}/{fname}", + "output_result_file": str(per_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } + ) + return jobs + + if not params.input_structures: + raise ValueError( + "Either input_structures or remote_structure_directory " + "must be provided." + ) + + structure_files, _ = resolve_structure_files( + params.input_structures, extensions={".cif"} + ) + jobs = [] for struct_path in structure_files: mof_name = struct_path.stem for condition in params.conditions: - per_struct_output = make_per_structure_output(struct_path, base_output) - job = { - "input_structure_file": str(struct_path), - "output_result_file": str(per_struct_output), - "temperature": condition.temperature, - "pressure": condition.pressure, - "adsorbate": params.adsorbate, - "n_cycles": params.n_cycles, - } - - task = TaskSpec( - task_id=f"graspa_{mof_name}_{condition.temperature}K_{condition.pressure}Pa", - task_type="python", - callable=_run_graspa_single, - kwargs={"job": job}, + per_output = make_per_structure_output(struct_path, base_output) + jobs.append( + { + "_structure_name": mof_name, + "input_structure_file": str(struct_path), + "output_result_file": str(per_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } ) - fut = backend.submit(task) + return jobs - task_meta = { - "structure": mof_name, - "temperature": condition.temperature, - "pressure": condition.pressure, - } - pending_tasks.append((task_meta, fut)) - result = await submit_or_gather( - backend, pending_tasks, tracker, "run_graspa_ensemble", - ) +@mcp.schema_fanout_tool( + name="run_graspa_ensemble", + description=( + "Run gRASPA calculations over every structure in a directory at " + "one or more (temperature, pressure) conditions. Local mode " + "uses input_structures; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_graspa_worker, +) +def run_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: + return _expand_graspa_ensemble(params) - if result["status"] == "completed": - summary_log_path = output_dir / "simulation_results.jsonl" - success_count, total_count = write_results_jsonl( - result["results"], summary_log_path, - ) - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." - ) - # Async remote: return submission confirmation - return result +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on gRASPA MCP server.") if __name__ == "__main__": - run_mcp_server(mcp, default_port=9001) + from chemgraph.mcp.server_utils import run_mcp_server + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9001) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/graspa_mcp_parsl.py b/src/chemgraph/mcp/graspa_mcp_parsl.py index 3b55690a..378dc5ad 100644 --- a/src/chemgraph/mcp/graspa_mcp_parsl.py +++ b/src/chemgraph/mcp/graspa_mcp_parsl.py @@ -1,8 +1,19 @@ import asyncio import json import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.graspa_mcp_parsl is deprecated; use " + "chemgraph.mcp.graspa_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP import parsl diff --git a/src/chemgraph/mcp/mace_mcp_parsl.py b/src/chemgraph/mcp/mace_mcp_parsl.py index 4b3f03fc..42ae67c7 100644 --- a/src/chemgraph/mcp/mace_mcp_parsl.py +++ b/src/chemgraph/mcp/mace_mcp_parsl.py @@ -1,6 +1,17 @@ import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.mace_mcp_parsl is deprecated; use " + "chemgraph.mcp.mace_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP from parsl.config import Config from parsl.executors import HighThroughputExecutor diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 4abb94e0..8583ae65 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -1,25 +1,29 @@ """Backend-agnostic XANES/FDMNES MCP server. -Replaces ``xanes_mcp_parsl.py`` by using the :mod:`chemgraph.execution` -abstraction layer. The execution backend (Parsl, EnsembleLauncher, -local) is selected at startup via ``config.toml`` or the -``CHEMGRAPH_EXECUTION_BACKEND`` environment variable. +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +The ensemble expander runs server-side and prepares per-structure +FDMNES input files in ``runs_dir``; the worker (which runs on the +backend) executes FDMNES via subprocess and extracts convergence data. +This assumes the server and worker share a filesystem (true for any +Globus Compute endpoint on the same HPC where the MCP server runs; +Globus Transfer staging is a separate concern). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. """ import logging +import subprocess from pathlib import Path -from mcp.server.fastmcp import FastMCP - -from chemgraph.execution import TaskSpec, get_backend -from chemgraph.execution.job_tracker import JobTracker -from chemgraph.execution.utils import ( - resolve_structure_files, - submit_or_gather, - write_results_jsonl, -) -from chemgraph.mcp.job_tools import register_job_tools -from chemgraph.mcp.server_utils import run_mcp_server +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import resolve_structure_files +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.xanes_schema import ( mp_query_schema, xanes_input_schema, @@ -28,14 +32,9 @@ logger = logging.getLogger(__name__) -# ── Initialise execution backend ──────────────────────────────────────── -backend = get_backend() +_JOBS_FILE = Path("~/.chemgraph/xanes_jobs.json").expanduser() -_jobs_file = Path("~/.chemgraph/xanes_jobs.json").expanduser() -tracker = JobTracker(persist_file=_jobs_file) - -# ── MCP server ────────────────────────────────────────────────────────── -mcp = FastMCP( +mcp = CGFastMCP( name="ChemGraph XANES Tools", instructions=""" You expose tools for running XANES/FDMNES simulations. @@ -45,10 +44,11 @@ using the configured execution backend. 3. fetch_mp_structures: fetch optimized structures from Materials Project. 4. plot_xanes: generate normalized XANES plots for completed calculations. - 5. check_job_status: check progress of a submitted HPC job batch. - 6. get_job_results: retrieve results from a completed job batch. - 7. list_jobs: list all tracked job batches. - 8. cancel_job: cancel pending tasks in a job batch. + 5. check_job_status / get_job_results / list_jobs / cancel_job: HPC + job batch management. Job state persists across sessions. + 6. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on the + remote HPC filesystem before running ensembles. Guidelines: - Use each tool only when its input schema matches the user request. @@ -64,7 +64,20 @@ to retrieve results. """, ) -register_job_tools(mcp, tracker, backend) + + +# ── Single-structure tool ────────────────────────────────────────────── + + +def _xanes_single_worker(params: xanes_input_schema) -> dict: + """Run a single FDMNES calculation on a backend worker.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + result = run_xanes_core(params) + if isinstance(result, dict): + result.setdefault("status", "success") + return result + return {"status": "success", "result": result} @mcp.tool( @@ -72,18 +85,63 @@ description="Run a single XANES/FDMNES calculation for one input structure.", ) def run_xanes_single(params: xanes_input_schema): - """Run a single FDMNES calculation using the core engine.""" - from chemgraph.tools.xanes_tools import run_xanes_core + """Run a single FDMNES calculation using the core engine. + + The CGFastMCP wrapper submits this call to the configured backend; + the body is the direct-call fallback when no backend is active. + """ + return _xanes_single_worker(params) - return run_xanes_core(params) +# ── Ensemble fanout ──────────────────────────────────────────────────── -def _xanes_post_fn(meta: dict, _result) -> dict: - """Post-process a completed FDMNES task: extract convergence data.""" + +def _xanes_ensemble_worker(item: dict) -> dict: + """Execute one prepared FDMNES run on the backend. + + The expander has already written ``input_fdmnes.txt`` (or the + equivalent) into ``item['run_dir']``; this worker runs the binary + via subprocess and then extracts convergence data. + """ from chemgraph.tools.xanes_tools import extract_conv + run_dir = item["run_dir"] + fdmnes_exe = item["fdmnes_exe"] + meta = { + "structure": item.get("structure"), + "run_dir": run_dir, + "z_absorber": item.get("z_absorber"), + } + + stdout_path = Path(run_dir) / "fdmnes_stdout.txt" + stderr_path = Path(run_dir) / "fdmnes_stderr.txt" try: - conv_data = extract_conv(meta["run_dir"]) + with open(stdout_path, "w") as out, open(stderr_path, "w") as err: + proc = subprocess.run( + [fdmnes_exe], + cwd=run_dir, + stdout=out, + stderr=err, + check=False, + ) + if proc.returncode != 0: + return { + **meta, + "status": "failure", + "error_type": "FDMNESExitCode", + "message": f"FDMNES exited with code {proc.returncode}", + "returncode": proc.returncode, + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"FDMNES launch failed: {e}", + } + + try: + conv_data = extract_conv(run_dir) return { **meta, "status": "success", @@ -98,24 +156,9 @@ def _xanes_post_fn(meta: dict, _result) -> dict: } -@mcp.tool( - name="run_xanes_ensemble", - description="Run an ensemble of XANES/FDMNES calculations using the configured backend.", -) -async def run_xanes_ensemble(params: xanes_input_schema_ensemble): - """Run ensemble XANES calculations over all structure files. - - For each structure file: - 1. Reads the structure via ASE. - 2. Creates FDMNES input files in a per-structure subdirectory. - 3. Submits a shell task to run FDMNES. - 4. Gathers results and writes a JSONL summary log. - - Parameters - ---------- - params : xanes_input_schema_ensemble - Input parameters for the ensemble calculation. - """ +def _expand_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: + """Server-side expansion: prepare per-structure run dirs and return + one item per structure for the worker to execute.""" from ase.io import read as ase_read from chemgraph.tools.xanes_tools import write_fdmnes_input @@ -125,19 +168,14 @@ async def run_xanes_ensemble(params: xanes_input_schema_ensemble): extensions={".cif", ".xyz", ".poscar"}, ) - # Create a batch runs directory runs_dir = output_dir / "fdmnes_batch_runs" runs_dir.mkdir(parents=True, exist_ok=True) - fdmnes_exe = params.fdmnes_exe - - pending_tasks = [] - + items: list[dict] = [] for i, struct_path in enumerate(structure_files): run_dir = runs_dir / f"run_{i}" run_dir.mkdir(parents=True, exist_ok=True) - # Read structure and write FDMNES inputs atoms = ase_read(str(struct_path)) z_abs = ( params.z_absorber @@ -153,48 +191,34 @@ async def run_xanes_ensemble(params: xanes_input_schema_ensemble): magnetism=params.magnetism, ) - # Submit shell task - task = TaskSpec( - task_id=f"xanes_{struct_path.stem}_{i}", - task_type="shell", - command=f'cd "{run_dir}" && "{fdmnes_exe}"', - working_dir=str(run_dir), - stdout=str(run_dir / "fdmnes_stdout.txt"), - stderr=str(run_dir / "fdmnes_stderr.txt"), + items.append( + { + "structure": struct_path.name, + "run_dir": str(run_dir), + "z_absorber": z_abs, + "fdmnes_exe": params.fdmnes_exe, + } ) - fut = backend.submit(task) - task_meta = { - "structure": struct_path.name, - "run_dir": str(run_dir), - "z_absorber": z_abs, - } - pending_tasks.append((task_meta, fut)) + return items - result = await submit_or_gather( - backend, pending_tasks, tracker, "run_xanes_ensemble", - post_fn=_xanes_post_fn, - ) - if result["status"] == "completed": - summary_log_path = output_dir / "xanes_results.jsonl" - success_count, total_count = write_results_jsonl( - result["results"], summary_log_path, - ) - return ( - f"Ensemble execution completed. Ran {total_count} tasks " - f"({success_count} successful). " - f"Detailed results appended to '{summary_log_path}'." - ) +@mcp.schema_fanout_tool( + name="run_xanes_ensemble", + description=( + "Run FDMNES/XANES calculations over every structure in an input " + "directory (or list of files). Each structure is prepared " + "server-side and submitted to the configured execution backend." + ), + worker=_xanes_ensemble_worker, +) +def run_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: + return _expand_xanes_ensemble(params) - # Async remote: return submission confirmation - return result + +# ── Orchestration tools (no backend involvement) ─────────────────────── -@mcp.tool( - name="fetch_mp_structures", - description="Fetch optimized structures from Materials Project.", -) def fetch_mp_structures(params: mp_query_schema): """Fetch structures from Materials Project and save as CIF files and pickle database.""" from chemgraph.tools.xanes_tools import ( @@ -214,19 +238,8 @@ def fetch_mp_structures(params: mp_query_schema): } -@mcp.tool( - name="plot_xanes", - description="Generate normalized XANES plots for completed FDMNES calculations.", -) def plot_xanes(runs_dir: str): - """Generate XANES plots for all completed runs in a directory. - - Parameters - ---------- - runs_dir : str - Path to the ``fdmnes_batch_runs`` directory containing ``run_*`` - subdirectories with FDMNES outputs. - """ + """Generate XANES plots for all completed runs in a directory.""" from chemgraph.tools.xanes_tools import ( _get_data_dir, plot_xanes_results, @@ -247,5 +260,32 @@ def plot_xanes(runs_dir: str): } +mcp.add_tool( + fetch_mp_structures, + name="fetch_mp_structures", + description="Fetch optimized structures from Materials Project.", +) +mcp.add_tool( + plot_xanes, + name="plot_xanes", + description="Generate normalized XANES plots for completed FDMNES calculations.", +) + + +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on XANES MCP server.") + + if __name__ == "__main__": - run_mcp_server(mcp, default_port=9007) + from chemgraph.mcp.server_utils import run_mcp_server + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9007) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/xanes_mcp_parsl.py b/src/chemgraph/mcp/xanes_mcp_parsl.py index 0ec794c1..b5f8729a 100644 --- a/src/chemgraph/mcp/xanes_mcp_parsl.py +++ b/src/chemgraph/mcp/xanes_mcp_parsl.py @@ -1,8 +1,19 @@ import asyncio import json import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.xanes_mcp_parsl is deprecated; use " + "chemgraph.mcp.xanes_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP import parsl From bc54083c7871ca5af2f3c3cec02206133c1303a6 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Tue, 2 Jun 2026 09:28:27 -0500 Subject: [PATCH 033/119] Forward HPC env vars to MCP stdio subprocess + document EL config cli/mcp_utils.py: when launching an MCP server over stdio, forward an explicit allowlist of environment variables (CHEMGRAPH_EXECUTION_BACKEND, GLOBUS_COMPUTE_ENDPOINT_ID, GLOBUS_TRANSFER_*, ALCF_ACCESS_TOKEN, plus shell + virtualenv essentials). The MCP SDK's stdio transport otherwise inherits only a hard-coded whitelist of standard system variables, so a user who exported a Globus Compute endpoint ID in their shell would find the spawned server unable to see it. config.toml: rewrite the [execution.ensemble_launcher] section to reflect the keys actually consumed by get_launcher_config() (task_executor_name, mpi_flavour) and document the client-only mode (client_only, checkpoint_dir, node_id) as a commented alternative. Add a commented [execution.globus_transfer] template so the file staging endpoints are discoverable without having to read the source. --- src/chemgraph/cli/mcp_utils.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/chemgraph/cli/mcp_utils.py b/src/chemgraph/cli/mcp_utils.py index ce287af9..752849e3 100644 --- a/src/chemgraph/cli/mcp_utils.py +++ b/src/chemgraph/cli/mcp_utils.py @@ -6,6 +6,7 @@ from __future__ import annotations +import os import shlex import time from typing import List, Optional @@ -15,6 +16,35 @@ from chemgraph.cli.formatting import console from chemgraph.utils.async_utils import run_async_callable +# Env vars that the MCP stdio subprocess may need. The MCP SDK's stdio +# transport inherits only a hard-coded whitelist of standard system vars +# (PATH, HOME, etc.) by default -- ChemGraph- and Globus-specific keys +# must be passed through explicitly or the spawned MCP server has no way +# to see what the user exported in their shell. +_FORWARDED_ENV_VARS = ( + # Shell essentials (so python and the user's HOME resolve correctly) + "PATH", + "HOME", + "USER", + "TMPDIR", + "LANG", + "LC_ALL", + "VIRTUAL_ENV", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + # ChemGraph runtime selection + "CHEMGRAPH_EXECUTION_BACKEND", + "CHEMGRAPH_LOG_DIR", + # Globus Compute + "GLOBUS_COMPUTE_ENDPOINT_ID", + # Globus Transfer + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + # ALCF inference endpoints + "ALCF_ACCESS_TOKEN", +) + def load_mcp_tools_from_config( url: Optional[str] = None, @@ -65,11 +95,13 @@ def load_mcp_tools_from_config( transport_label = f"streamable_http @ {url}" elif command: parts = shlex.split(command) + env = {k: os.environ[k] for k in _FORWARDED_ENV_VARS if k in os.environ} connections = { server_name: { "command": parts[0], "args": parts[1:], "transport": "stdio", + "env": env, } } transport_label = f"stdio: {command}" From e85c6759007df45577149aa06ae3f8d5c8bdb8b8 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 4 Jun 2026 08:51:21 -0500 Subject: [PATCH 034/119] Silence LocalBackend worker stdout under stdio MCP transport Stdio MCP servers use stdout as the JSON-RPC channel; ProcessPoolExecutor workers inherit that fd. An unguarded worker print (e.g. mace/tools/cg.py's "cuequivariance ... will be disabled" notice) corrupts the protocol stream and aborts the client session on teardown with a JSONRPCMessage ValidationError. LocalBackend.initialize now accepts a silence_worker_stdout kwarg and also reads CHEMGRAPH_LOCAL_SILENCE_STDOUT=1. When set, it passes a module-level _silence_worker_stdout initializer to ProcessPoolExecutor that runs os.dup2(stderr_fd, stdout_fd) in each child, so worker prints land on stderr (logged) instead of the JSON-RPC pipe. server_utils.run_mcp_server now sets the env var via setdefault before mcp.run(transport='stdio'), so every stdio-launched MCP server gets the fix automatically. Override with CHEMGRAPH_LOCAL_SILENCE_STDOUT=0 to restore raw stdout for debugging. Default behavior unchanged for direct LocalBackend users (notebooks, CLI): env var defaults to off; the 9 TestLocalBackend pytest cases still pass. --- src/chemgraph/execution/local_backend.py | 40 ++++++++++++++++++++++-- src/chemgraph/mcp/server_utils.py | 5 +++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/chemgraph/execution/local_backend.py b/src/chemgraph/execution/local_backend.py index c6a66abe..e9250914 100644 --- a/src/chemgraph/execution/local_backend.py +++ b/src/chemgraph/execution/local_backend.py @@ -8,7 +8,9 @@ from __future__ import annotations import logging +import os import subprocess +import sys from concurrent.futures import Future, ProcessPoolExecutor from typing import Any @@ -20,6 +22,24 @@ _DEFAULT_MAX_WORKERS = 4 +def _silence_worker_stdout() -> None: + """ProcessPoolExecutor *initializer*: redirect this worker's stdout fd to stderr. + + Used when ``LocalBackend`` runs inside a stdio MCP server, where the + parent process's stdout is the JSON-RPC channel. Worker children inherit + that fd by default, so any unguarded print (e.g. ``mace/tools/cg.py``'s + "cuequivariance ... will be disabled" notice) corrupts the protocol + stream. dup2 redirects this child's stdout fd to its stderr fd so prints + are logged but never reach the client. + """ + try: + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + except (OSError, ValueError, AttributeError): + # Best-effort: skip silently if the fds aren't real (e.g. in some + # test or notebook contexts where stderr is captured). + pass + + def _run_shell_task( command: str, working_dir: str | None, @@ -72,10 +92,26 @@ def __init__(self) -> None: def initialize(self, system: str = "local", **kwargs: Any) -> None: max_workers = kwargs.get("max_workers", _DEFAULT_MAX_WORKERS) - self._pool = ProcessPoolExecutor(max_workers=max_workers) + + # Opt-in: silence worker stdout (redirect fd to stderr) so prints + # from worker callables don't pollute a parent's stdout. Required + # when LocalBackend runs under stdio MCP, where the parent's stdout + # IS the JSON-RPC channel. Off by default so notebook/CLI users + # still see prints. Explicit kwarg wins; otherwise env var. + silence = kwargs.get("silence_worker_stdout") + if silence is None: + silence = os.environ.get("CHEMGRAPH_LOCAL_SILENCE_STDOUT") == "1" + + pool_kwargs: dict[str, Any] = {"max_workers": max_workers} + if silence: + pool_kwargs["initializer"] = _silence_worker_stdout + + self._pool = ProcessPoolExecutor(**pool_kwargs) self._initialized = True logger.info( - "LocalBackend initialized with %d workers", max_workers + "LocalBackend initialized with %d workers (silence_worker_stdout=%s)", + max_workers, + bool(silence), ) def submit(self, task: TaskSpec) -> Future: diff --git a/src/chemgraph/mcp/server_utils.py b/src/chemgraph/mcp/server_utils.py index 91fce11e..71cc5c6d 100644 --- a/src/chemgraph/mcp/server_utils.py +++ b/src/chemgraph/mcp/server_utils.py @@ -84,6 +84,11 @@ def run_mcp_server( uvicorn.run(app, host=args.host, port=args.port) else: logging.info("Starting %s via stdio transport...", mcp.name) + # Under stdio, the server's stdout IS the JSON-RPC channel. Any + # unguarded print from a worker (e.g. mace's "cuequivariance ... + # will be disabled" notice) would corrupt it. setdefault so the + # user can override with CHEMGRAPH_LOCAL_SILENCE_STDOUT=0. + os.environ.setdefault("CHEMGRAPH_LOCAL_SILENCE_STDOUT", "1") # FastMCP.run(transport='stdio') handles the stdio loop mcp.run(transport="stdio") From 84c87dd42903b365c12a1bc1c56b5960b7e51dce Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 4 Jun 2026 08:52:31 -0500 Subject: [PATCH 035/119] Add smoke and demo scripts for execution backends scripts/smoke/ -- pass/fail validators, one per backend (local, globus_compute, globus_transfer, parsl_in_job, ensemble_launcher_in_job) plus a shared _smoke_utils.py and README. Trivial water payload, prints [PASS]/[FAIL] per check, exits nonzero on failure. Drives the production get_backend() / GlobusTransferManager / _mace_worker code paths. scripts/demo/ -- real-chemistry demonstrations, 10 scripts covering each backend with both direct (no LLM) and agent (LLM + MCP) flavours. Each demo runs a 5-molecule (H2O, CH4, NH3, CO2, ethanol) MACE driver='thermo' screen and prints an electronic energy / enthalpy / Gibbs free energy table plus CSV. Shared _demo_chemistry.py helper handles inline-vs-remote structure embedding and JSON round-trip. In-job (Parsl + EnsembleLauncher) scripts target qsub interactive allocations on Polaris/Aurora; EL scripts include a client-only mode for the orchestrator-connection pathway added in bc54083c. Validated locally: smoke_local.py 7/7 pass; demo_local_direct.py screens the 5 molecules (water G=-13.69 eV, ethanol G=-44.91 eV, ethanol lowest); demo_local_agent.py round-trips via stdio MCP with clean teardown. --- scripts/demo/README.md | 187 ++++++++++++ scripts/demo/_demo_chemistry.py | 275 +++++++++++++++++ .../demo_ensemble_launcher_in_job_agent.py | 127 ++++++++ .../demo_ensemble_launcher_in_job_direct.py | 91 ++++++ scripts/demo/demo_globus_compute_agent.py | 133 ++++++++ scripts/demo/demo_globus_compute_direct.py | 97 ++++++ scripts/demo/demo_globus_transfer_agent.py | 169 ++++++++++ scripts/demo/demo_globus_transfer_direct.py | 182 +++++++++++ scripts/demo/demo_local_agent.py | 135 ++++++++ scripts/demo/demo_local_direct.py | 73 +++++ scripts/demo/demo_parsl_in_job_agent.py | 127 ++++++++ scripts/demo/demo_parsl_in_job_direct.py | 101 ++++++ scripts/demo/structures/ammonia.xyz | 6 + scripts/demo/structures/co2.xyz | 5 + scripts/demo/structures/ethanol.xyz | 11 + scripts/demo/structures/methane.xyz | 7 + scripts/demo/structures/water.xyz | 5 + scripts/smoke/README.md | 125 ++++++++ scripts/smoke/_smoke_utils.py | 110 +++++++ .../smoke/smoke_ensemble_launcher_in_job.py | 288 ++++++++++++++++++ scripts/smoke/smoke_globus_compute.py | 226 ++++++++++++++ scripts/smoke/smoke_globus_transfer.py | 191 ++++++++++++ scripts/smoke/smoke_local.py | 178 +++++++++++ scripts/smoke/smoke_parsl_in_job.py | 234 ++++++++++++++ scripts/smoke/water.xyz | 5 + 25 files changed, 3088 insertions(+) create mode 100644 scripts/demo/README.md create mode 100644 scripts/demo/_demo_chemistry.py create mode 100644 scripts/demo/demo_ensemble_launcher_in_job_agent.py create mode 100644 scripts/demo/demo_ensemble_launcher_in_job_direct.py create mode 100644 scripts/demo/demo_globus_compute_agent.py create mode 100644 scripts/demo/demo_globus_compute_direct.py create mode 100644 scripts/demo/demo_globus_transfer_agent.py create mode 100644 scripts/demo/demo_globus_transfer_direct.py create mode 100644 scripts/demo/demo_local_agent.py create mode 100644 scripts/demo/demo_local_direct.py create mode 100644 scripts/demo/demo_parsl_in_job_agent.py create mode 100644 scripts/demo/demo_parsl_in_job_direct.py create mode 100644 scripts/demo/structures/ammonia.xyz create mode 100644 scripts/demo/structures/co2.xyz create mode 100644 scripts/demo/structures/ethanol.xyz create mode 100644 scripts/demo/structures/methane.xyz create mode 100644 scripts/demo/structures/water.xyz create mode 100644 scripts/smoke/README.md create mode 100644 scripts/smoke/_smoke_utils.py create mode 100644 scripts/smoke/smoke_ensemble_launcher_in_job.py create mode 100644 scripts/smoke/smoke_globus_compute.py create mode 100644 scripts/smoke/smoke_globus_transfer.py create mode 100644 scripts/smoke/smoke_local.py create mode 100644 scripts/smoke/smoke_parsl_in_job.py create mode 100644 scripts/smoke/water.xyz diff --git a/scripts/demo/README.md b/scripts/demo/README.md new file mode 100644 index 00000000..17c7167b --- /dev/null +++ b/scripts/demo/README.md @@ -0,0 +1,187 @@ +# ChemGraph execution-layer demonstration scripts + +Real-chemistry demos that exercise each `ExecutionBackend` end-to-end. +A 5-molecule library (H2O, CH4, NH3, CO2, ethanol) is screened for +thermochemistry with MACE-MP (`driver="thermo"` → optimize geometry + +vibrational frequencies + ideal-gas thermo at 298.15 K). Each script +writes a CSV of electronic energy, enthalpy, entropy, Gibbs free +energy per molecule and prints a fixed-width summary table. + +These complement `scripts/smoke/`: + +| Directory | Purpose | Pass criterion | +|-----------|---------|---------------| +| `scripts/smoke/` | Regression validators on a trivial water payload | Exit 0 with every `[PASS]` | +| `scripts/demo/` | Realistic chemistry showcases | Useful property table; demos *fail loud* but their value is the output, not a green check | + +## Layout + +``` +scripts/demo/ +├── README.md (this file) +├── _demo_chemistry.py shared helpers (workload, formatting, agent prompt) +├── structures/ 5 .xyz fixtures (~50 lines each) +│ ├── water.xyz methane.xyz ammonia.xyz co2.xyz ethanol.xyz +├── demo_local_direct.py laptop, no LLM, no HPC +├── demo_local_agent.py laptop, LLM, no HPC +├── demo_globus_compute_direct.py laptop, no LLM, live GC endpoint +├── demo_globus_compute_agent.py laptop, LLM, live GC endpoint +├── demo_globus_transfer_direct.py laptop, no LLM, Globus Transfer + GC +├── demo_globus_transfer_agent.py laptop, LLM, Globus Transfer + GC +├── demo_parsl_in_job_direct.py inside qsub -I on Polaris/Aurora, no LLM +├── demo_parsl_in_job_agent.py inside qsub -I, LLM +├── demo_ensemble_launcher_in_job_direct.py inside qsub -I, no LLM +└── demo_ensemble_launcher_in_job_agent.py inside qsub -I, LLM +``` + +Direct demos call `chemgraph.execution.config.get_backend()` and +`backend.submit_batch(...)` directly. Agent demos spawn +`python -m chemgraph.mcp.mace_mcp_hpc` as a stdio subprocess and drive +it with a ChemGraph LLM agent over `langchain-mcp-adapters`. + +## Environment-variable matrix + +| Variable | Required by | Notes | +|----------|-------------|-------| +| `GLOBUS_COMPUTE_ENDPOINT_ID` | `demo_globus_compute_*`, `demo_globus_transfer_*` | UUID from `globus-compute-endpoint start chemgraph-` | +| `GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID` | `demo_globus_transfer_*` | Globus Connect Personal on the laptop | +| `GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID` | `demo_globus_transfer_*` | HPC collection UUID (ALCF data portal) | +| `GLOBUS_TRANSFER_DESTINATION_BASE_PATH` | `demo_globus_transfer_*` | e.g. `/eagle/projects/MyProj/staging` | +| `COMPUTE_SYSTEM` | `demo_parsl_in_job_*`, `demo_ensemble_launcher_in_job_*` | `polaris` or `aurora` | +| `PBS_NODEFILE` | both in-job demos | Set automatically inside `qsub` — demos abort if missing | +| `CG_AMQP_PORT=443` | optional, Aurora | Use when outbound 5671 is blocked | +| LLM API key (e.g. `OPENAI_API_KEY`) | all `*_agent.py` | Match the `--model` flag | + +## Running + +### Laptop, no creds + +```bash +source .cg_env/bin/activate +python scripts/demo/demo_local_direct.py +# ~20s for the 5 molecules on CPU; writes demo_local_out/{demo_local.csv,*_thermo.json} +``` + +Sample output: +``` +=== Local backend thermo screen (cpu) === +molecule energy/eV enthalpy/eV S/(eV/K) G/eV #freqs wall/s conv +--------------------------------------------------------------------------------------------- +water -13.7861 -13.1063 0.001958 -13.6900 9 3.0 True +methane -23.1669 -21.8802 0.001931 -22.4559 15 3.6 True +ammonia -18.9970 -17.9888 0.001996 -18.5839 12 3.3 True +co2 -22.5459 -22.1320 0.002209 -22.7906 9 2.9 True +ethanol -46.2767 -44.0648 0.002820 -44.9056 27 3.3 True +``` + +### Laptop + LLM + +```bash +export OPENAI_API_KEY=... +python scripts/demo/demo_local_agent.py --model gpt-4o-mini +``` + +Agent will call `run_mace_single` 5 times via the MCP subprocess and +respond with a markdown table. + +### Laptop → live Globus Compute endpoint + +```bash +export GLOBUS_COMPUTE_ENDPOINT_ID="" +export COMPUTE_SYSTEM=polaris # for logging +python scripts/demo/demo_globus_compute_direct.py # ~5-15 min first run (model download on remote) +python scripts/demo/demo_globus_compute_agent.py --model gpt-4o-mini +``` + +For Aurora add `--device xpu --amqp-port 443`. + +### Laptop → Globus Transfer + Globus Compute + +```bash +export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging +python scripts/demo/demo_globus_transfer_direct.py +python scripts/demo/demo_globus_transfer_agent.py --model gpt-4o-mini +``` + +The direct demo stages the 5 `.xyz` fixtures, then runs MACE in +*remote-path* mode (worker reads from the staged dir, no inline +embedding). The agent demo asks the LLM to call `transfer_files` and +then `run_mace_ensemble` itself. + +Remote-path mode has one quirk: `_mace_worker` only attaches +`full_output` back to the caller when an `inline_structure` is set +(see `src/chemgraph/mcp/mace_mcp_hpc.py:127-131`). So in +`demo_globus_transfer_direct.py` the printed table will have blank +thermo columns — the full JSON results sit on the HPC under +`/_thermo.json`. Pull them back with a +follow-up Globus Transfer if needed. + +### Inside a PBS allocation on Polaris + +```bash +qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle +# Now on the compute node: +module load conda +conda activate base +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=polaris +cd ~/chemgraph/ChemGraph +python scripts/demo/demo_parsl_in_job_direct.py +python scripts/demo/demo_ensemble_launcher_in_job_direct.py +``` + +### Inside a PBS allocation on Aurora + +```bash +qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare +module load frameworks +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=aurora +cd ~/chemgraph/ChemGraph +python scripts/demo/demo_parsl_in_job_direct.py --device xpu +python scripts/demo/demo_ensemble_launcher_in_job_direct.py --device xpu +``` + +Agent variants on either system require an LLM key and follow the +same pattern as `demo_local_agent.py`. + +## Tips + +- `--molecules water methane` to run on a subset (faster iteration). +- `--output-dir /custom/path` to redirect CSV + per-molecule JSON. +- The first run on a fresh endpoint / fresh venv will be slow because + MACE-MP downloads a ~hundred-MB model. Subsequent runs hit the cache + at `~/.cache/mace/`. + +## Known caveats + +- **`langchain-mcp-adapters` must be pinned to `0.1.14`** for the + `*_agent.py` scripts to import. Versions `>=0.2.0` import + `langchain_core.messages.content` (a 1.x API) which doesn't exist in + `langchain-core 0.3.x` — and `langgraph 0.4.7` (pinned in + `pyproject.toml`) constrains us to `langchain-core 0.3.x`. Fix in + `.cg_env`: + ```bash + pip install 'langchain-mcp-adapters==0.1.14' + ``` + This is an **env-only pin** — `pyproject.toml` still lists + `langchain-mcp-adapters` unpinned, so a fresh `pip install -e .` + will regress to `>=0.2`. Re-run the pin command after any clean env + rebuild. The durable fix (one-line edit to `pyproject.toml`) was + deferred per user request. +- `ensemble-launcher` is not on PyPI for Python 3.12; the in-job EL + demos only work on HPC where `scripts/hpc_setup/install_remote.sh` + builds it from source. + +## See also + +- `scripts/smoke/` — pass/fail regression validators (trivial payload). +- `scripts/hpc_setup/{README.md,e2e_test_runbook.md}` — install + ChemGraph + start a Globus Compute endpoint on Polaris/Aurora. +- `scripts/globus_compute_example/` — looser tutorial-style examples, + predecessors of these demos. +- `src/chemgraph/execution/` — the production backends the demos call. +- `src/chemgraph/mcp/mace_mcp_hpc.py` — the MCP server every agent + demo spawns as a subprocess. diff --git a/scripts/demo/_demo_chemistry.py b/scripts/demo/_demo_chemistry.py new file mode 100644 index 00000000..82f714ee --- /dev/null +++ b/scripts/demo/_demo_chemistry.py @@ -0,0 +1,275 @@ +"""Shared chemistry-screening helpers for scripts/demo/*. + +Each demo script in this directory is a thin wrapper around +``submit_and_collect`` -- the actual chemistry workload (a 5-molecule +thermochemistry screen) lives here so we don't duplicate it across +backends. + +Workload +-------- +For each of {water, methane, ammonia, CO2, ethanol} a single MACE +``driver="thermo"`` job is submitted via the configured +``ExecutionBackend``. This drives ``chemgraph.mcp.mace_mcp_hpc._mace_worker`` +under the hood, which itself wraps ``chemgraph.tools.parsl_tools.run_mace_core`` +-> ``chemgraph.tools.ase_core.run_ase_core``. The ``thermo`` driver +optimises the geometry, computes vibrational frequencies, then derives +ideal-gas thermochemistry at the requested temperature/pressure +(``src/chemgraph/tools/ase_core.py:556-602``). + +Two modes +--------- +* ``inline=False`` -- the worker reads ``input_structure_file`` from a + shared filesystem (local, Parsl on a compute node, EL). The demo + reads the on-disk ``output_result_file`` JSON after the future + resolves. +* ``inline=True`` -- the structure is embedded in the payload via + ``atoms_to_atomsdata`` (Globus Compute, where the worker has no + access to the laptop FS). The worker materialises the structure in a + temp dir, runs MACE, then attaches the on-disk JSON back to the + result as ``full_output`` (see ``mace_mcp_hpc.py:127-131``). The demo + reads from ``raw["full_output"]``. +""" + +from __future__ import annotations + +import csv +import json +import os +from pathlib import Path +from typing import Any + +MOLECULE_NAMES: list[str] = ["water", "methane", "ammonia", "co2", "ethanol"] +_HERE = Path(__file__).resolve().parent +_STRUCTURES_DIR = _HERE / "structures" + + +def molecule_xyz_path(name: str) -> Path: + """Absolute path to the .xyz fixture for *name*.""" + p = _STRUCTURES_DIR / f"{name}.xyz" + if not p.is_file(): + raise FileNotFoundError(f"Missing structure fixture: {p}") + return p + + +def structures_dir() -> Path: + """Directory holding the per-molecule .xyz fixtures.""" + return _STRUCTURES_DIR + + +def build_thermo_job( + name: str, + *, + device: str, + output_dir: Path, + inline: bool, + model: str = "medium-mpa-0", + temperature: float = 298.15, + pressure: float = 101325.0, + fmax: float = 0.01, + steps: int = 200, +) -> dict: + """Build the job dict consumed by ``_mace_worker`` for one molecule. + + For ``inline=True`` the structure is embedded and the + ``output_result_file`` is left relative so the worker writes into + its own temp dir (and the on-disk JSON is shipped back to the + caller via the ``full_output`` key). + """ + xyz = molecule_xyz_path(name) + job: dict[str, Any] = { + "input_structure_file": str(xyz), + "driver": "thermo", + "model": model, + "device": device, + "temperature": temperature, + "pressure": pressure, + "fmax": fmax, + "steps": steps, + "optimizer": "lbfgs", + } + if inline: + # Worker resolves the (relative) output path against its own + # tempdir -- see mace_mcp_hpc._mace_worker:117-120. + job["output_result_file"] = f"{name}_thermo.json" + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(xyz)) + job["inline_structure"] = atoms_to_atomsdata(atoms).model_dump() + else: + Path(output_dir).mkdir(parents=True, exist_ok=True) + job["output_result_file"] = str( + (Path(output_dir) / f"{name}_thermo.json").resolve() + ) + return job + + +def _read_full_output(raw: dict, job: dict, *, inline: bool) -> dict: + """Return the full ASEOutputSchema dict for one finished job. + + Inline jobs carry the JSON back inline via ``full_output``. + Non-inline jobs leave it on the shared filesystem at + ``job["output_result_file"]``. + """ + if inline and isinstance(raw.get("full_output"), dict): + return raw["full_output"] + out_file = job.get("output_result_file") + if out_file and os.path.isfile(out_file): + with open(out_file) as fh: + return json.load(fh) + return {} + + +def _extract_properties(name: str, raw: dict, job: dict, *, inline: bool) -> dict: + """Pull the chemistry summary fields out of one job's result.""" + full = _read_full_output(raw, job, inline=inline) + thermo = full.get("thermochemistry") or {} + vib = full.get("vibrational_frequencies") or {} + return { + "molecule": name, + "status": raw.get("status", "?"), + "n_atoms": len(full.get("final_structure", {}).get("numbers", [])) + if isinstance(full.get("final_structure"), dict) + else None, + "energy_eV": full.get("single_point_energy"), + "enthalpy_eV": thermo.get("enthalpy"), + "entropy_eV_per_K": thermo.get("entropy"), + "gibbs_free_energy_eV": thermo.get("gibbs_free_energy"), + "n_frequencies": ( + len(vib.get("frequencies", [])) + if isinstance(vib, dict) and isinstance(vib.get("frequencies"), list) + else None + ), + "converged": full.get("converged"), + "wall_time_s": full.get("wall_time"), + } + + +def submit_and_collect( + backend, + molecule_names: list[str] | None = None, + *, + device: str, + output_dir: Path | str, + inline: bool, + timeout: float = 6000.0, +) -> list[dict]: + """Submit one MACE thermo job per molecule, gather and summarise. + + Returns a list of per-molecule property dicts in submission order. + Raises if any future fails -- demos should *fail loud*, not swallow. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + names = molecule_names or MOLECULE_NAMES + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + jobs = [ + build_thermo_job(name, device=device, output_dir=output_dir, inline=inline) + for name in names + ] + tasks = [ + TaskSpec( + task_id=f"demo-thermo-{name}", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + for name, job in zip(names, jobs) + ] + print( + f"\nSubmitting {len(tasks)} thermo jobs to backend={type(backend).__name__} " + f"(device={device}, inline={inline})..." + ) + futures = backend.submit_batch(tasks) + + results: list[dict] = [] + for name, job, fut in zip(names, jobs, futures): + print(f" waiting on {name}...", flush=True) + raw = fut.result(timeout=timeout) + if not isinstance(raw, dict): + raise RuntimeError(f"{name}: non-dict result {type(raw).__name__}: {raw!r}") + if raw.get("status") != "success": + raise RuntimeError(f"{name}: backend returned status={raw.get('status')!r}: {raw}") + results.append(_extract_properties(name, raw, job, inline=inline)) + return results + + +def write_csv(results: list[dict], csv_path: Path | str) -> Path: + """Write the property table to *csv_path*. Returns the path.""" + csv_path = Path(csv_path) + csv_path.parent.mkdir(parents=True, exist_ok=True) + if not results: + csv_path.write_text("") + return csv_path + fieldnames = list(results[0].keys()) + with open(csv_path, "w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(results) + return csv_path + + +def print_summary(results: list[dict], title: str = "") -> None: + """Print a fixed-width table of the screening results.""" + if title: + print(f"\n=== {title} ===") + if not results: + print("(no results)") + return + header = ( + f"{'molecule':<10} {'energy/eV':>12} {'enthalpy/eV':>13} " + f"{'S/(eV/K)':>12} {'G/eV':>12} {'#freqs':>7} {'wall/s':>8} {'conv':>5}" + ) + print(header) + print("-" * len(header)) + + def fmt(val, w, p=4): + if val is None: + return f"{'-':>{w}}" + if isinstance(val, float): + return f"{val:>{w}.{p}f}" + return f"{val!s:>{w}}" + + for r in results: + print( + f"{r['molecule']:<10} " + f"{fmt(r.get('energy_eV'), 12)} " + f"{fmt(r.get('enthalpy_eV'), 13)} " + f"{fmt(r.get('entropy_eV_per_K'), 12, 6)} " + f"{fmt(r.get('gibbs_free_energy_eV'), 12)} " + f"{fmt(r.get('n_frequencies'), 7, 0)} " + f"{fmt(r.get('wall_time_s'), 8, 1)} " + f"{fmt(r.get('converged'), 5)}" + ) + print() + + +def agent_prompt(device: str = "cpu") -> str: + """Standard natural-language prompt used by all *_agent.py demos. + + The structure paths reference the demo's own ``structures/`` so the + agent can call ``run_mace_single`` directly without staging. + Replace the file paths if you adapt this for a different layout. + """ + files = ", ".join(str(molecule_xyz_path(n)) for n in MOLECULE_NAMES) + return ( + f"Using the MACE tool with driver='thermo', model='medium-mpa-0', " + f"device='{device}', temperature=298.15 K, pressure=101325 Pa, " + f"compute thermochemistry for each of these five molecules:\n" + f" - water: {molecule_xyz_path('water')}\n" + f" - methane: {molecule_xyz_path('methane')}\n" + f" - ammonia: {molecule_xyz_path('ammonia')}\n" + f" - CO2: {molecule_xyz_path('co2')}\n" + f" - ethanol: {molecule_xyz_path('ethanol')}\n" + f"Call run_mace_single once per molecule (do not batch them yourself). " + f"For each result, retrieve the optimized electronic energy, enthalpy, " + f"entropy and Gibbs free energy by reading the output JSON via " + f"extract_output_json. After all five complete, report a markdown table " + f"with columns: molecule, energy (eV), H (eV), G (eV), then a one-line " + f"observation about which molecule has the lowest Gibbs free energy.\n\n" + f"(Structure paths for reference: {files})" + ) diff --git a/scripts/demo/demo_ensemble_launcher_in_job_agent.py b/scripts/demo/demo_ensemble_launcher_in_job_agent.py new file mode 100644 index 00000000..a059d070 --- /dev/null +++ b/scripts/demo/demo_ensemble_launcher_in_job_agent.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +"""Agent + MCP + EnsembleLauncher demo on an HPC compute node. + +LLM agent on the compute node drives a local ``mace_mcp_hpc`` +subprocess whose backend is ``ensemble_launcher``. Same 5-molecule +thermo screen as the direct demo, but driven natural-language. + +Run inside ``qsub -I`` on Polaris/Aurora. LLM API key required. + +Run:: + + export COMPUTE_SYSTEM=polaris + export OPENAI_API_KEY=... + python scripts/demo/demo_ensemble_launcher_in_job_agent.py --model gpt-4o-mini +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +async def amain(model: str, system: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + python = sys.executable + env = { + "CHEMGRAPH_EXECUTION_BACKEND": "ensemble_launcher", + "COMPUTE_SYSTEM": system, + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + "PBS_NODEFILE": os.environ.get("PBS_NODEFILE", ""), + "PBS_O_WORKDIR": os.environ.get("PBS_O_WORKDIR", ""), + } + server_configs = { + "ChemGraph MACE (EnsembleLauncher)": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": env, + }, + } + + print(f"LLM model: {model}") + print(f"System: {system}") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context( + client.session("ChemGraph MACE (EnsembleLauncher)") + ) + tools = await load_mcp_tools(session) + print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt-4o-mini") + parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) + parser.add_argument("--device", default=None) + parser.add_argument("--query", default=None) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + if not os.environ.get("PBS_NODEFILE"): + _abort("PBS_NODEFILE not set. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora"): + _abort(f"Unsupported --system: {system!r}") + device = args.device or ("xpu" if system == "aurora" else "cuda") + query = args.query or agent_prompt(device=device) + asyncio.run(amain(args.model, system, device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_ensemble_launcher_in_job_direct.py b/scripts/demo/demo_ensemble_launcher_in_job_direct.py new file mode 100644 index 00000000..dbf2de70 --- /dev/null +++ b/scripts/demo/demo_ensemble_launcher_in_job_direct.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +"""Direct EnsembleLauncherBackend demo on an HPC compute node. + +5-molecule thermo screen via the EnsembleLauncher orchestrator, +managed mode (the backend starts and tears down the orchestrator +itself). Must run inside ``qsub -I`` on Polaris or Aurora, in a venv +where ``ensemble_launcher`` is installed (built from source by +``scripts/hpc_setup/install_remote.sh``). + +Run:: + + export COMPUTE_SYSTEM=polaris + python scripts/demo/demo_ensemble_launcher_in_job_direct.py + python scripts/demo/demo_ensemble_launcher_in_job_direct.py --device xpu +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) + parser.add_argument("--device", default=None) + parser.add_argument("--output-dir", default="demo_el_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument("--timeout", type=float, default=6000.0) + args = parser.parse_args() + + if not os.environ.get("PBS_NODEFILE"): + _abort("PBS_NODEFILE not set. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora"): + _abort(f"Unsupported --system: {system!r}") + device = args.device or ("xpu" if system == "aurora" else "cuda") + + try: + import ensemble_launcher # noqa: F401 + except ImportError as exc: + _abort( + f"ensemble_launcher import failed: {exc}. " + "Install via scripts/hpc_setup/install_remote.sh on HPC." + ) + + print(f"system={system} device={device} mode=managed") + + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="ensemble_launcher", system=system) + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=device, + output_dir=args.output_dir, + inline=False, + timeout=args.timeout, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_el.csv") + print_summary( + results, + title=f"EnsembleLauncher thermo screen (system={system}, device={device})", + ) + print(f"CSV: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_compute_agent.py b/scripts/demo/demo_globus_compute_agent.py new file mode 100644 index 00000000..bad23f2f --- /dev/null +++ b/scripts/demo/demo_globus_compute_agent.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +"""Agent + MCP + Globus Compute demo: 5-molecule thermo screen on remote HPC. + +LLM agent on the laptop, MCP server (``mace_mcp_hpc``) as a local +subprocess, work dispatched to a Globus Compute endpoint on Polaris / +Aurora. Mirrors ``scripts/globus_compute_example/run_agent_mcp_remote.py`` +but with a structured 5-molecule chemistry workload instead of a free +prompt. + +Prereqs:: + + export GLOBUS_COMPUTE_ENDPOINT_ID="" + export OPENAI_API_KEY=... # or other model creds + +Run:: + + python scripts/demo/demo_globus_compute_agent.py --model gpt-4o-mini + python scripts/demo/demo_globus_compute_agent.py --device xpu --model argo:gpt-4o +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +async def amain(model: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + endpoint = os.environ["GLOBUS_COMPUTE_ENDPOINT_ID"] + os.environ["CHEMGRAPH_EXECUTION_BACKEND"] = "globus_compute" + + python = sys.executable + server_configs = { + "ChemGraph MACE (Globus Compute)": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": { + "CHEMGRAPH_EXECUTION_BACKEND": "globus_compute", + "GLOBUS_COMPUTE_ENDPOINT_ID": endpoint, + # Forward optional knobs if set. + **({"CG_AMQP_PORT": os.environ["CG_AMQP_PORT"]} if "CG_AMQP_PORT" in os.environ else {}), + **({"COMPUTE_SYSTEM": os.environ["COMPUTE_SYSTEM"]} if "COMPUTE_SYSTEM" in os.environ else {}), + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + }, + }, + } + + print(f"LLM model: {model}") + print(f"GC endpoint: {endpoint[:8]}... ({os.environ.get('COMPUTE_SYSTEM', '?')})") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context( + client.session("ChemGraph MACE (Globus Compute)") + ) + tools = await load_mcp_tools(session) + print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model", + default="gpt-4o-mini", + help="LLM model name (default: gpt-4o-mini)", + ) + parser.add_argument( + "--device", + default=os.environ.get("CG_DEMO_DEVICE", "cuda"), + help="MACE device on the remote endpoint (default: cuda; use xpu on Aurora)", + ) + parser.add_argument("--query", default=None, help="Override the default query") + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + if not os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID"): + print("ERROR: export GLOBUS_COMPUTE_ENDPOINT_ID= first.") + sys.exit(2) + + query = args.query or agent_prompt(device=args.device) + asyncio.run(amain(args.model, args.device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_compute_direct.py b/scripts/demo/demo_globus_compute_direct.py new file mode 100644 index 00000000..48aceeee --- /dev/null +++ b/scripts/demo/demo_globus_compute_direct.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +"""Direct GlobusComputeBackend demo: thermo screen on a remote HPC endpoint. + +Submits 5 MACE ``driver="thermo"`` jobs to a Globus Compute endpoint +(Polaris/Aurora/etc.) and gathers results back to the laptop. The +structures are embedded inline (``inline=True``) so the workers don't +need to read anything from the laptop's filesystem. + +Prereq env vars:: + + export GLOBUS_COMPUTE_ENDPOINT_ID="" # required + export COMPUTE_SYSTEM=polaris # optional, for logging + # export CG_AMQP_PORT=443 # if 5671 blocked (Aurora) + +Run:: + + python scripts/demo/demo_globus_compute_direct.py + python scripts/demo/demo_globus_compute_direct.py --device xpu # Aurora + python scripts/demo/demo_globus_compute_direct.py --molecules water methane +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output-dir", default="demo_globus_compute_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument( + "--device", + default=os.environ.get("CG_DEMO_DEVICE", "cuda"), + help="MACE device on the remote endpoint (default: cuda; use xpu on Aurora)", + ) + parser.add_argument( + "--amqp-port", + type=int, + default=int(os.environ.get("CG_AMQP_PORT", "0")) or None, + help="Override AMQP port (set to 443 if 5671 is blocked, e.g. Aurora)", + ) + parser.add_argument( + "--timeout", + type=float, + default=6000.0, + help="Per-task timeout in seconds (default 6000)", + ) + args = parser.parse_args() + + if not os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID"): + print("ERROR: export GLOBUS_COMPUTE_ENDPOINT_ID= first.") + sys.exit(2) + + from chemgraph.execution.config import get_backend + + backend_kwargs: dict = {} + if args.amqp_port: + backend_kwargs["amqp_port"] = args.amqp_port + + backend = get_backend(backend_name="globus_compute", **backend_kwargs) + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=args.device, + output_dir=args.output_dir, + inline=True, + timeout=args.timeout, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_globus_compute.csv") + print_summary( + results, + title=( + f"Globus Compute thermo screen " + f"(system={os.environ.get('COMPUTE_SYSTEM', '?')}, device={args.device})" + ), + ) + print(f"CSV written to: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_transfer_agent.py b/scripts/demo/demo_globus_transfer_agent.py new file mode 100644 index 00000000..865c32b7 --- /dev/null +++ b/scripts/demo/demo_globus_transfer_agent.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +"""Agent + MCP + Globus Transfer + Globus Compute demo. + +LLM agent on the laptop drives a local ``mace_mcp_hpc`` subprocess. +With both Compute and Transfer env vars set, the MCP server +auto-registers the transfer tools (``mace_mcp_hpc.py:310-313``). The +agent is told to (a) stage the demo's structures to the remote +collection via ``transfer_files``, then (b) call ``run_mace_ensemble`` +with ``remote_structure_directory`` so MACE runs on the pre-staged +files. Finally it reports a Gibbs-energy table. + +Prereqs:: + + export GLOBUS_COMPUTE_ENDPOINT_ID=... + export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID=... + export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID=... + export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging + export OPENAI_API_KEY=... # or any supported model + +Run:: + + python scripts/demo/demo_globus_transfer_agent.py --model gpt-4o-mini +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import MOLECULE_NAMES, structures_dir + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +_TRANSFER_AGENT_PROMPT_TMPL = """\ +The following five molecule structure files live on the local filesystem: +{listing} + +Workflow: +1. Call `transfer_files` with `source_paths` set to that list of absolute + paths (you may pass them as one batch) to stage them on the remote + HPC endpoint. Use `wait=true` so the call blocks until SUCCEEDED. +2. From the transfer result, take the `remote_directory` value. +3. Call `run_mace_ensemble` with: + - remote_structure_directory = + - driver = "thermo" + - model = "medium-mpa-0" + - device = "{device}" + - temperature = 298.15 + - pressure = 101325 + This dispatches one MACE thermo job per file via Globus Compute. +4. If `run_mace_ensemble` returns a `batch_id`, poll `check_job_status` + until completed, then call `get_job_results` to retrieve the per-file + energies and thermochemistry. +5. Report a markdown table with columns: molecule | electronic energy (eV) | + Gibbs free energy (eV). Add a one-line observation about which + molecule has the lowest Gibbs free energy. +""" + + +def _agent_prompt(device: str) -> str: + paths = [str(structures_dir() / f"{n}.xyz") for n in MOLECULE_NAMES] + listing = "\n".join(f" - {p}" for p in paths) + return _TRANSFER_AGENT_PROMPT_TMPL.format(listing=listing, device=device) + + +async def amain(model: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + python = sys.executable + forwarded = { + "CHEMGRAPH_EXECUTION_BACKEND": "globus_compute", + "GLOBUS_COMPUTE_ENDPOINT_ID": os.environ["GLOBUS_COMPUTE_ENDPOINT_ID"], + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID": os.environ["GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID"], + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID": os.environ["GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID"], + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH": os.environ["GLOBUS_TRANSFER_DESTINATION_BASE_PATH"], + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + } + server_configs = { + "ChemGraph MACE+Transfer": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": forwarded, + }, + } + + print(f"LLM model: {model}") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context(client.session("ChemGraph MACE+Transfer")) + tools = await load_mcp_tools(session) + names = [t.name for t in tools] + print(f"Loaded {len(tools)} MCP tools: {names}\n") + if "transfer_files" not in names: + print( + "WARNING: transfer_files not registered. Did you export the " + "GLOBUS_TRANSFER_* env vars? mace_mcp_hpc only registers the " + "transfer tools when a transfer manager is configured." + ) + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt-4o-mini") + parser.add_argument("--device", default=os.environ.get("CG_DEMO_DEVICE", "cuda")) + parser.add_argument("--query", default=None) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + required = ( + "GLOBUS_COMPUTE_ENDPOINT_ID", + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + ) + missing = [v for v in required if not os.environ.get(v)] + if missing: + print(f"ERROR: missing env vars: {', '.join(missing)}") + sys.exit(2) + + query = args.query or _agent_prompt(args.device) + asyncio.run(amain(args.model, args.device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_transfer_direct.py b/scripts/demo/demo_globus_transfer_direct.py new file mode 100644 index 00000000..a0e869fb --- /dev/null +++ b/scripts/demo/demo_globus_transfer_direct.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +"""Direct Globus Transfer + Globus Compute demo. + +Stages the 5 .xyz fixtures to a remote HPC collection via Globus +Transfer, then runs MACE ``driver="thermo"`` on each pre-staged file +through Globus Compute. Workers read the structures from the HPC +filesystem (remote-path mode), not embedded inline -- this exercises +``mace_mcp_hpc._mace_worker``'s ``remote_structure_file`` branch +(`mace_mcp_hpc.py:92-99`). + +Prereq env vars:: + + export GLOBUS_COMPUTE_ENDPOINT_ID=... + export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID=... # laptop GCP collection + export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID=... # HPC collection + export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging + +First run prompts for Globus OAuth; the token caches at +``~/.globus/chemgraph_transfer_tokens.json``. + +Run:: + + python scripts/demo/demo_globus_transfer_direct.py + python scripts/demo/demo_globus_transfer_direct.py --device xpu +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + _extract_properties, + molecule_xyz_path, + print_summary, + write_csv, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output-dir", default="demo_globus_transfer_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument("--device", default=os.environ.get("CG_DEMO_DEVICE", "cuda")) + parser.add_argument( + "--amqp-port", + type=int, + default=int(os.environ.get("CG_AMQP_PORT", "0")) or None, + ) + parser.add_argument( + "--transfer-timeout", + type=float, + default=6000.0, + help="Seconds to wait for the Globus Transfer task (default 6000).", + ) + parser.add_argument( + "--compute-timeout", + type=float, + default=6000.0, + help="Seconds to wait for each MACE thermo task (default 6000).", + ) + args = parser.parse_args() + + required = ( + "GLOBUS_COMPUTE_ENDPOINT_ID", + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + ) + missing = [v for v in required if not os.environ.get(v)] + if missing: + print(f"ERROR: missing env vars: {', '.join(missing)}") + sys.exit(2) + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend, get_transfer_manager + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + # ── 1. Stage all 5 .xyz files to the remote HPC collection ───────── + print("\n[1/3] Submitting Globus Transfer for fixtures...") + tm = get_transfer_manager() + if tm is None: + print("ERROR: get_transfer_manager() returned None.") + sys.exit(2) + + local_paths = [str(molecule_xyz_path(n)) for n in args.molecules] + transfer = tm.transfer_files( + local_paths=local_paths, + label=f"chemgraph-demo-{int(time.time())}", + ) + print(f" task_id = {transfer.task_id}") + print(f" remote_dir = {transfer.remote_directory}") + print(f" waiting up to {args.transfer_timeout}s for SUCCEEDED...") + status = tm.wait_for_transfer( + transfer.task_id, timeout=args.transfer_timeout, poll_interval=5 + ) + if status.get("status") != "SUCCEEDED": + print(f"ERROR: transfer did not succeed: {status}") + sys.exit(1) + print( + f" done: {status['files_transferred']}/{status['files']} files, " + f"{status['bytes_transferred']} bytes" + ) + + # ── 2. Submit one MACE thermo task per pre-staged file ───────────── + print(f"\n[2/3] Dispatching {len(args.molecules)} MACE thermo jobs via Globus Compute...") + backend_kwargs = {} + if args.amqp_port: + backend_kwargs["amqp_port"] = args.amqp_port + backend = get_backend(backend_name="globus_compute", **backend_kwargs) + + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + jobs = [] + tasks = [] + for name in args.molecules: + remote_xyz = f"{transfer.remote_directory}/{name}.xyz" + job = { + # input_structure_file is ignored when remote_structure_file is set + # (mace_mcp_hpc._mace_worker:92-99 overrides it). Pass a sentinel. + "input_structure_file": f"remote::{name}", + "remote_structure_file": remote_xyz, + "output_result_file": f"{name}_thermo.json", + "driver": "thermo", + "model": "medium-mpa-0", + "device": args.device, + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 200, + "optimizer": "lbfgs", + } + jobs.append(job) + tasks.append( + TaskSpec( + task_id=f"demo-tr-{name}", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + + futures = backend.submit_batch(tasks) + + results = [] + try: + for name, job, fut in zip(args.molecules, jobs, futures): + print(f" waiting on {name}...", flush=True) + raw = fut.result(timeout=args.compute_timeout) + if not isinstance(raw, dict) or raw.get("status") != "success": + raise RuntimeError(f"{name}: backend returned {raw!r}") + # Remote-path mode: full_output is NOT attached (only inline triggers + # the JSON round-trip). Convergence + thermo cannot be read here + # without staging the JSON back -- see the note in the summary table. + results.append(_extract_properties(name, raw, job, inline=True)) + finally: + backend.shutdown() + + # ── 3. Report ────────────────────────────────────────────────────── + print(f"\n[3/3] Results (remote-path mode -- full JSON stays on the HPC):") + print_summary( + results, + title=f"Globus Transfer + Compute thermo screen (device={args.device})", + ) + csv_path = write_csv(results, output_dir / "demo_globus_transfer.csv") + print(f"CSV (per-call status; thermo values blank in remote-path mode): {csv_path}") + print( + f"\nNote: workers wrote full JSON results under {transfer.remote_directory} " + f"on the HPC. To pull them back, you can run another Globus Transfer " + f"job in the reverse direction." + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_local_agent.py b/scripts/demo/demo_local_agent.py new file mode 100644 index 00000000..bb435150 --- /dev/null +++ b/scripts/demo/demo_local_agent.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +"""Agent + MCP demo on LocalBackend: LLM screens 5 molecules locally. + +Spawns ``chemgraph.mcp.mace_mcp_hpc`` as a local subprocess wired to +the LocalBackend, then asks the ChemGraph LLM agent to compute +thermochemistry on water / methane / ammonia / CO2 / ethanol via the +MCP ``run_mace_single`` tool and report a markdown table. + +Prereq: an LLM API key for the chosen model (e.g. ``OPENAI_API_KEY``, +``ANTHROPIC_API_KEY``, Argo gateway tokens via ``inference_auth_token.py``, +etc.) and ``langchain-mcp-adapters`` installed (already a dep). + +Run:: + + export OPENAI_API_KEY=... + python scripts/demo/demo_local_agent.py --model gpt-4o-mini + python scripts/demo/demo_local_agent.py --model argo:gpt-4o +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +async def amain(model: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + # Make sure the spawned MCP subprocess uses LocalBackend. + os.environ["CHEMGRAPH_EXECUTION_BACKEND"] = "local" + + python = sys.executable + server_configs = { + "ChemGraph MACE": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": { + "CHEMGRAPH_EXECUTION_BACKEND": "local", + # Forward the user's PATH/HOME so the subprocess can resolve + # the venv's chemgraph + mace_torch installs. + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + }, + }, + } + + print(f"LLM model: {model}") + print(f"MCP server: mace_mcp_hpc (stdio subprocess, CHEMGRAPH_EXECUTION_BACKEND=local)") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context(client.session("ChemGraph MACE")) + tools = await load_mcp_tools(session) + tool_names = [t.name for t in tools] + print(f"Loaded {len(tools)} MCP tools: {tool_names}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model", + default="argo:gpt-4o", + help="LLM model name (default: argo:gpt-4o). Try argo:gpt-4o, claude-sonnet-4-6, gpt-4o.", + ) + parser.add_argument( + "--device", + default="cpu", + help="MACE device passed to the agent prompt (default: cpu)", + ) + parser.add_argument( + "--query", + default=None, + help="Override the natural-language query (default: 5-molecule thermo screen)", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Increase verbosity (-v INFO, -vv DEBUG).", + ) + args = parser.parse_args() + + query = args.query or agent_prompt(device=args.device) + asyncio.run(amain(args.model, args.device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_local_direct.py b/scripts/demo/demo_local_direct.py new file mode 100644 index 00000000..d3148538 --- /dev/null +++ b/scripts/demo/demo_local_direct.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +"""Direct LocalBackend demo: thermochemistry screen of 5 small molecules. + +Runs entirely on the laptop, no LLM, no HPC. Submits 5 MACE +``driver="thermo"`` jobs to a ``LocalBackend`` ProcessPoolExecutor, +gathers the results, prints a property table, and writes a CSV. + +Run:: + + python scripts/demo/demo_local_direct.py + python scripts/demo/demo_local_direct.py --output-dir /tmp/cg_demo +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +# Make _demo_chemistry importable when run from any cwd. +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--output-dir", + default="demo_local_out", + help="Where per-molecule JSON + CSV land (default: ./demo_local_out)", + ) + parser.add_argument( + "--molecules", + nargs="+", + default=MOLECULE_NAMES, + help=f"Subset to run (default: {MOLECULE_NAMES})", + ) + parser.add_argument( + "--device", + default="cpu", + help="MACE device (default: cpu; local Mac/CPU)", + ) + args = parser.parse_args() + + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="local", system="local") + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=args.device, + output_dir=args.output_dir, + inline=False, + timeout=1200, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_local.csv") + print_summary(results, title=f"Local backend thermo screen ({args.device})") + print(f"CSV written to: {csv_path}") + print(f"Per-molecule JSON written under: {Path(args.output_dir).resolve()}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_parsl_in_job_agent.py b/scripts/demo/demo_parsl_in_job_agent.py new file mode 100644 index 00000000..4aab2f46 --- /dev/null +++ b/scripts/demo/demo_parsl_in_job_agent.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +"""Agent + MCP + Parsl demo on an HPC compute node. + +LLM agent on the compute node drives a local ``mace_mcp_hpc`` +subprocess whose backend is ``parsl`` configured for Polaris or +Aurora. The agent uses ``run_mace_single`` to compute thermochemistry +for each of the 5 molecules and reports a markdown table. + +Must run inside ``qsub -I`` on Polaris/Aurora. LLM API key required. + +Run:: + + export COMPUTE_SYSTEM=polaris + export OPENAI_API_KEY=... + python scripts/demo/demo_parsl_in_job_agent.py --model gpt-4o-mini + python scripts/demo/demo_parsl_in_job_agent.py --device xpu --model argo:gpt-4o +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +async def amain(model: str, system: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + python = sys.executable + env = { + "CHEMGRAPH_EXECUTION_BACKEND": "parsl", + "COMPUTE_SYSTEM": system, + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + "PBS_NODEFILE": os.environ.get("PBS_NODEFILE", ""), + "PBS_O_WORKDIR": os.environ.get("PBS_O_WORKDIR", ""), + } + server_configs = { + "ChemGraph MACE (Parsl)": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": env, + }, + } + + print(f"LLM model: {model}") + print(f"System: {system}") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context(client.session("ChemGraph MACE (Parsl)")) + tools = await load_mcp_tools(session) + print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt-4o-mini") + parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) + parser.add_argument("--device", default=None) + parser.add_argument("--query", default=None) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + if not os.environ.get("PBS_NODEFILE"): + _abort("PBS_NODEFILE not set. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora"): + _abort(f"Unsupported --system: {system!r}") + device = args.device or ("xpu" if system == "aurora" else "cuda") + query = args.query or agent_prompt(device=device) + asyncio.run(amain(args.model, system, device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_parsl_in_job_direct.py b/scripts/demo/demo_parsl_in_job_direct.py new file mode 100644 index 00000000..3b4e749d --- /dev/null +++ b/scripts/demo/demo_parsl_in_job_direct.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +"""Direct ParslBackend demo on an HPC compute node: 5-molecule thermo screen. + +Must run inside a PBS interactive allocation on Polaris or Aurora:: + + # Polaris + qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle + # Aurora + qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare + +Inside the allocation:: + + module load conda # or `module load frameworks` on Aurora + source /bin/activate + export COMPUTE_SYSTEM=polaris # or aurora + cd + python scripts/demo/demo_parsl_in_job_direct.py + python scripts/demo/demo_parsl_in_job_direct.py --device xpu # Aurora +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--system", + default=os.environ.get("COMPUTE_SYSTEM"), + help="polaris | aurora (default: $COMPUTE_SYSTEM)", + ) + parser.add_argument("--device", default=None, help="cuda (Polaris) | xpu (Aurora)") + parser.add_argument("--output-dir", default="demo_parsl_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument( + "--run-dir", + default=None, + help="Parsl run_dir (default: $PBS_O_WORKDIR/parsl_demo_runs or ./parsl_demo_runs).", + ) + parser.add_argument("--timeout", type=float, default=6000.0) + args = parser.parse_args() + + pbs_nodefile = os.environ.get("PBS_NODEFILE") + if not pbs_nodefile or not Path(pbs_nodefile).is_file(): + _abort("PBS_NODEFILE not set or missing. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora"): + _abort(f"Unsupported --system: {system!r}") + device = args.device or ("xpu" if system == "aurora" else "cuda") + + run_dir = args.run_dir or os.environ.get("PBS_O_WORKDIR") + if run_dir: + run_dir = str(Path(run_dir) / "parsl_demo_runs") + else: + run_dir = str(Path.cwd() / "parsl_demo_runs") + Path(run_dir).mkdir(parents=True, exist_ok=True) + + print(f"system={system} device={device} run_dir={run_dir}") + + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="parsl", system=system, run_dir=run_dir) + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=device, + output_dir=args.output_dir, + inline=False, + timeout=args.timeout, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_parsl.csv") + print_summary(results, title=f"Parsl thermo screen (system={system}, device={device})") + print(f"CSV: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/structures/ammonia.xyz b/scripts/demo/structures/ammonia.xyz new file mode 100644 index 00000000..e4254a0f --- /dev/null +++ b/scripts/demo/structures/ammonia.xyz @@ -0,0 +1,6 @@ +4 +ammonia +N 0.0000000 0.0000000 0.0000000 +H 0.9400000 0.0000000 -0.3300000 +H -0.4700000 0.8140000 -0.3300000 +H -0.4700000 -0.8140000 -0.3300000 diff --git a/scripts/demo/structures/co2.xyz b/scripts/demo/structures/co2.xyz new file mode 100644 index 00000000..0ccb5a2c --- /dev/null +++ b/scripts/demo/structures/co2.xyz @@ -0,0 +1,5 @@ +3 +co2 +C 0.0000000 0.0000000 0.0000000 +O 1.1600000 0.0000000 0.0000000 +O -1.1600000 0.0000000 0.0000000 diff --git a/scripts/demo/structures/ethanol.xyz b/scripts/demo/structures/ethanol.xyz new file mode 100644 index 00000000..594fbd6d --- /dev/null +++ b/scripts/demo/structures/ethanol.xyz @@ -0,0 +1,11 @@ +9 +ethanol +C -0.7480000 0.0150000 -0.0240000 +C 0.6850000 -0.4020000 0.2730000 +O 1.5670000 0.5140000 -0.3270000 +H -0.9270000 0.0430000 -1.1050000 +H -1.4340000 -0.7030000 0.4430000 +H -0.9480000 1.0200000 0.3500000 +H 0.8400000 -0.4500000 1.3580000 +H 0.8640000 -1.3950000 -0.1490000 +H 2.4640000 0.1980000 -0.1140000 diff --git a/scripts/demo/structures/methane.xyz b/scripts/demo/structures/methane.xyz new file mode 100644 index 00000000..89690609 --- /dev/null +++ b/scripts/demo/structures/methane.xyz @@ -0,0 +1,7 @@ +5 +methane +C 0.0000000 0.0000000 0.0000000 +H 0.6290000 0.6290000 0.6290000 +H -0.6290000 -0.6290000 0.6290000 +H -0.6290000 0.6290000 -0.6290000 +H 0.6290000 -0.6290000 -0.6290000 diff --git a/scripts/demo/structures/water.xyz b/scripts/demo/structures/water.xyz new file mode 100644 index 00000000..03120dab --- /dev/null +++ b/scripts/demo/structures/water.xyz @@ -0,0 +1,5 @@ +3 +water +O 0.0000000 0.0000000 0.0000000 +H 0.7570000 0.5860000 0.0000000 +H -0.7570000 0.5860000 0.0000000 diff --git a/scripts/smoke/README.md b/scripts/smoke/README.md new file mode 100644 index 00000000..e52f2ee3 --- /dev/null +++ b/scripts/smoke/README.md @@ -0,0 +1,125 @@ +# ChemGraph execution-layer smoke tests + +Self-contained scripts that exercise each ExecutionBackend live and emit +`[PASS]` / `[FAIL]` per check. Exit code is `0` only if every check passes +(`2` if required env vars are missing → "skip"). Use them for one-shot +validation after install, after a rebase, or before running real workloads. + +These are *not* pytest tests — they hit live infrastructure (process pools, +PBS allocations, Globus Compute endpoints, Globus Transfer). The mocked +unit suite still lives at `tests/test_execution.py`. + +## Script matrix + +| Script | Runs where | Backends | Live deps | +|--------|------------|----------|-----------| +| [`smoke_local.py`](smoke_local.py) | laptop | `local` | none | +| [`smoke_globus_compute.py`](smoke_globus_compute.py) | laptop | `globus_compute` | live GC endpoint | +| [`smoke_globus_transfer.py`](smoke_globus_transfer.py) | laptop | `GlobusTransferManager` (+ optional `globus_compute` MCP) | Globus collections on both ends | +| [`smoke_parsl_in_job.py`](smoke_parsl_in_job.py) | inside `qsub -I` on Polaris/Aurora | `parsl` | PBS allocation | +| [`smoke_ensemble_launcher_in_job.py`](smoke_ensemble_launcher_in_job.py) | inside `qsub -I` on Polaris/Aurora | `ensemble_launcher` (managed + client-only) | PBS allocation, `ensemble_launcher` built from source | + +`_smoke_utils.py` holds shared helpers (`SmokeReporter`, picklable trivial +callables). `water.xyz` is the shared 3-atom fixture. + +## Environment-variable matrix + +| Variable | Required by | Notes | +|----------|-------------|-------| +| `GLOBUS_COMPUTE_ENDPOINT_ID` | `smoke_globus_compute.py`, `smoke_globus_transfer.py --with-mcp` | UUID printed by `globus-compute-endpoint start chemgraph-` | +| `GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID` | `smoke_globus_transfer.py` | Globus Connect Personal collection on the laptop | +| `GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID` | `smoke_globus_transfer.py` | HPC collection UUID (ALCF data portal) | +| `GLOBUS_TRANSFER_DESTINATION_BASE_PATH` | `smoke_globus_transfer.py` | e.g. `/eagle/projects/MyProj/staging` (Polaris), `/flare/projects/MyProj/staging` (Aurora) | +| `COMPUTE_SYSTEM` | `smoke_parsl_in_job.py`, `smoke_ensemble_launcher_in_job.py` | `polaris` or `aurora` | +| `PBS_NODEFILE` | both in-job scripts | Set automatically by PBS inside `qsub` — the scripts abort if missing | +| `CG_SMOKE_DEVICE` | optional, MACE device override | Defaults: `cuda` (Polaris/Globus), `xpu` (Aurora) | + +## Running + +### Laptop only (no creds) + +```bash +source .cg_env/bin/activate +python scripts/smoke/smoke_local.py # ~5s + first-run MACE model download +python scripts/smoke/smoke_local.py --quick # ~3s, skips MACE +``` + +### Laptop → live Globus Compute endpoint + +```bash +export GLOBUS_COMPUTE_ENDPOINT_ID="" +export COMPUTE_SYSTEM=polaris # or aurora +python scripts/smoke/smoke_globus_compute.py +python scripts/smoke/smoke_globus_compute.py --amqp 443 # Aurora (5671 blocked) +``` + +### Laptop → live Globus Transfer + +```bash +export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging +python scripts/smoke/smoke_globus_transfer.py # transfer only +python scripts/smoke/smoke_globus_transfer.py --with-mcp # also dispatch MACE ensemble in remote-path mode +``` + +First run triggers an OAuth flow; the token caches at +`~/.globus/chemgraph_transfer_tokens.json` for subsequent runs. + +### Inside a PBS allocation on Polaris + +```bash +qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle +# (now on the compute node) +module load conda +conda activate base +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=polaris +cd ~/chemgraph/ChemGraph + +python scripts/smoke/smoke_parsl_in_job.py +python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed +``` + +### Inside a PBS allocation on Aurora + +```bash +qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare +module load frameworks +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=aurora +cd ~/chemgraph/ChemGraph + +python scripts/smoke/smoke_parsl_in_job.py --device xpu +python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed --device xpu +``` + +### EnsembleLauncher client-only mode + +Exercises `EnsembleLauncherBackend(client_only=True, ...)` introduced in +commit `bc54083c`. Requires two shells on the same compute node: + +```bash +# Shell A — start the orchestrator +cd $PBS_O_WORKDIR +python -m ensemble_launcher \ + --system $COMPUTE_SYSTEM \ + --checkpoint-dir $PBS_O_WORKDIR/el_ckpt \ + --node-id 0 + +# Shell B — connect this client to it +python scripts/smoke/smoke_ensemble_launcher_in_job.py \ + --mode client-only \ + --checkpoint-dir $PBS_O_WORKDIR/el_ckpt \ + --node-id 0 +``` + +The client-only run leaves the orchestrator in Shell A running; stop it +there with `Ctrl-C` when done. + +## See also + +- `scripts/hpc_setup/README.md` — install ChemGraph + Globus Compute endpoint on Polaris/Aurora +- `scripts/hpc_setup/e2e_test_runbook.md` — tier-by-tier manual runbook (these smoke scripts are the automation around Tiers 1, 2, and the gap tests) +- `scripts/globus_compute_example/` — tutorial-style demonstrations (longer-form than the smoke scripts) +- `src/chemgraph/execution/` — the production code paths these scripts call diff --git a/scripts/smoke/_smoke_utils.py b/scripts/smoke/_smoke_utils.py new file mode 100644 index 00000000..492657f8 --- /dev/null +++ b/scripts/smoke/_smoke_utils.py @@ -0,0 +1,110 @@ +"""Shared helpers for the scripts/smoke/* test scripts. + +A tiny PASS/FAIL reporter so every script has the same output shape and +exit code semantics. No external dependencies beyond the stdlib. +""" + +from __future__ import annotations + +import sys +import time +import traceback +from contextlib import contextmanager +from pathlib import Path + + +class SmokeReporter: + def __init__(self, title: str) -> None: + self.title = title + self.passed = 0 + self.failed = 0 + self._t0 = time.monotonic() + print(f"\n=== {title} ===") + + @contextmanager + def check(self, name: str): + start = time.monotonic() + try: + yield + except Exception as exc: + elapsed = time.monotonic() - start + self.failed += 1 + print(f"[FAIL] {name} ({elapsed:.1f}s): {type(exc).__name__}: {exc}") + traceback.print_exc() + else: + elapsed = time.monotonic() - start + self.passed += 1 + print(f"[PASS] {name} ({elapsed:.1f}s)") + + def summary_and_exit(self) -> None: + total = self.passed + self.failed + wall = time.monotonic() - self._t0 + print( + f"\n--- {self.title}: {self.passed}/{total} passed, " + f"{self.failed} failed ({wall:.1f}s total) ---" + ) + sys.exit(0 if self.failed == 0 else 1) + + +def require_env(*names: str) -> dict[str, str]: + """Return a {name: value} dict for the listed env vars, or exit 2 if any + are missing. Use at the top of scripts that need credentials.""" + import os + + missing = [n for n in names if not os.environ.get(n)] + if missing: + print(f"[SKIP] Missing required env vars: {', '.join(missing)}") + print(" Export them and re-run.") + sys.exit(2) + return {n: os.environ[n] for n in names} + + +def water_xyz_path() -> Path: + """Absolute path to the shared water.xyz fixture.""" + return Path(__file__).resolve().parent / "water.xyz" + + +# ── module-level helpers picklable across process / globus boundaries ── + + +def trivial_add(a: int, b: int) -> int: + return a + b + + +def trivial_square(x: int) -> int: + return x * x + + +def trivial_hostname() -> str: + import socket + + return socket.gethostname() + + +def trivial_env_probe() -> dict: + import os + import sys + + info: dict = { + "hostname": __import__("socket").gethostname(), + "python": sys.version.split()[0], + "pid": os.getpid(), + "cwd": os.getcwd(), + } + try: + info["sched_affinity"] = sorted(os.sched_getaffinity(0)) + except (AttributeError, OSError): + info["sched_affinity"] = None + try: + import torch + + info["torch"] = torch.__version__ + info["cuda_devices"] = ( + torch.cuda.device_count() if torch.cuda.is_available() else 0 + ) + info["xpu_devices"] = ( + torch.xpu.device_count() if hasattr(torch, "xpu") and torch.xpu.is_available() else 0 + ) + except Exception as exc: + info["torch_error"] = str(exc) + return info diff --git a/scripts/smoke/smoke_ensemble_launcher_in_job.py b/scripts/smoke/smoke_ensemble_launcher_in_job.py new file mode 100644 index 00000000..968f2f4d --- /dev/null +++ b/scripts/smoke/smoke_ensemble_launcher_in_job.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python +"""Smoke test for EnsembleLauncherBackend on an HPC compute node. + +Must run **inside** a PBS interactive allocation on Polaris or Aurora, +in a venv where ``ensemble_launcher`` is installed (it is built from +source by ``scripts/hpc_setup/install_remote.sh`` -- PyPI wheels only +support Python <3.12). + +Two modes +--------- + +``--mode managed`` (default) + The script starts and tears down the EnsembleLauncher orchestrator + in-process via ``get_backend(backend_name="ensemble_launcher", ...)``. + +``--mode client-only`` *(exercises commit bc54083c)* + In a **second shell on the same compute node**, first start the + orchestrator yourself, e.g.:: + + # second shell + python -m ensemble_launcher \\ + --system $COMPUTE_SYSTEM \\ + --checkpoint-dir $PBS_O_WORKDIR/el_ckpt \\ + --node-id 0 + + Then run this script with ``--mode client-only --checkpoint-dir + $PBS_O_WORKDIR/el_ckpt``. It connects to the running orchestrator + via ``ClusterClient`` rather than starting its own. + +Usage +----- +:: + + export COMPUTE_SYSTEM=polaris # or aurora + python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed + python scripts/smoke/smoke_ensemble_launcher_in_job.py \\ + --mode client-only --checkpoint-dir $PBS_O_WORKDIR/el_ckpt +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path + +from _smoke_utils import ( + SmokeReporter, + trivial_add, + trivial_hostname, + trivial_square, + water_xyz_path, +) + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def _wait_for_checkpoint(checkpoint_dir: Path, timeout: float) -> None: + """Wait until the orchestrator has written something to checkpoint_dir. + + The exact ready-marker shape depends on the ensemble_launcher + version; we just wait for the directory to be non-empty. + """ + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if checkpoint_dir.is_dir() and any(checkpoint_dir.iterdir()): + return + time.sleep(1.0) + _abort( + f"No checkpoint files appeared under {checkpoint_dir} within {timeout}s. " + "Start the orchestrator in another shell first; see --help." + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--mode", + choices=("managed", "client-only"), + default="managed", + ) + parser.add_argument( + "--system", + default=os.environ.get("COMPUTE_SYSTEM"), + help="polaris | aurora | local (default: $COMPUTE_SYSTEM)", + ) + parser.add_argument( + "--checkpoint-dir", + default=None, + help="(client-only) path the externally-started orchestrator writes to.", + ) + parser.add_argument( + "--node-id", + type=int, + default=0, + help="(client-only) node id assigned by the orchestrator (default 0).", + ) + parser.add_argument( + "--device", + default=None, + help="MACE device: cuda | xpu | cpu (default: cuda on polaris, xpu on aurora)", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference.", + ) + parser.add_argument( + "--wait-timeout", + type=float, + default=60.0, + help="(client-only) seconds to wait for orchestrator checkpoint to appear.", + ) + args = parser.parse_args() + + pbs_nodefile = os.environ.get("PBS_NODEFILE") + if not pbs_nodefile and args.system not in (None, "local"): + _abort( + "PBS_NODEFILE not set. Run inside a PBS allocation, or use --system local." + ) + + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "local"): + _abort(f"Unsupported --system: {system!r}") + + device = args.device or ("xpu" if system == "aurora" else "cuda") + + try: + import ensemble_launcher # noqa: F401 + except ImportError as exc: + _abort( + f"ensemble_launcher is not importable: {exc}. " + "On HPC, install it via scripts/hpc_setup/install_remote.sh." + ) + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + r = SmokeReporter( + f"smoke_ensemble_launcher_in_job (mode={args.mode}, system={system})" + ) + backend = None + + if args.mode == "managed": + with r.check("get_backend(ensemble_launcher, managed) initialises"): + backend = get_backend(backend_name="ensemble_launcher", system=system) + assert backend is not None + else: + if not args.checkpoint_dir: + _abort("--mode client-only requires --checkpoint-dir.") + ckpt = Path(args.checkpoint_dir).resolve() + with r.check( + f"orchestrator checkpoint dir is populated ({ckpt})" + ): + _wait_for_checkpoint(ckpt, args.wait_timeout) + with r.check("get_backend(ensemble_launcher, client_only=True) connects"): + backend = get_backend( + backend_name="ensemble_launcher", + system=system, + client_only=True, + checkpoint_dir=str(ckpt), + node_id=args.node_id, + ) + assert backend is not None + + if backend is None: + r.summary_and_exit() + return + + with r.check("python TaskSpec returns correct result"): + fut = backend.submit( + TaskSpec( + task_id="el-py", + task_type="python", + callable=trivial_square, + args=(11,), + ) + ) + assert fut.result(timeout=180) == 121 + + with r.check("python TaskSpec ran on a compute node"): + fut = backend.submit( + TaskSpec( + task_id="el-host", + task_type="python", + callable=trivial_hostname, + ) + ) + host = fut.result(timeout=180) + print(f" EL worker hostname = {host!r}") + + with r.check("shell TaskSpec runs"): + fut = backend.submit( + TaskSpec( + task_id="el-sh", + task_type="shell", + command="echo smoke_el_shell_ok", + ) + ) + # EL shell-task return shape depends on the version; just assert + # the future resolves without raising. + fut.result(timeout=180) + + with r.check("submit_batch of 3 python tasks all resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"el-batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, 50), + ) + for i in range(3) + ] + ) + results = [f.result(timeout=240) for f in futures] + assert results == [50, 51, 52], results + + if not args.quick: + with r.check(f"MACE geometry opt on water (device={device}, converged)"): + from ase.io import read as ase_read + + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(water_xyz_path())) + inline = atoms_to_atomsdata(atoms).model_dump() + job = { + "input_structure_file": "ignored_by_inline_path", + "output_result_file": "water_smoke_el.json", + "driver": "opt", + "model": "medium-mpa-0", + "device": device, + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + "inline_structure": inline, + } + fut = backend.submit( + TaskSpec( + task_id="el-mace-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + out = fut.result(timeout=900) + assert out.get("status") == "success", f"opt failed: {out}" + energy = next( + (out[k] for k in ("single_point_energy", "energy", "final_energy") if k in out), + None, + ) + assert energy is not None and energy < 0, f"bad MACE result: {out}" + full = out.get("full_output") or {} + if full: + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + print( + f" water opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + else: + print( + f" water opt energy = {energy:.6f} eV " + "(WARNING: full_output missing; convergence not verified inline)" + ) + + with r.check("backend.shutdown() is clean"): + if args.mode == "managed": + backend.shutdown() + else: + # In client-only mode, shutdown should NOT stop the orchestrator + # the user started -- it should only disconnect this client. + backend.shutdown() + print(" (client-only: orchestrator left running in the other shell)") + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_globus_compute.py b/scripts/smoke/smoke_globus_compute.py new file mode 100644 index 00000000..80226972 --- /dev/null +++ b/scripts/smoke/smoke_globus_compute.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +"""Smoke test for the GlobusComputeBackend. + +Drives the production execution layer against a live Globus Compute +endpoint. Exits 0 on success, nonzero on any failure. + +Prereqs (env vars) +------------------ +- GLOBUS_COMPUTE_ENDPOINT_ID -- UUID printed by ``globus-compute-endpoint start``. +- (optional) COMPUTE_SYSTEM -- "polaris" or "aurora" (used for logging only). + +Run:: + + export GLOBUS_COMPUTE_ENDPOINT_ID="" + python scripts/smoke/smoke_globus_compute.py + python scripts/smoke/smoke_globus_compute.py --quick # skip MACE + python scripts/smoke/smoke_globus_compute.py --amqp 443 # firewalled networks +""" + +from __future__ import annotations + +import argparse +import os + +from _smoke_utils import ( + SmokeReporter, + require_env, + trivial_add, + trivial_env_probe, + trivial_hostname, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference (Globus model download on remote endpoint is slow on first run).", + ) + parser.add_argument( + "--amqp", + type=int, + default=None, + help="AMQP port override. Set to 443 when outbound 5671 is blocked (Aurora).", + ) + args = parser.parse_args() + + require_env("GLOBUS_COMPUTE_ENDPOINT_ID") + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + backend_kwargs: dict = {} + if args.amqp is not None: + backend_kwargs["amqp_port"] = args.amqp + + r = SmokeReporter( + f"smoke_globus_compute (system={os.environ.get('COMPUTE_SYSTEM', '?')}, " + f"endpoint={os.environ['GLOBUS_COMPUTE_ENDPOINT_ID'][:8]}...)" + ) + backend = None + local_hostname = trivial_hostname() + + with r.check("get_backend(globus_compute) initialises"): + backend = get_backend(backend_name="globus_compute", **backend_kwargs) + assert backend is not None + + if backend is None: + r.summary_and_exit() + return + + with r.check("check_endpoint_status() reports online"): + status = backend.check_endpoint_status() + # The SDK returns either a dict like {"status": "online"} or a + # string; both shapes count as healthy if "online" appears in the + # repr. "error" status means we cannot reach the endpoint. + s = status.get("status") + assert s != "error", f"endpoint unreachable: {status}" + s_repr = str(s).lower() + assert "online" in s_repr or "ok" in s_repr or "running" in s_repr, ( + f"endpoint not online: {status}" + ) + print(f" endpoint status: {status}") + + with r.check("python TaskSpec (trivial_add) round-trips through Globus"): + fut = backend.submit( + TaskSpec( + task_id="gc-add", + task_type="python", + callable=trivial_add, + args=(40, 2), + ) + ) + result = fut.result(timeout=300) + assert result == 42, f"expected 42, got {result!r}" + + with r.check("python TaskSpec ran on the HPC node (hostname differs from laptop)"): + fut = backend.submit( + TaskSpec( + task_id="gc-host", + task_type="python", + callable=trivial_hostname, + ) + ) + remote_host = fut.result(timeout=300) + assert isinstance(remote_host, str) and remote_host, "empty hostname" + assert remote_host != local_hostname, ( + f"task ran on the laptop ({remote_host}), not the endpoint!" + ) + print(f" local={local_hostname!r} remote={remote_host!r}") + + with r.check("env probe: torch + accelerators visible on worker"): + fut = backend.submit( + TaskSpec( + task_id="gc-env", + task_type="python", + callable=trivial_env_probe, + ) + ) + info = fut.result(timeout=300) + assert isinstance(info, dict) + print(f" worker env: {info}") + + with r.check("shell TaskSpec returns SDK ShellResult"): + fut = backend.submit( + TaskSpec( + task_id="gc-sh", + task_type="shell", + command="echo smoke_globus_compute_shell_ok && hostname", + ) + ) + sh = fut.result(timeout=300) + # ShellFunction returns a ShellResult object with .stdout + stdout = getattr(sh, "stdout", str(sh)) + assert "smoke_globus_compute_shell_ok" in stdout, f"unexpected stdout: {stdout!r}" + print(f" remote shell stdout (truncated): {stdout[:120]!r}") + + with r.check("submit_batch of 3 python tasks all resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"gc-batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, 10), + ) + for i in range(3) + ] + ) + results = [f.result(timeout=300) for f in futures] + assert results == [10, 11, 12], f"expected [10,11,12], got {results}" + + if not args.quick: + with r.check("MACE geometry opt on water runs on Globus Compute (converged)"): + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + # Worker pulls the structure from its own filesystem. Since + # the laptop's water.xyz is not on the HPC node, embed it + # inline the same way the pre-submit hook would. The + # ``full_output`` key in the result carries the on-disk JSON + # back to us (mace_mcp_hpc._mace_worker, lines 127-131) so + # we can check converged without a follow-up transfer. + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + from _smoke_utils import water_xyz_path + + atoms = ase_read(str(water_xyz_path())) + inline = atoms_to_atomsdata(atoms).model_dump() + + job = { + "input_structure_file": "ignored_by_inline_path", + "output_result_file": "water_smoke_gc.json", + "driver": "opt", + "model": "medium-mpa-0", + "device": os.environ.get("CG_SMOKE_DEVICE", "cuda"), + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + "inline_structure": inline, + } + fut = backend.submit( + TaskSpec( + task_id="gc-mace-water-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + # First MACE run on the endpoint downloads the model + opt loop. + mace_out = fut.result(timeout=6000) + assert isinstance(mace_out, dict), type(mace_out) + assert mace_out.get("status") == "success", f"opt failed: {mace_out}" + energy = next( + (mace_out[k] for k in ("single_point_energy", "energy", "final_energy") if k in mace_out), + None, + ) + assert energy is not None and energy < 0, f"bad energy: {mace_out}" + + full = mace_out.get("full_output") or {} + if full: + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + assert full.get("success") is True, f"opt success=False: {full}" + print( + f" remote opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + else: + # full_output is attached by _mace_worker only when inline_structure + # is set; we always pass inline above so this branch should not hit. + print( + f" remote opt energy = {energy:.6f} eV " + "(WARNING: full_output not returned; convergence not verified)" + ) + + with r.check("backend.shutdown() is clean"): + backend.shutdown() + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_globus_transfer.py b/scripts/smoke/smoke_globus_transfer.py new file mode 100644 index 00000000..5126c60c --- /dev/null +++ b/scripts/smoke/smoke_globus_transfer.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +"""Smoke test for GlobusTransferManager (+ optional MCP integration). + +Exercises the production transfer layer from the laptop. Exits 0 on +success, nonzero on any failure. + +Prereqs (env vars) +------------------ +- GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID -- local Globus collection UUID +- GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID -- HPC collection UUID +- GLOBUS_TRANSFER_DESTINATION_BASE_PATH -- e.g. /eagle/projects/MyProj/staging +- (for --with-mcp): GLOBUS_COMPUTE_ENDPOINT_ID and HPC venv with MACE + +First run triggers a Globus OAuth flow. Token caches at +~/.globus/chemgraph_transfer_tokens.json. + +Run:: + + python scripts/smoke/smoke_globus_transfer.py + python scripts/smoke/smoke_globus_transfer.py --keep-remote # don't delete after + python scripts/smoke/smoke_globus_transfer.py --with-mcp # also exercise MCP ensemble in remote mode +""" + +from __future__ import annotations + +import argparse +import os +import time + +from _smoke_utils import SmokeReporter, require_env, water_xyz_path + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--with-mcp", + action="store_true", + help="Also exercise mace_mcp_hpc.run_mace_ensemble(remote_structure_directory=...).", + ) + parser.add_argument( + "--keep-remote", + action="store_true", + help="Don't attempt to delete the staged remote directory at the end.", + ) + parser.add_argument( + "--timeout", + type=float, + default=6000.0, + help="Per-transfer timeout in seconds (default 6000).", + ) + args = parser.parse_args() + + require_env( + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + ) + if args.with_mcp: + require_env("GLOBUS_COMPUTE_ENDPOINT_ID") + + from chemgraph.execution.config import get_transfer_manager + + r = SmokeReporter("smoke_globus_transfer") + mgr = None + transfer_result = None + + with r.check("get_transfer_manager() returns a configured manager"): + mgr = get_transfer_manager() + assert mgr is not None, ( + "get_transfer_manager returned None -- check env vars are exported." + ) + + if mgr is None: + r.summary_and_exit() + return + + with r.check("transfer_files(water.xyz) submits a Globus Transfer task"): + xyz = water_xyz_path() + assert xyz.is_file(), f"fixture missing: {xyz}" + transfer_result = mgr.transfer_files( + local_paths=[str(xyz)], + label=f"chemgraph-smoke-{int(time.time())}", + ) + assert transfer_result.task_id, "no task_id returned" + print(f" task_id = {transfer_result.task_id}") + print(f" remote_dir = {transfer_result.remote_directory}") + + with r.check(f"wait_for_transfer(timeout={args.timeout}s) reaches SUCCEEDED"): + assert transfer_result is not None + status = mgr.wait_for_transfer( + transfer_result.task_id, + timeout=args.timeout, + poll_interval=5, + ) + assert status.get("status") == "SUCCEEDED", f"final status: {status}" + assert status.get("files_transferred", 0) >= 1, status + print( + f" transferred {status['files_transferred']}/{status['files']} files, " + f"{status['bytes_transferred']} bytes" + ) + + with r.check("check_transfer_status() returns SUCCEEDED for completed task"): + assert transfer_result is not None + status = mgr.check_transfer_status(transfer_result.task_id) + assert status["status"] == "SUCCEEDED", status + + with r.check("list_remote_directory() finds the staged file"): + assert transfer_result is not None + entries = mgr.list_remote_directory(transfer_result.remote_directory) + names = {e["name"] for e in entries} + assert "water.xyz" in names, f"water.xyz not in {names!r}" + size = next((e["size"] for e in entries if e["name"] == "water.xyz"), 0) + print(f" remote water.xyz size = {size} bytes") + + if args.with_mcp: + with r.check("MCP run_mace_ensemble(remote_structure_directory=...) succeeds"): + # Drive the MCP server's tool function directly (in-process) -- + # the heavy work is dispatched to Globus Compute by the + # backend that mcp.init_backend() configured. + from chemgraph.mcp.mace_mcp_hpc import ( + _expand_mace_ensemble, + _mace_worker, + mcp, + ) + from chemgraph.execution.base import TaskSpec + from chemgraph.schemas.mace_parsl_schema import ( + mace_input_schema_ensemble, + ) + + # Init the MCP server's backend (reads CHEMGRAPH_EXECUTION_BACKEND + # / GLOBUS_COMPUTE_ENDPOINT_ID exactly like the prod server does). + os.environ.setdefault("CHEMGRAPH_EXECUTION_BACKEND", "globus_compute") + mcp.init_backend() + try: + params = mace_input_schema_ensemble( + remote_structure_directory=transfer_result.remote_directory, + output_result_file="water_smoke_tr.json", + driver="opt", + model="medium-mpa-0", + device=os.environ.get("CG_SMOKE_DEVICE", "cuda"), + ) + jobs = _expand_mace_ensemble(params) + assert jobs, "no jobs expanded from remote dir" + assert all("remote_structure_file" in j for j in jobs), jobs[0] + # Submit each job through the same backend the MCP server uses. + futures = [ + mcp._backend.submit( + TaskSpec( + task_id=f"tr-mace-opt-{i}", + task_type="python", + callable=_mace_worker, + kwargs={"job": j}, + ) + ) + for i, j in enumerate(jobs) + ] + results = [f.result(timeout=6000) for f in futures] + assert all(isinstance(r, dict) for r in results), results + assert all(r.get("status") == "success" for r in results), [ + r.get("status") for r in results + ] + energies = [ + next( + (r[k] for k in ("single_point_energy", "energy", "final_energy") if k in r), + None, + ) + for r in results + ] + assert all(e is not None and e < 0 for e in energies), results + # Remote-path mode does NOT attach full_output (only the + # inline-structure path does -- see mace_mcp_hpc._mace_worker + # lines 127-131). Convergence can be verified after the fact + # by reading the per-structure JSON on the remote filesystem + # (e.g. via Globus Transfer back to the laptop) -- out of + # scope for this smoke test. + print(f" remote MACE opt energies (eV): {energies}") + finally: + mcp.shutdown_backend() + + if not args.keep_remote and transfer_result is not None: + print( + f"\nNOTE: staged directory left in place at {transfer_result.remote_directory}\n" + " (the manager does not implement remote deletion). " + "Clean it up manually if needed." + ) + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_local.py b/scripts/smoke/smoke_local.py new file mode 100644 index 00000000..6058638e --- /dev/null +++ b/scripts/smoke/smoke_local.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +"""Smoke test for the LocalBackend. + +Drives the production execution layer end-to-end on the laptop with no +HPC and no credentials. Exits 0 on success, nonzero on any failure. + +Checks +------ +1. ``get_backend(backend_name="local")`` initialises cleanly. +2. Python TaskSpec round-trip (callable returns correct result). +3. Shell TaskSpec round-trip (exit code 0). +4. ``submit_batch`` of three tasks all resolve. +5. ``JobTracker`` register_batch / get_status / get_results round-trip. +6. MACE worker path: build a job dict for ``water.xyz`` and submit it to + the local backend exactly as ``mace_mcp_hpc._mace_transport_hook`` would. + +Run:: + + python scripts/smoke/smoke_local.py + python scripts/smoke/smoke_local.py --quick # skip the MACE check +""" + +from __future__ import annotations + +import argparse + +from _smoke_utils import ( + SmokeReporter, + trivial_add, + trivial_square, + water_xyz_path, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference (saves ~30s on first run downloading the model).", + ) + args = parser.parse_args() + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + r = SmokeReporter("smoke_local") + backend = None + + with r.check("get_backend(local) initialises"): + backend = get_backend(backend_name="local", system="local") + assert backend is not None, "backend is None" + + if backend is None: + r.summary_and_exit() + return + + with r.check("python TaskSpec returns correct result"): + fut = backend.submit( + TaskSpec( + task_id="py-1", + task_type="python", + callable=trivial_square, + args=(7,), + ) + ) + result = fut.result(timeout=30) + assert result == 49, f"expected 49, got {result!r}" + + with r.check("shell TaskSpec exits 0"): + fut = backend.submit( + TaskSpec( + task_id="sh-1", + task_type="shell", + command="echo smoke_local_shell_ok", + ) + ) + rc = fut.result(timeout=30) + assert rc == 0, f"expected exit 0, got {rc!r}" + + with r.check("submit_batch of 3 python tasks resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, i + 1), + ) + for i in range(3) + ] + ) + results = [f.result(timeout=30) for f in futures] + assert results == [1, 3, 5], f"expected [1,3,5], got {results}" + + with r.check("JobTracker register_batch / get_results round-trip"): + from chemgraph.execution.job_tracker import JobTracker + + tracker = JobTracker() + fut = backend.submit( + TaskSpec( + task_id="tracked-1", + task_type="python", + callable=trivial_square, + args=(6,), + ) + ) + batch_id = tracker.register_batch( + tool_name="smoke_local", + pending_tasks=[({"task_id": "tracked-1"}, fut)], + ) + # Block on the future then ask the tracker for results. + fut.result(timeout=30) + out = tracker.get_results(batch_id) + assert out["status"] == "completed", f"status={out.get('status')}" + assert out["results"][0]["result"] == 36, out["results"] + + if not args.quick: + with r.check("MACE geometry opt: water.xyz on local backend (converged)"): + import json + + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + xyz = water_xyz_path() + assert xyz.is_file(), f"fixture missing: {xyz}" + out_json = xyz.parent / "water_smoke_output.json" + job = { + "input_structure_file": str(xyz), + "output_result_file": str(out_json), + "driver": "opt", + "model": "medium-mpa-0", + "device": "cpu", + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + } + # Submit through the backend (not in-process) to prove the + # submission pipeline serializes the worker callable and the + # arg dict correctly. + fut = backend.submit( + TaskSpec( + task_id="mace-water-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + # First MACE run downloads the model; allow generous timeout. + mace_out = fut.result(timeout=600) + assert isinstance(mace_out, dict), f"non-dict result: {type(mace_out)}" + assert mace_out.get("status") == "success", f"opt failed: {mace_out}" + energy = next( + (mace_out[k] for k in ("single_point_energy", "energy", "final_energy") if k in mace_out), + None, + ) + assert energy is not None, f"no energy in result keys={list(mace_out)}" + assert energy < 0, f"water energy should be negative, got {energy}" + + assert out_json.is_file(), f"opt output JSON not written: {out_json}" + with open(out_json) as fh: + full = json.load(fh) + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + assert full.get("success") is True, f"opt success=False: {full}" + print( + f" water opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + + with r.check("backend.shutdown() is clean"): + backend.shutdown() + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_parsl_in_job.py b/scripts/smoke/smoke_parsl_in_job.py new file mode 100644 index 00000000..f755f05d --- /dev/null +++ b/scripts/smoke/smoke_parsl_in_job.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python +"""Smoke test for ParslBackend on an HPC compute node. + +Must run **inside** a PBS interactive allocation on Polaris or Aurora:: + + # Polaris + qsub -I -A -l select=1 -l walltime=01:00:00 -q debug + # Aurora + qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare + +Inside the allocation:: + + module load conda # or `module load frameworks` on Aurora + source /bin/activate + export COMPUTE_SYSTEM=polaris # or aurora + python scripts/smoke/smoke_parsl_in_job.py + python scripts/smoke/smoke_parsl_in_job.py --quick + python scripts/smoke/smoke_parsl_in_job.py --device xpu # Aurora + +The script fails fast with a clear message if PBS_NODEFILE is missing. +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +from _smoke_utils import ( + SmokeReporter, + trivial_add, + trivial_env_probe, + trivial_hostname, + trivial_square, + water_xyz_path, +) + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--system", + default=os.environ.get("COMPUTE_SYSTEM"), + help="polaris | aurora (default: COMPUTE_SYSTEM env var)", + ) + parser.add_argument( + "--device", + default=None, + help="MACE device: cuda (Polaris default), xpu (Aurora), or cpu.", + ) + parser.add_argument( + "--run-dir", + default=None, + help="Parsl run_dir (default: $PBS_O_WORKDIR/parsl_runs or ./parsl_runs).", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference.", + ) + args = parser.parse_args() + + pbs_nodefile = os.environ.get("PBS_NODEFILE") + if not pbs_nodefile or not Path(pbs_nodefile).is_file(): + _abort( + "PBS_NODEFILE not set or missing. This script must run inside a " + "PBS interactive allocation (qsub -I ...)." + ) + + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora"): + _abort(f"Unsupported --system: {system!r} (expected polaris|aurora)") + + device = args.device or ("xpu" if system == "aurora" else "cuda") + nodes = Path(pbs_nodefile).read_text().splitlines() + + run_dir = args.run_dir or os.environ.get("PBS_O_WORKDIR") + if run_dir: + run_dir = str(Path(run_dir) / "parsl_runs_smoke") + else: + run_dir = str(Path.cwd() / "parsl_runs_smoke") + Path(run_dir).mkdir(parents=True, exist_ok=True) + + print(f"system={system} device={device} nodes={len(nodes)} run_dir={run_dir}") + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + r = SmokeReporter(f"smoke_parsl_in_job (system={system}, nodes={len(nodes)})") + backend = None + + with r.check("get_backend(parsl) initialises with HPC config"): + backend = get_backend( + backend_name="parsl", + system=system, + run_dir=run_dir, + ) + assert backend is not None + + if backend is None: + r.summary_and_exit() + return + + with r.check("python TaskSpec returns correct result"): + fut = backend.submit( + TaskSpec( + task_id="p-py", + task_type="python", + callable=trivial_square, + args=(9,), + ) + ) + assert fut.result(timeout=120) == 81 + + with r.check("python TaskSpec ran on a compute node (hostname != login)"): + fut = backend.submit( + TaskSpec( + task_id="p-host", + task_type="python", + callable=trivial_hostname, + ) + ) + host = fut.result(timeout=120) + print(f" parsl worker hostname = {host!r}") + assert isinstance(host, str) and host + + with r.check("worker env: torch + accelerators visible"): + fut = backend.submit( + TaskSpec( + task_id="p-env", + task_type="python", + callable=trivial_env_probe, + ) + ) + info = fut.result(timeout=120) + print(f" worker env: {info}") + # Polaris should show cuda; Aurora should show xpu. + if system == "polaris": + assert info.get("cuda_devices", 0) >= 1, info + elif system == "aurora": + assert info.get("xpu_devices", 0) >= 1, info + + with r.check("shell TaskSpec exits 0"): + fut = backend.submit( + TaskSpec( + task_id="p-sh", + task_type="shell", + command="echo smoke_parsl_shell_ok && hostname", + ) + ) + rc = fut.result(timeout=120) + assert rc == 0, f"exit code = {rc}" + + with r.check("submit_batch of 4 python tasks all resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"p-batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, 100), + ) + for i in range(4) + ] + ) + results = [f.result(timeout=180) for f in futures] + assert results == [100, 101, 102, 103], results + + if not args.quick: + with r.check(f"MACE geometry opt on water (device={device}, converged)"): + from ase.io import read as ase_read + + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(water_xyz_path())) + inline = atoms_to_atomsdata(atoms).model_dump() + job = { + "input_structure_file": "ignored_by_inline_path", + "output_result_file": "water_smoke_parsl.json", + "driver": "opt", + "model": "medium-mpa-0", + "device": device, + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + "inline_structure": inline, + } + fut = backend.submit( + TaskSpec( + task_id="p-mace-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + out = fut.result(timeout=900) + assert out.get("status") == "success", f"opt failed: {out}" + energy = next( + (out[k] for k in ("single_point_energy", "energy", "final_energy") if k in out), + None, + ) + assert energy is not None and energy < 0, f"bad MACE result: {out}" + full = out.get("full_output") or {} + if full: + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + print( + f" water opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + else: + print( + f" water opt energy = {energy:.6f} eV " + "(WARNING: full_output missing; convergence not verified inline)" + ) + + with r.check("backend.shutdown() is clean"): + backend.shutdown() + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/water.xyz b/scripts/smoke/water.xyz new file mode 100644 index 00000000..baec6e18 --- /dev/null +++ b/scripts/smoke/water.xyz @@ -0,0 +1,5 @@ +3 +water molecule +O 0.0000000 0.0000000 0.0000000 +H 0.7570000 0.5860000 0.0000000 +H -0.7570000 0.5860000 0.0000000 From a31c148e0bd015632aa7e003a093fb46f73adb8c Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Fri, 5 Jun 2026 12:58:07 -0500 Subject: [PATCH 036/119] Add Crux (ALCF) support to Parsl and EnsembleLauncher backends Crux is a CPU-only AMD EPYC system (no GPUs), so the new configs drop accelerators and use a conda-based worker_init. Wires "crux" through the loader dispatch, the EL SystemConfig registry, and the in-job smoke/demo allowlists (defaulting device to "cpu" instead of cuda/xpu). - src/chemgraph/hpc_configs/crux_parsl.py: new HighThroughputExecutor config, requires PBS_NODEFILE, max_workers_per_node=16 - src/chemgraph/hpc_configs/loader.py: dispatch "crux" to get_crux_config - src/chemgraph/execution/ensemble_launcher_backend.py: add get_crux_system_config (ncpus=128, no GPUs) and register it - scripts/smoke/*, scripts/demo/*: accept --system crux and resolve device defaults to cpu for Crux - tests/test_execution.py: TestELSystemConfigCrux asserts registry membership and CPU-only SystemConfig shape --- .../demo_ensemble_launcher_in_job_direct.py | 11 ++- scripts/demo/demo_parsl_in_job_direct.py | 11 ++- .../smoke/smoke_ensemble_launcher_in_job.py | 11 ++- scripts/smoke/smoke_parsl_in_job.py | 18 +++-- src/chemgraph/execution/config.py | 2 +- .../execution/ensemble_launcher_backend.py | 11 +++ src/chemgraph/hpc_configs/crux_parsl.py | 69 +++++++++++++++++++ src/chemgraph/hpc_configs/loader.py | 9 ++- tests/test_execution.py | 25 +++++++ 9 files changed, 153 insertions(+), 14 deletions(-) create mode 100644 src/chemgraph/hpc_configs/crux_parsl.py diff --git a/scripts/demo/demo_ensemble_launcher_in_job_direct.py b/scripts/demo/demo_ensemble_launcher_in_job_direct.py index dbf2de70..d7126924 100644 --- a/scripts/demo/demo_ensemble_launcher_in_job_direct.py +++ b/scripts/demo/demo_ensemble_launcher_in_job_direct.py @@ -50,9 +50,16 @@ def main() -> None: if not args.system: _abort("COMPUTE_SYSTEM env var not set and --system not given.") system = args.system.lower().strip() - if system not in ("polaris", "aurora"): + if system not in ("polaris", "aurora", "crux"): _abort(f"Unsupported --system: {system!r}") - device = args.device or ("xpu" if system == "aurora" else "cuda") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" try: import ensemble_launcher # noqa: F401 diff --git a/scripts/demo/demo_parsl_in_job_direct.py b/scripts/demo/demo_parsl_in_job_direct.py index 3b4e749d..e7df85ea 100644 --- a/scripts/demo/demo_parsl_in_job_direct.py +++ b/scripts/demo/demo_parsl_in_job_direct.py @@ -64,9 +64,16 @@ def main() -> None: if not args.system: _abort("COMPUTE_SYSTEM env var not set and --system not given.") system = args.system.lower().strip() - if system not in ("polaris", "aurora"): + if system not in ("polaris", "aurora", "crux"): _abort(f"Unsupported --system: {system!r}") - device = args.device or ("xpu" if system == "aurora" else "cuda") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" run_dir = args.run_dir or os.environ.get("PBS_O_WORKDIR") if run_dir: diff --git a/scripts/smoke/smoke_ensemble_launcher_in_job.py b/scripts/smoke/smoke_ensemble_launcher_in_job.py index 968f2f4d..b178454d 100644 --- a/scripts/smoke/smoke_ensemble_launcher_in_job.py +++ b/scripts/smoke/smoke_ensemble_launcher_in_job.py @@ -126,10 +126,17 @@ def main() -> None: if not args.system: _abort("COMPUTE_SYSTEM env var not set and --system not given.") system = args.system.lower().strip() - if system not in ("polaris", "aurora", "local"): + if system not in ("polaris", "aurora", "local", "crux"): _abort(f"Unsupported --system: {system!r}") - device = args.device or ("xpu" if system == "aurora" else "cuda") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" try: import ensemble_launcher # noqa: F401 diff --git a/scripts/smoke/smoke_parsl_in_job.py b/scripts/smoke/smoke_parsl_in_job.py index f755f05d..75daead3 100644 --- a/scripts/smoke/smoke_parsl_in_job.py +++ b/scripts/smoke/smoke_parsl_in_job.py @@ -76,10 +76,17 @@ def main() -> None: if not args.system: _abort("COMPUTE_SYSTEM env var not set and --system not given.") system = args.system.lower().strip() - if system not in ("polaris", "aurora"): - _abort(f"Unsupported --system: {system!r} (expected polaris|aurora)") - - device = args.device or ("xpu" if system == "aurora" else "cuda") + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r} (expected polaris|aurora|crux)") + + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" nodes = Path(pbs_nodefile).read_text().splitlines() run_dir = args.run_dir or os.environ.get("PBS_O_WORKDIR") @@ -142,11 +149,12 @@ def main() -> None: ) info = fut.result(timeout=120) print(f" worker env: {info}") - # Polaris should show cuda; Aurora should show xpu. + # Polaris should show cuda; Aurora should show xpu; Crux is CPU-only. if system == "polaris": assert info.get("cuda_devices", 0) >= 1, info elif system == "aurora": assert info.get("xpu_devices", 0) >= 1, info + # Crux: CPU-only; no accelerator assertion. with r.check("shell TaskSpec exits 0"): fut = backend.submit( diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index dc650a26..6be99a58 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -10,7 +10,7 @@ ``"globus_compute"``, ``"local"``). ``COMPUTE_SYSTEM`` Override the target HPC system (``"polaris"``, ``"aurora"``, - ``"local"``). + ``"crux"``, ``"local"``). """ from __future__ import annotations diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 56210d97..b1631a9a 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -87,6 +87,16 @@ def get_aurora_system_config(): return system_config +def get_crux_system_config(): + _require_ensemble_launcher() + system_config = SystemConfig( + name="crux", + ncpus=128, + cpus=list(range(128)), + ) + return system_config + + def get_launcher_config( task_executor_name: Union[str, List] = "async_processpool", child_executor_policy: str = "fixed_leafs_children_policy", @@ -296,6 +306,7 @@ def shutdown(self) -> None: "local": get_local_system_config, "aurora": get_aurora_system_config, "polaris": get_polaris_system_config, + "crux": get_crux_system_config, } diff --git a/src/chemgraph/hpc_configs/crux_parsl.py b/src/chemgraph/hpc_configs/crux_parsl.py new file mode 100644 index 00000000..07b3051b --- /dev/null +++ b/src/chemgraph/hpc_configs/crux_parsl.py @@ -0,0 +1,69 @@ +import os +from parsl.config import Config +from parsl.providers import LocalProvider +from parsl.executors import HighThroughputExecutor +from parsl.launchers import MpiExecLauncher + + +def get_crux_config( + run_dir=None, + max_workers_per_node: int = 16, +): + """Create a Parsl configuration for ALCF Crux PBS jobs. + + Crux is a CPU-only AMD EPYC system (no accelerators). + + Parameters + ---------- + run_dir : str, optional + Directory used as Parsl's run directory and worker working directory. + max_workers_per_node : int, optional + Number of concurrent workers per node. Defaults to 16 + (≈8 cores per worker on a 128-core node). + + Returns + ------- + parsl.config.Config + Configured Parsl ``Config`` for Crux. + """ + if run_dir is None: + run_dir = os.getcwd() + + worker_init = ( + f"export TMPDIR=/tmp; cd {run_dir}; " + "module load conda; conda activate base" + ) + + node_file = os.getenv("PBS_NODEFILE") + if node_file and os.path.exists(node_file): + with open(node_file, "r", encoding="utf-8") as f: + node_list = f.readlines() + num_nodes = len(node_list) + else: + raise ValueError( + "PBS_NODEFILE not found. Cannot determine node count for Crux." + ) + + config = Config( + executors=[ + HighThroughputExecutor( + label="htex", + heartbeat_period=30, + heartbeat_threshold=240, + max_workers_per_node=max_workers_per_node, + provider=LocalProvider( + nodes_per_block=num_nodes, + launcher=MpiExecLauncher( + bind_cmd="--cpu-bind", overrides="--ppn 1" + ), + init_blocks=1, + worker_init=worker_init, + max_blocks=1, + min_blocks=0, + ), + ) + ], + run_dir=run_dir, + ) + + return config diff --git a/src/chemgraph/hpc_configs/loader.py b/src/chemgraph/hpc_configs/loader.py index 4a25d5e7..7aed297f 100644 --- a/src/chemgraph/hpc_configs/loader.py +++ b/src/chemgraph/hpc_configs/loader.py @@ -20,7 +20,7 @@ def load_parsl_config(system_name: str, run_dir: str | None = None, **kwargs): ---------- system_name : str Target system name. Supported: ``"local"``, ``"polaris"``, - ``"aurora"``. + ``"aurora"``, ``"crux"``. run_dir : str, optional Parsl run directory. Defaults to the current working directory. **kwargs @@ -58,8 +58,13 @@ def load_parsl_config(system_name: str, run_dir: str | None = None, **kwargs): return get_aurora_config(run_dir=run_dir, **kwargs) + elif system_name == "crux": + from chemgraph.hpc_configs.crux_parsl import get_crux_config + + return get_crux_config(run_dir=run_dir, **kwargs) + else: raise ValueError( f"Unknown HPC system: '{system_name}'. " - f"Supported systems: local, polaris, aurora" + f"Supported systems: local, polaris, aurora, crux" ) diff --git a/tests/test_execution.py b/tests/test_execution.py index c662547c..dd52f415 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -412,6 +412,31 @@ def test_shell_task_missing_command(self): backend.shutdown() +class TestELSystemConfigCrux: + """EnsembleLauncher SystemConfig builder for Crux (CPU-only).""" + + def test_crux_in_registry(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + ) + + assert "crux" in SYSTEM_CONFIG_REGISTRY + + def test_crux_system_config_cpu_only(self): + pytest.importorskip("ensemble_launcher") + from chemgraph.execution.ensemble_launcher_backend import ( + get_crux_system_config, + ) + + cfg = get_crux_system_config() + assert cfg.name == "crux" + assert cfg.ncpus == 128 + assert len(cfg.cpus) == 128 + # CPU-only: ngpus / gpus must not be populated + assert getattr(cfg, "ngpus", None) in (None, 0) + assert not getattr(cfg, "gpus", None) + + # ── GlobusComputeBackend tests ────────────────────────────────────────── From 2dc37c857c0187d66f35f352188ad202b617a64a Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 12:59:55 -0500 Subject: [PATCH 037/119] Ensure MACE worker creates output directories --- src/chemgraph/mcp/mace_mcp_hpc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index 58750b46..b93dcd4c 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -119,6 +119,10 @@ def _mace_worker(job: dict) -> dict: tmpdir, job.get("output_result_file", "output.json") ) + output_file = job.get("output_result_file") + if output_file: + os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) + params = mace_input_schema(**job) result = run_mace_core(params) From dfa9950501414d4257d2bf741fa4b6c51c45591d Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:13:24 -0500 Subject: [PATCH 038/119] Fix schema fanout batch return annotation --- src/chemgraph/mcp/cg_fastmcp.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index 155dd76d..13e72d8f 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -484,9 +484,13 @@ async def wrapper(**kwargs): wrapper.__doc__ = expander.__doc__ wrapper.__module__ = expander.__module__ wrapper.__qualname__ = expander.__qualname__ - # Preserve the expander's signature so FastMCP advertises the - # ensemble schema to the LLM, not the worker's per-item one. - wrapper.__signature__ = sig + # Preserve the expander's input signature so FastMCP advertises + # the ensemble schema to the LLM, not the worker's per-item one. + # The wrapper returns a submit_or_gather batch summary, though, + # so it must not inherit the expander's list-of-jobs annotation. + wrapper.__signature__ = sig.replace( + return_annotation=dict[str, Any] + ) self.add_tool(wrapper, **fastmcp_kwargs) return expander From aad6fae5653f37cf3db42d752590368efc158bb3 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:26:55 -0500 Subject: [PATCH 039/119] Ignore local run and model artifacts --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index cacf26a5..ab3026f5 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,8 @@ opencode.json chemgraph_mcp_logs/ vllm/ logs/ +runs/ +**/*.model error_log.txt .env test.csv From 9d1b20a23a3a2a278577cfda9db72f2fd35a9089 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:27:22 -0500 Subject: [PATCH 040/119] Cover schema fanout and MACE output path fixes --- tests/test_mcp.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 66cab765..b4615871 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,13 +1,16 @@ """Test suite for MCP servers.""" +import inspect import json from pathlib import Path +from typing import Any import pytest try: from mcp.types import TextContent from fastmcp import Client + from chemgraph.mcp.cg_fastmcp import CGFastMCP from chemgraph.mcp.mcp_tools import mcp from chemgraph.mcp.data_analysis_mcp import mcp as data_mcp except ModuleNotFoundError: @@ -16,6 +19,63 @@ TEST_DIR = Path(__file__).parent +def _fanout_worker(item: dict) -> dict: + return item + + +def test_schema_fanout_tool_advertises_batch_result_signature(monkeypatch): + """Fanout tools expose an ensemble input but return batch summaries.""" + local_mcp = CGFastMCP(name="test") + captured = {} + + def capture_tool(fn, **kwargs): + captured["fn"] = fn + captured["kwargs"] = kwargs + + monkeypatch.setattr(local_mcp, "add_tool", capture_tool) + + @local_mcp.schema_fanout_tool(name="fanout", worker=_fanout_worker) + def fanout(params: dict) -> list[dict]: + return [params] + + sig = inspect.signature(captured["fn"]) + + assert list(sig.parameters) == ["params"] + assert sig.parameters["params"].annotation is dict + assert sig.return_annotation == dict[str, Any] + + +def test_mace_worker_creates_inline_output_parent(monkeypatch): + from ase import Atoms + + from chemgraph.mcp import mace_mcp_hpc + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = Atoms(numbers=[1, 1], positions=[[0, 0, 0], [0, 0, 0.74]]) + output_file = "nested/family/output.json" + + def fake_run_mace_core(params): + output_path = Path(params.output_result_file) + assert output_path.parent.is_dir() + output_path.write_text('{"ok": true}', encoding="utf-8") + return {"status": "success"} + + monkeypatch.setattr(mace_mcp_hpc, "run_mace_core", fake_run_mace_core) + + result = mace_mcp_hpc._mace_worker( + { + "inline_structure": atoms_to_atomsdata(atoms).model_dump(), + "output_result_file": output_file, + "driver": "energy", + "model": "small", + "device": "cpu", + } + ) + + assert result["status"] == "success" + assert result["full_output"] == {"ok": True} + + @pytest.mark.asyncio async def test_split_cif_dataset(tmp_path): """Test splitting a dataset of CIF files.""" From acc8a3ba8777cb7ec2ca76f9aabce29b8961362b Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:27:40 -0500 Subject: [PATCH 041/119] Create parent directory for generated coordinate files --- src/chemgraph/tools/cheminformatics_core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/chemgraph/tools/cheminformatics_core.py b/src/chemgraph/tools/cheminformatics_core.py index 0ffe13b3..321fc659 100644 --- a/src/chemgraph/tools/cheminformatics_core.py +++ b/src/chemgraph/tools/cheminformatics_core.py @@ -142,6 +142,9 @@ def smiles_to_coordinate_file_core( atoms = Atoms(numbers=numbers, positions=positions) final_output_file = _resolve_path(output_file) + parent = os.path.dirname(os.path.abspath(final_output_file)) + if parent: + os.makedirs(parent, exist_ok=True) ase_write(final_output_file, atoms) return { From 42b15a02e477398b64bd07444ad6163458219315 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:28:00 -0500 Subject: [PATCH 042/119] Read version from chemgraphagent distribution --- src/chemgraph/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/__init__.py b/src/chemgraph/__init__.py index 216a3532..c973a4ea 100644 --- a/src/chemgraph/__init__.py +++ b/src/chemgraph/__init__.py @@ -1,11 +1,10 @@ """ChemGraph package metadata.""" -from importlib.metadata import PackageNotFoundError, packages_distributions, version +from importlib.metadata import PackageNotFoundError, version try: - dist_names = packages_distributions().get("chemgraph", []) - __version__ = version(dist_names[0]) if dist_names else "unknown" -except (PackageNotFoundError, IndexError): + __version__ = version("chemgraphagent") +except PackageNotFoundError: # Local source tree without installed package metadata. __version__ = "unknown" From 78805621efa1a45d714ca9a7a165bec4e803b5c5 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:40:42 -0500 Subject: [PATCH 043/119] Parameterize Aurora Parsl worker setup --- src/chemgraph/hpc_configs/aurora_parsl.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/hpc_configs/aurora_parsl.py b/src/chemgraph/hpc_configs/aurora_parsl.py index ece27183..03c9e912 100644 --- a/src/chemgraph/hpc_configs/aurora_parsl.py +++ b/src/chemgraph/hpc_configs/aurora_parsl.py @@ -8,12 +8,19 @@ def get_aurora_config( run_dir=None, + worker_init: str | None = None, + max_workers_per_node: int | None = None, ): if run_dir is None: run_dir = os.getcwd() - # Hard-wired worker_init for aurora - worker_init = f"export TMPDIR=/tmp; cd {run_dir}; module load frameworks" + if worker_init is None: + worker_init = f"export TMPDIR=/tmp; cd {run_dir}; module load frameworks" + + if max_workers_per_node is None: + max_workers_per_node = int( + os.getenv("CHEMGRAPH_PARSL_MAX_WORKERS_PER_NODE", "9") + ) # Get the number of nodes: node_file = os.getenv("PBS_NODEFILE") @@ -33,7 +40,7 @@ def get_aurora_config( heartbeat_period=30, heartbeat_threshold=240, available_accelerators=12, - max_workers_per_node=9, + max_workers_per_node=max_workers_per_node, address=address_by_interface('bond0'), provider=LocalProvider( nodes_per_block=num_nodes, From 591dba78d36b238f87e009c0755439454eeda38d Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:41:29 -0500 Subject: [PATCH 044/119] Clarify MACE model path schema descriptions --- src/chemgraph/schemas/mace_parsl_schema.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/chemgraph/schemas/mace_parsl_schema.py b/src/chemgraph/schemas/mace_parsl_schema.py index 63dc8008..17d5c54f 100644 --- a/src/chemgraph/schemas/mace_parsl_schema.py +++ b/src/chemgraph/schemas/mace_parsl_schema.py @@ -23,11 +23,13 @@ class mace_input_schema(BaseModel): ) model: str = Field( default="medium-mpa-0", - description="MACE foundation model name (NOT the calculator type). " + description="MACE foundation model name or absolute local model file path " + "(NOT the calculator type). " "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " - "'mace-matpes-r2scan-0'. Default is 'medium-mpa-0'. " + "'mace-matpes-r2scan-0', or an absolute path to a local .model file. " + "Default is 'medium-mpa-0'. " "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( @@ -81,11 +83,13 @@ class mace_input_schema_ensemble(BaseModel): ) model: str = Field( default="medium-mpa-0", - description="MACE foundation model name (NOT the calculator type). " + description="MACE foundation model name or absolute local model file path " + "(NOT the calculator type). " "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " - "'mace-matpes-r2scan-0'. Default is 'medium-mpa-0'. " + "'mace-matpes-r2scan-0', or an absolute path to a local .model file. " + "Default is 'medium-mpa-0'. " "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( @@ -122,7 +126,11 @@ class mace_output_schema(BaseModel): description="Path to a JSON file where simulation results is saved.", ) model: str | None = Field( - default=None, description="Path to the model. Default is medium-mpa-0." + default=None, + description=( + "MACE foundation model name or absolute local model file path. " + "Default is medium-mpa-0." + ), ) device: str = Field( default="cpu", From 534a292f9a53e9409b8ec751322dc1cee6c429bf Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:43:08 -0500 Subject: [PATCH 045/119] Add HPC JSON inspection MCP tool --- src/chemgraph/mcp/hpc_misc_mcp.py | 167 ++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 src/chemgraph/mcp/hpc_misc_mcp.py diff --git a/src/chemgraph/mcp/hpc_misc_mcp.py b/src/chemgraph/mcp/hpc_misc_mcp.py new file mode 100644 index 00000000..106e5c52 --- /dev/null +++ b/src/chemgraph/mcp/hpc_misc_mcp.py @@ -0,0 +1,167 @@ +"""FastMCP tools for generic HPC run-artifact inspection.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from mcp.server.fastmcp import FastMCP + + +mcp = FastMCP( + name="ChemGraph HPC Misc Tools", + instructions=""" + You expose small, generic tools for inspecting files produced by HPC + calculations. These tools do not run chemistry; they help agents inspect + run artifacts without relying on simulation-specific readers. + """, +) + + +@mcp.tool( + name="inspect_json", + description=( + "Inspect a JSON file, a directory of JSON files, or a missing expected " + "JSON path. Returns compact summaries and nearby JSON files when the " + "requested path is absent." + ), +) +def inspect_json( + path: str, + glob_pattern: str = "*.json", + max_files: int = 20, + max_preview_chars: int = 1200, + recursive: bool = False, +) -> dict[str, Any]: + """Inspect JSON artifacts without assuming one fixed output-file layout.""" + target = Path(path).expanduser() + if target.is_file(): + return { + "status": "ok", + "kind": "file", + "path": str(target), + "json": _load_json_summary( + target, + max_preview_chars=max_preview_chars, + ), + } + + if target.is_dir(): + files = _json_files( + target, + glob_pattern=glob_pattern, + max_files=max_files, + recursive=recursive, + ) + return { + "status": "ok", + "kind": "directory", + "path": str(target), + "glob_pattern": glob_pattern, + "recursive": recursive, + "file_count_returned": len(files), + "files": [ + { + "path": str(file), + "json": _load_json_summary( + file, + max_preview_chars=max_preview_chars, + ), + } + for file in files + ], + } + + parent = target.parent + nearby = ( + _json_files( + parent, + glob_pattern=glob_pattern, + max_files=max_files, + recursive=False, + ) + if parent.is_dir() + else [] + ) + return { + "status": "not_found", + "kind": "missing", + "path": str(target), + "parent_exists": parent.is_dir(), + "nearby_json_files": [str(file) for file in nearby], + } + + +def _json_files( + directory: Path, + *, + glob_pattern: str, + max_files: int, + recursive: bool, +) -> list[Path]: + if max_files < 1: + return [] + iterator = directory.rglob(glob_pattern) if recursive else directory.glob(glob_pattern) + return sorted(path for path in iterator if path.is_file())[:max_files] + + +def _load_json_summary(path: Path, *, max_preview_chars: int) -> dict[str, Any]: + try: + value = json.loads(path.read_text(encoding="utf-8")) + except Exception as exc: # noqa: BLE001 - report file/read/parse failure. + return { + "status": "error", + "error": repr(exc), + } + return { + "status": "ok", + "summary": _summarize_json(value), + "preview": _json_preview(value, max_chars=max_preview_chars), + } + + +def _summarize_json(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + summary: dict[str, Any] = { + "type": "object", + "keys": sorted(str(key) for key in value.keys())[:40], + } + for key in ("status", "energy", "energy_ev", "driver", "model"): + if key in value: + summary[key] = value[key] + for key in ("results", "failures", "errors"): + nested = value.get(key) + if isinstance(nested, list): + summary[f"{key}_count"] = len(nested) + return summary + if isinstance(value, list): + return { + "type": "array", + "length": len(value), + "first_item": _summarize_json(value[0]) if value else None, + } + return { + "type": type(value).__name__, + "value": value, + } + + +def _json_preview(value: Any, *, max_chars: int) -> Any: + try: + text = json.dumps(value, sort_keys=True) + except TypeError: + text = repr(value) + if len(text) <= max_chars: + return value + return { + "truncated": True, + "chars": len(text), + "text": text[:max_chars], + } + + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + run_mcp_server(mcp, default_port=9020) From bf433ae96bf7b7c3b6518888484e5afcb0c6bbfa Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:45:20 -0500 Subject: [PATCH 046/119] Add Academy canonical event observability --- .../academy/observability/__init__.py | 13 ++ .../observability/communication_proof.py | 117 ++++++++++++++ .../academy/observability/event_log.py | 145 ++++++++++++++++++ .../academy/observability/payloads.py | 130 ++++++++++++++++ .../academy/observability/run_files.py | 66 ++++++++ tests/test_academy_payloads.py | 69 +++++++++ 6 files changed, 540 insertions(+) create mode 100644 src/chemgraph/academy/observability/__init__.py create mode 100644 src/chemgraph/academy/observability/communication_proof.py create mode 100644 src/chemgraph/academy/observability/event_log.py create mode 100644 src/chemgraph/academy/observability/payloads.py create mode 100644 src/chemgraph/academy/observability/run_files.py create mode 100644 tests/test_academy_payloads.py diff --git a/src/chemgraph/academy/observability/__init__.py b/src/chemgraph/academy/observability/__init__.py new file mode 100644 index 00000000..752e52d4 --- /dev/null +++ b/src/chemgraph/academy/observability/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.event_log import read_events +from chemgraph.academy.observability.payloads import typed_payload + +__all__ = [ + 'CampaignEvent', + 'EventLog', + 'read_events', + 'typed_payload', +] diff --git a/src/chemgraph/academy/observability/communication_proof.py b/src/chemgraph/academy/observability/communication_proof.py new file mode 100644 index 00000000..4f61b7bf --- /dev/null +++ b/src/chemgraph/academy/observability/communication_proof.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import Any + +from chemgraph.academy.observability.event_log import CampaignEvent + + +def build_communication_proof( + events: list[CampaignEvent], + placement: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build proof that communication could affect recipient behavior.""" + message_ids: dict[str, dict[str, Any]] = {} + sent_messages: list[dict[str, Any]] = [] + for event in events: + if event.event != "message_sent": + continue + payload = event.payload + message_id = payload.get("message_id") + if not isinstance(message_id, str): + continue + message = { + "message_id": message_id, + "sender": payload.get("sender"), + "recipient": payload.get("recipient"), + "content": payload.get("content"), + "evidence_refs": payload.get("evidence_refs", []), + "artifact_refs": payload.get("artifact_refs", []), + "tool_result_ids": payload.get("tool_result_ids", []), + "timestamp": payload.get("timestamp"), + } + message_ids[message_id] = message + sent_messages.append(message) + + agents = (placement or {}).get("agents", {}) + cross_node_messages = [] + if isinstance(agents, dict): + for message in sent_messages: + sender = agents.get(message.get("sender"), {}) + recipient = agents.get(message.get("recipient"), {}) + sender_host = sender.get("short_hostname") or sender.get("hostname") + recipient_host = recipient.get("short_hostname") or recipient.get("hostname") + if sender_host and recipient_host and sender_host != recipient_host: + cross_node_messages.append( + { + **message, + "sender_hostname": sender_host, + "recipient_hostname": recipient_host, + }, + ) + + cited_beliefs = [] + cited_message_ids: set[str] = set() + for event in events: + if event.event != "belief_updated": + continue + refs = event.payload.get("supporting_message_ids", []) + if not isinstance(refs, list): + continue + cited = [ref for ref in refs if isinstance(ref, str) and ref in message_ids] + if not cited: + continue + cited_message_ids.update(cited) + cited_beliefs.append( + { + "agent_id": event.agent_id, + "role": event.role, + "hypothesis": event.payload.get("hypothesis"), + "confidence": event.payload.get("confidence"), + "supporting_message_ids": cited, + }, + ) + + cited_tool_refs = [] + final_report_count = 0 + for event in events: + if event.event != "belief_updated": + continue + final_report_count += 1 + refs = ( + event.payload.get("supporting_tool_result_ids") + or event.payload.get("supporting_artifact_ids") + or [] + ) + if not isinstance(refs, list): + continue + calls = [ + ref + for ref in refs + if isinstance(ref, str) + and (ref.startswith("call-") or ref.startswith("tool-")) + ] + if calls: + cited_tool_refs.append( + { + "agent_id": event.agent_id, + "hypothesis": event.payload.get("hypothesis"), + "supporting_artifact_ids": calls, + }, + ) + + return { + "message_count": len(sent_messages), + "received_message_ids_cited_in_beliefs": sorted(cited_message_ids), + "belief_changes_citing_messages": len(cited_beliefs), + "belief_change_examples": cited_beliefs[:10], + "cross_node_message_count": len(cross_node_messages), + "cross_node_message_examples": cross_node_messages[:10], + "tool_refs_cited_in_beliefs": cited_tool_refs[:10], + "final_report_count": final_report_count, + "passes": { + "has_message": bool(sent_messages), + "has_belief_citing_message": bool(cited_beliefs), + "has_cross_node_message": bool(cross_node_messages), + "final_report": final_report_count > 0, + }, + } diff --git a/src/chemgraph/academy/observability/event_log.py b/src/chemgraph/academy/observability/event_log.py new file mode 100644 index 00000000..e42013b6 --- /dev/null +++ b/src/chemgraph/academy/observability/event_log.py @@ -0,0 +1,145 @@ +"""Shared event log for Academy-ChemGraph campaign runs. + +The dynamic campaign layer treats agent messages and ChemGraph job updates as +one append-only event stream. The dashboard and HPC run scripts can consume +this file without knowing which science use case created the event. +""" + +from __future__ import annotations + +import json +import time +import uuid +from pathlib import Path +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field + + +EventKind = Literal[ + "campaign_started", + "campaign_planned", + "campaign_finished", + "agent_started", + "agent_stopped", + "agent_decision", + "agent_error", + "message_sent", + "message_received", + "message_delivered", + "message_delivery_failed", + "belief_updated", + "tool_call_started", + "tool_call_finished", + "tool_call_failed", + "chemgraph_batch_submitted", + "chemgraph_job_status", + "chemgraph_job_result", + "chemgraph_transfer_submitted", + "chemgraph_transfer_done", + "round_started", + "round_finished", + "self_wake_scheduled", + "idle_timeout", + "max_decisions_reached", + "daemon_started", + "daemon_stopped", + "bootstrap_message_dispatched", + "llm_tool_calls", + "turn_finished_without_external_action", + "chemgraph_reasoning_turn_started", + "chemgraph_reasoning_turn_finished", + "run_started", + "run_finished", + "workflow_started", + "workflow_finished", + "workflow_node_started", + "workflow_node_finished", + "llm_decision", + "workflow_output", +] + +__all__ = [ + 'CampaignEvent', + 'EventKind', + 'EventLog', + 'read_events', +] + + +class CampaignEvent(BaseModel): + """One durable event emitted by a campaign runtime.""" + + model_config = ConfigDict(extra="forbid") + + event_id: str = Field(default_factory=lambda: f"evt-{uuid.uuid4()}") + timestamp: float = Field(default_factory=time.time) + event: EventKind + run_id: str | None = None + agent_id: str | None = None + role: str | None = None + correlation_id: str | None = None + payload: dict[str, Any] = Field(default_factory=dict) + + +class EventLog: + """Append/read helper for campaign JSONL event logs.""" + + def __init__(self, path: str | Path) -> None: + self.path = Path(path) + + def append(self, event: CampaignEvent) -> CampaignEvent: + """Append *event* and return it.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + with self.path.open("a", encoding="utf-8") as handle: + handle.write(event.model_dump_json()) + handle.write("\n") + return event + + def emit( + self, + event: EventKind, + *, + run_id: str | None = None, + agent_id: str | None = None, + role: str | None = None, + correlation_id: str | None = None, + payload: dict[str, Any] | None = None, + ) -> CampaignEvent: + """Build and append a :class:`CampaignEvent`.""" + return self.append( + CampaignEvent( + event=event, + run_id=run_id, + agent_id=agent_id, + role=role, + correlation_id=correlation_id, + payload=payload or {}, + ) + ) + + def read(self) -> list[CampaignEvent]: + """Read all valid JSONL events from the log.""" + return read_events(self.path) + + +def read_events(path: str | Path) -> list[CampaignEvent]: + """Read valid campaign events from *path*. + + Partially written or malformed lines are skipped so live dashboards can + poll while another process is appending. + """ + event_path = Path(path) + if not event_path.exists(): + return [] + events: list[CampaignEvent] = [] + with event_path.open(encoding="utf-8") as handle: + for line in handle: + if not line.strip(): + continue + try: + payload = json.loads(line) + events.append(CampaignEvent.model_validate(payload)) + except (json.JSONDecodeError, ValueError): + continue + return events diff --git a/src/chemgraph/academy/observability/payloads.py b/src/chemgraph/academy/observability/payloads.py new file mode 100644 index 00000000..2844eec8 --- /dev/null +++ b/src/chemgraph/academy/observability/payloads.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from chemgraph.academy.observability.event_log import CampaignEvent + + +class MessageSentPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + message_id: str + sender: str + recipient: str + content: str + kind: str | None = None + tldr: str | None = None + artifact_refs: list[str] = Field(default_factory=list) + tool_result_ids: list[str] = Field(default_factory=list) + reason: str | None = None + confidence: float | None = None + round: int | None = None + timestamp: float | None = None + + +class MessageReceivedPayload(MessageSentPayload): + pass + + +class ToolCallStartedPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + tool_result_id: str | None = None + tool_call_id: str | None = None + tool_name: str + arguments: dict[str, Any] = Field(default_factory=dict) + + +class ToolCallFinishedPayload(ToolCallStartedPayload): + status: str + result: Any = None + timestamp: float | None = None + agent_name: str | None = None + + +class ToolCallFailedPayload(ToolCallStartedPayload): + status: str + error: str + + +class WorkflowStartedPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + workflow_type: str + workflow_node: str | None = None + model_name: str | None = None + query: str | None = None + log_dir: str | None = None + round: int | None = None + thread_id: str | None = None + tool_names: list[str] = Field(default_factory=list) + span_id: str | None = None + parent_span_id: str | None = None + + +class WorkflowFinishedPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + workflow_type: str + status: str + error: str | None = None + log_dir: str | None = None + round: int | None = None + thread_id: str | None = None + span_id: str | None = None + parent_span_id: str | None = None + + +class LLMDecisionPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + round: int | None = None + tool_names: list[str] = Field(default_factory=list) + action_tools_called: list[str] = Field(default_factory=list) + science_tools_called: list[str] = Field(default_factory=list) + workflow_span_id: str | None = None + thread_id: str | None = None + + +class AgentStartedPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + role: str | None = None + tool_names: list[str] = Field(default_factory=list) + allowed_peers: list[str] = Field(default_factory=list) + placement: dict[str, Any] | None = None + hostname: str | None = None + short_hostname: str | None = None + + +class BeliefUpdatedPayload(BaseModel): + model_config = ConfigDict(extra="allow") + + hypothesis: str | None = None + summary: str | None = None + confidence: float | None = None + supporting_message_ids: list[str] = Field(default_factory=list) + supporting_tool_result_ids: list[str] = Field(default_factory=list) + reason: str | None = None + + +PAYLOAD_MODELS: dict[str, type[BaseModel]] = { + "message_sent": MessageSentPayload, + "message_received": MessageReceivedPayload, + "tool_call_started": ToolCallStartedPayload, + "tool_call_finished": ToolCallFinishedPayload, + "tool_call_failed": ToolCallFailedPayload, + "workflow_started": WorkflowStartedPayload, + "workflow_finished": WorkflowFinishedPayload, + "llm_decision": LLMDecisionPayload, + "llm_tool_calls": LLMDecisionPayload, + "agent_started": AgentStartedPayload, + "belief_updated": BeliefUpdatedPayload, +} + + +def typed_payload(event: CampaignEvent) -> BaseModel | None: + model = PAYLOAD_MODELS.get(event.event) + return model.model_validate(event.payload) if model else None diff --git a/src/chemgraph/academy/observability/run_files.py b/src/chemgraph/academy/observability/run_files.py new file mode 100644 index 00000000..bc128904 --- /dev/null +++ b/src/chemgraph/academy/observability/run_files.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import json +import os +import uuid +from pathlib import Path +from typing import Any + +__all__ = [ + 'append_jsonl', + 'read_json_file', + 'read_jsonl', + 'write_json', + 'write_json_atomic', +] + + +def write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open('w', encoding='utf-8') as fp: + json.dump(payload, fp, indent=2, sort_keys=True) + fp.write('\n') + + +def write_json_atomic(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_name(f'.{path.name}.{os.getpid()}.{uuid.uuid4()}.tmp') + tmp.write_text( + json.dumps(payload, indent=2, sort_keys=True) + '\n', + encoding='utf-8', + ) + tmp.replace(path) + + +def append_jsonl(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open('a', encoding='utf-8') as fp: + fp.write(json.dumps(payload, sort_keys=True)) + fp.write('\n') + + +def read_jsonl(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + return [] + rows = [] + with path.open(encoding='utf-8') as fp: + for line in fp: + if not line.strip(): + continue + try: + item = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(item, dict): + rows.append(item) + return rows + + +def read_json_file(path: Path, *, default: dict[str, Any]) -> dict[str, Any]: + if not path.exists(): + return default + try: + payload = json.loads(path.read_text(encoding='utf-8')) + except json.JSONDecodeError: + return default + return payload if isinstance(payload, dict) else default diff --git a/tests/test_academy_payloads.py b/tests/test_academy_payloads.py new file mode 100644 index 00000000..36e2c057 --- /dev/null +++ b/tests/test_academy_payloads.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.payloads import PAYLOAD_MODELS +from chemgraph.academy.observability.payloads import typed_payload + + +def _payload_for(event: str) -> dict: + payloads = { + "message_sent": { + "message_id": "msg-1", + "sender": "agent-a", + "recipient": "agent-b", + "kind": "message", + "content": "content", + }, + "message_received": { + "message_id": "msg-1", + "sender": "agent-a", + "recipient": "agent-b", + "kind": "message", + "content": "content", + }, + "tool_call_started": { + "tool_result_id": "tool-1", + "tool_name": "tool", + "arguments": {}, + }, + "tool_call_finished": { + "tool_result_id": "tool-1", + "tool_name": "tool", + "arguments": {}, + "status": "ok", + }, + "tool_call_failed": { + "tool_result_id": "tool-1", + "tool_name": "tool", + "arguments": {}, + "status": "failed", + "error": "boom", + }, + "workflow_started": {"workflow_type": "single_agent"}, + "workflow_finished": {"workflow_type": "single_agent", "status": "completed"}, + "llm_decision": {}, + "llm_tool_calls": {}, + "agent_started": {}, + "belief_updated": {}, + } + return payloads[event] + + +def test_payload_models_round_trip() -> None: + for event, model in PAYLOAD_MODELS.items(): + payload = _payload_for(event) + parsed = model.model_validate(payload) + reparsed = model.model_validate(parsed.model_dump()) + assert reparsed.model_dump() == parsed.model_dump() + + +def test_typed_payload_selects_model() -> None: + event = CampaignEvent( + event="message_sent", + payload=_payload_for("message_sent"), + ) + + payload = typed_payload(event) + + assert payload is not None + assert payload.model_dump()["message_id"] == "msg-1" From ccecdd8e2bc4fd8eedb3feb93d5b06082a96eb22 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:50:02 -0500 Subject: [PATCH 047/119] feat(chemgraph): add local workflow tracing --- src/chemgraph/agent/llm_agent.py | 42 +- src/chemgraph/observability/__init__.py | 17 + src/chemgraph/observability/events.py | 119 ++++++ .../observability/langgraph_stream.py | 346 +++++++++++++++ .../observability/workflow_runner.py | 397 ++++++++++++++++++ 5 files changed, 913 insertions(+), 8 deletions(-) create mode 100644 src/chemgraph/observability/__init__.py create mode 100644 src/chemgraph/observability/events.py create mode 100644 src/chemgraph/observability/langgraph_stream.py create mode 100644 src/chemgraph/observability/workflow_runner.py diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 3f20a44b..5b914860 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -21,6 +21,8 @@ supported_gemini_models, ) +from chemgraph.observability.langgraph_stream import ChemGraphWorkflowCallback +from chemgraph.observability.langgraph_stream import emit_live_message_events from chemgraph.prompt.single_agent_prompt import ( single_agent_prompt, @@ -697,7 +699,13 @@ async def _call_human_input_handler(self, question: str) -> str: return await handler(question) return handler(question) - async def run(self, query: str, config=None, resume_from: Optional[str] = None): + async def run( + self, + query: str, + config=None, + resume_from: Optional[str] = None, + workflow_span_id: Optional[str] = None, + ): """ Async-only runner. Requires `self.workflow.astream(...)`. Streams values, logs new messages, writes state, and returns according to @@ -736,6 +744,12 @@ def _validate_config(cfg): cfg.setdefault("configurable", {}).setdefault("thread_id", "1") cfg["recursion_limit"] = self.recursion_limit + if workflow_span_id: + callbacks = list(cfg.get("callbacks") or []) + callbacks.append( + ChemGraphWorkflowCallback(workflow_span_id=workflow_span_id), + ) + cfg["callbacks"] = callbacks return cfg def _save_state_and_select_return(last_state, cfg): @@ -792,13 +806,25 @@ async def _stream_until_interrupt(stream_input, cfg): } if "messages" in s and s["messages"] != prev_msgs: - new_message = s["messages"][-1] - try: - new_message.pretty_print() - except Exception: - pass - logger.info(new_message) - prev_msgs = s["messages"] + messages = s["messages"] + if workflow_span_id: + emit_live_message_events( + previous_messages=prev_msgs, + current_messages=messages, + workflow_span_id=workflow_span_id, + ) + new_messages = ( + messages[len(prev_msgs) :] + if len(messages) >= len(prev_msgs) + else messages[-1:] + ) + for new_message in new_messages: + try: + new_message.pretty_print() + except Exception: + pass + logger.info(new_message) + prev_msgs = list(messages) last_st = s except GraphInterrupt as gi: # Fallback: some LangGraph versions may still raise. diff --git a/src/chemgraph/observability/__init__.py b/src/chemgraph/observability/__init__.py new file mode 100644 index 00000000..d6b910e5 --- /dev/null +++ b/src/chemgraph/observability/__init__.py @@ -0,0 +1,17 @@ +"""Shared observability helpers for ChemGraph runtimes.""" + +from chemgraph.observability.events import WorkflowEventContext +from chemgraph.observability.events import WorkflowEventSink +from chemgraph.observability.events import current_workflow_event_context +from chemgraph.observability.events import emit_workflow_event +from chemgraph.observability.events import new_span_id +from chemgraph.observability.events import workflow_event_context + +__all__ = [ + "WorkflowEventContext", + "WorkflowEventSink", + "current_workflow_event_context", + "emit_workflow_event", + "new_span_id", + "workflow_event_context", +] diff --git a/src/chemgraph/observability/events.py b/src/chemgraph/observability/events.py new file mode 100644 index 00000000..b24d601f --- /dev/null +++ b/src/chemgraph/observability/events.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import contextlib +import contextvars +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterator + +from chemgraph.academy.observability.event_log import EventLog + + +def new_span_id(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4()}" + + +@dataclass(frozen=True) +class WorkflowEventContext: + """Execution context for nested ChemGraph workflow events.""" + + run_id: str | None + run_dir: str | None + agent_id: str | None + role: str | None + parent_span_id: str | None + tool_name: str | None + runtime: str = "chemgraph-langgraph" + + +@dataclass(frozen=True) +class WorkflowEventSink: + """Write normalized workflow events to canonical Academy events.""" + + path: Path + context: WorkflowEventContext + + def emit( + self, + event: str, + payload: dict[str, Any] | None = None, + *, + span_id: str | None = None, + parent_span_id: str | None = None, + runtime: str | None = None, + agent_id: str | None = None, + role: str | None = None, + ) -> dict[str, Any]: + ctx = self.context + resolved_agent_id = agent_id or ctx.agent_id + resolved_role = role or ctx.role + body = { + **(payload or {}), + "span_id": span_id, + "parent_span_id": parent_span_id or ctx.parent_span_id, + "runtime": runtime or ctx.runtime, + "run_id": ctx.run_id, + "run_dir": ctx.run_dir, + "agent_id": resolved_agent_id, + "role": resolved_role, + "parent_tool_name": ctx.tool_name, + "nested": True, + } + record = EventLog(self.path).emit( + event, # type: ignore[arg-type] + run_id=ctx.run_id, + agent_id=resolved_agent_id or "system", + role=resolved_role, + correlation_id=span_id, + payload=body, + ) + return record.model_dump(mode="json") + + +_CURRENT_SINK: contextvars.ContextVar[WorkflowEventSink | None] = ( + contextvars.ContextVar("chemgraph_workflow_event_sink", default=None) +) +_CURRENT_CONTEXT: contextvars.ContextVar[WorkflowEventContext | None] = ( + contextvars.ContextVar("chemgraph_workflow_event_context", default=None) +) + + +def current_workflow_event_context() -> WorkflowEventContext | None: + return _CURRENT_CONTEXT.get() + + +def emit_workflow_event( + event: str, + payload: dict[str, Any] | None = None, + *, + span_id: str | None = None, + parent_span_id: str | None = None, + runtime: str | None = None, +) -> dict[str, Any] | None: + sink = _CURRENT_SINK.get() + if sink is None: + return None + return sink.emit( + event, + payload, + span_id=span_id, + parent_span_id=parent_span_id, + runtime=runtime, + ) + + +@contextlib.contextmanager +def workflow_event_context( + *, + jsonl_path: str | Path, + context: WorkflowEventContext, +) -> Iterator[WorkflowEventSink]: + sink = WorkflowEventSink(Path(jsonl_path), context=context) + sink_token = _CURRENT_SINK.set(sink) + context_token = _CURRENT_CONTEXT.set(context) + try: + yield sink + finally: + _CURRENT_CONTEXT.reset(context_token) + _CURRENT_SINK.reset(sink_token) diff --git a/src/chemgraph/observability/langgraph_stream.py b/src/chemgraph/observability/langgraph_stream.py new file mode 100644 index 00000000..1fe81923 --- /dev/null +++ b/src/chemgraph/observability/langgraph_stream.py @@ -0,0 +1,346 @@ +"""Live LangGraph/LangChain event emission for ChemGraph workflows.""" + +from __future__ import annotations + +import json +import math +from typing import Any +from uuid import UUID + +from langchain_core.callbacks import BaseCallbackHandler + +from chemgraph.observability.events import emit_workflow_event +from chemgraph.observability.events import new_span_id + + +def _compact(value: Any, *, max_chars: int = 1000) -> Any: + try: + text = json.dumps(value, default=str, sort_keys=True) + except TypeError: + text = str(value) + if len(text) <= max_chars: + try: + return json.loads(text) + except json.JSONDecodeError: + return text + return { + "truncated": True, + "preview": text[:max_chars], + } + + +def _message_type(message: Any) -> str: + if isinstance(message, dict): + return str(message.get("type") or message.get("role") or "") + return str(getattr(message, "type", "") or getattr(message, "role", "")) + + +def _message_content(message: Any) -> Any: + if isinstance(message, dict): + return message.get("content") + return getattr(message, "content", None) + + +def _message_tool_calls(message: Any) -> list[dict[str, Any]]: + calls = ( + message.get("tool_calls") + if isinstance(message, dict) + else getattr(message, "tool_calls", None) + ) + if not isinstance(calls, list): + return [] + normalized = [] + for call in calls: + if isinstance(call, dict): + normalized.append( + { + "name": call.get("name"), + "id": call.get("id"), + "args": _compact(call.get("args") or {}, max_chars=2000), + }, + ) + else: + normalized.append({"name": str(call), "id": None, "args": {}}) + return normalized + + +def _message_usage_metadata(message: Any) -> dict[str, Any]: + usage = ( + message.get("usage_metadata") + if isinstance(message, dict) + else getattr(message, "usage_metadata", None) + ) + if isinstance(usage, dict) and usage: + return usage + response_metadata = ( + message.get("response_metadata") + if isinstance(message, dict) + else getattr(message, "response_metadata", None) + ) + if not isinstance(response_metadata, dict): + return {} + token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") + return token_usage if isinstance(token_usage, dict) else {} + + +def _usage_int(usage: dict[str, Any], *keys: str) -> int | None: + for key in keys: + value = usage.get(key) + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + return None + + +def _text_for_token_estimate(value: Any) -> str: + try: + return json.dumps(value, default=str, sort_keys=True) + except TypeError: + return str(value) + + +def _json_safe(value: Any) -> Any: + try: + return json.loads(json.dumps(value, default=str)) + except TypeError: + return str(value) + + +def _serialize_message(message: Any) -> dict[str, Any]: + if isinstance(message, dict): + return _json_safe(message) + if hasattr(message, "model_dump"): + return _json_safe(message.model_dump(mode="json")) + return { + "type": _message_type(message), + "content": _json_safe(_message_content(message)), + "tool_calls": _message_tool_calls(message), + } + + +def _serialize_messages(messages: list[Any]) -> list[dict[str, Any]]: + return [_serialize_message(message) for message in messages] + + +def _estimate_tokens(text: str) -> int: + try: + import tiktoken # type: ignore[import-not-found] + + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(text)) + except Exception: + return max(1, math.ceil(len(text) / 4)) + + +def _llm_token_counts( + *, + previous_messages: list[Any], + message: Any, + tool_calls: list[dict[str, Any]], +) -> dict[str, Any]: + usage = _message_usage_metadata(message) + provider_input = _usage_int(usage, "input_tokens", "prompt_tokens") + provider_output = _usage_int(usage, "output_tokens", "completion_tokens") + provider_total = _usage_int(usage, "total_tokens") + if provider_input is not None or provider_output is not None or provider_total is not None: + input_tokens = provider_input + output_tokens = provider_output + if provider_total is None: + provider_total = (input_tokens or 0) + (output_tokens or 0) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": provider_total, + "source": "provider", + "raw_usage": _compact(usage, max_chars=1000), + } + + input_text = _text_for_token_estimate(previous_messages) + output_text = _text_for_token_estimate( + { + "content": _message_content(message), + "tool_calls": tool_calls, + }, + ) + input_tokens = _estimate_tokens(input_text) + output_tokens = _estimate_tokens(output_text) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "source": "local_estimate", + "estimate_scope": "langgraph_state_messages", + } + + +def emit_live_message_events( + *, + previous_messages: list[Any], + current_messages: list[Any], + workflow_span_id: str, +) -> int: + """Emit live workflow events for newly streamed LangGraph messages.""" + if len(current_messages) <= len(previous_messages): + return 0 + count = 0 + for index, message in enumerate( + current_messages[len(previous_messages) :], + start=len(previous_messages), + ): + message_type = _message_type(message) + if message_type != "ai": + continue + tool_calls = _message_tool_calls(message) + token_counts = _llm_token_counts( + previous_messages=current_messages[:index], + message=message, + tool_calls=tool_calls, + ) + prompt_messages = _serialize_messages(current_messages[:index]) + if tool_calls: + emit_workflow_event( + "llm_decision", + { + "workflow_node": "ChemGraphAgent", + "message_index": index, + "tool_calls": tool_calls, + "token_counts": token_counts, + "prompt_messages": prompt_messages, + }, + span_id=new_span_id("chemgraph-llm-decision"), + parent_span_id=workflow_span_id, + ) + count += 1 + continue + content = _message_content(message) + if content: + emit_workflow_event( + "workflow_output", + { + "workflow_node": "ChemGraphAgent", + "message_index": index, + "content_preview": str(content)[:2000], + "token_counts": token_counts, + "prompt_messages": prompt_messages, + }, + span_id=new_span_id("chemgraph-output"), + parent_span_id=workflow_span_id, + ) + count += 1 + return count + + +def _tool_name(serialized: dict[str, Any] | None, kwargs: dict[str, Any]) -> str: + serialized = serialized or {} + value = ( + serialized.get("name") + or serialized.get("id") + or kwargs.get("name") + or kwargs.get("tool_name") + ) + if isinstance(value, list) and value: + value = value[-1] + return str(value or "tool") + + +def _run_id_text(run_id: UUID | str | None) -> str: + return str(run_id) if run_id is not None else new_span_id("tool-run") + + +class ChemGraphWorkflowCallback(BaseCallbackHandler): + """Emit live tool lifecycle events for a ChemGraph LangGraph run.""" + + def __init__(self, *, workflow_span_id: str) -> None: + self.workflow_span_id = workflow_span_id + self._tool_runs: dict[str, dict[str, Any]] = {} + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Any: + tool_run_id = _run_id_text(run_id) + tool_name = _tool_name(serialized, kwargs) + span_id = f"chemgraph-tool-call-{tool_run_id}" + self._tool_runs[tool_run_id] = { + "tool_name": tool_name, + "span_id": span_id, + } + emit_workflow_event( + "tool_call_started", + { + "workflow_node": "tools", + "tool_name": tool_name, + "tool_call_id": tool_run_id, + "parent_tool_run_id": _run_id_text(parent_run_id) + if parent_run_id + else None, + "input": _compact(inputs if inputs is not None else input_str), + }, + span_id=span_id, + parent_span_id=self.workflow_span_id, + ) + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + tool_run_id = _run_id_text(run_id) + tool_run = self._tool_runs.get(tool_run_id, {}) + tool_name = str(tool_run.get("tool_name") or _tool_name(None, kwargs)) + span_id = str( + tool_run.get("span_id") or f"chemgraph-tool-call-{tool_run_id}", + ) + emit_workflow_event( + "tool_call_finished", + { + "workflow_node": "tools", + "tool_name": tool_name, + "tool_call_id": tool_run_id, + "parent_tool_run_id": _run_id_text(parent_run_id) + if parent_run_id + else None, + "content_preview": str(_compact(output))[:2000], + }, + span_id=span_id, + parent_span_id=self.workflow_span_id, + ) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + tool_run_id = _run_id_text(run_id) + tool_run = self._tool_runs.get(tool_run_id, {}) + tool_name = str(tool_run.get("tool_name") or _tool_name(None, kwargs)) + span_id = str( + tool_run.get("span_id") or f"chemgraph-tool-call-{tool_run_id}", + ) + emit_workflow_event( + "tool_call_failed", + { + "workflow_node": "tools", + "tool_name": tool_name, + "tool_call_id": tool_run_id, + "parent_tool_run_id": _run_id_text(parent_run_id) + if parent_run_id + else None, + "error": repr(error), + }, + span_id=span_id, + parent_span_id=self.workflow_span_id, + ) diff --git a/src/chemgraph/observability/workflow_runner.py b/src/chemgraph/observability/workflow_runner.py new file mode 100644 index 00000000..198a5aca --- /dev/null +++ b/src/chemgraph/observability/workflow_runner.py @@ -0,0 +1,397 @@ +"""Observed execution helpers for traditional ChemGraph workflows.""" + +from __future__ import annotations + +import json +import os +import time +from pathlib import Path +from typing import Any, Literal + +from chemgraph.agent.llm_agent import ChemGraph +from chemgraph.agent.llm_agent import serialize_state +from chemgraph.observability.events import WorkflowEventContext +from chemgraph.observability.events import current_workflow_event_context +from chemgraph.observability.events import emit_workflow_event +from chemgraph.observability.events import new_span_id +from chemgraph.observability.events import workflow_event_context + + +def _env_first(*names: str) -> str | None: + for name in names: + value = os.environ.get(name) + if value: + return value + return None + + +def normalize_model_name(model_name: str, base_url: str | None) -> str: + value = model_name.strip() + if base_url and "argoapi" in base_url and value.startswith("GPT-"): + return "argo:" + value.lower() + return value + + +def compact_value(value: Any, *, max_chars: int = 8000) -> Any: + try: + text = json.dumps(value, default=str, sort_keys=True) + except TypeError: + text = str(value) + if len(text) <= max_chars: + try: + return json.loads(text) + except json.JSONDecodeError: + return text + return { + "truncated": True, + "preview": text[:max_chars], + } + + +def _write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, default=str) + "\n", encoding="utf-8") + + +def _write_status( + *, + run_dir: Path, + run_id: str, + workflow_span_id: str, + query: str, + workflow_type: str, + model_name: str, + base_url: str | None, + status: str, + started_at: float, + error: str | None = None, +) -> None: + now = time.time() + payload = { + "mode": "chemgraph_workflow", + "run_id": run_id, + "workflow_span_id": workflow_span_id, + "query": query, + "workflow_type": workflow_type, + "model_name": model_name, + "base_url": base_url, + "status": status, + "started": started_at, + "updated": now, + "finished": now if status in {"completed", "failed"} else None, + "events_path": str(run_dir / "events.jsonl"), + } + if error: + payload["error"] = error + _write_json(run_dir / "status.json", payload) + + +def _write_manifest( + *, + run_dir: Path, + run_id: str, + workflow_span_id: str, + query: str, + workflow_type: str, + model_name: str, + base_url: str | None, +) -> None: + _write_json( + run_dir / "manifest.json", + { + "mode": "chemgraph_workflow", + "run_id": run_id, + "workflow_span_id": workflow_span_id, + "query": query, + "workflow_type": workflow_type, + "model_name": model_name, + "base_url": base_url, + "events_path": str(run_dir / "events.jsonl"), + }, + ) + + +def _workflow_log_dir(run_dir: Path, workflow_span_id: str) -> str: + path = run_dir / "chemgraph_workflows" / workflow_span_id + path.mkdir(parents=True, exist_ok=True) + return str(path) + + +async def run_observed_chemgraph_workflow( + *, + query: str, + run_dir: str | Path | None = None, + run_id: str | None = None, + workflow_type: str = "single_agent", + model_name: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + argo_user: str | None = None, + return_option: Literal["last_message", "state"] = "state", + recursion_limit: int = 50, + parent_span_id: str | None = None, + agent_id: str = "chemgraph-workflow", + role: str = "TraditionalChemGraphWorkflow", + write_run_files: bool = True, +) -> dict[str, Any]: + """Run a traditional ChemGraph workflow while emitting dashboard events.""" + current_context = current_workflow_event_context() + run_dir_value = run_dir + if run_dir_value is None and current_context and current_context.run_dir: + run_dir_value = current_context.run_dir + if run_dir_value is None: + run_dir_value = "runs/local-chemgraph-workflow" + effective_run_dir = Path(run_dir_value).resolve() + effective_run_dir.mkdir(parents=True, exist_ok=True) + + workflow_span_id = new_span_id("chemgraph-workflow") + effective_run_id = run_id or effective_run_dir.name + base_url = base_url or _env_first( + "CHEMGRAPH_WORKFLOW_BASE_URL", + "ACADEMY_LM_BASE_URL", + ) + model_name = normalize_model_name( + model_name + or _env_first("CHEMGRAPH_WORKFLOW_MODEL", "ACADEMY_LM_MODEL") + or "argo:gpt-5.4", + base_url, + ) + api_key = api_key or _env_first( + "CHEMGRAPH_WORKFLOW_API_KEY", + "ACADEMY_LM_API_KEY", + "OPENAI_API_KEY", + ) + argo_user = argo_user or _env_first( + "CHEMGRAPH_WORKFLOW_ARGO_USER", + "ACADEMY_LM_USER", + "ARGO_USER", + ) + + started_at = time.time() + if write_run_files: + _write_manifest( + run_dir=effective_run_dir, + run_id=effective_run_id, + workflow_span_id=workflow_span_id, + query=query, + workflow_type=workflow_type, + model_name=model_name, + base_url=base_url, + ) + _write_status( + run_dir=effective_run_dir, + run_id=effective_run_id, + workflow_span_id=workflow_span_id, + query=query, + workflow_type=workflow_type, + model_name=model_name, + base_url=base_url, + status="running", + started_at=started_at, + ) + + context_manager = ( + workflow_event_context( + jsonl_path=effective_run_dir / "events.jsonl", + context=WorkflowEventContext( + run_id=effective_run_id, + run_dir=str(effective_run_dir), + agent_id=agent_id, + role=role, + parent_span_id=parent_span_id, + tool_name=None, + ), + ) + if current_context is None + else None + ) + + if context_manager is None: + return await _run_observed_chemgraph_workflow_inner( + query=query, + run_dir=effective_run_dir, + run_id=effective_run_id, + workflow_span_id=workflow_span_id, + workflow_type=workflow_type, + model_name=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + return_option=return_option, + recursion_limit=recursion_limit, + write_run_files=write_run_files, + started_at=started_at, + ) + with context_manager: + return await _run_observed_chemgraph_workflow_inner( + query=query, + run_dir=effective_run_dir, + run_id=effective_run_id, + workflow_span_id=workflow_span_id, + workflow_type=workflow_type, + model_name=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + return_option=return_option, + recursion_limit=recursion_limit, + write_run_files=write_run_files, + started_at=started_at, + ) + + +async def _run_observed_chemgraph_workflow_inner( + *, + query: str, + run_dir: Path, + run_id: str, + workflow_span_id: str, + workflow_type: str, + model_name: str, + base_url: str | None, + api_key: str | None, + argo_user: str | None, + return_option: Literal["last_message", "state"], + recursion_limit: int, + write_run_files: bool, + started_at: float, +) -> dict[str, Any]: + log_dir = _workflow_log_dir(run_dir, workflow_span_id) + config = {"configurable": {"thread_id": workflow_span_id}} + + emit_workflow_event( + "run_started", + { + "workflow_type": workflow_type, + "model_name": model_name, + "query": query, + }, + span_id=workflow_span_id, + ) + emit_workflow_event( + "workflow_started", + { + "workflow_type": workflow_type, + "model_name": model_name, + "query": query, + "log_dir": log_dir, + }, + span_id=workflow_span_id, + ) + try: + emit_workflow_event( + "workflow_node_started", + {"workflow_node": "ChemGraph", "phase": "construct"}, + span_id=new_span_id("chemgraph-node"), + parent_span_id=workflow_span_id, + ) + agent = ChemGraph( + model_name=model_name, + workflow_type=workflow_type, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + return_option=return_option, + recursion_limit=recursion_limit, + log_dir=log_dir, + ) + emit_workflow_event( + "workflow_node_finished", + {"workflow_node": "ChemGraph", "phase": "construct"}, + span_id=new_span_id("chemgraph-node"), + parent_span_id=workflow_span_id, + ) + emit_workflow_event( + "workflow_node_started", + {"workflow_node": "LangGraph", "phase": "run"}, + span_id=new_span_id("chemgraph-node"), + parent_span_id=workflow_span_id, + ) + result = await agent.run( + query, + config=config, + workflow_span_id=workflow_span_id, + ) + state_payload = serialize_state(agent.get_state(config=config)) + emit_workflow_event( + "workflow_node_finished", + {"workflow_node": "LangGraph", "phase": "run"}, + span_id=new_span_id("chemgraph-node"), + parent_span_id=workflow_span_id, + ) + emit_workflow_event( + "workflow_finished", + { + "workflow_type": workflow_type, + "status": "completed", + "log_dir": log_dir, + }, + span_id=workflow_span_id, + ) + emit_workflow_event( + "run_finished", + { + "workflow_type": workflow_type, + "status": "completed", + }, + span_id=workflow_span_id, + ) + payload = { + "status": "completed", + "workflow_type": workflow_type, + "model_name": model_name, + "span_id": workflow_span_id, + "log_dir": log_dir, + "return_option": return_option, + "result": compact_value(serialize_state(result)), + "state": compact_value(state_payload), + } + if write_run_files: + _write_status( + run_dir=run_dir, + run_id=run_id, + workflow_span_id=workflow_span_id, + query=query, + workflow_type=workflow_type, + model_name=model_name, + base_url=base_url, + status="completed", + started_at=started_at, + ) + _write_json(run_dir / "result.json", payload) + return payload + except Exception as exc: + error = repr(exc) + emit_workflow_event( + "workflow_finished", + { + "workflow_type": workflow_type, + "status": "failed", + "error": error, + "log_dir": log_dir, + }, + span_id=workflow_span_id, + ) + emit_workflow_event( + "run_finished", + { + "workflow_type": workflow_type, + "status": "failed", + "error": error, + }, + span_id=workflow_span_id, + ) + if write_run_files: + _write_status( + run_dir=run_dir, + run_id=run_id, + workflow_span_id=workflow_span_id, + query=query, + workflow_type=workflow_type, + model_name=model_name, + base_url=base_url, + status="failed", + started_at=started_at, + error=error, + ) + raise From ab8c0efdd511e43c0d6cab78e8f7548f2a18fbd5 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 13:54:56 -0500 Subject: [PATCH 048/119] feat(chemgraph): support terminal tools in single-agent graph --- src/chemgraph/agent/llm_agent.py | 5 +- src/chemgraph/graphs/single_agent.py | 82 ++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 13 deletions(-) diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 5b914860..5f9f7759 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -1,7 +1,7 @@ import asyncio import datetime import os -from typing import Callable, List, Optional +from typing import Callable, Collection, List, Optional import uuid from chemgraph.memory.store import SessionStore @@ -172,6 +172,7 @@ def __init__( max_retries: int = 1, human_input_handler: Optional[Callable[[str], str]] = None, human_supervised: bool = False, + terminal_tool_names: Collection[str] = (), ): # Always generate a unique identifier for this instance self.uuid = str(uuid.uuid4())[:8] @@ -302,6 +303,7 @@ def __init__( self.max_retries = max_retries self.human_input_handler = human_input_handler self.human_supervised = human_supervised + self.terminal_tool_names = tuple(terminal_tool_names) # When human supervision is disabled and the caller is using the # default system prompt, strip the ask_human instructions so the @@ -342,6 +344,7 @@ def __init__( self.tools, max_retries=self.max_retries, human_supervised=self.human_supervised, + terminal_tool_names=self.terminal_tool_names, ) elif self.workflow_type == "multi_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( diff --git a/src/chemgraph/graphs/single_agent.py b/src/chemgraph/graphs/single_agent.py index f5af4abf..d2e193d2 100644 --- a/src/chemgraph/graphs/single_agent.py +++ b/src/chemgraph/graphs/single_agent.py @@ -1,4 +1,5 @@ import json +from collections.abc import Collection from langgraph.graph import StateGraph, START, END from langchain_openai import ChatOpenAI @@ -70,6 +71,36 @@ def _tool_message_content(message): return getattr(message, "content", "") +def _message_tool_calls(message) -> list: + """Extract tool calls from a message-like object.""" + if isinstance(message, dict): + calls = message.get("tool_calls") + else: + calls = getattr(message, "tool_calls", None) + return calls if isinstance(calls, list) else [] + + +def _state_messages(state: State): + """Extract messages from a LangGraph state or message list.""" + if isinstance(state, list): + return state + if messages := state.get("messages", []): + return messages + raise ValueError(f"No messages found in input state to tool_edge: {state}") + + +def _tool_result_names_after_latest_ai_tool_call(messages) -> set[str]: + """Return tool-result names appended after the latest AI tool-call message.""" + names: set[str] = set() + for message in reversed(messages): + if _message_tool_calls(message): + return names + name = _tool_message_name(message) + if name: + names.add(str(name)) + return names + + def _is_successful_report_message(message) -> bool: """Return True when message indicates successful generate_html execution.""" if _tool_message_name(message) != "generate_html": @@ -97,19 +128,29 @@ def route_tools(state: State): str Either 'tools' or 'done' based on the state conditions """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + messages = _state_messages(state) + ai_message = messages[-1] + if _message_tool_calls(ai_message): if not isinstance(state, list) and _is_repeated_tool_cycle(messages): return "done" return "tools" return "done" +def route_after_tools( + state: State, + terminal_tool_names: Collection[str] = (), +): + """Stop the graph after terminal tools; otherwise continue to the LLM.""" + if not terminal_tool_names: + return "continue" + executed_names = _tool_result_names_after_latest_ai_tool_call( + _state_messages(state), + ) + terminal_names = {str(name) for name in terminal_tool_names} + return "done" if executed_names & terminal_names else "continue" + + def route_report_tools(state: State): """Route report tool execution and stop if a report was already generated.""" if isinstance(state, list): @@ -120,14 +161,15 @@ def route_report_tools(state: State): else: raise ValueError(f"No messages found in input state to tool_edge: {state}") - if not (hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0): + tool_calls = _message_tool_calls(ai_message) + if not tool_calls: return "done" # Only allow known report tool calls to reach ToolNode. valid_report_tools = {"generate_html"} requested_tools = { call.get("name") - for call in getattr(ai_message, "tool_calls", []) + for call in tool_calls if isinstance(call, dict) } if not requested_tools or not requested_tools.issubset(valid_report_tools): @@ -334,6 +376,7 @@ def construct_single_agent_graph( tools: list = None, max_retries: int = 1, human_supervised: bool = False, + terminal_tool_names: Collection[str] = (), ): """Construct a geometry optimization graph. @@ -359,6 +402,9 @@ def construct_single_agent_graph( human_supervised : bool, optional Whether to include the ``ask_human`` tool so the agent can pause and request human input, by default False + terminal_tool_names : Collection[str], optional + Tool names that should terminate the graph after successful tool + execution instead of routing back to the LLM, by default empty. Returns ------- @@ -414,7 +460,11 @@ def construct_single_agent_graph( route_tools, {"tools": "tools", "done": "ReportAgent"}, ) - graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_conditional_edges( + "tools", + lambda state: route_after_tools(state, terminal_tool_names), + {"continue": "ChemGraphAgent", "done": END}, + ) graph_builder.add_conditional_edges( "ReportAgent", route_report_tools, @@ -431,7 +481,11 @@ def construct_single_agent_graph( route_tools, {"tools": "tools", "done": END}, ) - graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_conditional_edges( + "tools", + lambda state: route_after_tools(state, terminal_tool_names), + {"continue": "ChemGraphAgent", "done": END}, + ) graph = graph_builder.compile(checkpointer=checkpointer) logger.info("Graph construction completed") @@ -462,7 +516,11 @@ def construct_single_agent_graph( route_tools, {"tools": "tools", "done": "ResponseAgent"}, ) - graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_conditional_edges( + "tools", + lambda state: route_after_tools(state, terminal_tool_names), + {"continue": "ChemGraphAgent", "done": END}, + ) graph_builder.add_edge(START, "ChemGraphAgent") graph_builder.add_edge("ResponseAgent", END) From d6f22d2d86f3f1508a2952e52bd1146a315ccc66 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 14:07:36 -0500 Subject: [PATCH 049/119] refactor(academy): reshape persistent agent runtime --- src/chemgraph/academy/__init__.py | 59 +- src/chemgraph/academy/agent.py | 123 - src/chemgraph/academy/config.py | 175 - src/chemgraph/academy/coordinator.py | 179 - src/chemgraph/academy/core/__init__.py | 35 + src/chemgraph/academy/core/agent.py | 281 ++ src/chemgraph/academy/core/campaign.py | 484 +++ src/chemgraph/academy/core/fastmcp.py | 344 ++ src/chemgraph/academy/core/lm.py | 71 + src/chemgraph/academy/core/peer_protocol.py | 54 + src/chemgraph/academy/core/prompt.py | 35 + src/chemgraph/academy/core/tools.py | 520 +++ src/chemgraph/academy/core/turn.py | 569 +++ src/chemgraph/academy/dashboard/__init__.py | 19 + src/chemgraph/academy/dashboard/__main__.py | 6 + src/chemgraph/academy/dashboard/server.py | 151 + src/chemgraph/academy/dashboard/static/app.js | 3053 +++++++++++++++++ .../academy/dashboard/static/index.html | 703 ++++ .../academy/observability/run_artifacts.py | 416 +++ src/chemgraph/academy/rate_limiter.py | 135 - src/chemgraph/academy/runtime/__init__.py | 1 + .../academy/runtime/compute_launcher.py | 396 +++ src/chemgraph/academy/runtime/daemon.py | 259 ++ src/chemgraph/academy/runtime/mpi.py | 103 + .../academy/runtime/operator_console.py | 759 ++++ .../academy/runtime/profiles/__init__.py | 38 + .../runtime/profiles/aurora.template.json | 38 + .../runtime/profiles/polaris.template.json | 38 + .../academy/runtime/profiles/system.py | 51 + src/chemgraph/academy/runtime/registration.py | 87 + src/chemgraph/academy/screening.py | 151 - tests/test_academy_compute_launcher.py | 57 + tests/test_academy_dashboard.py | 134 + tests/test_academy_operator_console.py | 68 + tests/test_academy_reasoning_phase2.py | 432 +++ tests/test_tool_adapter_validation.py | 172 + 36 files changed, 9401 insertions(+), 795 deletions(-) delete mode 100644 src/chemgraph/academy/agent.py delete mode 100644 src/chemgraph/academy/config.py delete mode 100644 src/chemgraph/academy/coordinator.py create mode 100644 src/chemgraph/academy/core/__init__.py create mode 100644 src/chemgraph/academy/core/agent.py create mode 100644 src/chemgraph/academy/core/campaign.py create mode 100644 src/chemgraph/academy/core/fastmcp.py create mode 100644 src/chemgraph/academy/core/lm.py create mode 100644 src/chemgraph/academy/core/peer_protocol.py create mode 100644 src/chemgraph/academy/core/prompt.py create mode 100644 src/chemgraph/academy/core/tools.py create mode 100644 src/chemgraph/academy/core/turn.py create mode 100644 src/chemgraph/academy/dashboard/__init__.py create mode 100644 src/chemgraph/academy/dashboard/__main__.py create mode 100644 src/chemgraph/academy/dashboard/server.py create mode 100644 src/chemgraph/academy/dashboard/static/app.js create mode 100644 src/chemgraph/academy/dashboard/static/index.html create mode 100644 src/chemgraph/academy/observability/run_artifacts.py delete mode 100644 src/chemgraph/academy/rate_limiter.py create mode 100644 src/chemgraph/academy/runtime/__init__.py create mode 100644 src/chemgraph/academy/runtime/compute_launcher.py create mode 100644 src/chemgraph/academy/runtime/daemon.py create mode 100644 src/chemgraph/academy/runtime/mpi.py create mode 100644 src/chemgraph/academy/runtime/operator_console.py create mode 100644 src/chemgraph/academy/runtime/profiles/__init__.py create mode 100644 src/chemgraph/academy/runtime/profiles/aurora.template.json create mode 100644 src/chemgraph/academy/runtime/profiles/polaris.template.json create mode 100644 src/chemgraph/academy/runtime/profiles/system.py create mode 100644 src/chemgraph/academy/runtime/registration.py delete mode 100644 src/chemgraph/academy/screening.py create mode 100644 tests/test_academy_compute_launcher.py create mode 100644 tests/test_academy_dashboard.py create mode 100644 tests/test_academy_operator_console.py create mode 100644 tests/test_academy_reasoning_phase2.py create mode 100644 tests/test_tool_adapter_validation.py diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py index 90e5bf12..fc3dece9 100644 --- a/src/chemgraph/academy/__init__.py +++ b/src/chemgraph/academy/__init__.py @@ -3,43 +3,38 @@ Provides agent classes and utilities for deploying ChemGraph workflows across federated HPC infrastructure using the Academy framework. -Requires the ``academy`` optional extra:: - - pip install chemgraphagent[academy] - -Modules that depend on ``academy-py`` (agent, screening, coordinator) -use lazy imports so that the rate limiter and config utilities remain -usable without the optional dependency. +Requires the ``academy`` optional extra. """ from __future__ import annotations -from chemgraph.academy.config import AcademyConfig, build_manager -from chemgraph.academy.rate_limiter import RateLimiter - - -def __getattr__(name: str): # noqa: N807 - """Lazy-import Academy-dependent classes.""" - if name == "ChemGraphAgent": - from chemgraph.academy.agent import ChemGraphAgent - - return ChemGraphAgent - if name == "ScreeningAgent": - from chemgraph.academy.screening import ScreeningAgent - - return ScreeningAgent - if name == "CoordinatorAgent": - from chemgraph.academy.coordinator import CoordinatorAgent - - return CoordinatorAgent - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import ExecutionSpec +from chemgraph.academy.core.campaign import ResourceSpec +from chemgraph.academy.core.campaign import ToolSpec +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.core.prompt import load_prompt_profile __all__ = [ - "ChemGraphAgent", - "AcademyConfig", - "build_manager", - "RateLimiter", - "ScreeningAgent", - "CoordinatorAgent", + "CampaignEvent", + "ChemGraphAgentSpec", + "ChemGraphCampaign", + "ChemGraphDaemonConfig", + "EventLog", + "ExecutionSpec", + "PromptProfile", + "ResourceSpec", + "ChemGraphLogicalAgent", + "load_campaign", + "load_prompt_profile", + "resolve_campaign_resources", + "ToolSpec", ] diff --git a/src/chemgraph/academy/agent.py b/src/chemgraph/academy/agent.py deleted file mode 100644 index 1ec04b3e..00000000 --- a/src/chemgraph/academy/agent.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Base Academy Agent wrapping a ChemGraph instance. - -Each ``ChemGraphAgent`` holds one ``ChemGraph`` object and exposes its -``run()`` method as an Academy ``@action`` so it can be invoked remotely -by peer agents, coordinators, or the Manager user handle. -""" - -from __future__ import annotations - -import logging -import os -import uuid -from typing import Any, Optional - -from academy.agent import Agent, action - -from chemgraph.agent.llm_agent import ChemGraph - -logger = logging.getLogger(__name__) - - -class ChemGraphAgent(Agent): - """Academy Agent wrapping a single :class:`ChemGraph` instance. - - Parameters - ---------- - model_name : str - LLM model to use (e.g. ``"gpt-4o"``, ``"claude-sonnet-4"``). - workflow_type : str - ChemGraph workflow (e.g. ``"single_agent"``, ``"multi_agent"``). - log_dir : str or None - Base directory for agent logs. A per-agent subdirectory is - created automatically. - rate_limiter : RateLimiter or None - Shared rate limiter for LLM API calls. - chemgraph_kwargs : dict - Extra keyword arguments forwarded to the :class:`ChemGraph` - constructor (e.g. ``base_url``, ``api_key``, ``recursion_limit``). - """ - - def __init__( - self, - model_name: str = "gpt-4o-mini", - workflow_type: str = "single_agent", - log_dir: Optional[str] = None, - rate_limiter: Any = None, - **chemgraph_kwargs: Any, - ) -> None: - super().__init__() - self._model_name = model_name - self._workflow_type = workflow_type - self._log_dir = log_dir - self._rate_limiter = rate_limiter - self._chemgraph_kwargs = chemgraph_kwargs - self._cg: Optional[ChemGraph] = None - self._agent_uuid = uuid.uuid4().hex[:8] - - async def agent_on_startup(self) -> None: - """Initialise the ChemGraph instance on the remote worker.""" - agent_log_dir = self._log_dir - if agent_log_dir: - agent_log_dir = os.path.join(agent_log_dir, self._agent_uuid) - os.makedirs(agent_log_dir, exist_ok=True) - - self._cg = ChemGraph( - model_name=self._model_name, - workflow_type=self._workflow_type, - log_dir=agent_log_dir, - enable_memory=False, - **self._chemgraph_kwargs, - ) - logger.info( - "ChemGraphAgent %s started: model=%s workflow=%s", - self._agent_uuid, - self._model_name, - self._workflow_type, - ) - - async def agent_on_shutdown(self) -> None: - """Clean up resources.""" - logger.info("ChemGraphAgent %s shutting down", self._agent_uuid) - self._cg = None - - @action - async def run_query( - self, - query: str, - *, - config: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Execute a ChemGraph query and return the result. - - Parameters - ---------- - query : str - The natural-language chemistry query. - config : dict, optional - LangGraph config (thread_id, etc.). - - Returns - ------- - dict - The workflow result (serialised state or last message, - depending on the ChemGraph ``return_option``). - """ - if self._cg is None: - raise RuntimeError("Agent not initialised (call agent_on_startup first)") - - if self._rate_limiter is not None: - await self._rate_limiter.acquire(self._model_name) - - thread_cfg = config or {"configurable": {"thread_id": uuid.uuid4().hex[:8]}} - result = await self._cg.run(query=query, config=thread_cfg) - return result - - @action - async def get_info(self) -> dict[str, str]: - """Return metadata about this agent instance.""" - return { - "agent_uuid": self._agent_uuid, - "model_name": self._model_name, - "workflow_type": self._workflow_type, - } diff --git a/src/chemgraph/academy/config.py b/src/chemgraph/academy/config.py deleted file mode 100644 index 5f7a98b3..00000000 --- a/src/chemgraph/academy/config.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Bridge between ChemGraph config.toml and Academy Manager/Exchange/Launcher. - -Reads the ``[academy]`` section from ``config.toml`` and builds the -corresponding Academy objects. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from typing import Any, Literal, Optional - -import toml - -logger = logging.getLogger(__name__) - -# Exchange and launcher types supported by this bridge. -ExchangeType = Literal["local", "redis", "hybrid"] -LauncherType = Literal["thread", "process", "parsl", "globus_compute"] - - -@dataclass -class AcademyConfig: - """Parsed ``[academy]`` configuration section. - - Attributes - ---------- - exchange : ExchangeType - Message exchange backend (default ``"local"``). - launcher : LauncherType - Agent deployment mechanism (default ``"thread"``). - num_agents : int - Number of worker agents to spawn (default ``1``). - redis_hostname : str - Redis host when ``exchange="redis"`` (default ``"localhost"``). - redis_port : int - Redis port (default ``6379``). - parsl_system : str - HPC system name for Parsl config (default ``"local"``). - globus_endpoint_id : str - Globus Compute endpoint UUID. - max_concurrency : int - Max concurrent LLM calls per provider (default ``50``). - log_dir : str or None - Base log directory for agent output. - extra : dict - Any additional keys from the config section. - """ - - exchange: ExchangeType = "local" - launcher: LauncherType = "thread" - num_agents: int = 1 - redis_hostname: str = "localhost" - redis_port: int = 6379 - parsl_system: str = "local" - globus_endpoint_id: str = "" - max_concurrency: int = 50 - log_dir: Optional[str] = None - extra: dict = field(default_factory=dict) - - -def load_academy_config(config_path: str = "config.toml") -> AcademyConfig: - """Load the ``[academy]`` section from a TOML config file. - - Missing keys are filled with defaults. Unknown keys are stored - in ``extra``. - """ - try: - data = toml.load(config_path) - except FileNotFoundError: - logger.warning("Config file %s not found, using defaults", config_path) - return AcademyConfig() - - section = data.get("academy", {}) - - known_keys = {f.name for f in AcademyConfig.__dataclass_fields__.values()} - known = {k: v for k, v in section.items() if k in known_keys} - extra = {k: v for k, v in section.items() if k not in known_keys} - - return AcademyConfig(**known, extra=extra) - - -def _build_exchange_factory(cfg: AcademyConfig) -> Any: - """Create the Academy ExchangeFactory matching the config.""" - if cfg.exchange == "local": - from academy.exchange import LocalExchangeFactory - - return LocalExchangeFactory() - - if cfg.exchange == "redis": - from academy.exchange import RedisExchangeFactory - - return RedisExchangeFactory( - hostname=cfg.redis_hostname, - port=cfg.redis_port, - ) - - if cfg.exchange == "hybrid": - from academy.exchange import HybridExchangeFactory - - return HybridExchangeFactory() - - raise ValueError(f"Unsupported exchange type: {cfg.exchange}") - - -def _build_executor(cfg: AcademyConfig) -> Any: - """Create the executor matching the configured launcher type.""" - if cfg.launcher == "thread": - from concurrent.futures import ThreadPoolExecutor - - return ThreadPoolExecutor(max_workers=cfg.num_agents) - - if cfg.launcher == "process": - from concurrent.futures import ProcessPoolExecutor - - return ProcessPoolExecutor(max_workers=cfg.num_agents) - - if cfg.launcher == "parsl": - try: - from academy.executor import ParslExecutor - except ImportError as exc: - raise ImportError( - "Parsl launcher requires: pip install chemgraphagent[academy,parsl]" - ) from exc - return ParslExecutor() - - if cfg.launcher == "globus_compute": - try: - from academy.executor import GlobusComputeExecutor - except ImportError as exc: - raise ImportError( - "Globus Compute launcher requires: " - "pip install chemgraphagent[academy,globus_compute]" - ) from exc - return GlobusComputeExecutor(endpoint_id=cfg.globus_endpoint_id) - - raise ValueError(f"Unsupported launcher type: {cfg.launcher}") - - -async def build_manager( - cfg: AcademyConfig | None = None, - config_path: str = "config.toml", -) -> Any: - """Build an Academy Manager from ChemGraph config. - - Returns an async context manager. Usage:: - - async with await build_manager() as manager: - handle = await manager.launch(ScreeningAgent, ...) - result = await handle.screen_molecule("CCO", "optimize") - - Parameters - ---------- - cfg : AcademyConfig, optional - Pre-loaded config. If ``None``, loads from *config_path*. - config_path : str - Path to config.toml (used only when *cfg* is ``None``). - - Returns - ------- - Manager - An Academy Manager ready for ``async with``. - """ - from academy.manager import Manager - - if cfg is None: - cfg = load_academy_config(config_path) - - factory = _build_exchange_factory(cfg) - executor = _build_executor(cfg) - - return await Manager.from_exchange_factory( - factory=factory, - executors=executor, - ) diff --git a/src/chemgraph/academy/coordinator.py b/src/chemgraph/academy/coordinator.py deleted file mode 100644 index 12f9fc76..00000000 --- a/src/chemgraph/academy/coordinator.py +++ /dev/null @@ -1,179 +0,0 @@ -"""Coordinator agent for multi-wave screening campaigns. - -The coordinator manages a fleet of :class:`ScreeningAgent` instances, -collects results, and optionally uses a ChemGraph LLM workflow to -analyse the collected data and spawn follow-up screening waves. -""" - -from __future__ import annotations - -import asyncio -import glob -import json -import logging -import os -import time -from typing import Any, Optional - -from academy.agent import Agent, action, timer -from academy.handle import Handle - -logger = logging.getLogger(__name__) - - -class CoordinatorAgent(Agent): - """Collects screening results and orchestrates follow-up waves. - - Parameters - ---------- - results_dir : str - Directory where :class:`ScreeningAgent` instances write their - per-molecule JSON result files. - worker_handles : list[Handle] or None - Handles to active screening agents (for progress polling). - analysis_model : str - LLM model for analysing aggregated results. - analysis_workflow : str - ChemGraph workflow type for the analysis step. - analysis_kwargs : dict - Extra kwargs for the analysis ChemGraph instance. - """ - - def __init__( - self, - results_dir: str, - worker_handles: list[Handle] | None = None, - analysis_model: str = "gpt-4o", - analysis_workflow: str = "single_agent", - **analysis_kwargs: Any, - ) -> None: - super().__init__() - self._results_dir = results_dir - self._worker_handles = worker_handles or [] - self._analysis_model = analysis_model - self._analysis_workflow = analysis_workflow - self._analysis_kwargs = analysis_kwargs - self._collected: list[dict[str, Any]] = [] - self._analysis_result: Optional[dict[str, Any]] = None - - async def agent_on_startup(self) -> None: - os.makedirs(self._results_dir, exist_ok=True) - logger.info( - "CoordinatorAgent started: watching %s, %d workers", - self._results_dir, - len(self._worker_handles), - ) - - # ------------------------------------------------------------------ - # Progress monitoring - # ------------------------------------------------------------------ - - @action - async def poll_progress(self) -> dict[str, Any]: - """Query all workers for their screening progress.""" - progress = [] - for handle in self._worker_handles: - try: - p = await handle.get_progress() - progress.append(p) - except Exception as exc: - progress.append({"error": str(exc)}) - total = sum(p.get("total", 0) for p in progress if "error" not in p) - completed = sum(p.get("completed", 0) for p in progress if "error" not in p) - failed = sum(p.get("failed", 0) for p in progress if "error" not in p) - return { - "workers": len(progress), - "total": total, - "completed": completed, - "failed": failed, - "per_worker": progress, - } - - # ------------------------------------------------------------------ - # Result collection - # ------------------------------------------------------------------ - - @action - async def collect_results(self) -> list[dict[str, Any]]: - """Read all result JSON files from the shared results directory.""" - pattern = os.path.join(self._results_dir, "*.json") - files = sorted(glob.glob(pattern)) - results = [] - for path in files: - try: - with open(path) as f: - results.append(json.load(f)) - except (json.JSONDecodeError, OSError): - logger.warning("Skipping corrupt result file: %s", path) - self._collected = results - logger.info("Collected %d results from %s", len(results), self._results_dir) - return results - - # ------------------------------------------------------------------ - # LLM-powered analysis - # ------------------------------------------------------------------ - - @action - async def analyse(self, query: Optional[str] = None) -> dict[str, Any]: - """Use a ChemGraph agent to analyse collected results. - - Parameters - ---------- - query : str, optional - Custom analysis query. Defaults to a standard prompt - asking the LLM to rank candidates. - """ - from chemgraph.agent.llm_agent import ChemGraph - - if not self._collected: - await self.collect_results() - - successes = [r for r in self._collected if r.get("status") == "success"] - if not successes: - return {"error": "No successful results to analyse"} - - summary = json.dumps(successes, default=str, indent=2) - if query is None: - query = ( - "You are analysing computational chemistry screening results. " - f"Here are {len(successes)} results:\n\n{summary}\n\n" - "Identify the top candidates based on energy, stability, " - "or other relevant properties. Rank them and explain why." - ) - - cg = ChemGraph( - model_name=self._analysis_model, - workflow_type=self._analysis_workflow, - enable_memory=False, - **self._analysis_kwargs, - ) - self._analysis_result = await cg.run(query=query) - return self._analysis_result - - @action - async def get_analysis(self) -> dict[str, Any] | None: - """Return the most recent analysis result.""" - return self._analysis_result - - # ------------------------------------------------------------------ - # Wave dispatch - # ------------------------------------------------------------------ - - @action - async def suggest_followup_molecules( - self, - top_n: int = 10, - ) -> list[str]: - """Extract top candidate SMILES from analysis for a follow-up wave. - - Returns a list of SMILES strings identified as promising by - the analysis step. Falls back to returning the top-N by - lowest energy if no analysis is available. - """ - if not self._collected: - await self.collect_results() - - successes = [r for r in self._collected if r.get("status") == "success"] - # Simple heuristic: return the SMILES of completed molecules. - # A real implementation would parse energies from results. - return [r["smiles"] for r in successes[:top_n]] diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py new file mode 100644 index 00000000..6dd05c9f --- /dev/null +++ b/src/chemgraph/academy/core/__init__.py @@ -0,0 +1,35 @@ +"""Core ChemGraph Academy campaign contracts and agent logic.""" + +from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import ExecutionSpec +from chemgraph.academy.core.campaign import ResourceSpec +from chemgraph.academy.core.campaign import ToolSpec +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.lm import LLMSettings +from chemgraph.academy.core.lm import load_lm_config +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.core.turn import ChemGraphReasoningRoundEngine +from chemgraph.academy.core.turn import ReasoningTurnResult + +__all__ = [ + "ChemGraphAgentSpec", + "ChemGraphCampaign", + "ChemGraphDaemonConfig", + "ChemGraphLogicalAgent", + "ChemGraphReasoningRoundEngine", + "ExecutionSpec", + "LLMSettings", + "PromptProfile", + "ReasoningTurnResult", + "ResourceSpec", + "ToolSpec", + "load_campaign", + "load_lm_config", + "load_prompt_profile", + "resolve_campaign_resources", +] diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py new file mode 100644 index 00000000..d4b55da7 --- /dev/null +++ b/src/chemgraph/academy/core/agent.py @@ -0,0 +1,281 @@ +"""Persistent logical Academy agent for ChemGraph campaigns.""" + +from __future__ import annotations + +import asyncio +import time +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +from academy.agent import Agent, action +from academy.agent import loop +from academy.handle import Handle +from academy.identifier import AgentId + +from chemgraph.academy.core.fastmcp import ( + CampaignFastMCPToolInvoker, +) +from chemgraph.academy.core.peer_protocol import validate_message +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.run_artifacts import write_status_snapshot +from chemgraph.academy.core.turn import build_recent_actions +from chemgraph.academy.core.turn import ChemGraphReasoningRoundEngine +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.lm import LLMSettings +from chemgraph.academy.core.prompt import PromptProfile + + +class ChemGraphLogicalAgent(Agent): + """Persistent Academy logical agent for one ChemGraph campaign role.""" + + def __init__( + self, + spec: ChemGraphAgentSpec, + *, + campaign: ChemGraphCampaign, + llm_settings: LLMSettings, + prompt_profile: PromptProfile, + run_dir: Path, + max_decisions: int, + tool_invoker: CampaignFastMCPToolInvoker, + peer_agent_ids: Mapping[str, AgentId[Any]] | None = None, + placement: dict[str, Any] | None = None, + poll_timeout_s: float = 2.0, + idle_timeout_s: float = 120.0, + status_interval_s: float = 5.0, + ) -> None: + super().__init__() + self.spec = spec + self.campaign = campaign + self.llm_settings = llm_settings + self.prompt_profile = prompt_profile + self.run_dir = run_dir + self.max_decisions = max_decisions + self.tool_invoker = tool_invoker + self.peer_agent_ids = dict(peer_agent_ids or {}) + self.placement = placement or {} + self.poll_timeout_s = poll_timeout_s + self.idle_timeout_s = idle_timeout_s + self.status_interval_s = status_interval_s + + self.peer_names = tuple(spec.allowed_peers) + self.peer_handles: dict[str, Handle[Any]] = {} + self.received_message_history: list[dict[str, Any]] = [] + self.outbox: list[dict[str, Any]] = [] + self.tool_results: list[dict[str, Any]] = [] + self.final_result: dict[str, Any] | None = None + self.round_index = 0 + self.finished = False + self.last_error: str | None = None + self._wake_event: asyncio.Event | None = None + self._reasoning_engine: ChemGraphReasoningRoundEngine | None = None + + async def agent_on_startup(self) -> None: + self._wake_event = asyncio.Event() + self.peer_handles = { + name: Handle(agent_id) + for name, agent_id in self.peer_agent_ids.items() + if name in self.peer_names + } + self._reasoning_engine = await ChemGraphReasoningRoundEngine.create( + campaign=self.campaign, + spec=self.spec, + llm_settings=self.llm_settings, + prompt_profile=self.prompt_profile, + run_dir=self.run_dir, + max_decisions=self.max_decisions, + tool_invoker=self.tool_invoker, + peer_names=self.peer_names, + peer_handles=self.peer_handles, + received_message_history=self.received_message_history, + outbox=self.outbox, + tool_results=self.tool_results, + get_final_result=lambda: self.final_result, + get_round_index=lambda: self.round_index, + set_final_result=self._set_final_result, + trace=self._trace, + ) + self._trace( + 'agent_started', + { + 'role': self.spec.role, + 'tool_names': list(self.spec.tool_names), + 'allowed_peers': list(self.spec.allowed_peers), + 'placement': self.placement, + **self.placement, + }, + ) + + @action + async def receive_message(self, message: dict[str, Any]) -> None: + validate_message(message) + self.received_message_history.append(message) + self._trace('message_received', message) + if self._wake_event is not None: + self._wake_event.set() + + @action + async def get_status(self) -> dict[str, Any]: + return await self.report_state() + + @loop + async def deliberate(self, shutdown: asyncio.Event) -> None: + if self._wake_event is None: + raise RuntimeError('agent startup did not initialize wake state') + + decisions_completed = 0 + last_activity = time.monotonic() + last_status = 0.0 + + while not shutdown.is_set(): + if self._wake_event.is_set(): + self._wake_event.clear() + decisions_completed, self_wake = await self.run_decision_turn( + decisions_completed, + ) + last_activity = time.monotonic() + if self_wake: + self._wake_event.set() + await self.write_runtime_status() + if decisions_completed >= self.max_decisions: + self._trace( + 'max_decisions_reached', + {'decisions_completed': decisions_completed}, + ) + break + continue + + now = time.monotonic() + if now - last_status >= self.status_interval_s: + await self.write_runtime_status() + last_status = now + + if now - last_activity >= self.idle_timeout_s: + self._trace( + 'idle_timeout', + { + 'idle_timeout_s': self.idle_timeout_s, + 'decisions_completed': decisions_completed, + }, + ) + break + + try: + await asyncio.wait_for( + self._wake_event.wait(), + timeout=self.poll_timeout_s, + ) + except asyncio.TimeoutError: + pass + + self.finished = True + self._trace( + 'daemon_stopped', + { + 'decisions_completed': decisions_completed, + 'shutdown_requested': shutdown.is_set(), + }, + ) + await self.write_runtime_status() + self.agent_shutdown() + + async def write_runtime_status(self) -> None: + write_status_snapshot( + run_dir=self.run_dir, + campaign=self.campaign, + agent_state=await self.report_state(), + placement=self.placement, + ) + + async def run_decision_turn(self, decisions_completed: int) -> tuple[int, bool]: + self.round_index += 1 + try: + self_wake = await self._reasoning_round() + except Exception as exc: + self.last_error = repr(exc) + self._trace('agent_error', {'error': self.last_error}) + raise + return decisions_completed + 1, self_wake + + async def report_state(self) -> dict[str, Any]: + return { + 'agent_name': self.spec.name, + 'role': self.spec.role, + 'status_updated_at': time.time(), + 'round': self.round_index, + 'finished': self.finished, + 'last_error': self.last_error, + 'current_activity': None, + 'received_message_count': len(self.received_message_history), + 'outbox_count': len(self.outbox), + 'recent_received_messages': self.received_message_history[-10:], + 'recent_outbox': self.outbox[-10:], + 'tool_names': list(self.spec.tool_names), + 'tool_result_count': len(self.tool_results), + 'recent_tool_results': self.tool_results[-8:], + 'recent_actions': build_recent_actions( + outbox=self.outbox, + tool_results=self.tool_results, + limit=12, + ), + 'belief': self.final_result or { + 'hypothesis': None, + 'confidence': 0.0, + 'supporting_message_ids': [], + 'supporting_tool_result_ids': [], + 'reason': None, + }, + 'belief_history': [self.final_result] if self.final_result else [], + } + + async def _reasoning_round(self) -> bool: + if self._reasoning_engine is None: + raise RuntimeError('agent startup did not initialize reasoning engine') + self._trace('round_started', {'round': self.round_index}) + result = await self._reasoning_engine.run_turn() + self._trace( + 'agent_decision', + { + 'mode': 'mpi_daemon', + 'wake_reason': f'daemon round {self.round_index}', + 'rationale': 'LM returned the listed tool calls for this daemon turn.', + 'round': self.round_index, + 'tool_names': list(result.executed_tool_names), + 'action_tools_called': list(result.action_tools_called), + 'science_tools_called': list(result.science_tools_called), + 'workflow_span_id': result.workflow_span_id, + 'thread_id': result.thread_id, + 'engine': 'chemgraph_single_agent', + 'actions': [ + {'action': name} + for name in result.executed_tool_names + ], + }, + ) + self._trace('round_finished', {'round': self.round_index}) + if result.requested_self_wake: + self._trace( + 'self_wake_scheduled', + { + 'round': self.round_index, + 'reason': ( + 'local ChemGraph tool result is now available in ' + 'local_chemgraph_tool_results' + ), + }, + ) + return result.requested_self_wake + + def _set_final_result(self, result: dict[str, Any]) -> None: + self.final_result = result + + def _trace(self, event: str, payload: dict[str, Any]) -> None: + EventLog(self.run_dir / 'events.jsonl').emit( + event, # type: ignore[arg-type] + run_id=self.run_dir.name, + agent_id=self.spec.name, + role=self.spec.role, + payload=payload, + ) diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py new file mode 100644 index 00000000..7eae1cd3 --- /dev/null +++ b/src/chemgraph/academy/core/campaign.py @@ -0,0 +1,484 @@ +from __future__ import annotations + +import dataclasses +import json +import pathlib +from collections.abc import Mapping +from typing import Any + +from chemgraph.academy.examples import resolve_builtin_campaign +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +_REMOVED_CAMPAIGN_FIELDS = frozenset( + { + 'completion_criteria', + 'parameters', + 'routing_policy', + 'workflow_stages', + }, +) +_RESOURCE_KINDS = frozenset({'directory', 'file', 'json'}) +_RESOURCE_SCOPES = frozenset( + { + 'absolute', + 'campaign_file', + 'external', + 'shared_run', + }, +) + + +class ToolSpec(BaseModel): + """Campaign-declared in-process FastMCP tool available to agents.""" + + model_config = ConfigDict(extra='forbid') + + name: str + module: str + tool: str + description: str = '' + + @field_validator('name', 'module', 'tool') + @classmethod + def _non_empty(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError('tool spec fields must be non-empty strings') + return value + + +class ExecutionSpec(BaseModel): + """Execution defaults used when configuring ChemGraph FastMCP backends.""" + + model_config = ConfigDict(extra='forbid') + + backend: str = 'local' + system: str = 'local' + config_path: str | None = None + options: dict[str, Any] = Field(default_factory=dict) + + +class ResourceSpec(BaseModel): + """Campaign-declared resource or artifact handle. + + The runtime resolves only these explicit ``path`` fields. It never scans + arbitrary campaign metadata looking for strings that might be paths. + """ + + model_config = ConfigDict(extra='forbid') + + kind: str + path: str | None = None + uri: str | None = None + scope: str = 'campaign_file' + description: str = '' + expose_content: bool = False + + @field_validator('kind') + @classmethod + def _known_resource_kind(cls, value: str) -> str: + value = value.strip() + if value not in _RESOURCE_KINDS: + raise ValueError( + f'resource kind must be one of {sorted(_RESOURCE_KINDS)}', + ) + return value + + @field_validator('scope') + @classmethod + def _known_resource_scope(cls, value: str) -> str: + value = value.strip() + if value not in _RESOURCE_SCOPES: + raise ValueError( + f'resource scope must be one of {sorted(_RESOURCE_SCOPES)}', + ) + return value + + @field_validator('path', 'uri', 'description') + @classmethod + def _strip_optional_resource_field(cls, value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + +@dataclasses.dataclass(frozen=True) +class ChemGraphAgentSpec: + name: str + role: str + mission: str + allowed_peers: tuple[str, ...] + tools: tuple[ToolSpec, ...] + resources: tuple[str, ...] = () + + @property + def tool_names(self) -> tuple[str, ...]: + return tuple(tool.name for tool in self.tools) + + +@dataclasses.dataclass(frozen=True) +class ChemGraphCampaign: + run_id: str + user_task: str + initial_agent: str + prompt_profile: pathlib.Path + agents: tuple[ChemGraphAgentSpec, ...] + tool_catalog: tuple[ToolSpec, ...] = () + resources: Mapping[str, ResourceSpec] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass(frozen=True) +class ChemGraphDaemonConfig: + run_dir: pathlib.Path + run_token: str + agent_count: int + campaign_config: pathlib.Path + lm_config: pathlib.Path + max_decisions: int + poll_timeout_s: float + idle_timeout_s: float + startup_timeout_s: float + completion_timeout_s: float + status_interval_s: float + redis_host: str + redis_port: int + redis_namespace: str + clean_redis: bool + rank: int + local_rank: int | None + chemgraph_repo_root: pathlib.Path + + +def namespace_for_run(run_dir: pathlib.Path) -> str: + return f'academy-chemgraph-swarm:{run_dir.name}' + + +def resolve_campaign_resources( + campaign: ChemGraphCampaign, + run_dir: str | pathlib.Path, + *, + shared_dir_name: str = 'shared', +) -> ChemGraphCampaign: + """Resolve explicit shared-run resource paths for one concrete run.""" + shared_root = (pathlib.Path(run_dir).resolve() / shared_dir_name) + resources: dict[str, ResourceSpec] = {} + + for name, spec in campaign.resources.items(): + if spec.path is None: + resources[name] = spec + continue + if spec.scope != 'shared_run': + resources[name] = spec + continue + path = pathlib.Path(spec.path) + resolved = path if path.is_absolute() else shared_root / path + resources[name] = spec.model_copy( + update={ + 'path': str(resolved.resolve()), + 'uri': spec.uri or _file_uri(resolved.resolve()), + }, + ) + + return dataclasses.replace(campaign, resources=resources) + + +def _file_uri(path: pathlib.Path) -> str: + return path.resolve().as_uri() + + +def _resolve_resource_spec( + raw: Mapping[str, Any], + *, + campaign_path: pathlib.Path, +) -> ResourceSpec: + spec = ResourceSpec.model_validate(raw) + if spec.path is None: + return spec + if spec.scope == 'campaign_file': + path = pathlib.Path(spec.path) + resolved = path if path.is_absolute() else campaign_path.parent / path + resolved = resolved.resolve() + return spec.model_copy( + update={ + 'path': str(resolved), + 'uri': spec.uri or _file_uri(resolved), + }, + ) + if spec.scope == 'absolute': + path = pathlib.Path(spec.path) + if not path.is_absolute(): + raise RuntimeError( + f'absolute resource path must be absolute: {spec.path}', + ) + resolved = path.resolve() + return spec.model_copy( + update={ + 'path': str(resolved), + 'uri': spec.uri or _file_uri(resolved), + }, + ) + if spec.scope in {'shared_run', 'external'}: + return spec + + raise RuntimeError(f'unsupported resource scope {spec.scope!r}') + + +def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: + path = resolve_builtin_campaign(path) + data = _load_jsonc(path) + _reject_removed_campaign_fields(data, campaign_path=path) + prompt_profile = _resolve_campaign_relative_path( + data.get('prompt_profile'), + campaign_path=path, + field_name='prompt_profile', + ) + + tool_catalog = _load_tool_catalog(data) + resources = { + name: _resolve_resource_spec(raw, campaign_path=path) + for name, raw in dict(data.get('resources', {})).items() + } + agents = [] + for item in data['agents']: + agents.append( + ChemGraphAgentSpec( + name=item['name'], + role=item['role'], + mission=item['mission'], + allowed_peers=tuple(item.get('allowed_peers', ())), + tools=_load_declared_tools(item, tool_catalog), + resources=tuple(item.get('resources', ())), + ), + ) + return ChemGraphCampaign( + run_id=data.get('run_id', path.stem), + user_task=data['user_task'], + initial_agent=data.get('initial_agent', agents[0].name), + prompt_profile=prompt_profile, + agents=tuple(agents), + tool_catalog=tuple(tool_catalog.values()), + resources=resources, + ) + + +def _load_jsonc(path: pathlib.Path) -> dict[str, Any]: + """Load a campaign file that may contain JSONC-style comments.""" + data = json.loads(_strip_json_comments(path.read_text(encoding='utf-8'))) + if not isinstance(data, dict): + raise RuntimeError(f'campaign {path} must contain a JSON object') + return data + + +def _strip_json_comments(text: str) -> str: + """Remove // and /* */ comments without touching JSON string values.""" + out: list[str] = [] + in_string = False + escape = False + i = 0 + + while i < len(text): + char = text[i] + nxt = text[i + 1] if i + 1 < len(text) else '' + + if in_string: + out.append(char) + if escape: + escape = False + elif char == '\\': + escape = True + elif char == '"': + in_string = False + i += 1 + continue + + if char == '"': + in_string = True + out.append(char) + i += 1 + continue + + if char == '/' and nxt == '/': + i += 2 + while i < len(text) and text[i] not in '\r\n': + i += 1 + continue + + if char == '/' and nxt == '*': + i += 2 + while i < len(text): + if text[i] in '\r\n': + out.append(text[i]) + i += 1 + continue + if text[i] == '*' and i + 1 < len(text) and text[i + 1] == '/': + i += 2 + break + i += 1 + continue + + out.append(char) + i += 1 + + return ''.join(out) + + +def _reject_removed_campaign_fields( + data: Mapping[str, Any], + *, + campaign_path: pathlib.Path, +) -> None: + removed = sorted(_REMOVED_CAMPAIGN_FIELDS.intersection(data)) + if not removed: + return + raise RuntimeError( + f'campaign {campaign_path} uses removed structured orchestration ' + f'field(s): {removed}. Put simple natural-language coordination hints ' + 'in agent mission fields and enforce the communication graph with ' + 'allowed_peers.', + ) + + +def _resolve_campaign_relative_path( + raw: Any, + *, + campaign_path: pathlib.Path, + field_name: str, +) -> pathlib.Path: + if not isinstance(raw, str) or not raw.strip(): + raise RuntimeError(f'campaign requires non-empty {field_name!r}') + path = pathlib.Path(raw.strip()) + if not path.is_absolute(): + path = campaign_path.parent / path + return path.resolve() + + +def _load_tool_catalog(data: Mapping[str, Any]) -> dict[str, ToolSpec]: + catalog: dict[str, ToolSpec] = {} + for raw in data.get('tools', ()): + if not isinstance(raw, dict): + raise RuntimeError('campaign top-level tools[] entries must be objects') + spec = ToolSpec.model_validate(raw) + if spec.name in catalog: + raise RuntimeError(f'duplicate campaign tool name: {spec.name}') + catalog[spec.name] = spec + return catalog + + +def _load_declared_tools( + item: Mapping[str, Any], + catalog: Mapping[str, ToolSpec], +) -> tuple[ToolSpec, ...]: + raw_tools = item.get('tools') + if raw_tools is None: + raw_tools = item.get('tool_names', ()) + tools: list[ToolSpec] = [] + for raw in raw_tools: + if isinstance(raw, str): + try: + tools.append(catalog[raw]) + except KeyError as exc: + raise RuntimeError( + f'agent {item.get("name")!r} references unknown campaign tool {raw!r}; ' + 'declare it in top-level tools[] or inline as a FastMCP ToolSpec object', + ) from exc + elif isinstance(raw, dict): + tools.append(ToolSpec.model_validate(raw)) + else: + raise RuntimeError( + f'agent {item.get("name")!r} tools[] entries must be strings or objects', + ) + return tuple(tools) + + +def validate_campaign(campaign: ChemGraphCampaign, agent_count: int) -> None: + if len(campaign.agents) != agent_count: + raise RuntimeError( + f'campaign defines {len(campaign.agents)} agents but ' + f'agent_count={agent_count}', + ) + names = [agent.name for agent in campaign.agents] + if len(set(names)) != len(names): + raise RuntimeError('campaign agent names must be unique') + if campaign.initial_agent not in names: + raise RuntimeError( + f'initial_agent {campaign.initial_agent!r} is not an agent', + ) + for agent in campaign.agents: + unknown = sorted(set(agent.allowed_peers).difference(names)) + if unknown: + raise RuntimeError( + f'{agent.name} has unknown allowed peers: {unknown}', + ) + if agent.name in agent.allowed_peers: + raise RuntimeError(f'{agent.name} must not list itself as a peer') + tool_names = list(agent.tool_names) + if len(set(tool_names)) != len(tool_names): + raise RuntimeError(f'{agent.name} has duplicate tool declarations') + unknown_resources = sorted(set(agent.resources).difference(campaign.resources)) + if unknown_resources: + raise RuntimeError( + f'{agent.name} references unknown resources: {unknown_resources}', + ) + + +def selected_agent(campaign: ChemGraphCampaign, rank: int) -> ChemGraphAgentSpec: + if rank < 0 or rank >= len(campaign.agents): + raise RuntimeError( + f'MPI rank {rank} has no agent. Launch exactly ' + f'{len(campaign.agents)} ranks for this campaign.', + ) + return campaign.agents[rank] + + +def campaign_bootstrap_text(campaign: ChemGraphCampaign) -> str: + initial_agent = next( + (agent for agent in campaign.agents if agent.name == campaign.initial_agent), + None, + ) + initial_resources = initial_agent.resources if initial_agent is not None else () + payload: dict[str, Any] = { + 'user_task': campaign.user_task, + 'resources': _resources_payload(campaign, initial_resources), + 'resource_data': _resource_data_payload(campaign, initial_resources), + } + return json.dumps(payload, sort_keys=True) + + +def _resources_payload( + campaign: ChemGraphCampaign, + resource_names: tuple[str, ...] | list[str], +) -> dict[str, dict[str, Any]]: + payload: dict[str, dict[str, Any]] = {} + for name in resource_names: + spec = campaign.resources.get(name) + if spec is None: + continue + payload[name] = spec.model_dump(exclude_none=True) + return payload + + +def _resource_data_payload( + campaign: ChemGraphCampaign, + resource_names: tuple[str, ...] | list[str], +) -> dict[str, Any]: + payload: dict[str, Any] = {} + for name in resource_names: + spec = campaign.resources.get(name) + if spec is None or not spec.expose_content: + continue + if spec.kind != 'json' or spec.path is None: + continue + path = pathlib.Path(spec.path) + if not path.exists(): + raise FileNotFoundError(f'campaign resource does not exist: {path}') + payload[name] = json.loads(path.read_text(encoding='utf-8')) + return payload + + +def visible_resources_payload( + campaign: ChemGraphCampaign, + agent: ChemGraphAgentSpec, +) -> dict[str, dict[str, Any]]: + return _resources_payload(campaign, agent.resources) diff --git a/src/chemgraph/academy/core/fastmcp.py b/src/chemgraph/academy/core/fastmcp.py new file mode 100644 index 00000000..b7c148b7 --- /dev/null +++ b/src/chemgraph/academy/core/fastmcp.py @@ -0,0 +1,344 @@ +"""Campaign-scoped in-process FastMCP tool loading and invocation.""" + +from __future__ import annotations + +import importlib +import json +import uuid +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from chemgraph.academy.core.campaign import ExecutionSpec +from chemgraph.academy.core.campaign import ToolSpec + +class ToolInvocation(BaseModel): + """A normalized record of one agent-requested FastMCP tool call.""" + + model_config = ConfigDict(extra="forbid") + + tool_name: str + arguments: dict[str, Any] = Field(default_factory=dict) + agent_id: str | None = None + role: str | None = None + correlation_id: str = Field(default_factory=lambda: f"call-{uuid.uuid4()}") + + +class ToolResult(BaseModel): + """Normalized result from a campaign FastMCP tool call.""" + + model_config = ConfigDict(extra="allow") + + tool_name: str + status: str + result: Any = None + error: str | None = None + correlation_id: str + + +def load_fastmcp_tool_module( + module_name: str, + *, + cache: dict[str, Any] | None = None, +) -> Any: + """Return a module's top-level FastMCP object declared by a campaign tool.""" + if cache is not None and module_name in cache: + return cache[module_name] + + module = importlib.import_module(module_name) + try: + server = module.mcp + except AttributeError as exc: + raise RuntimeError( + f"FastMCP tool module {module_name!r} does not expose " + "a top-level 'mcp' object", + ) from exc + + if cache is not None: + cache[module_name] = server + return server + + +async def fastmcp_tool_schemas(specs: list[ToolSpec]) -> list[dict[str, Any]]: + """Build OpenAI tool schemas from declared FastMCP ToolSpecs.""" + schemas: list[dict[str, Any]] = [] + module_cache: dict[str, Any] = {} + tools_cache: dict[str, dict[str, Any]] = {} + for spec in specs: + if spec.module not in tools_cache: + tools = await load_fastmcp_tool_module( + spec.module, + cache=module_cache, + ).list_tools() + tools_cache[spec.module] = { + _fastmcp_tool_name(tool): _fastmcp_tool_payload(tool) + for tool in tools + } + try: + tool_payload = tools_cache[spec.module][spec.tool] + except KeyError as exc: + raise RuntimeError( + f"FastMCP module {spec.module!r} does not expose tool " + f"{spec.tool!r}", + ) from exc + schemas.append(_openai_tool_schema(spec, tool_payload)) + return schemas + + +def _fastmcp_tool_name(tool: Any) -> str: + if isinstance(tool, dict): + return str(tool.get("name", "")) + return str(getattr(tool, "name", "")) + + +def _fastmcp_tool_payload(tool: Any) -> dict[str, Any]: + if isinstance(tool, dict): + return dict(tool) + if hasattr(tool, "model_dump"): + return tool.model_dump(mode="json") + return { + "name": getattr(tool, "name", ""), + "description": getattr(tool, "description", ""), + "inputSchema": getattr(tool, "inputSchema", None), + } + + +def _openai_tool_schema( + spec: ToolSpec, + tool_payload: dict[str, Any], +) -> dict[str, Any]: + parameters = _sanitize_input_schema( + tool_payload.get("inputSchema") or {"type": "object", "properties": {}}, + ) + return { + "type": "function", + "function": { + "name": spec.name, + "description": spec.description + or str(tool_payload.get("description") or ""), + "parameters": parameters, + }, + } + + +def _sanitize_input_schema(schema: Any) -> dict[str, Any]: + if hasattr(schema, "model_dump"): + schema = schema.model_dump(mode="json") + if not isinstance(schema, dict): + return {"type": "object", "properties": {}, "additionalProperties": False} + sanitized = json.loads(json.dumps(schema)) + sanitized.setdefault("type", "object") + sanitized.setdefault("properties", {}) + sanitized.setdefault("additionalProperties", False) + return sanitized + + +def serialize_fastmcp_result(result: Any) -> Any: + """Convert FastMCP content blocks to JSON-friendly values.""" + if isinstance(result, dict): + return result + if isinstance(result, (str, int, float, bool)) or result is None: + return result + if hasattr(result, "model_dump"): + return result.model_dump(mode="json") + if isinstance(result, Sequence) and not isinstance(result, (str, bytes)): + values = [serialize_fastmcp_result(item) for item in result] + structured = _first_structured_result(values) + if structured is not None: + return structured + json_text = _first_json_text_result(values) + if json_text is not None: + return json_text + return values + return str(result) + + +def _first_structured_result(values: list[Any]) -> dict[str, Any] | None: + for value in values: + if isinstance(value, dict) and ( + "results" in value + or "batch_id" in value + or "progress_pct" in value + or value.get("status") in {"completed", "submitted"} + ): + return value + if isinstance(value, list): + nested = _first_structured_result(value) + if nested is not None: + return nested + if isinstance(value, dict) and isinstance(value.get("text"), str): + try: + parsed = json.loads(value["text"]) + except json.JSONDecodeError: + continue + nested = _first_structured_result([parsed]) + if nested is not None: + return nested + return None + + +def _first_json_text_result(values: list[Any]) -> Any | None: + for value in values: + if isinstance(value, dict) and isinstance(value.get("text"), str): + try: + return json.loads(value["text"]) + except json.JSONDecodeError: + continue + if isinstance(value, list): + nested = _first_json_text_result(value) + if nested is not None: + return nested + return None + + +class CampaignFastMCPToolInvoker: + """Invoke campaign-allowed tools through in-process FastMCP modules.""" + + def __init__( + self, + *, + specs: list[ToolSpec], + execution: ExecutionSpec, + run_dir: str | Path, + ) -> None: + self.specs = {spec.name: spec for spec in specs} + self.execution = execution + self.run_dir = Path(run_dir) + self._module_cache: dict[str, Any] = {} + self._available_cache: dict[str, set[str]] = {} + + def names(self) -> list[str]: + return sorted(self.specs) + + async def verify_allowed_tools(self) -> list[str]: + """Return tools missing from their declared FastMCP module.""" + missing: list[str] = [] + for spec in self.specs.values(): + try: + available = await self._available_tool_names(spec.module) + except Exception: # noqa: BLE001 - caller needs aggregate missing names + missing.append(spec.name) + continue + if spec.tool not in available: + missing.append(spec.name) + return missing + + async def invoke(self, invocation: ToolInvocation) -> ToolResult: + spec = self.specs.get(invocation.tool_name) + if spec is None: + raise KeyError( + f"unknown campaign FastMCP tool: {invocation.tool_name}", + ) + + try: + available = await self._available_tool_names(spec.module) + if spec.tool not in available: + raise KeyError( + f"FastMCP module {spec.module!r} does not expose " + f"tool {spec.tool!r}", + ) + mcp = self._fastmcp_module(spec.module) + _configure_fastmcp_backend( + mcp, + module_name=spec.module, + execution=self.execution, + run_dir=self.run_dir, + ) + from chemgraph.observability.events import WorkflowEventContext + from chemgraph.observability.events import workflow_event_context + + context = WorkflowEventContext( + run_id=self.run_dir.name, + run_dir=str(self.run_dir), + agent_id=invocation.agent_id, + role=invocation.role, + parent_span_id=invocation.correlation_id, + tool_name=invocation.tool_name, + ) + with workflow_event_context( + jsonl_path=self.run_dir / "events.jsonl", + context=context, + ): + result = await mcp.call_tool(spec.tool, invocation.arguments) + except Exception as exc: # noqa: BLE001 - preserve tool failure as data + return ToolResult( + tool_name=invocation.tool_name, + status="error", + error=repr(exc), + correlation_id=invocation.correlation_id, + ) + + return ToolResult( + tool_name=invocation.tool_name, + status="success", + result=serialize_fastmcp_result(result), + correlation_id=invocation.correlation_id, + ) + + async def _available_tool_names(self, module_name: str) -> set[str]: + if module_name not in self._available_cache: + tools = await self._fastmcp_module(module_name).list_tools() + self._available_cache[module_name] = { + str(getattr(tool, "name", "")) + if not isinstance(tool, dict) + else str(tool.get("name", "")) + for tool in tools + } + return self._available_cache[module_name] + + def _fastmcp_module(self, module_name: str) -> Any: + return load_fastmcp_tool_module(module_name, cache=self._module_cache) + + +def _configure_fastmcp_backend( + mcp: Any, + *, + module_name: str, + execution: ExecutionSpec, + run_dir: str | Path, +) -> None: + """Configure a CGFastMCP backend without initialising compute resources.""" + if not hasattr(mcp, "init_backend"): + return + if getattr(mcp, "_backend_kwargs", None) is not None: + return + + kwargs: dict[str, Any] = dict(execution.options) + if execution.config_path: + kwargs["config_path"] = execution.config_path + if execution.backend: + kwargs["backend_name"] = execution.backend + if execution.system: + kwargs["system"] = execution.system + + tracker_name = module_name.replace(".", "_") + tracker_path = Path(run_dir) / f"{tracker_name}_jobs.json" + mcp.init_backend( + tracker_kwargs={"persist_file": str(tracker_path)}, + **kwargs, + ) + + +async def build_campaign_fastmcp_tool_invoker( + *, + specs: list[ToolSpec], + execution: ExecutionSpec, + run_dir: str | Path, + agent_name: str, +) -> CampaignFastMCPToolInvoker: + """Build and verify one agent's campaign-scoped FastMCP tool invoker.""" + invoker = CampaignFastMCPToolInvoker( + specs=list(specs), + execution=execution, + run_dir=run_dir, + ) + missing = await invoker.verify_allowed_tools() + if missing: + raise RuntimeError( + f"Could not resolve requested FastMCP tools for {agent_name}: {missing}", + ) + return invoker diff --git a/src/chemgraph/academy/core/lm.py b/src/chemgraph/academy/core/lm.py new file mode 100644 index 00000000..52f886c0 --- /dev/null +++ b/src/chemgraph/academy/core/lm.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import dataclasses +import json +from pathlib import Path +from typing import Any + + +@dataclasses.dataclass(frozen=True) +class LLMSettings: + """Configuration for an OpenAI-compatible chat-completions endpoint.""" + + base_url: str + model: str + provider: str + timeout_s: float + temperature: float + max_tokens: int + max_retries: int + retry_delay_s: float + api_key: str | None = None + user: str | None = None + + +def load_lm_config(path: str | Path) -> LLMSettings: + """Load LM settings from a JSON config file.""" + config_path = Path(path) + data = json.loads(config_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise ValueError(f"LM config must be a JSON object: {config_path}") + return _settings_from_mapping(data, source=str(config_path)) + + +def _settings_from_mapping(data: dict[str, Any], *, source: str) -> LLMSettings: + required = ( + "base_url", + "model", + "provider", + "timeout_s", + "temperature", + "max_tokens", + "max_retries", + "retry_delay_s", + ) + missing = [name for name in required if data.get(name) is None] + if missing: + raise ValueError(f"LM config {source} is missing required keys: {missing}") + + provider = str(data["provider"]) + if provider != "openai_compatible_tools": + raise ValueError( + f"LM config {source} provider must be 'openai_compatible_tools'", + ) + if not data.get("api_key"): + raise ValueError( + f"LM config {source} requires api_key; use 'dummy' for Argo shim " + "routes that do not require auth", + ) + + return LLMSettings( + base_url=str(data["base_url"]), + model=str(data["model"]), + provider=provider, + api_key=str(data["api_key"]), + user=str(data["user"]) if data.get("user") else None, + timeout_s=float(data["timeout_s"]), + temperature=float(data["temperature"]), + max_tokens=int(data["max_tokens"]), + max_retries=int(data["max_retries"]), + retry_delay_s=float(data["retry_delay_s"]), + ) diff --git a/src/chemgraph/academy/core/peer_protocol.py b/src/chemgraph/academy/core/peer_protocol.py new file mode 100644 index 00000000..503efd80 --- /dev/null +++ b/src/chemgraph/academy/core/peer_protocol.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import uuid +import time +from typing import Any + + +REQUIRED_MESSAGE_KEYS = { + 'message_id', + 'sender', + 'recipient', + 'content', +} + + +def validate_message(message: dict[str, Any]) -> None: + """Validate the generic Academy message envelope.""" + if missing := REQUIRED_MESSAGE_KEYS.difference(message): + raise ValueError(f'message missing keys: {sorted(missing)}') + + +def build_message( + *, + sender: str, + recipient: str, + content: str, + round_index: int | None = None, + kind: str = 'message', + tldr: str | None = None, + artifact_refs: list[str] | None = None, + tool_result_ids: list[str] | None = None, + reason: str | None = None, + confidence: float | None = None, +) -> dict[str, Any]: + """Create the structured message payload sent through Academy handles.""" + payload: dict[str, Any] = { + 'message_id': f'msg-{uuid.uuid4()}', + 'timestamp': time.time(), + 'sender': sender, + 'recipient': recipient, + 'kind': kind, + 'content': content, + 'artifact_refs': artifact_refs or [], + 'tool_result_ids': tool_result_ids or [], + } + if round_index is not None: + payload['round'] = round_index + if tldr is not None: + payload['tldr'] = tldr + if reason is not None: + payload['reason'] = reason + if confidence is not None: + payload['confidence'] = confidence + return payload diff --git a/src/chemgraph/academy/core/prompt.py b/src/chemgraph/academy/core/prompt.py new file mode 100644 index 00000000..8268bf87 --- /dev/null +++ b/src/chemgraph/academy/core/prompt.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class PromptStateLimits(BaseModel): + """Visibility limits for state included in each logical-agent prompt.""" + + model_config = ConfigDict(extra='forbid') + + received_messages_last_n: int = Field(ge=0) + tool_results_last_n: int = Field(ge=0) + actions_last_n: int = Field(ge=0) + + +class PromptProfile(BaseModel): + """Prompt/rendering profile shared by logical agents in a campaign run.""" + + model_config = ConfigDict(extra='forbid') + + prompt_version: str + prompt_style: Literal['json_state'] + system_prompt: str + protocol_prompt: str + langchain_recursion_limit: int = Field(ge=4) + state_limits: PromptStateLimits + + +def load_prompt_profile(path: str | Path) -> PromptProfile: + data = json.loads(Path(path).read_text(encoding='utf-8')) + return PromptProfile.model_validate(data) diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py new file mode 100644 index 00000000..38a60e38 --- /dev/null +++ b/src/chemgraph/academy/core/tools.py @@ -0,0 +1,520 @@ +"""Adapt Academy actions and campaign FastMCP tools for ChemGraph turns.""" + +from __future__ import annotations + +import json +import pathlib +import time +import uuid +import asyncio +from collections.abc import Callable +from collections.abc import Mapping +from dataclasses import dataclass +from dataclasses import field +from typing import Any + +from academy.handle import Handle +from langchain_core.tools import BaseTool +from langchain_core.tools import StructuredTool +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import ValidationError + +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.fastmcp import ToolInvocation +from chemgraph.academy.core.fastmcp import fastmcp_tool_schemas +from chemgraph.academy.core.fastmcp import ( + CampaignFastMCPToolInvoker, +) +from chemgraph.academy.core.peer_protocol import build_message +from chemgraph.academy.observability.run_files import append_jsonl + + +TraceFn = Callable[[str, dict[str, Any]], None] +SetFinalResultFn = Callable[[dict[str, Any]], None] + + +@dataclass +class ReasoningToolRuntimeState: + """Mutable per-turn state updated by ChemGraph reasoning tools.""" + + science_tool_completed: bool = False + submitted_result: bool = False + finished_turn: bool = False + executed_tool_names: list[str] = field(default_factory=list) + action_tool_names: list[str] = field(default_factory=list) + science_tool_names: list[str] = field(default_factory=list) + background_tasks: set[asyncio.Task[Any]] = field(default_factory=set) + + @property + def tool_completed(self) -> bool: + """Backward-compatible name for a completed science tool call.""" + return self.science_tool_completed + + def reset(self) -> None: + self.science_tool_completed = False + self.submitted_result = False + self.finished_turn = False + self.executed_tool_names.clear() + self.action_tool_names.clear() + self.science_tool_names.clear() + + def record_action(self, name: str) -> None: + self.executed_tool_names.append(name) + self.action_tool_names.append(name) + + def record_science(self, name: str) -> None: + self.executed_tool_names.append(name) + self.science_tool_names.append(name) + + +class SendMessageArgs(BaseModel): + """Arguments for the LM-visible peer-message action.""" + + model_config = ConfigDict(extra="forbid") + + recipient: str = Field( + min_length=1, + description="Allowed peer agent name that should receive this message.", + ) + tldr: str = Field( + min_length=1, + max_length=160, + description="One-line user-visible summary for dashboard edge labels.", + ) + content: str = Field( + min_length=1, + max_length=1800, + description="Full peer message content with concise evidence summaries.", + ) + artifact_refs: list[str] = Field( + default_factory=list, + description="JSON array of artifact path strings cited by this message.", + ) + tool_result_ids: list[str] = Field( + default_factory=list, + description="JSON array of ChemGraph tool_result_id strings cited by this message.", + ) + reason: str = Field( + min_length=1, + max_length=600, + description="Non-empty sentence explaining why this peer needs the message now.", + ) + confidence: float = Field( + ge=0, + le=1, + description="Numeric confidence from 0 to 1.", + ) + + +class AskPeerArgs(BaseModel): + """Arguments for asking a peer a question.""" + + model_config = ConfigDict(extra="forbid") + + recipient: str = Field(min_length=1) + tldr: str = Field( + min_length=1, + max_length=160, + description="One-line user-visible summary for dashboard edge labels.", + ) + question: str = Field(min_length=1, max_length=900) + reason: str = Field(min_length=1, max_length=600) + + +class SubmitResultArgs(BaseModel): + """Arguments for submitting a logical agent's current result.""" + + model_config = ConfigDict(extra="forbid") + + summary: str = Field(min_length=1, max_length=1200) + artifact_refs: list[str] = Field(default_factory=list) + tool_result_ids: list[str] = Field(default_factory=list) + supporting_message_ids: list[str] = Field(default_factory=list) + confidence: float = Field(ge=0, le=1) + reason: str = Field(min_length=1, max_length=600) + + +class FinishTurnArgs(BaseModel): + """Arguments for ending the current logical-agent turn.""" + + model_config = ConfigDict(extra="forbid") + + reason: str = Field(min_length=1, max_length=600) + + +def _stable_validation_errors(exc: ValidationError) -> list[dict[str, str]]: + """Project Pydantic validation errors to a stable LM-facing shape.""" + return [ + { + "field": ".".join(str(part) for part in error.get("loc", ())), + "message": str(error.get("msg", "invalid value")), + } + for error in exc.errors() + ] + + +def _invalid_args_response( + tool_name: str, + exc: ValidationError, + trace: TraceFn, +) -> dict[str, Any]: + payload = { + "tool_name": tool_name, + "status": "failed", + "error": "invalid_tool_arguments", + "error_type": "invalid_tool_arguments", + "errors": _stable_validation_errors(exc), + } + trace("tool_call_failed", payload) + return {**payload, "status": "error"} + + +def _disallowed_recipient_response( + tool_name: str, + recipient: str, + allowed: tuple[str, ...], + trace: TraceFn, +) -> dict[str, Any]: + payload = { + "tool_name": tool_name, + "status": "failed", + "error": "disallowed_recipient", + "error_type": "disallowed_recipient", + "recipient": recipient, + "allowed_peers": list(allowed), + } + trace("tool_call_failed", payload) + return {**payload, "status": "error"} + + +def _compact_for_lm(value: Any, *, max_chars: int = 4000) -> Any: + """Return a JSON-safe, size-bounded value for tool feedback.""" + try: + text = json.dumps(value, sort_keys=True) + except TypeError: + text = repr(value) + if len(text) <= max_chars: + try: + return json.loads(text) + except json.JSONDecodeError: + return text + return { + "truncated": True, + "preview": text[:max_chars], + "full_result_location": "tool_results.jsonl", + } + + +async def build_chemgraph_reasoning_tools( + *, + spec: ChemGraphAgentSpec, + run_dir: pathlib.Path, + tool_invoker: CampaignFastMCPToolInvoker, + peer_names: tuple[str, ...], + peer_handles: Mapping[str, Handle[Any]], + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + get_round_index: Callable[[], int], + set_final_result: SetFinalResultFn, + trace: TraceFn, + runtime_state: ReasoningToolRuntimeState, +) -> list[BaseTool]: + """Build explicit tools for one ChemGraph-backed reasoning turn.""" + + async def _send_message_impl( + *, + recipient: str, + tldr: str, + content: str, + artifact_refs: list[str], + tool_result_ids: list[str], + reason: str, + confidence: float, + kind: str, + ) -> dict[str, Any]: + if recipient not in peer_names: + raise ValueError( + f"{spec.name} tried to message disallowed peer {recipient}", + ) + message = build_message( + sender=spec.name, + recipient=recipient, + content=content, + round_index=get_round_index(), + kind=kind, + tldr=tldr, + artifact_refs=artifact_refs, + tool_result_ids=tool_result_ids, + reason=reason, + confidence=confidence, + ) + outbox.append(message) + append_jsonl(run_dir / "messages.jsonl", message) + trace("message_sent", message) + if recipient not in peer_handles: + raise RuntimeError(f"No Academy handle for allowed peer {recipient}") + task = asyncio.create_task( + _deliver_message( + recipient=recipient, + message=message, + handle=peer_handles[recipient], + trace=trace, + ), + ) + runtime_state.background_tasks.add(task) + task.add_done_callback(runtime_state.background_tasks.discard) + return { + "status": "sent", + "delivery": "queued", + "message_id": message["message_id"], + "recipient": recipient, + } + + async def _deliver_message( + *, + recipient: str, + message: dict[str, Any], + handle: Handle[Any], + trace: TraceFn, + ) -> None: + try: + await handle.action("receive_message", message) + except Exception as exc: # noqa: BLE001 - preserve async delivery failure. + trace( + "message_delivery_failed", + { + "recipient": recipient, + "message_id": message["message_id"], + "error": repr(exc), + }, + ) + return + trace( + "message_delivered", + { + "recipient": recipient, + "message_id": message["message_id"], + }, + ) + + def _validation_error_handler(tool_name: str) -> Callable[[ValidationError], dict[str, Any]]: + def handle(exc: ValidationError) -> dict[str, Any]: + runtime_state.record_action(tool_name) + return _invalid_args_response(tool_name, exc, trace) + + return handle + + async def send_message(**kwargs: Any) -> dict[str, Any]: + runtime_state.record_action("send_message") + try: + args = SendMessageArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("send_message", exc, trace) + if args.recipient not in peer_names: + return _disallowed_recipient_response( + "send_message", + args.recipient, + peer_names, + trace, + ) + return await _send_message_impl( + recipient=args.recipient, + tldr=args.tldr, + content=args.content, + artifact_refs=args.artifact_refs, + tool_result_ids=args.tool_result_ids, + reason=args.reason, + confidence=args.confidence, + kind="message", + ) + + async def ask_peer(**kwargs: Any) -> dict[str, Any]: + runtime_state.record_action("ask_peer") + try: + args = AskPeerArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("ask_peer", exc, trace) + if args.recipient not in peer_names: + return _disallowed_recipient_response( + "ask_peer", + args.recipient, + peer_names, + trace, + ) + return await _send_message_impl( + recipient=args.recipient, + tldr=args.tldr, + content=args.question, + artifact_refs=[], + tool_result_ids=[], + reason=args.reason, + confidence=0.0, + kind="question", + ) + + async def submit_result(**kwargs: Any) -> dict[str, Any]: + runtime_state.record_action("submit_result") + try: + args = SubmitResultArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("submit_result", exc, trace) + runtime_state.submitted_result = True + result = { + "timestamp": time.time(), + "round": get_round_index(), + "hypothesis": args.summary, + "summary": args.summary, + "artifact_refs": args.artifact_refs, + "tool_result_ids": args.tool_result_ids, + "supporting_message_ids": args.supporting_message_ids, + "supporting_tool_result_ids": args.tool_result_ids, + "confidence": args.confidence, + "reason": args.reason, + } + set_final_result(result) + trace("belief_updated", result) + return {"status": "submitted", "confidence": result["confidence"]} + + async def finish_turn(**kwargs: Any) -> dict[str, Any]: + runtime_state.record_action("finish_turn") + try: + args = FinishTurnArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("finish_turn", exc, trace) + runtime_state.finished_turn = True + trace("turn_finished_without_external_action", {"reason": args.reason}) + return {"status": "finished", "reason": args.reason} + + tools: list[BaseTool] = [ + StructuredTool.from_function( + coroutine=send_message, + name="send_message", + description=( + "Send tool-backed evidence, reasoning, or a request to one " + "allowed peer. Always provide recipient, tldr, content, " + "artifact_refs as an array of strings or [], tool_result_ids " + "as an array of strings or [], a non-empty reason, and numeric " + "confidence from 0 to 1." + ), + args_schema=SendMessageArgs, + handle_validation_error=_validation_error_handler("send_message"), + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + StructuredTool.from_function( + coroutine=ask_peer, + name="ask_peer", + description=( + "Ask an allowed peer for missing information or a tool result " + "needed for the molecule workflow. Always provide recipient, " + "tldr, question, and reason." + ), + args_schema=AskPeerArgs, + handle_validation_error=_validation_error_handler("ask_peer"), + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + StructuredTool.from_function( + coroutine=submit_result, + name="submit_result", + description=( + "Submit this agent's current final answer or report. Cite peer " + "message IDs and ChemGraph tool result IDs." + ), + args_schema=SubmitResultArgs, + handle_validation_error=_validation_error_handler("submit_result"), + return_direct=True, + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + StructuredTool.from_function( + coroutine=finish_turn, + name="finish_turn", + description=( + "End this decision turn when no tool, message, or report action " + "is currently useful." + ), + args_schema=FinishTurnArgs, + handle_validation_error=_validation_error_handler("finish_turn"), + return_direct=True, + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + ] + + fastmcp_schemas = await fastmcp_tool_schemas(list(spec.tools)) + schema_by_name = { + schema["function"]["name"]: schema["function"] + for schema in fastmcp_schemas + if schema.get("type") == "function" + } + + for tool_spec in spec.tools: + function_schema = schema_by_name[tool_spec.name] + + async def run_fastmcp_tool( + __tool_name: str = tool_spec.name, + **kwargs: Any, + ) -> dict[str, Any]: + runtime_state.record_science(__tool_name) + if __tool_name not in spec.tool_names: + raise RuntimeError( + f"{spec.name} cannot call unavailable tool {__tool_name}", + ) + tool_result_id = f"tool-{uuid.uuid4()}" + started = { + "tool_result_id": tool_result_id, + "tool_name": __tool_name, + "arguments": kwargs, + } + trace("tool_call_started", started) + result_record = await tool_invoker.invoke( + ToolInvocation( + tool_name=__tool_name, + arguments=kwargs, + agent_id=spec.name, + role=spec.role, + correlation_id=tool_result_id, + ), + ) + if result_record.status != "success": + failure = { + **started, + "status": "failed", + "error": result_record.error + or "tool returned non-success status", + } + append_jsonl(run_dir / "tool_results.jsonl", failure) + trace("tool_call_failed", failure) + raise RuntimeError(f"{__tool_name} failed: {failure['error']}") + + runtime_state.science_tool_completed = True + record = { + **started, + "timestamp": time.time(), + "agent_name": spec.name, + "status": "ok", + "result": result_record.result, + } + tool_results.append(record) + append_jsonl(run_dir / "tool_results.jsonl", record) + trace("tool_call_finished", record) + return { + "status": "ok", + "tool_result_id": tool_result_id, + "tool_name": __tool_name, + "result": _compact_for_lm(result_record.result), + } + + tools.append( + StructuredTool.from_function( + coroutine=run_fastmcp_tool, + name=tool_spec.name, + description=function_schema.get("description") + or tool_spec.description + or f"Run ChemGraph FastMCP tool {tool_spec.name}.", + args_schema=function_schema.get("parameters") + or {"type": "object", "properties": {}}, + metadata={"chemgraph_academy_tool_kind": "science_tool"}, + ), + ) + + return tools diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py new file mode 100644 index 00000000..0731adb5 --- /dev/null +++ b/src/chemgraph/academy/core/turn.py @@ -0,0 +1,569 @@ +"""Run one Academy logical-agent wakeup through ChemGraph LangGraph.""" + +from __future__ import annotations + +import json +import time +from collections.abc import Callable +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from academy.handle import Handle +from langchain_core.tools import BaseTool + +from chemgraph.academy.core.fastmcp import ( + CampaignFastMCPToolInvoker, +) +from chemgraph.academy.core.tools import ( + ReasoningToolRuntimeState, +) +from chemgraph.academy.core.tools import ( + build_chemgraph_reasoning_tools, +) +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import visible_resources_payload +from chemgraph.academy.core.lm import LLMSettings +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.observability.run_files import read_json_file +from chemgraph.academy.observability.run_files import read_jsonl + +TraceFn = Callable[[str, dict[str, Any]], None] +SetFinalResultFn = Callable[[dict[str, Any]], None] + + +@dataclass(frozen=True) +class ReasoningTurnResult: + """Summary of one ChemGraph-managed logical-agent reasoning turn.""" + + final_text: str + state: dict[str, Any] + tool_calls_completed: int + action_tools_called: tuple[str, ...] + science_tools_called: tuple[str, ...] + executed_tool_names: tuple[str, ...] + requested_finish: bool + requested_self_wake: bool + workflow_span_id: str + thread_id: str + + +class ChemGraphReasoningRoundEngine: + """Use ChemGraph single_agent as the per-wakeup reasoning loop.""" + + def __init__( + self, + *, + campaign: ChemGraphCampaign, + spec: ChemGraphAgentSpec, + llm_settings: LLMSettings, + prompt_profile: PromptProfile, + run_dir: Path, + max_decisions: int, + tools: list[BaseTool], + runtime_state: ReasoningToolRuntimeState, + received_message_history: list[dict[str, Any]], + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + get_final_result: Callable[[], dict[str, Any] | None], + get_round_index: Callable[[], int], + trace: TraceFn, + peer_names: tuple[str, ...] = (), + ) -> None: + self.campaign = campaign + self.spec = spec + self.llm_settings = llm_settings + self.prompt_profile = prompt_profile + self.run_dir = run_dir + self.max_decisions = max_decisions + self.tools = list(tools) + self.runtime_state = runtime_state + self.received_message_history = received_message_history + self.outbox = outbox + self.tool_results = tool_results + self.peer_names = tuple(peer_names) + self.get_final_result = get_final_result + self.get_round_index = get_round_index + self.trace = trace + + @classmethod + async def create( + cls, + *, + campaign: ChemGraphCampaign, + spec: ChemGraphAgentSpec, + llm_settings: LLMSettings, + prompt_profile: PromptProfile, + run_dir: Path, + max_decisions: int, + tool_invoker: CampaignFastMCPToolInvoker, + peer_names: tuple[str, ...], + peer_handles: Mapping[str, Handle[Any]], + received_message_history: list[dict[str, Any]], + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + get_final_result: Callable[[], dict[str, Any] | None], + get_round_index: Callable[[], int], + set_final_result: SetFinalResultFn, + trace: TraceFn, + ) -> "ChemGraphReasoningRoundEngine": + runtime_state = ReasoningToolRuntimeState() + tools = await build_chemgraph_reasoning_tools( + spec=spec, + run_dir=run_dir, + tool_invoker=tool_invoker, + peer_names=peer_names, + peer_handles=peer_handles, + outbox=outbox, + tool_results=tool_results, + get_round_index=get_round_index, + set_final_result=set_final_result, + trace=trace, + runtime_state=runtime_state, + ) + return cls( + campaign=campaign, + spec=spec, + llm_settings=llm_settings, + prompt_profile=prompt_profile, + run_dir=run_dir, + max_decisions=max_decisions, + tools=tools, + runtime_state=runtime_state, + received_message_history=received_message_history, + outbox=outbox, + tool_results=tool_results, + peer_names=peer_names, + get_final_result=get_final_result, + get_round_index=get_round_index, + trace=trace, + ) + + async def run_turn(self) -> ReasoningTurnResult: + """Run one turn-local ChemGraph workflow for the current wakeup.""" + from chemgraph.agent.llm_agent import ChemGraph + from chemgraph.observability.events import WorkflowEventContext + from chemgraph.observability.events import emit_workflow_event + from chemgraph.observability.events import new_span_id + from chemgraph.observability.events import workflow_event_context + + round_index = self.get_round_index() + thread_id = f"{self.spec.name}-round-{round_index}" + workflow_span_id = new_span_id(f"chemgraph-turn-{self.spec.name}") + parent_span_id = f"academy-round-{self.spec.name}-{round_index}" + query = self.build_wakeup_query(round_index=round_index) + log_dir = ( + self.run_dir + / "chemgraph_turns" + / f"{self.spec.name}-round-{round_index:04d}" + ) + log_dir.mkdir(parents=True, exist_ok=True) + + self.runtime_state.reset() + self.trace( + "chemgraph_reasoning_turn_started", + { + "round": round_index, + "thread_id": thread_id, + "workflow_span_id": workflow_span_id, + "tool_names": [tool.name for tool in self.tools], + }, + ) + context = WorkflowEventContext( + run_id=self.run_dir.name, + run_dir=str(self.run_dir), + agent_id=self.spec.name, + role=self.spec.role, + parent_span_id=parent_span_id, + tool_name=None, + ) + + with workflow_event_context( + jsonl_path=self.run_dir / "events.jsonl", + context=context, + ): + emit_workflow_event( + "workflow_started", + { + "workflow_type": "single_agent", + "workflow_node": "ChemGraphReasoningRoundEngine", + "round": round_index, + "thread_id": thread_id, + "tool_names": [tool.name for tool in self.tools], + "log_dir": str(log_dir), + }, + span_id=workflow_span_id, + parent_span_id=parent_span_id, + ) + agent = ChemGraph( + model_name=self.llm_settings.model, + workflow_type="single_agent", + base_url=self.llm_settings.base_url, + api_key=self.llm_settings.api_key, + argo_user=self.llm_settings.user, + system_prompt=self.prompt_profile.system_prompt, + return_option="state", + recursion_limit=self.prompt_profile.langchain_recursion_limit, + tools=self.tools, + terminal_tool_names=("finish_turn", "submit_result"), + enable_memory=False, + log_dir=str(log_dir), + ) + try: + state = await agent.run( + query, + config={"configurable": {"thread_id": thread_id}}, + workflow_span_id=workflow_span_id, + ) + except Exception as exc: + emit_workflow_event( + "workflow_finished", + { + "workflow_type": "single_agent", + "workflow_node": "ChemGraphReasoningRoundEngine", + "round": round_index, + "thread_id": thread_id, + "status": "failed", + "error": repr(exc), + "log_dir": str(log_dir), + }, + span_id=workflow_span_id, + parent_span_id=parent_span_id, + ) + raise + else: + state = _ensure_state_dict(state) + emit_workflow_event( + "workflow_finished", + { + "workflow_type": "single_agent", + "workflow_node": "ChemGraphReasoningRoundEngine", + "round": round_index, + "thread_id": thread_id, + "status": "completed", + "log_dir": str(log_dir), + }, + span_id=workflow_span_id, + parent_span_id=parent_span_id, + ) + + if not self.runtime_state.executed_tool_names: + raise RuntimeError( + "ChemGraph reasoning turn returned without calling an " + "Academy action or science tool; logical agents must call " + "finish_turn when no external action is useful.", + ) + + result = ReasoningTurnResult( + final_text=_extract_final_text(state), + state=state, + tool_calls_completed=len(self.runtime_state.executed_tool_names), + action_tools_called=tuple(self.runtime_state.action_tool_names), + science_tools_called=tuple(self.runtime_state.science_tool_names), + executed_tool_names=tuple(self.runtime_state.executed_tool_names), + requested_finish=self.runtime_state.finished_turn, + requested_self_wake=self.runtime_state.science_tool_completed, + workflow_span_id=workflow_span_id, + thread_id=thread_id, + ) + self.trace( + "chemgraph_reasoning_turn_finished", + { + "round": round_index, + "thread_id": thread_id, + "workflow_span_id": workflow_span_id, + "action_tools_called": list(result.action_tools_called), + "science_tools_called": list(result.science_tools_called), + "requested_finish": result.requested_finish, + "requested_self_wake": result.requested_self_wake, + }, + ) + return result + + def build_wakeup_query(self, *, round_index: int) -> str: + """Build the user message for one ChemGraph turn.""" + state = self.build_wakeup_state(round_index=round_index) + return json.dumps(state, sort_keys=True) + + def build_wakeup_state(self, *, round_index: int) -> dict[str, Any]: + """Build the exact state visible to the logical agent this turn.""" + limits = self.prompt_profile.state_limits + return { + "campaign": self.campaign.run_id, + "user_task": self.campaign.user_task, + "agent_name": self.spec.name, + "role": self.spec.role, + "mission": self.spec.mission, + "round": round_index, + "max_decisions": self.max_decisions, + "resources": visible_resources_payload(self.campaign, self.spec), + "allowed_peers": list(self.spec.allowed_peers), + "peer_status": build_peer_status( + run_dir=self.run_dir, + peer_names=self.peer_names, + ), + "available_chemgraph_tools": list(self.spec.tool_names), + "received_messages": ( + self.received_message_history[ + -limits.received_messages_last_n : + ] + if limits.received_messages_last_n + else [] + ), + "local_chemgraph_tool_results": ( + self.tool_results[-limits.tool_results_last_n :] + if limits.tool_results_last_n + else [] + ), + "recent_actions": build_recent_actions( + outbox=self.outbox, + tool_results=self.tool_results, + limit=limits.actions_last_n, + ), + "current_final_result": self.get_final_result(), + "required_protocol": self.prompt_profile.protocol_prompt, + } + + +def build_peer_status( + *, + run_dir: Path, + peer_names: tuple[str, ...], + event_scan_limit: int = 1000, +) -> dict[str, dict[str, Any]]: + """Return compact status snapshots for peers visible to this agent.""" + if not peer_names: + return {} + + now = time.time() + peers = set(peer_names) + status: dict[str, dict[str, Any]] = { + peer: _status_from_agent_file(run_dir, peer, now=now) + for peer in peer_names + } + + for event in read_jsonl(run_dir / "events.jsonl")[-event_scan_limit:]: + agent_id = event.get("agent_id") + if agent_id not in peers: + continue + kind = str(event.get("event") or "") + timestamp = _float_or_none(event.get("timestamp")) + payload = event.get("payload") + payload = payload if isinstance(payload, dict) else {} + peer_status = status[str(agent_id)] + + if kind == "round_started": + peer_status["state"] = "busy" + peer_status["current_activity"] = { + "type": "reasoning_round", + "round": payload.get("round"), + "started_at": timestamp, + } + _set_update_age(peer_status, timestamp, now=now) + elif kind == "tool_call_started": + peer_status["state"] = "busy" + peer_status["current_activity"] = { + "type": "tool_call", + "tool_name": payload.get("tool_name"), + "tool_result_id": payload.get("tool_result_id"), + "tool_call_id": payload.get("tool_call_id"), + "started_at": timestamp, + } + _set_update_age(peer_status, timestamp, now=now) + elif kind in {"tool_call_finished", "tool_call_failed"}: + peer_status["state"] = "busy" + peer_status["current_activity"] = { + "type": "reasoning_after_tool", + "last_tool": payload.get("tool_name"), + "tool_result_id": payload.get("tool_result_id"), + "status": payload.get("status"), + "updated_at": timestamp, + } + _set_update_age(peer_status, timestamp, now=now) + elif kind == "message_sent": + peer_status["last_outbox_tldr"] = ( + payload.get("tldr") or _preview(payload.get("content")) + ) + peer_status["last_outbox_message_id"] = payload.get("message_id") + _set_update_age(peer_status, timestamp, now=now) + elif kind == "belief_updated": + peer_status["last_belief"] = _compact_belief(payload) + _set_update_age(peer_status, timestamp, now=now) + elif kind in { + "round_finished", + "turn_finished_without_external_action", + "workflow_finished", + }: + if kind == "workflow_finished" and payload.get("status") == "failed": + peer_status["state"] = "error" + else: + peer_status["state"] = "idle" + peer_status["current_activity"] = None + _set_update_age(peer_status, timestamp, now=now) + elif kind == "agent_error": + peer_status["state"] = "error" + peer_status["last_error"] = payload.get("error") + peer_status["current_activity"] = None + _set_update_age(peer_status, timestamp, now=now) + elif kind == "daemon_stopped": + peer_status["state"] = "finished" + peer_status["finished"] = True + peer_status["current_activity"] = None + _set_update_age(peer_status, timestamp, now=now) + + return status + + +def _status_from_agent_file( + run_dir: Path, + peer_name: str, + *, + now: float, +) -> dict[str, Any]: + data = read_json_file( + run_dir / "agent_status" / f"{peer_name}.json", + default={}, + ) + state = "unknown" + if data: + if data.get("last_error"): + state = "error" + elif data.get("finished") is True: + state = "finished" + else: + state = "idle" + timestamp = _float_or_none(data.get("status_updated_at")) + return { + "state": state, + "round": data.get("round"), + "finished": bool(data.get("finished")) if data else False, + "last_error": data.get("last_error"), + "current_activity": data.get("current_activity"), + "seconds_since_update": _age(timestamp, now=now), + "last_outbox_tldr": _last_outbox_tldr(data), + "last_outbox_message_id": _last_outbox_message_id(data), + "last_belief": _compact_belief(data.get("belief")), + } + + +def _last_outbox_tldr(data: Mapping[str, Any]) -> str | None: + recent = data.get("recent_outbox") + if not isinstance(recent, list) or not recent: + return None + last = recent[-1] + if not isinstance(last, dict): + return None + return last.get("tldr") or _preview(last.get("content")) + + +def _last_outbox_message_id(data: Mapping[str, Any]) -> str | None: + recent = data.get("recent_outbox") + if not isinstance(recent, list) or not recent: + return None + last = recent[-1] + if not isinstance(last, dict): + return None + value = last.get("message_id") + return str(value) if value else None + + +def _compact_belief(value: Any) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + summary = value.get("summary") or value.get("hypothesis") + if not summary: + return None + return { + "summary": _preview(summary, max_chars=220), + "confidence": value.get("confidence"), + } + + +def _set_update_age( + peer_status: dict[str, Any], + timestamp: float | None, + *, + now: float, +) -> None: + peer_status["seconds_since_update"] = _age(timestamp, now=now) + + +def _age(timestamp: float | None, *, now: float) -> float | None: + if timestamp is None: + return None + return max(0.0, round(now - timestamp, 3)) + + +def _float_or_none(value: Any) -> float | None: + if isinstance(value, bool) or value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def build_recent_actions( + *, + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + limit: int, +) -> list[dict[str, Any]]: + """Build a compact chronological action history for LM prompt state.""" + if limit <= 0: + return [] + + actions: list[dict[str, Any]] = [] + for message in outbox[-limit:]: + kind = str(message.get("kind") or "message") + action_type = "ask_peer" if kind == "question" else "send_message" + actions.append( + { + "type": action_type, + "recipient": message.get("recipient"), + "tldr": message.get("tldr") or _preview(message.get("content")), + "message_id": message.get("message_id"), + "timestamp": message.get("timestamp"), + }, + ) + + for result in tool_results[-limit:]: + actions.append( + { + "type": "tool_call", + "tool_name": result.get("tool_name"), + "tool_result_id": result.get("tool_result_id"), + "status": result.get("status"), + "timestamp": result.get("timestamp"), + }, + ) + + actions.sort(key=lambda item: float(item.get("timestamp") or 0.0)) + return actions[-limit:] + + +def _preview(value: Any, *, max_chars: int = 160) -> str: + text = "" if value is None else str(value) + if len(text) <= max_chars: + return text + return text[: max_chars - 1] + "..." + + +def _ensure_state_dict(state: Any) -> dict[str, Any]: + if isinstance(state, dict): + return state + return {"value": state} + + +def _extract_final_text(state: Mapping[str, Any]) -> str: + messages = state.get("messages") + if not isinstance(messages, list) or not messages: + return "" + last = messages[-1] + if isinstance(last, dict): + content = last.get("content") + return "" if content is None else str(content) + content = getattr(last, "content", None) + return "" if content is None else str(content) diff --git a/src/chemgraph/academy/dashboard/__init__.py b/src/chemgraph/academy/dashboard/__init__.py new file mode 100644 index 00000000..27eb75b0 --- /dev/null +++ b/src/chemgraph/academy/dashboard/__init__.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from chemgraph.academy.dashboard.server import DashboardHandler +from chemgraph.academy.dashboard.server import events_payload +from chemgraph.academy.dashboard.server import main +from chemgraph.academy.dashboard.server import parse_args +from chemgraph.academy.dashboard.server import serve_dashboard +from chemgraph.academy.dashboard.server import snapshot +from chemgraph.academy.dashboard.server import status_payload + +__all__ = [ + 'DashboardHandler', + 'events_payload', + 'main', + 'parse_args', + 'serve_dashboard', + 'snapshot', + 'status_payload', +] diff --git a/src/chemgraph/academy/dashboard/__main__.py b/src/chemgraph/academy/dashboard/__main__.py new file mode 100644 index 00000000..a9021b2a --- /dev/null +++ b/src/chemgraph/academy/dashboard/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from chemgraph.academy.dashboard.server import main + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/src/chemgraph/academy/dashboard/server.py b/src/chemgraph/academy/dashboard/server.py new file mode 100644 index 00000000..141a0f90 --- /dev/null +++ b/src/chemgraph/academy/dashboard/server.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import argparse +import socket +import json +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from importlib.resources import files +from pathlib import Path +from typing import Any + +from chemgraph.academy.observability.event_log import read_events +from chemgraph.academy.observability.run_files import read_json_file +from chemgraph.academy.observability.run_artifacts import write_run_artifacts + +_STATIC_CACHE: dict[str, bytes] = {} + + +def _static_file(name: str, content_type: str) -> tuple[bytes, str]: + if name not in _STATIC_CACHE: + resource = files('chemgraph.academy.dashboard').joinpath( + 'static', + name, + ) + _STATIC_CACHE[name] = resource.read_bytes() + return _STATIC_CACHE[name], content_type + + +class DashboardHandler(BaseHTTPRequestHandler): + run_dir: Path + + def do_GET(self) -> None: + path = self.path.split('?', 1)[0] + if path in {'/', '/index.html'}: + body, content_type = _static_file('index.html', 'text/html; charset=utf-8') + self._send_bytes(200, body, content_type) + return + if path == '/static/app.js': + body, content_type = _static_file( + 'app.js', + 'application/javascript; charset=utf-8', + ) + self._send_bytes(200, body, content_type) + return + if path == '/api/status': + self._send_json(200, status_payload(self)) + return + if path == '/api/events': + self._send_json(200, events_payload(self.run_dir)) + return + if path == '/api/snapshot': + self._send_json(200, snapshot(self)) + return + self._send_json(404, {'error': 'not found'}) + + def log_message(self, format: str, *args: Any) -> None: + return + + def _send_json(self, status: int, payload: dict[str, Any]) -> None: + body = json.dumps(payload, indent=2, sort_keys=True).encode('utf-8') + self._send_bytes(status, body, 'application/json') + + def _send(self, status: int, body: str, content_type: str) -> None: + self._send_bytes(status, body.encode('utf-8'), content_type) + + def _send_bytes(self, status: int, body: bytes, content_type: str) -> None: + try: + self.send_response(status) + self.send_header('Content-Type', content_type) + self.send_header('Content-Length', str(len(body))) + self.end_headers() + self.wfile.write(body) + except (BrokenPipeError, ConnectionResetError, socket.timeout): + return + + +def snapshot(handler: DashboardHandler) -> dict[str, Any]: + data = status_payload(handler) + data.update(events_payload(handler.run_dir)) + return data + + +def status_payload(handler: DashboardHandler) -> dict[str, Any]: + run_dir = handler.run_dir + status_path = run_dir / "status.json" + status: dict[str, Any] = {} + if status_path.exists(): + try: + status = json.loads(status_path.read_text(encoding="utf-8")) + except json.JSONDecodeError: + status = {} + artifacts = write_run_artifacts(run_dir) + manifest = read_json_file(run_dir / "manifest.json", default={}) + updated = status.get("updated") or status.get("timestamp") + schema = ( + status.get("mode") + or (manifest.get("mode") if isinstance(manifest, dict) else None) + or "canonical_events" + ) + return { + "run_dir": str(run_dir), + "updated": updated, + "schema": schema, + "status": status, + "placement": artifacts["placement"], + "communication_proof": artifacts["communication_proof"], + "summary": artifacts["summary"], + } + + +def events_payload(run_dir: Path) -> dict[str, Any]: + events = [ + event.model_dump(mode="json") for event in read_events(run_dir / "events.jsonl") + ] + return { + "run_dir": str(run_dir), + "events": events, + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--run-dir", required=True) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8765) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + return serve_dashboard( + run_dir=Path(args.run_dir).resolve(), + host=args.host, + port=args.port, + ) + + +def serve_dashboard(*, run_dir: Path, host: str, port: int) -> int: + DashboardHandler.run_dir = run_dir + server = ThreadingHTTPServer((host, port), DashboardHandler) + print(f"Serving {run_dir} at http://{host}:{port}", flush=True) + try: + server.serve_forever() + except KeyboardInterrupt: + print("\nStopping dashboard.", flush=True) + finally: + server.server_close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/academy/dashboard/static/app.js b/src/chemgraph/academy/dashboard/static/app.js new file mode 100644 index 00000000..a06f2b30 --- /dev/null +++ b/src/chemgraph/academy/dashboard/static/app.js @@ -0,0 +1,3053 @@ + let snapshot = null; + let snapshotIdentity = null; + let selectedAgent = null; + let selectedEdgeKey = null; + let selectedActivityEventKey = null; + let timelineIndex = null; + let followLatest = true; + let graphMode = 'recent'; + let isReplaying = false; + let replayTimer = null; + let replayStartedAtMs = 0; + let replayStartTimestamp = null; + let replayStartIndex = 0; + let graphView = null; + let graphPanDrag = null; + let embeddedWorkflowView = null; + let embeddedWorkflowPanDrag = null; + let selectedEmbeddedWorkflowEventKey = null; + let embeddedWorkflowAnchorKey = null; + let workflowPanelOpen = true; + let workflowPanelFrame = {x: 28, y: 76, width: 1120, height: 640}; + let workflowPanelDrag = null; + let workflowPanelResizeDrag = null; + let detailResizeDrag = null; + let lastRenderedDetailIdentity = null; + let lastEmbeddedWorkflowInspectorIdentity = null; + const recentMessageWindow = 4; + const actionToolNames = new Set(['send_message', 'ask_peer', 'submit_result', 'finish_turn']); + const renderedHtmlCache = new WeakMap(); + + const esc = (s) => String(s ?? '').replace(/[&<>"']/g, c => ({'&':'&','<':'<','>':'>','"':'"',"'":'''}[c])); + const trunc = (s, n=180) => { + s = String(s ?? ''); + return s.length > n ? s.slice(0, n - 1) + '…' : s; + }; + const formatTime = (timestamp) => timestamp ? new Date(timestamp * 1000).toLocaleTimeString() : '-'; + const eventTimestamp = (event) => { + const value = Number(event?.timestamp); + return Number.isFinite(value) ? value : null; + }; + const replaySpeed = () => Number(document.getElementById('replaySpeed')?.value || 25); + const eventHoldSeconds = () => Number(document.getElementById('eventHold')?.value || 8); + + function estimateLabelWidth(text) { + return Math.min(460, Math.max(20, String(text || '').length * 6.2 + 14)); + } + + function labelBox(x, y, text) { + const width = estimateLabelWidth(text); + const height = 18; + return { + x1: x - width / 2, + y1: y - height + 5, + x2: x + width / 2, + y2: y + 7, + }; + } + + function boxesOverlap(a, b, pad = 5) { + return !(a.x2 + pad < b.x1 || b.x2 + pad < a.x1 || a.y2 + pad < b.y1 || b.y2 + pad < a.y1); + } + + function placeEdgeLabel(baseX, baseY, rawLabel, occupiedBoxes, force = false) { + const text = trunc(String(rawLabel || ''), force ? 96 : 72); + if (!text) return null; + const candidates = [ + [0, 0], [0, -18], [0, 18], [30, -12], [-30, 12], + [44, 18], [-44, -18], [0, -34], [0, 34], + ]; + for (const [dx, dy] of candidates) { + const x = baseX + dx; + const y = baseY + dy; + const box = labelBox(x, y, text); + const collision = occupiedBoxes.some(other => boxesOverlap(box, other)); + if (force || !collision) { + occupiedBoxes.push(box); + return {x, y, text}; + } + } + return null; + } + + async function load() { + const [statusRes, eventsRes] = await Promise.all([ + fetch('/api/status'), + fetch('/api/events'), + ]); + const statusData = await statusRes.json(); + const eventsData = await eventsRes.json(); + const nextSnapshot = {...statusData, events: eventsData.events || []}; + const nextIdentity = identityForSnapshot(nextSnapshot); + const previousEventCount = snapshot?.events?.length || 0; + const nextEventCount = nextSnapshot.events.length; + if ( + snapshotIdentity !== null + && ( + nextIdentity !== snapshotIdentity + || nextEventCount < previousEventCount + ) + ) { + resetInteractionState(); + } + snapshotIdentity = nextIdentity; + snapshot = nextSnapshot; + const latest = allEvents().length - 1; + if (followLatest || timelineIndex === null) { + timelineIndex = latest; + } else { + timelineIndex = Math.min(timelineIndex, latest); + } + render(); + } + + function identityForSnapshot(data) { + return [ + data?.schema || '', + data?.run_dir || '', + ].join('|'); + } + + function resetInteractionState() { + stopReplay(false); + selectedAgent = null; + selectedEdgeKey = null; + selectedActivityEventKey = null; + timelineIndex = null; + followLatest = true; + graphView = null; + graphPanDrag = null; + embeddedWorkflowView = null; + embeddedWorkflowPanDrag = null; + lastRenderedDetailIdentity = null; + lastEmbeddedWorkflowInspectorIdentity = null; + selectedEmbeddedWorkflowEventKey = null; + embeddedWorkflowAnchorKey = null; + } + + function allEvents() { + return snapshot?.events || []; + } + + function isWorkflowMode() { + return snapshot?.schema === 'chemgraph_workflow'; + } + + function currentEventIndex() { + const events = allEvents(); + if (!events.length) return -1; + if (timelineIndex === null) return events.length - 1; + return Math.max(0, Math.min(timelineIndex, events.length - 1)); + } + + function visibleEvents() { + const index = currentEventIndex(); + return index < 0 ? [] : allEvents().slice(0, index + 1); + } + + function currentEvent() { + const index = currentEventIndex(); + return index < 0 ? null : allEvents()[index]; + } + + function eventKey(event) { + if (!event) return ''; + if (event.event_id) return String(event.event_id); + const payload = JSON.stringify(event.payload || {}); + return [ + event.timestamp ?? '', + event.event || '', + event.agent_id || '', + event.correlation_id || '', + payload.slice(0, 220), + ].join('|'); + } + + function firstTimestamp() { + const first = allEvents().find(event => eventTimestamp(event) !== null); + return eventTimestamp(first); + } + + function currentTimestamp() { + return eventTimestamp(currentEvent()); + } + + function eventIndexAtTimestamp(timestamp) { + const events = allEvents(); + if (!events.length) return -1; + if (timestamp === null || timestamp === undefined || Number.isNaN(timestamp)) { + return Math.min(events.length - 1, replayStartIndex + 1); + } + let index = 0; + for (let i = 0; i < events.length; i += 1) { + const ts = eventTimestamp(events[i]); + if (ts === null) { + index = i; + continue; + } + if (ts <= timestamp) index = i; + else break; + } + return index; + } + + function activeWindowEvents(multiplier = 1) { + const events = visibleEvents(); + const now = currentTimestamp(); + if (now === null) return events.slice(-Math.max(1, recentMessageWindow)); + const hold = eventHoldSeconds() * multiplier; + return events.filter(event => { + const ts = eventTimestamp(event); + return ts !== null && ts <= now && now - ts <= hold; + }); + } + + function eventsOf(type) { + return visibleEvents().filter(e => e.event === type); + } + + function graphMessageEvents() { + const sent = eventsOf('message_sent'); + if (graphMode === 'cumulative') return sent; + if (graphMode === 'current') { + const event = currentEvent(); + const activeMessages = activeWindowEvents(1).filter(e => e.event === 'message_sent'); + if (!event) return activeMessages; + if (activeMessages.length) return activeMessages; + if (event.event === 'message_sent') return [event]; + if (event.event === 'message_received') { + const messageId = event.payload?.message_id; + return sent.filter(item => item.payload?.message_id === messageId).slice(-1); + } + return []; + } + const now = currentTimestamp(); + if (now !== null) { + const windowSeconds = Math.max(eventHoldSeconds() * 2, 8); + const windowed = sent.filter(event => { + const ts = eventTimestamp(event); + return ts !== null && ts <= now && now - ts <= windowSeconds; + }); + if (windowed.length) return windowed; + } + return sent.slice(-recentMessageWindow); + } + + function graphModeLabel() { + if (graphMode === 'current') return `showing active events for ${eventHoldSeconds()}s`; + if (graphMode === 'cumulative') return 'showing all prior communication'; + return `showing recent communication window`; + } + + function latestEventOf(type, agentId = null) { + const matches = visibleEvents().filter(e => e.event === type && (!agentId || e.agent_id === agentId)); + return matches.length ? matches[matches.length - 1] : null; + } + + function agents() { + if (!snapshot) return []; + const specs = snapshot.status?.agents || []; + const currentEvents = visibleEvents(); + const visiblePlacements = {}; + currentEvents.forEach(event => { + if (event.event !== 'agent_started' || !event.agent_id) return; + const placement = event.payload?.placement; + if (placement) visiblePlacements[event.agent_id] = placement; + }); + const finalPlacements = snapshot.placement?.agents || {}; + return specs.map(spec => { + const agentId = spec.agent_id || spec.agent_name || spec.name; + return { + ...spec, + agent_id: agentId, + agent_name: spec.agent_name || agentId, + ...agentStateAt(agentId, currentEvents), + placement: visiblePlacements[agentId] || finalPlacements[agentId] || spec.placement || {}, + }; + }).filter(agent => agent.agent_id); + } + + function agentStateAt(agentId, events) { + const state = { + started: false, + last_error: null, + decision_count: 0, + received_message_count: 0, + outbox_count: 0, + tool_started_count: 0, + tool_finished_count: 0, + }; + events.forEach(event => { + if (event.agent_id !== agentId) return; + if (event.event === 'agent_started') state.started = true; + if (event.event === 'agent_error') state.last_error = event.payload?.error || 'agent_error'; + if (event.event === 'agent_decision') state.decision_count += 1; + if (event.event === 'message_received') state.received_message_count += 1; + if (event.event === 'message_sent') state.outbox_count += 1; + if (event.event === 'tool_call_started') state.tool_started_count += 1; + if (event.event === 'tool_call_finished' || event.event === 'tool_call_failed') state.tool_finished_count += 1; + }); + return state; + } + + function agentHost(agent) { + return agent?.placement?.short_hostname || agent?.placement?.hostname || (agent?.started ? 'unknown host' : 'pending'); + } + + function hostColor(index) { + const colors = ['#dbeafe', '#dcfce7', '#fef3c7', '#fce7f3', '#e0e7ff', '#ccfbf1', '#fee2e2', '#ede9fe']; + return colors[index % colors.length]; + } + + function hostStroke(index) { + const colors = ['#2563eb', '#16a34a', '#d97706', '#db2777', '#4f46e5', '#0f766e', '#dc2626', '#7c3aed']; + return colors[index % colors.length]; + } + + function render() { + const detailScroll = captureDetailScrollSnapshot(); + document.getElementById('updated').textContent = snapshot.updated ? new Date(snapshot.updated * 1000).toLocaleTimeString() : ''; + document.getElementById('runPath').textContent = snapshot.run_dir || ''; + document.getElementById('graphTitle').textContent = isWorkflowMode() ? 'ChemGraph Workflow' : 'Agent Graph'; + renderTimeline(); + renderMetrics(); + renderGraph(); + renderAgentPicker(); + renderDetail(); + renderEmbeddedWorkflowPanel(); + restoreDetailScrollSnapshot(detailScroll); + lastRenderedDetailIdentity = currentDetailIdentity(); + } + + function renderTimeline() { + const events = allEvents(); + const slider = document.getElementById('timeSlider'); + const index = currentEventIndex(); + slider.max = String(Math.max(0, events.length - 1)); + slider.value = String(Math.max(0, index)); + slider.disabled = events.length === 0; + const event = index >= 0 ? events[index] : null; + const mode = isReplaying ? 'replay' : followLatest ? 'latest' : `event ${index + 1}`; + document.getElementById('timeLabel').textContent = `${mode} / ${events.length}`; + document.getElementById('timeEvent').textContent = event + ? `${formatTime(event.timestamp)} ${event.event}${event.agent_id ? ` · ${event.agent_id}` : ''} · ${graphModeLabel()}` + : ''; + document.getElementById('playReplay').textContent = isReplaying ? 'Pause' : 'Replay'; + document.querySelectorAll('#graphMode button').forEach(button => { + button.classList.toggle('active', button.dataset.mode === graphMode); + }); + } + + function renderMetrics() { + if (isWorkflowMode()) { + renderWorkflowMetrics(); + return; + } + const events = visibleEvents(); + const counts = {}; + events.forEach(event => { counts[event.event] = (counts[event.event] || 0) + 1; }); + const currentAgents = agents(); + const startedAgents = currentAgents.filter(agent => agent.started); + const hostByAgent = new Map(currentAgents.map(agent => [agent.agent_id, agentHost(agent)])); + const hosts = new Set(startedAgents.map(agentHost).filter(host => host && host !== 'pending')); + const finish = latestEventOf('campaign_finished')?.payload || {}; + const messageEvents = events.filter(event => event.event === 'message_sent'); + const crossNodeMessages = messageEvents.filter(event => { + const p = event.payload || {}; + const senderHost = hostByAgent.get(p.sender); + const recipientHost = hostByAgent.get(p.recipient); + return senderHost && recipientHost && senderHost !== recipientHost; + }); + const maceResults = events.filter(event => ( + ['tool_call_finished', 'chemgraph_job_result'].includes(event.event) + && event.payload?.tool_name === 'run_mace_ensemble' + )); + const values = [ + ['Finish', finish.reason || 'running'], + ['Decisions', counts.agent_decision || 0], + ['Agents / Hosts', `${startedAgents.length} / ${hosts.size}`], + ['Errors', counts.agent_error || 0], + ['Messages', messageEvents.length], + ['Cross-node', crossNodeMessages.length], + ['Tool calls', counts.tool_call_started || 0], + ['Workflows', counts.workflow_started || 0], + ]; + document.getElementById('metrics').innerHTML = values.map(([k,v]) => ` +
${esc(k)}
${esc(v)}
+ `).join(''); + const proof = snapshot.communication_proof || {}; + document.getElementById('proof').innerHTML = proof.passes + ? Object.entries(proof.passes).map(([k,v]) => `${esc(k)}=${v}`).join('') + : ''; + } + + function renderWorkflowMetrics() { + const events = visibleEvents(); + const counts = {}; + events.forEach(event => { counts[event.event] = (counts[event.event] || 0) + 1; }); + const status = snapshot.status || {}; + const finish = events.filter(event => event.event === 'workflow_finished').slice(-1)[0]?.payload || {}; + const toolResults = events.filter(event => event.event === 'tool_call_finished' && event.payload?.runtime); + const tokenEvents = workflowTokenEvents(events); + const tokenTotals = summedTokenCounts(events); + const values = [ + ['Status', finish.status || status.status || 'running'], + ['Workflow', status.workflow_type || finish.workflow_type || '-'], + ['Events', events.length], + ['LM calls', tokenEvents.length || (counts.llm_decision || 0)], + ['LM tokens', tokenTotals ? formatTokenCount(tokenTotals.total) : '-'], + ['Tool results', toolResults.length], + ['Errors', finish.status === 'failed' ? 1 : 0], + ['Model', status.model_name || '-'], + ['Span', trunc(status.workflow_span_id || finish.span_id || '-', 18)], + ]; + document.getElementById('metrics').innerHTML = values.map(([k,v]) => ` +
${esc(k)}
${esc(v)}
+ `).join(''); + document.getElementById('proof').innerHTML = 'local ChemGraph workflow'; + } + + function workflowGraphEvents() { + return visibleEvents().filter(event => isWorkflowEvent(event)); + } + + function activeEmbeddedWorkflowContext() { + if (isWorkflowMode()) return null; + const activity = selectedActivityEvent(); + if (activity) { + const events = workflowEventsForSelection(activity); + if (events.length) { + return { + events, + anchorKey: eventKey(activity), + title: isWorkflowEvent(activity) + ? `ChemGraph: ${workflowAgentId(activity) || activity.agent_id || 'workflow'}` + : `ChemGraph: ${activity.agent_id || 'agent'}`, + meta: embeddedWorkflowMeta(events), + }; + } + } + return null; + } + + function embeddedWorkflowMeta(events) { + const flow = workflowFlowGraph(events); + const tokenEventCount = workflowTokenEvents(events).length; + const llmCount = tokenEventCount || events.filter(event => event.event === 'llm_decision').length; + const toolCount = flow.nodes.filter(node => node.type === 'tool').length; + const tokenTotals = summedTokenCounts(events); + const first = events[0]; + const p = first?.payload || {}; + return [ + p.thread_id || (p.round !== undefined ? `round ${p.round}` : ''), + `${llmCount} LM`, + tokenTotals ? `${formatTokenCount(tokenTotals.total)} tok` : '', + `${toolCount} tools/actions`, + `${events.length} events`, + ].filter(Boolean).join(' · '); + } + + function renderEmbeddedWorkflowPanel() { + const context = activeEmbeddedWorkflowContext(); + const panel = document.getElementById('workflowFloatingPanel'); + const tab = document.getElementById('workflowFloatingTab'); + if (!context || !context.events.length) { + panel.classList.add('hidden'); + tab.classList.add('hidden'); + return; + } + if (!workflowPanelOpen) { + panel.classList.add('hidden'); + tab.classList.remove('hidden'); + return; + } + tab.classList.add('hidden'); + panel.classList.remove('hidden'); + applyWorkflowPanelFrame(); + if (embeddedWorkflowAnchorKey !== context.anchorKey) { + embeddedWorkflowAnchorKey = context.anchorKey; + selectedEmbeddedWorkflowEventKey = null; + embeddedWorkflowView = null; + } + if ( + selectedEmbeddedWorkflowEventKey + && !context.events.some(event => eventKey(event) === selectedEmbeddedWorkflowEventKey) + ) { + selectedEmbeddedWorkflowEventKey = null; + } + document.getElementById('workflowFloatingTitle').textContent = context.title; + document.getElementById('workflowFloatingMeta').textContent = context.meta; + renderEmbeddedWorkflowGraph(context.events); + renderEmbeddedWorkflowInspector(context.events); + } + + function currentEmbeddedWorkflowInspectorIdentity(events) { + const selected = selectedEmbeddedWorkflowEvent(events); + return [ + embeddedWorkflowAnchorKey || '', + selected ? eventKey(selected) : 'summary', + ].join('|'); + } + + function applyWorkflowPanelFrame() { + const panel = document.getElementById('workflowFloatingPanel'); + const maxWidth = Math.max(520, window.innerWidth - 24); + const maxHeight = Math.max(320, window.innerHeight - 24); + workflowPanelFrame.width = Math.min(Math.max(workflowPanelFrame.width, 520), maxWidth); + workflowPanelFrame.height = Math.min(Math.max(workflowPanelFrame.height, 320), maxHeight); + workflowPanelFrame.x = Math.min(Math.max(workflowPanelFrame.x, 8), window.innerWidth - 80); + workflowPanelFrame.y = Math.min(Math.max(workflowPanelFrame.y, 8), window.innerHeight - 56); + panel.style.left = `${workflowPanelFrame.x}px`; + panel.style.top = `${workflowPanelFrame.y}px`; + panel.style.width = `${workflowPanelFrame.width}px`; + panel.style.height = `${workflowPanelFrame.height}px`; + } + + function renderEmbeddedWorkflowGraph(events) { + const svg = document.getElementById('embeddedWorkflowGraph'); + const empty = document.getElementById('embeddedWorkflowEmpty'); + const flow = workflowFlowGraph(events); + const nodes = flow.nodes; + const edges = flow.edges; + if (!nodes.length) { + svg.innerHTML = ''; + empty.textContent = 'No ChemGraph workflow nodes visible for this selection.'; + empty.classList.remove('hidden'); + return; + } + empty.classList.add('hidden'); + + const nodeW = 190; + const nodeH = 66; + const columnGap = 112; + const toolLaneGap = 82; + const toolStartY = 260; + const maxColumn = Math.max(...nodes.map(node => node.column || 0)); + const maxToolLanes = Math.max(1, ...nodes.filter(node => node.type === 'tool').map(node => node.laneCount || 1)); + const width = Math.max(1120, 120 + (maxColumn + 1) * (nodeW + columnGap)); + const height = Math.max(540, toolStartY + maxToolLanes * toolLaneGap + 100); + const yByType = {input: 236, lm: 128, output: 236}; + const positions = new Map(); + nodes.forEach(node => { + const column = node.column || 0; + const y = node.type === 'tool' + ? toolStartY + (node.laneIndex || 0) * toolLaneGap + : (yByType[node.type] || 236); + positions.set(node.id, { + x: 96 + nodeW / 2 + column * (nodeW + columnGap), + y, + }); + }); + + const selectedEvent = selectedEmbeddedWorkflowEvent(events); + const selectedNodeId = selectedEvent ? workflowFlowNodeId(selectedEvent) : null; + const current = currentEvent(); + const currentNodeId = current ? workflowFlowNodeId(current) : null; + const selectedEdgeIds = new Set(); + if (selectedNodeId) { + edges.forEach(edge => { + if (edge.from === selectedNodeId || edge.to === selectedNodeId) { + selectedEdgeIds.add(`${edge.from}->${edge.to}`); + } + }); + } + const nodeById = new Map(nodes.map(node => [node.id, node])); + const edgeSvg = edges.map(edge => { + const prev = nodeById.get(edge.from); + const node = nodeById.get(edge.to); + const source = positions.get(edge.from); + const target = positions.get(edge.to); + if (!prev || !node || !source || !target) return ''; + const startX = source.x + nodeW / 2; + const endX = target.x - nodeW / 2; + const midX = (startX + endX) / 2; + const controlY = Math.min(source.y, target.y) - 54; + const path = `M ${startX.toFixed(1)} ${source.y.toFixed(1)} Q ${midX.toFixed(1)} ${controlY.toFixed(1)} ${endX.toFixed(1)} ${target.y.toFixed(1)}`; + const cls = [ + 'workflow-edge', + node.id === currentNodeId || prev.id === currentNodeId ? 'current' : '', + selectedEdgeIds.has(`${edge.from}->${edge.to}`) ? 'related' : '', + ].filter(Boolean).join(' '); + return ` + + ${esc(prev.title)} -> ${esc(node.title)} + + `; + }).join(''); + + const nodeSvg = nodes.map(node => { + const pos = positions.get(node.id); + const classes = [ + 'workflow-node', + node.type, + node.toolClass || '', + node.failed ? 'error' : '', + node.id === currentNodeId ? 'current' : '', + node.id === selectedNodeId ? 'selected' : '', + ].filter(Boolean).join(' '); + return ` + + + ${esc(trunc(node.title, 25))} + ${esc(trunc(node.meta, 34))} + ${esc(formatTime(node.event.timestamp))} + ${esc(formatWorkflowEvent(node.event))} + + `; + }).join(''); + + ensureEmbeddedWorkflowView(width, height); + svg.innerHTML = ` + + + + + + + ChemGraph turn · ${nodes.length} node(s) · ${events.length} visible event(s) + + ${edgeSvg} + ${nodeSvg} + `; + updateEmbeddedWorkflowViewBox(); + svg.querySelectorAll('[data-embedded-workflow-event-key]').forEach(node => { + node.addEventListener('click', event => { + selectedEmbeddedWorkflowEventKey = node.dataset.embeddedWorkflowEventKey; + renderEmbeddedWorkflowPanel(); + event.stopPropagation(); + }); + }); + } + + function selectedEmbeddedWorkflowEvent(events) { + if (!selectedEmbeddedWorkflowEventKey) return null; + return events.find(event => eventKey(event) === selectedEmbeddedWorkflowEventKey) || null; + } + + function renderEmbeddedWorkflowInspector(events) { + const title = document.getElementById('embeddedWorkflowInspectorTitle'); + const meta = document.getElementById('embeddedWorkflowInspectorMeta'); + const body = document.getElementById('embeddedWorkflowInspectorBody'); + const identity = currentEmbeddedWorkflowInspectorIdentity(events); + const previousIdentity = lastEmbeddedWorkflowInspectorIdentity; + const previousScrollTop = body.scrollTop; + const previousScrollLeft = body.scrollLeft; + const event = selectedEmbeddedWorkflowEvent(events); + if (!event) { + const tokenTotals = summedTokenCounts(events); + const flow = workflowFlowGraph(events); + const html = detailRich( + detailSection('Turn Summary', detailKvGrid([ + ['Visible events', events.length], + ['Nodes', flow.nodes.length], + ['LM calls', workflowTokenEvents(events).length || events.filter(item => item.event === 'llm_decision').length], + ['Input tokens', tokenTotals?.input ?? '-'], + ['Output tokens', tokenTotals?.output ?? '-'], + ['Total tokens', tokenTotals?.total ?? '-'], + ]), 'info'), + detailSection( + 'Inspect', + paragraphsHtml('Click a ChemGraph node in this panel to inspect its LM tokens, tool arguments, output, and payload without changing the outer dashboard selection.'), + ), + ); + title.textContent = 'ChemGraph Inspector'; + meta.textContent = 'Select an LM, tool, action, or output node.'; + setStableHtml(body, html, identity === previousIdentity); + if (identity === previousIdentity) { + body.scrollTop = previousScrollTop; + body.scrollLeft = previousScrollLeft; + } + lastEmbeddedWorkflowInspectorIdentity = identity; + return; + } + const html = detailRich( + chemgraphNodeDetailHtml(event), + payloadDetailHtml(event.payload || {}), + chemgraphNodeContextHtml(event), + ); + title.textContent = chemgraphNodeDetailTitle(event); + meta.textContent = `${formatTime(event.timestamp)} · ${event.event}`; + setStableHtml(body, html, identity === previousIdentity); + if (identity === previousIdentity) { + body.scrollTop = previousScrollTop; + body.scrollLeft = previousScrollLeft; + } + lastEmbeddedWorkflowInspectorIdentity = identity; + } + + function ensureEmbeddedWorkflowView(width, height) { + const padX = Math.max(180, width * 0.08); + const padY = Math.max(100, height * 0.12); + const bounds = { + x: -padX, + y: -padY, + width: width + padX * 2, + height: height + padY * 2, + }; + if (!embeddedWorkflowView) { + embeddedWorkflowView = { + x: bounds.x, + y: bounds.y, + width: bounds.width, + height: bounds.height, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + return; + } + if ( + embeddedWorkflowView.layoutWidth !== width + || embeddedWorkflowView.layoutHeight !== height + ) { + const nextView = preserveViewForLayoutChange( + embeddedWorkflowView, + bounds, + width, + height, + ); + embeddedWorkflowView = { + ...embeddedWorkflowView, + ...nextView, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + clampEmbeddedWorkflowView(); + } + } + + function updateEmbeddedWorkflowViewBox() { + const svg = document.getElementById('embeddedWorkflowGraph'); + if (!embeddedWorkflowView) return; + svg.setAttribute( + 'viewBox', + `${embeddedWorkflowView.x.toFixed(1)} ${embeddedWorkflowView.y.toFixed(1)} ${embeddedWorkflowView.width.toFixed(1)} ${embeddedWorkflowView.height.toFixed(1)}` + ); + } + + function clampEmbeddedWorkflowView() { + if (!embeddedWorkflowView) return; + const boundsX = embeddedWorkflowView.boundsX ?? 0; + const boundsY = embeddedWorkflowView.boundsY ?? 0; + const boundsWidth = embeddedWorkflowView.boundsWidth ?? embeddedWorkflowView.layoutWidth; + const boundsHeight = embeddedWorkflowView.boundsHeight ?? embeddedWorkflowView.layoutHeight; + embeddedWorkflowView.width = Math.min(boundsWidth, Math.max(embeddedWorkflowView.layoutWidth / 12, embeddedWorkflowView.width)); + embeddedWorkflowView.height = Math.min(boundsHeight, Math.max(embeddedWorkflowView.layoutHeight / 12, embeddedWorkflowView.height)); + embeddedWorkflowView.x = Math.min(Math.max(boundsX, embeddedWorkflowView.x), boundsX + boundsWidth - embeddedWorkflowView.width); + embeddedWorkflowView.y = Math.min(Math.max(boundsY, embeddedWorkflowView.y), boundsY + boundsHeight - embeddedWorkflowView.height); + } + + function zoomEmbeddedWorkflow(factor) { + if (!embeddedWorkflowView) return; + const centerX = embeddedWorkflowView.x + embeddedWorkflowView.width / 2; + const centerY = embeddedWorkflowView.y + embeddedWorkflowView.height / 2; + embeddedWorkflowView.width *= factor; + embeddedWorkflowView.height *= factor; + embeddedWorkflowView.x = centerX - embeddedWorkflowView.width / 2; + embeddedWorkflowView.y = centerY - embeddedWorkflowView.height / 2; + clampEmbeddedWorkflowView(); + updateEmbeddedWorkflowViewBox(); + } + + function resetEmbeddedWorkflowView() { + embeddedWorkflowView = null; + renderEmbeddedWorkflowPanel(); + } + + function renderWorkflowGraph() { + const svg = document.getElementById('graph'); + const events = workflowGraphEvents(); + document.getElementById('hostLegend').innerHTML = ` + query + LM + action tool + science tool + output + failure + `; + if (!events.length) { + svg.setAttribute('viewBox', '0 0 1000 260'); + svg.innerHTML = 'No ChemGraph workflow events yet.'; + return; + } + + const flow = workflowFlowGraph(events); + const nodes = flow.nodes; + const edges = flow.edges; + if (!nodes.length) { + svg.setAttribute('viewBox', '0 0 1000 260'); + svg.innerHTML = 'Waiting for ChemGraph workflow execution events.'; + return; + } + + const nodeW = 184; + const nodeH = 64; + const columnGap = 96; + const toolLaneGap = 76; + const toolStartY = 230; + const maxColumn = Math.max(...nodes.map(node => node.column || 0)); + const maxToolLanes = Math.max(1, ...nodes.filter(node => node.type === 'tool').map(node => node.laneCount || 1)); + const width = Math.max(1040, 120 + (maxColumn + 1) * (nodeW + columnGap)); + const height = Math.max(500, toolStartY + maxToolLanes * toolLaneGap + 80); + const yByType = {input: 220, lm: 126, output: 220}; + const positions = new Map(); + nodes.forEach(node => { + const column = node.column || 0; + const y = node.type === 'tool' + ? toolStartY + (node.laneIndex || 0) * toolLaneGap + : (yByType[node.type] || 220); + positions.set(node.id, { + x: 96 + nodeW / 2 + column * (nodeW + columnGap), + y, + }); + }); + const current = currentEvent(); + const selectedEvent = selectedActivityEvent(); + const currentNodeId = current ? workflowFlowNodeId(current) : null; + const selectedNodeId = selectedEvent ? workflowFlowNodeId(selectedEvent) : null; + const selectedEdgeIds = new Set(); + if (selectedNodeId) { + edges.forEach(edge => { + if (edge.from === selectedNodeId || edge.to === selectedNodeId) { + selectedEdgeIds.add(`${edge.from}->${edge.to}`); + } + }); + } + + const nodeById = new Map(nodes.map(node => [node.id, node])); + const edgeSvg = edges.map(edge => { + const prev = nodeById.get(edge.from); + const node = nodeById.get(edge.to); + const source = positions.get(edge.from); + const target = positions.get(edge.to); + if (!prev || !node || !source || !target) return ''; + const startX = source.x + nodeW / 2; + const endX = target.x - nodeW / 2; + const midX = (startX + endX) / 2; + const controlY = Math.min(source.y, target.y) - 46; + const path = `M ${startX.toFixed(1)} ${source.y.toFixed(1)} Q ${midX.toFixed(1)} ${controlY.toFixed(1)} ${endX.toFixed(1)} ${target.y.toFixed(1)}`; + const cls = [ + 'workflow-edge', + node.id === currentNodeId || prev.id === currentNodeId ? 'current' : '', + selectedEdgeIds.has(`${edge.from}->${edge.to}`) ? 'related' : '', + ].filter(Boolean).join(' '); + return ` + + ${esc(prev.title)} -> ${esc(node.title)} + + `; + }).join(''); + + const nodeSvg = nodes.map(node => { + const pos = positions.get(node.id); + const classes = [ + 'workflow-node', + node.type, + node.toolClass || '', + node.failed ? 'error' : '', + node.id === currentNodeId ? 'current' : '', + node.id === selectedNodeId ? 'selected' : '', + ].filter(Boolean).join(' '); + return ` + + + ${esc(trunc(node.title, 24))} + ${esc(trunc(node.meta, 32))} + ${esc(formatTime(node.event.timestamp))} + ${esc(formatWorkflowEvent(node.event))} + + `; + }).join(''); + + svg.style.minHeight = `${height}px`; + ensureGraphView(width, height); + svg.innerHTML = ` + + + + + + + ${esc(snapshot.status?.workflow_type || 'ChemGraph workflow')} · ${nodes.length} flow node(s) · ${events.length} visible event(s) + + ${edgeSvg} + ${nodeSvg} + `; + updateGraphViewBox(); + svg.querySelectorAll('[data-activity-event-key]').forEach(activityEl => { + activityEl.addEventListener('click', event => { + selectedActivityEventKey = activityEl.dataset.activityEventKey; + selectedAgent = null; + selectedEdgeKey = null; + event.stopPropagation(); + render(); + }); + }); + } + + function workflowFlowGraph(events) { + const nodes = []; + const edges = []; + const toolNodeIndexes = new Map(); + const hasGraphWork = events.some(event => ( + event.event === 'llm_decision' + || event.event === 'workflow_output' + || event.event === 'run_finished' + || event.event.startsWith('tool_call_') + )); + const runStart = events.find(event => event.event === 'run_started') + || events.find(event => event.event === 'workflow_started'); + let lastColumn = -1; + let lastNodeIds = []; + let currentLmId = null; + let currentToolBatchIds = []; + let currentToolBatchColumn = null; + let lmTurn = 0; + + function addEdge(from, to) { + if (!from || !to || from === to) return; + if (!edges.some(edge => edge.from === from && edge.to === to)) { + edges.push({from, to}); + } + } + + function addEdges(fromIds, to) { + Array.from(new Set(fromIds.filter(Boolean))).forEach(from => addEdge(from, to)); + } + + function updateCurrentToolBatchLanes() { + currentToolBatchIds.forEach((id, laneIndex) => { + const index = toolNodeIndexes.get(id); + if (index === undefined) return; + nodes[index].laneIndex = laneIndex; + nodes[index].laneCount = currentToolBatchIds.length; + }); + } + + if (runStart && hasGraphWork) { + const node = workflowFlowNode(runStart, 0); + node.column = 0; + nodes.push(node); + lastColumn = 0; + lastNodeIds = [node.id]; + } + + events.forEach(event => { + if (event.event === 'llm_decision') { + lmTurn += 1; + const afterToolBatch = currentToolBatchIds.length > 0; + const node = workflowFlowNode(event, nodes.length, {lmTurn}); + node.column = afterToolBatch + ? currentToolBatchColumn + 1 + : lastColumn + 1; + nodes.push(node); + addEdges(afterToolBatch ? currentToolBatchIds : lastNodeIds, node.id); + lastColumn = node.column; + lastNodeIds = [node.id]; + currentLmId = node.id; + currentToolBatchIds = []; + currentToolBatchColumn = null; + return; + } + if (event.event.startsWith('tool_call_')) { + if (currentToolBatchColumn === null) { + currentToolBatchColumn = lastColumn + 1; + } + const node = workflowFlowNode(event, nodes.length); + if (toolNodeIndexes.has(node.id)) { + const index = toolNodeIndexes.get(node.id); + node.column = nodes[index].column; + node.laneIndex = nodes[index].laneIndex; + node.laneCount = nodes[index].laneCount; + nodes[index] = node; + } else { + node.column = currentToolBatchColumn; + node.laneIndex = currentToolBatchIds.length; + node.laneCount = currentToolBatchIds.length + 1; + toolNodeIndexes.set(node.id, nodes.length); + currentToolBatchIds.push(node.id); + nodes.push(node); + addEdges(currentLmId ? [currentLmId] : lastNodeIds, node.id); + updateCurrentToolBatchLanes(); + } + } + }); + + const output = events.filter(event => event.event === 'workflow_output').slice(-1)[0] + || ( + nodes.length + ? events.filter(event => event.event === 'workflow_finished' || event.event === 'run_finished').slice(-1)[0] + : null + ); + if (output) { + const afterToolBatch = currentToolBatchIds.length > 0; + const node = workflowFlowNode(output, nodes.length); + node.column = afterToolBatch + ? currentToolBatchColumn + 1 + : lastColumn + 1; + nodes.push(node); + addEdges(afterToolBatch ? currentToolBatchIds : lastNodeIds, node.id); + } + return {nodes, edges}; + } + + function workflowFlowNodeId(event) { + const p = event?.payload || {}; + if (!event) return null; + if (event.event === 'run_started' || event.event === 'workflow_started') return 'workflow-query'; + if (event.event === 'workflow_output' || event.event === 'workflow_finished' || event.event === 'run_finished') return 'workflow-output'; + if (event.event.startsWith('tool_call_')) { + return p.tool_call_id ? `workflow-tool-${p.tool_call_id}` : (p.span_id || eventKey(event)); + } + return p.span_id || eventKey(event); + } + + function toolCallSummary(toolCalls) { + const counts = new Map(); + toolCalls + .map(call => call?.name || call?.id || 'tool') + .filter(Boolean) + .forEach(name => counts.set(name, (counts.get(name) || 0) + 1)); + const parts = Array.from(counts.entries()).map(([name, count]) => count > 1 ? `${name} x${count}` : name); + return parts.join(', '); + } + + function numberOrNull(value) { + const number = Number(value); + return Number.isFinite(number) ? number : null; + } + + function tokenField(raw, names) { + if (!raw || typeof raw !== 'object') return null; + for (const name of names) { + if (raw[name] !== undefined && raw[name] !== null) { + const value = numberOrNull(raw[name]); + if (value !== null) return value; + } + } + return null; + } + + function llmTokenCounts(event) { + const raw = event?.payload?.token_counts; + if (!raw || typeof raw !== 'object') return null; + const input = tokenField(raw, ['input_tokens', 'prompt_tokens']); + const output = tokenField(raw, ['output_tokens', 'completion_tokens']); + let total = tokenField(raw, ['total_tokens']); + if (total === null && (input !== null || output !== null)) { + total = (input || 0) + (output || 0); + } + if (input === null && output === null && total === null) return null; + return { + input, + output, + total, + source: raw.source || 'unknown', + estimateScope: raw.estimate_scope || '', + rawUsage: raw.raw_usage, + raw, + }; + } + + function workflowTokenEvents(events) { + return events + .map(event => ({event, counts: llmTokenCounts(event)})) + .filter(item => item.counts); + } + + function summedTokenCounts(events) { + const items = workflowTokenEvents(events); + if (!items.length) return null; + let input = 0; + let output = 0; + let total = 0; + let sawInput = false; + let sawOutput = false; + let sawTotal = false; + const sources = new Set(); + items.forEach(({counts}) => { + if (counts.input !== null) { + input += counts.input; + sawInput = true; + } + if (counts.output !== null) { + output += counts.output; + sawOutput = true; + } + if (counts.total !== null) { + total += counts.total; + sawTotal = true; + } + if (counts.source) sources.add(counts.source); + }); + if (!sawTotal && (sawInput || sawOutput)) total = input + output; + return { + input: sawInput ? input : null, + output: sawOutput ? output : null, + total: sawTotal || sawInput || sawOutput ? total : null, + source: Array.from(sources).join('+') || 'unknown', + }; + } + + function formatTokenCount(value) { + const number = numberOrNull(value); + if (number === null) return '-'; + if (Math.abs(number) >= 1000000) return `${(number / 1000000).toFixed(1)}M`; + if (Math.abs(number) >= 10000) return `${Math.round(number / 1000)}k`; + if (Math.abs(number) >= 1000) return `${(number / 1000).toFixed(1)}k`; + return String(Math.round(number)); + } + + function tokenSourceLabel(source) { + if (source === 'local_estimate') return 'estimate'; + if (source === 'provider') return 'provider'; + return source || 'unknown'; + } + + function tokenSummary(event, {compact = false} = {}) { + const counts = llmTokenCounts(event); + if (!counts) return ''; + const source = tokenSourceLabel(counts.source); + if (compact) { + const sourceSuffix = source === 'estimate' ? ' · est' : source === 'provider' ? '' : ` · ${source}`; + return `tok ${formatTokenCount(counts.input)} in / ${formatTokenCount(counts.output)} out${sourceSuffix}`; + } + return [ + `input ${formatTokenCount(counts.input)}`, + `output ${formatTokenCount(counts.output)}`, + `total ${formatTokenCount(counts.total)}`, + source, + ].filter(Boolean).join(' · '); + } + + function promptDisclosureHtml(event) { + const messages = event?.payload?.prompt_messages; + if (!Array.isArray(messages) || !messages.length) return ''; + return ` +
+ Show full prompt (${messages.length} messages) +
${esc(formatJson(messages))}
+
+ `; + } + + function promptDetailSection(event) { + if (!llmTokenCounts(event)) return ''; + const disclosure = promptDisclosureHtml(event); + return detailSection( + 'Prompt', + disclosure || paragraphsHtml('Full prompt was not captured for this event. Rerun with the updated ChemGraph observability code to populate this field.'), + disclosure ? 'info' : 'warn', + ); + } + + function tokenDetailSection(event, title = 'Token Counts') { + const counts = llmTokenCounts(event); + if (!counts) return ''; + const body = [ + detailKvGrid([ + ['Input tokens', counts.input === null ? '-' : Math.round(counts.input)], + ['Output tokens', counts.output === null ? '-' : Math.round(counts.output)], + ['Total tokens', counts.total === null ? '-' : Math.round(counts.total)], + ['Source', tokenSourceLabel(counts.source)], + ['Estimate scope', counts.estimateScope || ''], + ]), + ].filter(Boolean).join(''); + return [ + detailSection(title, body, counts.source === 'provider' ? 'ok' : 'info'), + counts.rawUsage ? detailSection('Provider Usage', detailValueHtml(counts.rawUsage)) : '', + ].filter(Boolean).join(''); + } + + function workflowToolKind(event) { + const name = toolDisplayName(event); + return actionToolNames.has(name) ? 'action-tool' : 'science-tool'; + } + + function workflowToolKindLabel(event) { + return workflowToolKind(event) === 'action-tool' ? 'action' : 'science'; + } + + function workflowFlowNode(event, index, options = {}) { + const p = event.payload || {}; + const key = eventKey(event); + if (event.event === 'run_started' || event.event === 'workflow_started') { + return { + id: workflowFlowNodeId(event), + type: 'input', + title: p.nested ? 'Prompt' : 'Query', + meta: p.query || p.thread_id || (p.round !== undefined ? `round ${p.round}` : snapshot.status?.query || 'user request'), + event, + eventKey: key, + failed: false, + }; + } + if (event.event === 'llm_decision') { + const calls = Array.isArray(p.tool_calls) ? p.tool_calls : []; + const callNames = toolCallSummary(calls); + const tokens = tokenSummary(event, {compact: true}); + return { + id: workflowFlowNodeId(event), + type: 'lm', + title: `LM turn ${options.lmTurn || index}`, + meta: [tokens, callNames ? `calls: ${callNames}` : 'response'].filter(Boolean).join(' · '), + event, + eventKey: key, + failed: false, + }; + } + if (event.event.startsWith('tool_call_')) { + const failed = event.event === 'tool_call_failed' || p.status === 'failed'; + const kind = workflowToolKind(event); + return { + id: workflowFlowNodeId(event), + type: 'tool', + toolClass: kind, + title: toolDisplayName(event), + meta: `${workflowToolKindLabel(event)} · ${failed ? 'failed' : event.event === 'tool_call_started' ? 'running' : 'finished'}`, + event, + eventKey: key, + failed, + }; + } + return { + id: workflowFlowNodeId(event), + type: 'output', + title: 'Output', + meta: tokenSummary(event, {compact: true}) || p.content_preview || p.status || snapshot.status?.status || 'finished', + event, + eventKey: key, + failed: p.status === 'failed', + }; + } + + function renderGraph() { + if (isWorkflowMode()) { + renderWorkflowGraph(); + return; + } + const svg = document.getElementById('graph'); + const currentAgents = agents(); + if (!currentAgents.length) { + svg.setAttribute('viewBox', '0 0 1000 260'); + svg.innerHTML = 'No agents yet.'; + document.getElementById('hostLegend').innerHTML = 'Waiting for placement.'; + return; + } + + const byHost = new Map(); + currentAgents.forEach(agent => { + const host = agentHost(agent); + if (!byHost.has(host)) byHost.set(host, []); + byHost.get(host).push(agent); + }); + const hosts = Array.from(byHost.keys()).sort(); + const maxPerHost = Math.max(...hosts.map(host => byHost.get(host).length)); + const radial = currentAgents.length >= 2; + const width = radial + ? Math.max(1200, currentAgents.length * 220, hosts.length * 220) + : 1000; + const height = radial + ? Math.max(760, currentAgents.length * 135) + : Math.max(340, 110 + maxPerHost * 88); + const marginX = 42; + const top = 58; + const bottom = 34; + const laneGap = radial ? 8 : 14; + const laneWidth = (width - marginX * 2 - laneGap * Math.max(0, hosts.length - 1)) / hosts.length; + const nodeW = radial ? 132 : Math.min(154, Math.max(82, laneWidth - 18)); + const nodeH = radial ? 58 : 58; + const positions = new Map(); + svg.style.minWidth = ''; + svg.style.minHeight = `${height}px`; + + const hostIndex = new Map(hosts.map((host, index) => [host, index])); + const legendPrefix = ` + ChemGraph turn chip + ${radial ? 'radial host layout' : ''} + `; + const legend = legendPrefix + hosts.map((host, index) => ` + + + ${esc(host)} (${byHost.get(host).length}) + + `).join(''); + document.getElementById('hostLegend').innerHTML = legend; + + let bands = ''; + if (radial) { + const centerX = width / 2; + const centerY = height / 2; + const radiusScale = currentAgents.length < 5 ? 0.31 : 0.36; + const radiusX = width * radiusScale; + const radiusY = height * (currentAgents.length < 5 ? 0.30 : 0.34); + const hostCenters = new Map(); + bands = ` + + + ${currentAgents.length} daemon agents across ${hosts.length} host(s) + `; + hosts.forEach((host, index) => { + const angle = -Math.PI / 2 + (2 * Math.PI * index) / hosts.length; + const x = centerX + Math.cos(angle) * radiusX; + const y = centerY + Math.sin(angle) * radiusY; + hostCenters.set(host, {x, y, angle}); + const fill = hostColor(index); + const stroke = hostStroke(index); + bands += ` + + ${esc(trunc(host, 28))} + `; + }); + hosts.forEach((host, hIndex) => { + const list = byHost.get(host).slice().sort((a, b) => a.agent_id.localeCompare(b.agent_id)); + const center = hostCenters.get(host); + const spread = list.length > 1 ? Math.max(nodeH + 22, 84) : 0; + list.forEach((agent, index) => { + const offset = (index - (list.length - 1) / 2) * spread; + const tangentX = -Math.sin(center.angle); + const tangentY = Math.cos(center.angle); + const x = center.x + tangentX * offset; + const y = center.y + tangentY * offset; + positions.set(agent.agent_id, {x, y, host, hostIndex: hIndex, agent}); + }); + }); + } else { + bands = hosts.map((host, index) => { + const x = marginX + index * (laneWidth + laneGap); + const fill = hostColor(index); + const stroke = hostStroke(index); + const label = trunc(host, 34); + return ` + + ${esc(label)} + `; + }).join(''); + + hosts.forEach((host, hIndex) => { + const list = byHost.get(host).sort((a, b) => a.agent_id.localeCompare(b.agent_id)); + const x = marginX + hIndex * (laneWidth + laneGap) + laneWidth / 2; + list.forEach((agent, index) => { + const usable = height - top - bottom; + const y = top + ((index + 1) * usable) / (list.length + 1); + positions.set(agent.agent_id, {x, y, host, hostIndex: hIndex, agent}); + }); + }); + } + + const sent = graphMessageEvents(); + const recentIds = new Set(sent.slice(-10).map(e => e.payload?.message_id).filter(Boolean)); + const edgeMap = new Map(); + sent.forEach((event, index) => { + const p = event.payload || {}; + if (!positions.has(p.sender) || !positions.has(p.recipient)) return; + const key = `${p.sender}->${p.recipient}`; + const prev = edgeMap.get(key) || { + key, + sender: p.sender, + recipient: p.recipient, + count: 0, + latestIndex: -1, + latestMessageId: null, + latestTldr: '', + latestContent: '', + messages: [], + }; + prev.count += 1; + prev.latestIndex = index; + prev.latestMessageId = p.message_id; + prev.latestTldr = p.tldr || ''; + prev.latestContent = p.content || ''; + prev.messages.push(event); + edgeMap.set(key, prev); + }); + const edges = Array.from(edgeMap.values()).sort((a, b) => a.latestIndex - b.latestIndex); + if (selectedEdgeKey && !edgeMap.has(selectedEdgeKey)) { + selectedEdgeKey = null; + } + const allVisibleMessages = eventsOf('message_sent').length; + const showEdgeLabels = Boolean(selectedEdgeKey) || (graphMode !== 'cumulative' && edges.length <= 8); + + const labelBoxes = []; + const edgeSvg = edges.map((edge, index) => { + const source = positions.get(edge.sender); + const target = positions.get(edge.recipient); + const cross = source.host !== target.host; + const selectedByAgent = selectedAgent && (edge.sender === selectedAgent || edge.recipient === selectedAgent); + const selectedByEdge = selectedEdgeKey === edge.key; + const dimmed = (selectedAgent && !selectedByAgent) || (selectedEdgeKey && !selectedByEdge); + const recent = recentIds.has(edge.latestMessageId); + const start = edgeEndpoint(source, target, nodeW, nodeH, true); + const end = edgeEndpoint(source, target, nodeW, nodeH, false); + const hasReverse = edgeMap.has(`${edge.recipient}->${edge.sender}`); + const route = curvedRoute(start, end, edge, index, radial, width / 2, height / 2, hasReverse); + const path = route.path; + const cls = ['edge', cross ? 'cross-node' : '', recent ? 'recent' : '', selectedByEdge ? 'selected' : '', dimmed ? 'dimmed' : ''].filter(Boolean).join(' '); + const marker = cross ? 'url(#arrowCross)' : 'url(#arrow)'; + const labelX = route.labelX; + const labelY = route.labelY; + const edgeSummary = edge.latestTldr || ''; + const shouldShowSummary = edgeSummary && !dimmed && (selectedByEdge || recent || showEdgeLabels); + const rawLabel = shouldShowSummary ? edgeSummary : (selectedByEdge || showEdgeLabels ? edge.count : ''); + const placedLabel = placeEdgeLabel(labelX, labelY, rawLabel, labelBoxes, selectedByEdge); + const titleText = edgeSummary || edge.latestContent; + return ` + + ${esc(edge.sender)} -> ${esc(edge.recipient)} (${edge.count}) ${cross ? 'cross-node' : 'same-node'} ${esc(trunc(titleText, 180))} + + + ${placedLabel ? `${esc(placedLabel.text)}` : ''} + `; + }).join(''); + + const edgeHint = !edges.length ? ` + + ${allVisibleMessages ? 'No routes visible in this graph mode. Try Recent or All.' : 'No messages visible at this time point.'} + + ` : ''; + const bubbleSvg = renderActivityBubbles(currentAgents, positions, nodeW, nodeH, width, height); + + const nodeSvg = currentAgents.map(agent => { + const pos = positions.get(agent.agent_id); + const current = currentEvent(); + const hIndex = hostIndex.get(pos.host) || 0; + const x = pos.x - nodeW / 2; + const y = pos.y - nodeH / 2; + const classes = [ + 'agent-node', + agent.agent_id === selectedAgent ? 'selected' : '', + current?.agent_id === agent.agent_id ? 'current' : '', + agent.last_error ? 'error' : '', + !agent.started ? 'pending' : '', + selectedEdgeKey && !selectedEdgeKey.split('->').includes(agent.agent_id) ? 'dimmed' : '', + ].filter(Boolean).join(' '); + const status = !agent.started ? 'pending' : agent.last_error ? 'error' : `${agent.decision_count || 0} decisions`; + return ` + + + ${esc(trunc(agent.agent_id, 20))} + ${esc(trunc(agent.role || '', 24))} + ${esc(status)} + ${esc(agent.agent_id)} host: ${esc(pos.host)} role: ${esc(agent.role || '')} + + `; + }).join(''); + + ensureGraphView(width, height); + svg.innerHTML = ` + + + + + + + + + ${bands} + ${edgeSvg} + ${edgeHint} + ${nodeSvg} + ${bubbleSvg} + `; + updateGraphViewBox(); + svg.querySelectorAll('[data-edge-key]').forEach(edgeEl => { + edgeEl.addEventListener('click', () => { + selectedEdgeKey = edgeEl.dataset.edgeKey; + selectedAgent = null; + selectedActivityEventKey = null; + render(); + }); + }); + svg.querySelectorAll('.agent-node').forEach(node => { + node.addEventListener('click', () => { + selectedAgent = node.dataset.agent; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + }); + node.addEventListener('keydown', event => { + if (event.key === 'Enter' || event.key === ' ') { + selectedAgent = node.dataset.agent; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + } + }); + }); + svg.querySelectorAll('[data-activity-event-key]').forEach(activityEl => { + activityEl.addEventListener('click', event => { + selectedActivityEventKey = activityEl.dataset.activityEventKey; + selectedAgent = null; + selectedEdgeKey = null; + event.stopPropagation(); + render(); + }); + }); + } + + function edgeEndpoint(source, target, nodeW, nodeH, isStart) { + const from = isStart ? source : target; + const to = isStart ? target : source; + const dx = to.x - from.x; + const dy = to.y - from.y; + if (Math.abs(dx) > Math.abs(dy)) { + return { + x: from.x + Math.sign(dx || 1) * nodeW / 2, + y: from.y + (dy / Math.max(Math.abs(dx), 1)) * nodeH * 0.18, + }; + } + return { + x: from.x + (dx / Math.max(Math.abs(dy), 1)) * nodeW * 0.18, + y: from.y + Math.sign(dy || 1) * nodeH / 2, + }; + } + + function stableHash(text) { + let hash = 0; + for (const ch of String(text || '')) { + hash = ((hash << 5) - hash + ch.charCodeAt(0)) | 0; + } + return Math.abs(hash); + } + + function curvedRoute(start, end, edge, index, radial, centerX, centerY, hasReverse) { + const dx = end.x - start.x; + const dy = end.y - start.y; + const distance = Math.max(Math.hypot(dx, dy), 1); + const midX = (start.x + end.x) / 2; + const midY = (start.y + end.y) / 2; + const perpX = -dy / distance; + const perpY = dx / distance; + const direction = edge.sender < edge.recipient ? 1 : -1; + const jitter = ((stableHash(edge.key) % 5) - 2) * (radial ? 11 : 8); + let curve = Math.min(radial ? 230 : 130, Math.max(radial ? 78 : 42, distance * (radial ? 0.24 : 0.17))); + if (!hasReverse) curve *= 0.62; + curve = curve * direction + jitter; + + let controlX = midX + perpX * curve; + let controlY = midY + perpY * curve; + if (radial) { + const outX = midX - centerX; + const outY = midY - centerY; + const outDistance = Math.max(Math.hypot(outX, outY), 1); + const outward = Math.min(130, Math.max(36, distance * 0.16)); + controlX += (outX / outDistance) * outward; + controlY += (outY / outDistance) * outward; + } else { + controlX += ((index % 3) - 1) * 18; + } + + return { + path: `M ${start.x.toFixed(1)} ${start.y.toFixed(1)} Q ${controlX.toFixed(1)} ${controlY.toFixed(1)} ${end.x.toFixed(1)} ${end.y.toFixed(1)}`, + labelX: (midX * 0.65 + controlX * 0.35), + labelY: (midY * 0.65 + controlY * 0.35) - 4, + }; + } + + function renderActivityBubbles(currentAgents, positions, nodeW, nodeH, layoutWidth, layoutHeight) { + const grouped = new Map(); + function addGrouped(agentId, item) { + if (!agentId || !positions.has(agentId)) return; + if (!grouped.has(agentId)) grouped.set(agentId, []); + grouped.get(agentId).push(item); + } + activeWindowEvents(1).forEach(event => { + if (isWorkflowEvent(event)) return; + if (event.event.startsWith('tool_call_') || event.event === 'chemgraph_job_result') return; + const bubble = bubbleInfo(event); + if (!bubble) return; + addGrouped(event.agent_id, { + event, + bubble, + selected: eventKey(event) === selectedActivityEventKey, + sortTime: eventTimestamp(event) ?? 0, + }); + }); + currentAgents.forEach(agent => { + reasoningTurnsForAgent(agent.agent_id).slice(-3).forEach(turn => { + const representative = turn.representative; + if (!representative) return; + const selected = turn.events.some(event => eventKey(event) === selectedActivityEventKey); + addGrouped(agent.agent_id, { + event: representative, + bubble: chemgraphTurnBubbleInfo(turn), + selected, + sortTime: turn.lastTimestamp ?? turn.firstTimestamp ?? 0, + }); + }); + }); + const rows = []; + grouped.forEach((items, agentId) => { + const pos = positions.get(agentId); + items + .slice() + .sort((a, b) => (a.sortTime || 0) - (b.sortTime || 0)) + .slice(-4) + .forEach((item, index) => { + const key = eventKey(item.event); + const selected = item.selected || key === selectedActivityEventKey; + const width = Math.max(64, Math.min(156, item.bubble.label.length * 7 + 18)); + let x = pos.x + nodeW / 2 + 8; + if (x + width > layoutWidth - 10) x = pos.x - nodeW / 2 - width - 8; + const y = Math.max(12, Math.min(layoutHeight - 24, pos.y - nodeH / 2 + index * 24)); + rows.push(` + + + ${esc(item.bubble.label)} + ${esc(item.bubble.title)} + + `); + }); + }); + return rows.join(''); + } + + function chemgraphTurnBubbleInfo(turn) { + const round = turn.round === null || turn.round === undefined ? '-' : turn.round; + const toolCount = (turn.scienceToolCount || 0) + (turn.actionToolCount || 0); + const status = turn.status || 'running'; + const failed = String(status).toLowerCase() === 'failed'; + const label = `CG r${round} ${toolCount}t`; + const title = [ + `Open ChemGraph turn for ${workflowAgentId(turn.representative) || '-'}`, + `Round: ${round}`, + `Status: ${status}`, + `LM calls: ${turn.lmCount || 0}`, + `Science tools: ${turn.scienceToolCount || 0}`, + `Message actions: ${turn.actionToolCount || 0}`, + `Span: ${turn.spanId || '-'}`, + ].join('\n'); + return { + className: failed ? 'bubble-error' : 'bubble-chemgraph', + label, + title, + }; + } + + function toolDisplayName(event) { + const p = event.payload || {}; + return ( + p.tool_name + || p.tool + || p.name + || p.result?.tool_name + || p.result?.name + || p.tool_result?.tool_name + || 'tool' + ); + } + + function bubbleInfo(event) { + const p = event.payload || {}; + if (event.event === 'agent_decision') { + const actions = Array.isArray(p.actions) ? p.actions.length : 0; + return { + className: 'bubble-decision', + label: actions ? `decide ${actions}` : 'decide', + title: `${event.agent_id} decision\n${trunc(p.rationale || p.wake_reason || '', 220)}`, + }; + } + if (event.event === 'belief_updated') { + return { + className: 'bubble-belief', + label: 'belief', + title: `${event.agent_id} belief\n${formatBelief(p)}`, + }; + } + if (event.event === 'agent_error') { + return { + className: 'bubble-error', + label: 'error', + title: `${event.agent_id} error\n${p.error || formatJson(p)}`, + }; + } + return null; + } + + function ensureGraphView(width, height) { + const padX = width >= 1180 ? Math.max(420, width * 0.28) : 140; + const padY = width >= 1180 ? Math.max(220, height * 0.22) : 120; + const bounds = { + x: -padX, + y: -padY, + width: width + padX * 2, + height: height + padY * 2, + }; + if (!graphView) { + graphView = { + x: bounds.x, + y: bounds.y, + width: bounds.width, + height: bounds.height, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + return; + } + if ( + graphView.layoutWidth !== width + || graphView.layoutHeight !== height + ) { + const nextView = preserveViewForLayoutChange( + graphView, + bounds, + width, + height, + ); + graphView = { + ...graphView, + ...nextView, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + clampGraphView(); + } + } + + function preserveViewForLayoutChange(view, nextBounds, nextLayoutWidth, nextLayoutHeight) { + const previousBoundsWidth = view.boundsWidth || view.layoutWidth || nextLayoutWidth; + const previousBoundsHeight = view.boundsHeight || view.layoutHeight || nextLayoutHeight; + const zoomX = previousBoundsWidth / Math.max(view.width || previousBoundsWidth, 1); + const zoomY = previousBoundsHeight / Math.max(view.height || previousBoundsHeight, 1); + const centerXRatio = ( + (view.x || 0) + (view.width || previousBoundsWidth) / 2 - (view.boundsX || 0) + ) / Math.max(previousBoundsWidth, 1); + const centerYRatio = ( + (view.y || 0) + (view.height || previousBoundsHeight) / 2 - (view.boundsY || 0) + ) / Math.max(previousBoundsHeight, 1); + const width = nextBounds.width / Math.max(zoomX, 1e-6); + const height = nextBounds.height / Math.max(zoomY, 1e-6); + const centerX = nextBounds.x + centerXRatio * nextBounds.width; + const centerY = nextBounds.y + centerYRatio * nextBounds.height; + return { + x: centerX - width / 2, + y: centerY - height / 2, + width, + height, + }; + } + + function updateGraphViewBox() { + const svg = document.getElementById('graph'); + if (!graphView) return; + svg.setAttribute( + 'viewBox', + `${graphView.x.toFixed(1)} ${graphView.y.toFixed(1)} ${graphView.width.toFixed(1)} ${graphView.height.toFixed(1)}` + ); + } + + function clampGraphView() { + if (!graphView) return; + const boundsX = graphView.boundsX ?? 0; + const boundsY = graphView.boundsY ?? 0; + const boundsWidth = graphView.boundsWidth ?? graphView.layoutWidth; + const boundsHeight = graphView.boundsHeight ?? graphView.layoutHeight; + graphView.width = Math.min(boundsWidth, Math.max(graphView.layoutWidth / 10, graphView.width)); + graphView.height = Math.min(boundsHeight, Math.max(graphView.layoutHeight / 10, graphView.height)); + graphView.x = Math.min(Math.max(boundsX, graphView.x), boundsX + boundsWidth - graphView.width); + graphView.y = Math.min(Math.max(boundsY, graphView.y), boundsY + boundsHeight - graphView.height); + } + + function zoomGraph(factor) { + if (!graphView) return; + const centerX = graphView.x + graphView.width / 2; + const centerY = graphView.y + graphView.height / 2; + graphView.width *= factor; + graphView.height *= factor; + graphView.x = centerX - graphView.width / 2; + graphView.y = centerY - graphView.height / 2; + clampGraphView(); + updateGraphViewBox(); + } + + function resetGraphView() { + graphView = null; + renderGraph(); + } + + function renderAgentPicker() { + const picker = document.getElementById('agentSelect'); + if (isWorkflowMode()) { + picker.innerHTML = ''; + picker.value = ''; + picker.disabled = true; + return; + } + picker.disabled = false; + const options = [''].concat( + agents().map(agent => ``) + ).join(''); + picker.innerHTML = options; + picker.value = selectedAgent || ''; + } + + function selectedState() { + return agents().find(agent => agent.agent_id === selectedAgent) || null; + } + + function selectedActivityEvent() { + if (!selectedActivityEventKey) return null; + return allEvents().find(event => eventKey(event) === selectedActivityEventKey) || null; + } + + function currentDetailIdentity() { + if (selectedActivityEventKey) return `activity:${selectedActivityEventKey}`; + if (selectedEdgeKey) return `edge:${selectedEdgeKey}`; + if (selectedAgent) return `agent:${selectedAgent}`; + if (isWorkflowMode()) { + const event = currentEvent(); + return event ? `workflow-event:${eventKey(event)}` : 'workflow-empty'; + } + const event = currentEvent(); + return event ? `timeline-event:${eventKey(event)}` : 'empty'; + } + + function captureDetailScrollSnapshot() { + const blockIds = ['detailPrimary', 'detailSecondary', 'detailTertiary']; + const blocks = {}; + blockIds.forEach(id => { + const el = document.getElementById(id); + if (!el) return; + blocks[id] = { + scrollTop: el.scrollTop, + scrollLeft: el.scrollLeft, + }; + }); + return { + identity: lastRenderedDetailIdentity, + blocks, + }; + } + + function restoreDetailScrollSnapshot(snapshot) { + if (!snapshot || snapshot.identity !== currentDetailIdentity()) return; + Object.entries(snapshot.blocks || {}).forEach(([id, pos]) => { + const el = document.getElementById(id); + if (!el) return; + el.scrollTop = pos.scrollTop || 0; + el.scrollLeft = pos.scrollLeft || 0; + }); + } + + function renderDetail() { + const activityEvent = selectedActivityEvent(); + if (activityEvent) { + renderTimelineEventDetail( + activityEvent, + allEvents().findIndex(event => eventKey(event) === selectedActivityEventKey), + ); + return; + } + if (isWorkflowMode()) { + const event = currentEvent(); + if (event) { + renderTimelineEventDetail(event); + } else { + renderEmptyDetail(); + } + return; + } + if (selectedEdgeKey) { + renderEdgeDetail(selectedEdgeKey); + return; + } + const agent = selectedState(); + if (!agent) { + renderEmptyDetail(); + return; + } + selectedAgent = agent.agent_id; + document.getElementById('agentSelect').value = selectedAgent; + document.getElementById('detailTitle').textContent = agent.agent_id; + document.getElementById('detailCards').innerHTML = detailCards([ + ['Role', agent.role || '-'], + ['Host', agentHost(agent)], + ['Decisions', agent.decision_count || 0], + ['Received / Sent', `${agent.received_message_count || 0} / ${agent.outbox_count || 0}`], + ['Tools', `${agent.tool_finished_count || 0} / ${agent.tool_started_count || 0}`], + ['State', agent.last_error ? 'error' : agent.started ? 'active' : 'pending'], + ]); + const current = currentEvent(); + if (current?.event === 'agent_decision' && current.agent_id === agent.agent_id) { + setDetailBlock('detailPrimaryTitle', 'Current Decision', 'detailPrimary', formatDecisionEvent(current)); + setDetailBlock('detailSecondaryTitle', 'Wake Context', 'detailSecondary', formatWakeEvents(current)); + const turnEvents = workflowEventsForSelection(current); + if (turnEvents.length) { + setDetailHtmlBlock( + 'detailTertiaryTitle', + 'ChemGraph Turn', + 'detailTertiary', + detailRich(detailSection( + 'ChemGraph Panel', + paragraphsHtml('Open in the floating ChemGraph panel. Click inner graph nodes there to inspect LM, tool, action, and output details inside the panel.'), + 'info', + )), + ); + } else { + const received = eventsOf('message_received').filter(e => e.agent_id === agent.agent_id).slice(-4); + setDetailBlock('detailTertiaryTitle', 'Recent Received Messages', 'detailTertiary', received.length + ? received.map(formatMessageEvent).join('\n\n') + : 'No received messages at this point in the timeline.'); + } + return; + } + const beliefEvents = eventsOf('belief_updated').filter(e => e.agent_id === agent.agent_id); + const latestBelief = beliefEvents.length ? beliefEvents[beliefEvents.length - 1].payload : null; + setDetailBlock('detailPrimaryTitle', 'Current Belief', 'detailPrimary', latestBelief + ? formatBelief(latestBelief) + : 'No belief recorded at this point in the timeline.'); + const received = eventsOf('message_received').filter(e => e.agent_id === agent.agent_id).slice(-6); + setDetailHtmlBlock('detailSecondaryTitle', 'Received Messages', 'detailSecondary', received.length + ? messageHistoryHtml(received) + : '
No received messages at this point in the timeline.
'); + const turns = reasoningTurnsForAgent(agent.agent_id); + setDetailHtmlBlock( + 'detailTertiaryTitle', + 'ChemGraph Turn Entries', + 'detailTertiary', + reasoningTurnListHtml(turns) + (turns.length + ? detailRich(detailSection( + 'Open Turn', + paragraphsHtml('Click a ChemGraph turn chip attached to this agent in the graph, or click a turn row above. The inner ChemGraph graph opens in the floating panel.'), + 'info', + )) + : '
No ChemGraph turns for this agent at this point in the timeline.
'), + ); + } + + function renderEdgeDetail(edgeKey) { + const [sender, recipient] = edgeKey.split('->'); + const currentAgents = agents(); + const senderAgent = currentAgents.find(agent => agent.agent_id === sender); + const recipientAgent = currentAgents.find(agent => agent.agent_id === recipient); + const senderHost = agentHost(senderAgent); + const recipientHost = agentHost(recipientAgent); + const messages = eventsOf('message_sent').filter(e => { + const p = e.payload || {}; + return p.sender === sender && p.recipient === recipient; + }); + const latest = messages.length ? messages[messages.length - 1] : null; + const latestPayload = latest?.payload || {}; + const route = senderHost && recipientHost && senderHost !== recipientHost ? 'cross-node' : 'same-node'; + document.getElementById('detailTitle').textContent = `${sender} -> ${recipient}`; + document.getElementById('detailCards').innerHTML = detailCards([ + ['Route', route], + ['Messages', messages.length], + ['From host', senderHost], + ['To host', recipientHost], + ]); + setDetailHtmlBlock('detailPrimaryTitle', 'Latest Message', 'detailPrimary', latest + ? messageDetailHtml(latest) + : '
No message visible at this point in the timeline.
'); + const history = messages.slice(-8); + setDetailHtmlBlock('detailSecondaryTitle', 'Message History', 'detailSecondary', history.length + ? messageHistoryHtml(history) + : '
No messages visible at this point in the timeline.
'); + const messageIds = new Set(messages.map(e => e.payload?.message_id).filter(Boolean)); + const relatedBeliefs = eventsOf('belief_updated').filter(e => { + const refs = e.payload?.supporting_message_ids || []; + return refs.some(ref => messageIds.has(ref)); + }); + setDetailBlock('detailTertiaryTitle', 'Beliefs Citing This Edge', 'detailTertiary', relatedBeliefs.length + ? relatedBeliefs.slice(-6).map(e => `${formatTime(e.timestamp)} ${e.agent_id}\n${formatBelief(e.payload)}`).join('\n\n') + : 'No belief cites this relationship at this point in the timeline.'); + } + + function renderEmptyDetail() { + const event = currentEvent(); + if (event) { + renderTimelineEventDetail(event); + return; + } + document.getElementById('detailTitle').textContent = 'Timeline Event'; + document.getElementById('detailCards').innerHTML = ''; + setDetailBlock('detailPrimaryTitle', 'State', 'detailPrimary', 'No events yet.'); + setDetailBlock('detailSecondaryTitle', 'Evidence', 'detailSecondary', ''); + setDetailBlock('detailTertiaryTitle', 'History', 'detailTertiary', ''); + } + + function renderTimelineEventDetail(event, indexOverride = null) { + const index = indexOverride ?? currentEventIndex(); + const isToolEvent = ['tool_call_started', 'tool_call_finished', 'tool_call_failed', 'chemgraph_job_result'].includes(event.event); + const isNestedWorkflowEvent = isWorkflowEvent(event); + document.getElementById('detailTitle').textContent = isNestedWorkflowEvent + ? chemgraphNodeDetailTitle(event) + : isToolEvent + ? `Tool: ${toolDisplayName(event)}` + : `Timeline Event ${index + 1}`; + const cards = [ + ['Event', event.event], + ['Time', formatTime(event.timestamp)], + ['Agent', event.agent_id || '-'], + ['Role', event.role || '-'], + ]; + if (isToolEvent) cards.push(['Tool', toolDisplayName(event)]); + if (isNestedWorkflowEvent) cards.push(['Runtime', event.payload?.runtime || '-']); + document.getElementById('detailCards').innerHTML = detailCards(cards); + if (isNestedWorkflowEvent) { + setDetailHtmlBlock('detailPrimaryTitle', chemgraphNodeDetailTitle(event), 'detailPrimary', chemgraphNodeDetailHtml(event)); + setDetailHtmlBlock('detailSecondaryTitle', 'Node Payload', 'detailSecondary', payloadDetailHtml(event.payload || {})); + setDetailHtmlBlock('detailTertiaryTitle', 'ChemGraph Context', 'detailTertiary', chemgraphNodeContextHtml(event)); + return; + } + if (event.event === 'agent_decision') { + setDetailBlock('detailPrimaryTitle', 'Agent Decision', 'detailPrimary', formatDecisionEvent(event)); + setDetailBlock('detailSecondaryTitle', 'Wake Context', 'detailSecondary', formatWakeEvents(event)); + const turnEvents = workflowEventsForSelection(event); + if (turnEvents.length) { + setDetailHtmlBlock('detailTertiaryTitle', 'ChemGraph Turn', 'detailTertiary', detailRich(detailSection( + 'ChemGraph Panel', + paragraphsHtml('The ChemGraph turn for this decision is shown in the floating panel. Click an inner node to inspect it inside the panel.'), + 'info', + ))); + } else { + setDetailBlock('detailTertiaryTitle', 'Raw Action Count', 'detailTertiary', `${event.payload?.actions?.length || 0} action(s) returned by LM.`); + } + return; + } + if (event.event === 'message_sent' || event.event === 'message_received') { + setDetailHtmlBlock('detailPrimaryTitle', 'Message', 'detailPrimary', messageDetailHtml(event)); + setDetailHtmlBlock('detailSecondaryTitle', 'Route', 'detailSecondary', routeDetailHtml(event)); + setDetailHtmlBlock('detailTertiaryTitle', 'Payload', 'detailTertiary', payloadDetailHtml(event.payload || {})); + return; + } + if (event.event === 'belief_updated') { + setDetailBlock('detailPrimaryTitle', 'Belief Update', 'detailPrimary', formatBelief(event.payload || {})); + setDetailBlock('detailSecondaryTitle', 'Supporting Messages', 'detailSecondary', (event.payload?.supporting_message_ids || []).join('\n') || 'No message refs.'); + setDetailBlock('detailTertiaryTitle', 'Supporting Artifacts', 'detailTertiary', (event.payload?.supporting_artifact_ids || []).join('\n') || 'No artifact refs.'); + return; + } + if (['tool_call_started', 'tool_call_finished', 'tool_call_failed', 'chemgraph_job_result'].includes(event.event)) { + setDetailHtmlBlock('detailPrimaryTitle', 'Tool Event', 'detailPrimary', toolDetailHtml(event)); + setDetailHtmlBlock('detailSecondaryTitle', 'Tool Payload', 'detailSecondary', payloadDetailHtml(event.payload || {})); + const nested = nestedWorkflowEventsForTool(event); + setDetailHtmlBlock('detailTertiaryTitle', nested.length ? 'ChemGraph Tool Trace' : 'Correlation', 'detailTertiary', nested.length + ? workflowHistoryHtml(nested) + : payloadDetailHtml({correlation_id: event.correlation_id || 'No correlation id.'})); + return; + } + setDetailHtmlBlock('detailPrimaryTitle', 'Event Payload', 'detailPrimary', payloadDetailHtml(event.payload || {})); + setDetailBlock('detailSecondaryTitle', 'Correlation', 'detailSecondary', event.correlation_id || 'No correlation id.'); + setDetailBlock('detailTertiaryTitle', 'Selection', 'detailTertiary', 'Click a node or edge to inspect derived agent or communication state.'); + } + + function detailCards(items) { + return items.map(([label, value]) => ` +
+
${esc(label)}
+
${esc(value)}
+
+ `).join(''); + } + + function setDetailBlock(titleId, title, bodyId, body) { + document.getElementById(titleId).textContent = title; + const el = document.getElementById(bodyId); + const text = body || ''; + if (el.textContent !== text) { + el.textContent = text; + renderedHtmlCache.delete(el); + } + } + + function setDetailHtmlBlock(titleId, title, bodyId, bodyHtml) { + document.getElementById(titleId).textContent = title; + setStableHtml(document.getElementById(bodyId), bodyHtml || '', true); + } + + function setStableHtml(el, html, preserveScroll = true) { + if (!el) return; + const next = html || ''; + if (renderedHtmlCache.get(el) === next) return; + const scrollTop = el.scrollTop; + const scrollLeft = el.scrollLeft; + el.innerHTML = next; + renderedHtmlCache.set(el, next); + if (preserveScroll) { + el.scrollTop = scrollTop; + el.scrollLeft = scrollLeft; + } + } + + function selectActivityEventKey(key) { + if (!key) return; + selectedActivityEventKey = key; + selectedAgent = null; + selectedEdgeKey = null; + const index = allEvents().findIndex(event => eventKey(event) === key); + if (index >= 0) { + followLatest = false; + timelineIndex = index; + const event = allEvents()[index]; + if (workflowEventsForSelection(event).length) { + workflowPanelOpen = true; + } + } + render(); + } + + function handleDetailPaneClick(event) { + const target = event.target.closest('[data-detail-activity-key]'); + if (!target) return; + event.preventDefault(); + event.stopPropagation(); + selectActivityEventKey(target.dataset.detailActivityKey); + } + + function detailRich(...parts) { + return `
${parts.filter(Boolean).join('')}
`; + } + + function detailSection(title, body, tone = '') { + return ` +
+
${esc(title)}
+ ${body || '
None
'} +
+ `; + } + + function detailKvGrid(rows) { + const visibleRows = rows.filter(([_, value]) => !isEmptyDetailValue(value)); + if (!visibleRows.length) return '
None
'; + return ` +
+ ${visibleRows.map(([label, value, kind]) => ` +
${esc(label)}
+
${detailValueHtml(value, kind)}
+ `).join('')} +
+ `; + } + + function isEmptyDetailValue(value) { + return value === undefined + || value === null + || value === '' + || (Array.isArray(value) && value.length === 0); + } + + function detailValueHtml(value, kind = '') { + if (value === undefined || value === null || value === '') return '-'; + if (Array.isArray(value)) { + if (!value.length) return 'none'; + if (value.every(item => ['string', 'number', 'boolean'].includes(typeof item))) { + return detailChips(value, kind); + } + return collapsedJsonHtml(`Array (${value.length})`, value); + } + if (typeof value === 'object') { + return collapsedJsonHtml(`Object (${Object.keys(value).length})`, value); + } + const text = String(value); + if (kind === 'text') return paragraphsHtml(text); + return esc(text); + } + + function detailChips(values, kind = '') { + const list = Array.isArray(values) ? values : [values]; + if (!list.length) return 'none'; + return ` +
+ ${list.map(value => `${esc(String(value))}`).join('')} +
+ `; + } + + function paragraphsHtml(text) { + const value = String(text || '').trim(); + if (!value) return '
None
'; + const paragraphs = value.split(/\n{2,}/).map(part => part.trim()).filter(Boolean); + return `
${paragraphs.map(part => `

${esc(part)}

`).join('')}
`; + } + + function collapsedJsonHtml(summary, value) { + return ` +
+ ${esc(summary)} +
${esc(formatJson(value))}
+
+ `; + } + + function rawJsonDetails(value) { + return ` +
+ Raw JSON +
${esc(formatJson(value))}
+
+ `; + } + + function statusTone(status) { + const value = String(status || '').toLowerCase(); + if (['ok', 'success', 'completed', 'finished'].includes(value)) return 'ok'; + if (['failed', 'failure', 'error'].includes(value)) return 'error'; + if (['running', 'pending', 'submitted'].includes(value)) return 'warn'; + return 'info'; + } + + function messageDetailHtml(event, {includeRaw = true} = {}) { + const p = event.payload || {}; + const sender = p.sender || event.agent_id || '-'; + const recipient = p.recipient || '-'; + const refs = [ + ...(Array.isArray(p.evidence_refs) ? p.evidence_refs : []), + ...(Array.isArray(p.tool_result_ids) ? p.tool_result_ids : []), + ...(Array.isArray(p.supporting_message_ids) ? p.supporting_message_ids : []), + ]; + return detailRich( + detailSection('Route', detailKvGrid([ + ['Direction', `${sender} -> ${recipient}`], + ['Time', formatTime(event.timestamp)], + ['Message id', p.message_id || '-', 'mono'], + ['Event', event.event], + ]), 'info'), + p.tldr ? detailSection('TLDR', paragraphsHtml(p.tldr), 'info') : '', + p.content ? detailSection('Content', paragraphsHtml(p.content)) : '', + p.reason ? detailSection('Reason', paragraphsHtml(p.reason), 'warn') : '', + refs.length ? detailSection('References', detailChips(refs, 'action')) : '', + includeRaw ? rawJsonDetails(p) : '', + ); + } + + function messageHistoryHtml(messages) { + if (!messages.length) return detailRich('
No messages visible.
'); + return detailRich(messages.map(event => { + const p = event.payload || {}; + const title = `${p.sender || event.agent_id || '-'} -> ${p.recipient || '-'}`; + const body = [ + detailKvGrid([ + ['Time', formatTime(event.timestamp)], + ['Message id', p.message_id || '-', 'mono'], + ]), + p.tldr ? paragraphsHtml(p.tldr) : paragraphsHtml(trunc(p.content || '', 360)), + ].join(''); + return detailSection(title, body, ''); + }).join('')); + } + + function routeDetailHtml(event) { + const p = event.payload || {}; + const currentAgents = agents(); + const sender = currentAgents.find(agent => agent.agent_id === p.sender); + const recipient = currentAgents.find(agent => agent.agent_id === p.recipient); + const senderHost = agentHost(sender); + const recipientHost = agentHost(recipient); + const route = senderHost && recipientHost && senderHost !== recipientHost ? 'cross-node' : 'same-node'; + return detailRich( + detailSection('Route', detailKvGrid([ + ['Type', route], + ['Sender host', senderHost], + ['Recipient host', recipientHost], + ['Message id', p.message_id || '-', 'mono'], + ]), route === 'cross-node' ? 'ok' : 'info'), + ); + } + + function issueListHtml(issues) { + if (!Array.isArray(issues) || !issues.length) { + return '
No issues recorded.
'; + } + return issues.map((issue, index) => detailSection( + `${index + 1}. ${issue.field || 'field'}`, + detailKvGrid([ + ['Expected', issue.expected || '-'], + ['Received type', issue.received_type || '-'], + ['Received', issue.received ?? '', 'mono'], + ['Defaulted to', issue.defaulted_to ?? '', 'mono'], + ['Normalized to', issue.normalized_to ?? '', 'mono'], + ['Dropped items', issue.dropped_items ?? ''], + ['Allowed peers', issue.allowed_peers || [], 'action'], + ]), + )).join(''); + } + + function workflowDetailHtml(event) { + const p = event.payload || {}; + const calls = Array.isArray(p.tool_calls) + ? p.tool_calls.map(call => call.name || call.id || 'tool').filter(Boolean) + : []; + return detailRich( + detailSection('Workflow', detailKvGrid([ + ['Status', p.status || event.event], + ['Runtime', p.runtime || '-'], + ['Workflow', p.workflow_type || '-'], + ['Node', p.workflow_node || '-'], + ['Round', p.round ?? '-'], + ['Thread', p.thread_id || '-', 'mono'], + ['Time', formatTime(event.timestamp)], + ]), statusTone(p.status)), + detailSection('Span', detailKvGrid([ + ['Span id', p.span_id || '-', 'mono'], + ['Parent span', p.parent_span_id || '-', 'mono'], + ['Correlation', event.correlation_id || '-', 'mono'], + ['Log dir', p.log_dir || '-', 'mono'], + ])), + calls.length ? detailSection('Tool Calls', detailChips(calls, 'science')) : '', + p.content_preview ? detailSection('Preview', paragraphsHtml(p.content_preview)) : '', + p.error ? detailSection('Error', paragraphsHtml(p.error), 'error') : '', + ); + } + + function chemgraphNodeDetailTitle(event) { + if (event.event === 'llm_decision') return 'LM Decision Node'; + if (event.event.startsWith('tool_call_')) return `Tool Node: ${toolDisplayName(event)}`; + if (event.event === 'workflow_started' || event.event === 'run_started') return 'Wake Input Node'; + if (event.event === 'workflow_finished' || event.event === 'workflow_output' || event.event === 'run_finished') return 'Output Node'; + return `ChemGraph Node: ${event.event}`; + } + + function chemgraphNodeDetailHtml(event) { + const p = event.payload || {}; + if (event.event === 'llm_decision') { + const calls = Array.isArray(p.tool_calls) ? p.tool_calls : []; + return detailRich( + detailSection('LM Decision', detailKvGrid([ + ['Agent', event.agent_id || p.agent_id || '-'], + ['Role', event.role || p.role || '-'], + ['Time', formatTime(event.timestamp)], + ['Message index', p.message_index ?? '-'], + ['Tool calls', calls.length], + ]), 'info'), + tokenDetailSection(event), + promptDetailSection(event), + calls.length ? detailSection('Requested Tool Calls', detailChips(calls.map(call => call.name || call.id || 'tool'), 'science')) : '', + calls.length ? detailSection('Call Arguments', calls.map((call, index) => detailSection( + `${index + 1}. ${call.name || call.id || 'tool'}`, + detailValueHtml(call.args || call.arguments || {}), + )).join('')) : '', + p.content_preview ? detailSection('Response Preview', paragraphsHtml(p.content_preview)) : '', + ); + } + if (event.event.startsWith('tool_call_')) { + return toolDetailHtml(event); + } + if (event.event === 'workflow_started' || event.event === 'run_started') { + return detailRich( + detailSection('Wake Input', detailKvGrid([ + ['Agent', event.agent_id || p.agent_id || '-'], + ['Role', event.role || p.role || '-'], + ['Round', p.round ?? '-'], + ['Thread', p.thread_id || '-', 'mono'], + ['Time', formatTime(event.timestamp)], + ['Tool count', Array.isArray(p.tool_names) ? p.tool_names.length : '-'], + ]), 'info'), + Array.isArray(p.tool_names) && p.tool_names.length + ? detailSection('Available Tools', detailChips(p.tool_names, 'science')) + : '', + p.query ? detailSection('Query', paragraphsHtml(p.query)) : '', + ); + } + if (event.event === 'workflow_finished' || event.event === 'workflow_output' || event.event === 'run_finished') { + return detailRich( + detailSection('Output', detailKvGrid([ + ['Status', p.status || event.event], + ['Agent', event.agent_id || p.agent_id || '-'], + ['Round', p.round ?? '-'], + ['Time', formatTime(event.timestamp)], + ]), statusTone(p.status)), + tokenDetailSection(event), + promptDetailSection(event), + p.content_preview ? detailSection('Preview', paragraphsHtml(p.content_preview)) : '', + p.error ? detailSection('Error', paragraphsHtml(p.error), 'error') : '', + ); + } + return workflowDetailHtml(event); + } + + function chemgraphNodeContextHtml(event) { + const p = event.payload || {}; + const turnEvents = workflowEventsForSelection(event); + const flow = workflowFlowGraph(turnEvents); + return detailRich( + detailSection('Turn Context', detailKvGrid([ + ['Agent', workflowAgentId(event) || '-'], + ['Runtime', p.runtime || '-'], + ['Round', p.round ?? '-'], + ['Thread', p.thread_id || '-', 'mono'], + ['Workflow span', workflowRootSpanId(event) || '-', 'mono'], + ['Selected span', p.span_id || event.correlation_id || '-', 'mono'], + ['Visible nodes', flow.nodes.length], + ['Visible events', turnEvents.length], + ]), 'info'), + ); + } + + function workflowHistoryHtml(events) { + if (!events.length) return detailRich('
No related workflow events visible.
'); + return detailRich(events.map(event => { + const p = event.payload || {}; + return detailSection(event.event, detailKvGrid([ + ['Time', formatTime(event.timestamp)], + ['Status', p.status || '-'], + ['Round', p.round ?? '-'], + ['Tool', p.tool_name || '-'], + ['Span', p.span_id || event.correlation_id || '-', 'mono'], + ]), statusTone(p.status)); + }).join('')); + } + + function toolDetailHtml(event) { + const p = event.payload || {}; + const status = p.status || p.result?.status || event.event; + return detailRich( + detailSection('Tool', detailKvGrid([ + ['Tool name', toolDisplayName(event)], + ['Status', status], + ['Event', event.event], + ['Time', formatTime(event.timestamp)], + ['Agent', event.agent_id || '-'], + ['Call id', p.tool_result_id || p.correlation_id || event.correlation_id || '-', 'mono'], + ]), statusTone(status)), + p.arguments ? detailSection('Arguments', detailValueHtml(p.arguments)) : '', + p.error ? detailSection('Error', paragraphsHtml(p.error), 'error') : '', + p.result ? detailSection('Result', detailValueHtml(p.result)) : '', + ); + } + + function payloadDetailHtml(payload) { + if (!payload || typeof payload !== 'object' || Array.isArray(payload)) { + return detailRich(detailSection('Value', detailValueHtml(payload)), rawJsonDetails(payload)); + } + const priority = [ + 'status', + 'tool_name', + 'tool_result_id', + 'message_id', + 'sender', + 'recipient', + 'tldr', + 'content_preview', + 'error', + 'reason', + 'runtime', + 'workflow_type', + 'workflow_node', + 'round', + 'thread_id', + 'span_id', + 'parent_span_id', + 'correlation_id', + 'log_dir', + 'run_dir', + 'model_name', + ]; + const prioritySet = new Set(priority); + const priorityRows = priority + .filter(key => Object.prototype.hasOwnProperty.call(payload, key)) + .map(key => [fieldLabel(key), payload[key], fieldKind(key)]); + const otherRows = Object.keys(payload) + .filter(key => !prioritySet.has(key)) + .sort() + .map(key => [fieldLabel(key), payload[key], fieldKind(key)]); + return detailRich( + priorityRows.length ? detailSection('Key Fields', detailKvGrid(priorityRows), statusTone(payload.status)) : '', + otherRows.length ? detailSection('Additional Fields', detailKvGrid(otherRows)) : '', + rawJsonDetails(payload), + ); + } + + function fieldLabel(key) { + return String(key || '').replaceAll('_', ' '); + } + + function fieldKind(key) { + if (/(^|_)(id|dir|path|file|span|thread|correlation)($|_)/.test(String(key))) return 'mono'; + if (['content', 'content_preview', 'reason', 'error', 'tldr'].includes(String(key))) return 'text'; + return ''; + } + + function formatBelief(payload) { + const lines = [ + `Hypothesis: ${payload.hypothesis || '-'}`, + `Confidence: ${payload.confidence ?? '-'}`, + ]; + if (payload.reason) lines.push(`Reason: ${payload.reason}`); + if (payload.supporting_message_ids?.length) lines.push(`Messages: ${payload.supporting_message_ids.join(', ')}`); + if (payload.supporting_artifact_ids?.length) lines.push(`Artifacts: ${payload.supporting_artifact_ids.join(', ')}`); + return lines.join('\n'); + } + + function formatMessageEvent(event) { + const p = event.payload || {}; + const lines = [ + `${formatTime(event.timestamp)} ${p.sender || event.agent_id || '-'} -> ${p.recipient || '-'}`, + ]; + if (p.tldr) lines.push(`TLDR: ${p.tldr}`); + if (p.content) lines.push(p.content); + if (p.evidence_refs?.length) lines.push(`Evidence: ${p.evidence_refs.join(', ')}`); + if (p.reason) lines.push(`Reason: ${p.reason}`); + return lines.filter(Boolean).join('\n'); + } + + function formatMessageRoute(event) { + const p = event.payload || {}; + const currentAgents = agents(); + const sender = currentAgents.find(agent => agent.agent_id === p.sender); + const recipient = currentAgents.find(agent => agent.agent_id === p.recipient); + const senderHost = agentHost(sender); + const recipientHost = agentHost(recipient); + const route = senderHost && recipientHost && senderHost !== recipientHost ? 'cross-node' : 'same-node'; + return [ + `Route: ${route}`, + `Sender host: ${senderHost}`, + `Recipient host: ${recipientHost}`, + `Message id: ${p.message_id || '-'}`, + ].join('\n'); + } + + function formatDecisionEvent(event) { + const p = event.payload || {}; + const lines = [ + `${formatTime(event.timestamp)} ${event.agent_id || '-'} decision`, + ]; + if (p.mode) lines.push(`Mode: ${p.mode}`); + if (p.wake_reason) lines.push(`Wake reason: ${p.wake_reason}`); + if (p.rationale) lines.push(`\nRationale:\n${p.rationale}`); + const actions = Array.isArray(p.actions) ? p.actions : []; + if (actions.length) { + lines.push('\nActions:'); + actions.forEach((action, index) => { + lines.push(formatDecisionToolCall(action, index + 1)); + }); + } else { + lines.push('\nActions: none'); + } + const ignored = Array.isArray(p.ignored_actions) ? p.ignored_actions : []; + if (ignored.length) { + lines.push(`\nIgnored actions: ${ignored.length}`); + } + return lines.join('\n'); + } + + function formatDecisionToolCall(action, number) { + const parts = [`${number}. ${action.action || 'unknown_action'}`]; + if (action.recipient) parts.push(`recipient=${action.recipient}`); + if (action.tool_name) parts.push(`tool=${action.tool_name}`); + if (action.confidence !== null && action.confidence !== undefined) parts.push(`confidence=${action.confidence}`); + let text = parts.join(' | '); + if (action.question) text += `\n Question: ${action.question}`; + if (action.content) text += `\n Content: ${action.content}`; + if (action.hypothesis) text += `\n Hypothesis: ${action.hypothesis}`; + if (action.reason) text += `\n Reason: ${action.reason}`; + if (action.evidence_refs?.length) text += `\n Evidence: ${action.evidence_refs.join(', ')}`; + if (action.supporting_message_ids?.length) text += `\n Messages: ${action.supporting_message_ids.join(', ')}`; + if (action.supporting_artifact_ids?.length) text += `\n Artifacts: ${action.supporting_artifact_ids.join(', ')}`; + if (action.arguments && Object.keys(action.arguments).length) { + text += `\n Arguments: ${formatJson(action.arguments)}`; + } + return text; + } + + function formatWakeEvents(event) { + const wakeEvents = event.payload?.wake_events || []; + if (!Array.isArray(wakeEvents) || !wakeEvents.length) { + return 'No wake events recorded for this decision.'; + } + return wakeEvents.map((wake, index) => { + const payload = wake.payload || {}; + const summary = [ + `${index + 1}. ${formatTime(wake.timestamp)} ${wake.event}${wake.agent_id ? ` · ${wake.agent_id}` : ''}`, + ]; + if (payload.message_id) summary.push(` Message: ${payload.message_id}`); + if (payload.sender || payload.recipient) summary.push(` Route: ${payload.sender || '-'} -> ${payload.recipient || '-'}`); + if (payload.tool_name) summary.push(` Tool: ${payload.tool_name}`); + if (payload.content) summary.push(` Content: ${trunc(payload.content, 500)}`); + if (payload.result) summary.push(` Result: ${formatJson(payload.result)}`); + return summary.join('\n'); + }).join('\n\n'); + } + + function formatToolEvent(event) { + const p = event.payload || {}; + const status = p.status || p.result?.status || event.event; + const toolName = toolDisplayName(event); + const lines = [ + `${formatTime(event.timestamp)} ${event.event}`, + `Tool: ${toolName}`, + `Status: ${status}`, + ]; + if (event.correlation_id) lines.push(`Call: ${event.correlation_id}`); + const results = p.results || p.result?.results; + if (Array.isArray(results)) { + lines.push(`Results: ${results.map(item => `${item.index ?? '-'}:${item.status || item.error_type || '-'}`).join(', ')}`); + } + if (p.error) lines.push(`Error: ${p.error}`); + return lines.join('\n'); + } + + function isWorkflowEvent(event) { + return [ + 'run_started', + 'run_finished', + 'workflow_started', + 'workflow_node_started', + 'workflow_node_finished', + 'workflow_output', + 'workflow_finished', + 'llm_decision', + 'tool_call_started', + 'tool_call_finished', + 'tool_call_failed', + ].includes(event.event) && Boolean(event.payload?.nested || event.payload?.runtime); + } + + function workflowSpanId(event) { + const p = event.payload || {}; + return p.span_id || event.correlation_id || null; + } + + function workflowParentSpanId(event) { + const p = event.payload || {}; + return p.parent_span_id || null; + } + + function workflowAgentId(event) { + const p = event.payload || {}; + return event.agent_id || p.agent_id || p.agent_name || p.agent || null; + } + + function workflowRootSpanId(event) { + const p = event.payload || {}; + if (p.workflow_span_id) return p.workflow_span_id; + if ((event.event === 'workflow_started' || event.event === 'workflow_finished') && p.span_id) { + return p.span_id; + } + return p.parent_span_id || p.span_id || event.correlation_id || null; + } + + function workflowEventsForSpan(spanId) { + if (!spanId) return []; + return visibleEvents().filter(candidate => ( + isWorkflowEvent(candidate) && workflowRootSpanId(candidate) === spanId + )); + } + + function workflowEventsForSelection(event) { + if (!event) return []; + const p = event.payload || {}; + const spanId = p.workflow_span_id || (isWorkflowEvent(event) ? workflowRootSpanId(event) : null); + if (spanId) return workflowEventsForSpan(spanId); + const agentId = event.agent_id || p.agent_id || p.agent_name; + const round = p.round; + if (!agentId || round === undefined || round === null) return []; + return visibleEvents().filter(candidate => ( + isWorkflowEvent(candidate) + && workflowAgentId(candidate) === agentId + && candidate.payload?.round === round + )); + } + + function reasoningTurnsForAgent(agentId) { + if (!agentId) return []; + const bySpan = new Map(); + visibleEvents().forEach(event => { + if (!isWorkflowEvent(event) || workflowAgentId(event) !== agentId) return; + const spanId = workflowRootSpanId(event); + if (!spanId) return; + if (!bySpan.has(spanId)) { + bySpan.set(spanId, { + spanId, + events: [], + round: null, + threadId: '', + started: null, + finished: null, + firstTimestamp: null, + lastTimestamp: null, + }); + } + const turn = bySpan.get(spanId); + turn.events.push(event); + const p = event.payload || {}; + if (p.round !== undefined && p.round !== null) turn.round = p.round; + if (p.thread_id) turn.threadId = p.thread_id; + if (event.event === 'workflow_started') turn.started = event; + if (event.event === 'workflow_finished') turn.finished = event; + const ts = eventTimestamp(event); + if (ts !== null) { + turn.firstTimestamp = turn.firstTimestamp === null ? ts : Math.min(turn.firstTimestamp, ts); + turn.lastTimestamp = turn.lastTimestamp === null ? ts : Math.max(turn.lastTimestamp, ts); + } + }); + return Array.from(bySpan.values()) + .map(turn => { + const toolNodes = workflowFlowGraph(turn.events).nodes.filter(node => node.type === 'tool'); + const actionTools = toolNodes.filter(node => node.toolClass === 'action-tool'); + const scienceTools = toolNodes.filter(node => node.toolClass === 'science-tool'); + return { + ...turn, + status: turn.finished?.payload?.status || 'running', + representative: turn.started || turn.events[0], + lmCount: workflowTokenEvents(turn.events).length || turn.events.filter(event => event.event === 'llm_decision').length, + tokenTotals: summedTokenCounts(turn.events), + actionToolCount: actionTools.length, + scienceToolCount: scienceTools.length, + }; + }) + .sort((a, b) => (a.firstTimestamp ?? 0) - (b.firstTimestamp ?? 0)); + } + + function reasoningTurnListHtml(turns) { + if (!turns.length) return ''; + const rows = turns.slice(-6).reverse().map(turn => { + const key = eventKey(turn.representative); + const round = turn.round === null || turn.round === undefined ? '-' : turn.round; + const title = `Round ${round}`; + const meta = [ + `${turn.status}`, + `${turn.lmCount} LM`, + turn.tokenTotals ? `${formatTokenCount(turn.tokenTotals.total)} tok` : '', + `${turn.actionToolCount} action`, + `${turn.scienceToolCount} science`, + formatTime(turn.firstTimestamp), + ].filter(Boolean).join(' · '); + return ` + + `; + }).join(''); + return `
${rows}
`; + } + + function nestedWorkflowEventsForTool(event) { + const p = event.payload || {}; + const callId = p.tool_result_id || p.correlation_id || event.correlation_id; + if (!callId) return []; + return allEvents().filter(candidate => { + if (!isWorkflowEvent(candidate)) return false; + return workflowParentSpanId(candidate) === callId || candidate.payload?.parent_tool_name === toolDisplayName(event); + }); + } + + function relatedWorkflowEvents(event) { + const spanId = workflowSpanId(event); + const parentSpanId = workflowParentSpanId(event); + if (!spanId && !parentSpanId) return []; + return allEvents().filter(candidate => { + if (!isWorkflowEvent(candidate)) return false; + const candidateSpan = workflowSpanId(candidate); + const candidateParent = workflowParentSpanId(candidate); + return candidateSpan === spanId + || candidateParent === spanId + || (parentSpanId && (candidateSpan === parentSpanId || candidateParent === parentSpanId)); + }); + } + + function formatWorkflowEvent(event) { + const p = event.payload || {}; + const lines = [ + `${formatTime(event.timestamp)} ${event.event}`, + `Runtime: ${p.runtime || '-'}`, + `Span: ${p.span_id || '-'}`, + ]; + if (p.parent_span_id) lines.push(`Parent span: ${p.parent_span_id}`); + if (p.workflow_type) lines.push(`Workflow: ${p.workflow_type}`); + if (p.workflow_node) lines.push(`Node: ${p.workflow_node}`); + if (p.phase) lines.push(`Phase: ${p.phase}`); + if (p.status) lines.push(`Status: ${p.status}`); + if (p.model_name) lines.push(`Model: ${p.model_name}`); + if (p.log_dir) lines.push(`Log dir: ${p.log_dir}`); + if (Array.isArray(p.tool_calls) && p.tool_calls.length) { + lines.push(`Tool calls: ${p.tool_calls.map(call => call.name || call.id || 'tool').join(', ')}`); + } + if (p.tool_name) lines.push(`Tool: ${p.tool_name}`); + if (p.content_preview) lines.push(`Preview: ${p.content_preview}`); + if (p.error) lines.push(`Error: ${p.error}`); + return lines.join('\n'); + } + + function formatWorkflowHistory(event) { + const history = relatedWorkflowEvents(event); + return history.length + ? history.map(formatWorkflowEvent).join('\n\n') + : 'No related workflow events visible.'; + } + + function formatJson(value) { + try { + return JSON.stringify(value, null, 2); + } catch { + return String(value); + } + } + + function toggleReplay() { + if (isReplaying) { + stopReplay(true); + } else { + startReplay(); + } + } + + function startReplay() { + const events = allEvents(); + if (!events.length) return; + followLatest = false; + isReplaying = true; + graphMode = 'current'; + selectedEdgeKey = null; + replayStartIndex = Math.max(0, currentEventIndex()); + replayStartTimestamp = eventTimestamp(events[replayStartIndex]) ?? firstTimestamp(); + replayStartedAtMs = Date.now(); + if (replayTimer) window.clearInterval(replayTimer); + replayTimer = window.setInterval(tickReplay, 120); + tickReplay(); + } + + function stopReplay(renderNow = false) { + isReplaying = false; + if (replayTimer) { + window.clearInterval(replayTimer); + replayTimer = null; + } + if (renderNow) render(); + } + + function tickReplay() { + const events = allEvents(); + if (!events.length) { + stopReplay(true); + return; + } + if (replayStartTimestamp === null) { + timelineIndex = Math.min(events.length - 1, currentEventIndex() + 1); + } else { + const elapsedSeconds = ((Date.now() - replayStartedAtMs) / 1000) * replaySpeed(); + timelineIndex = Math.max( + replayStartIndex, + eventIndexAtTimestamp(replayStartTimestamp + elapsedSeconds) + ); + } + render(); + if (timelineIndex >= events.length - 1) stopReplay(false); + } + + function startGraphPan(event) { + if (!graphView || event.button !== 0) return; + if (event.target.closest('.agent-node') || event.target.closest('[data-edge-key]') || event.target.closest('.activity-bubble')) return; + graphPanDrag = { + clientX: event.clientX, + clientY: event.clientY, + startX: graphView.x, + startY: graphView.y, + }; + document.getElementById('graph').classList.add('panning'); + event.preventDefault(); + } + + function moveGraphPan(event) { + if (!graphPanDrag || !graphView) return; + const svg = document.getElementById('graph'); + const rect = svg.getBoundingClientRect(); + const dx = (event.clientX - graphPanDrag.clientX) * graphView.width / Math.max(rect.width, 1); + const dy = (event.clientY - graphPanDrag.clientY) * graphView.height / Math.max(rect.height, 1); + graphView.x = graphPanDrag.startX - dx; + graphView.y = graphPanDrag.startY - dy; + clampGraphView(); + updateGraphViewBox(); + } + + function stopGraphPan() { + graphPanDrag = null; + document.getElementById('graph').classList.remove('panning'); + } + + function startEmbeddedWorkflowPan(event) { + if (!embeddedWorkflowView || event.button !== 0) return; + if (event.target.closest('.workflow-node')) return; + embeddedWorkflowPanDrag = { + clientX: event.clientX, + clientY: event.clientY, + startX: embeddedWorkflowView.x, + startY: embeddedWorkflowView.y, + }; + document.getElementById('embeddedWorkflowGraph').classList.add('panning'); + event.preventDefault(); + } + + function moveEmbeddedWorkflowPan(event) { + if (!embeddedWorkflowPanDrag || !embeddedWorkflowView) return; + const svg = document.getElementById('embeddedWorkflowGraph'); + const rect = svg.getBoundingClientRect(); + const dx = (event.clientX - embeddedWorkflowPanDrag.clientX) * embeddedWorkflowView.width / Math.max(rect.width, 1); + const dy = (event.clientY - embeddedWorkflowPanDrag.clientY) * embeddedWorkflowView.height / Math.max(rect.height, 1); + embeddedWorkflowView.x = embeddedWorkflowPanDrag.startX - dx; + embeddedWorkflowView.y = embeddedWorkflowPanDrag.startY - dy; + clampEmbeddedWorkflowView(); + updateEmbeddedWorkflowViewBox(); + } + + function stopEmbeddedWorkflowPan() { + embeddedWorkflowPanDrag = null; + document.getElementById('embeddedWorkflowGraph').classList.remove('panning'); + } + + function startWorkflowPanelDrag(event) { + if (event.button !== 0 || event.target.closest('button')) return; + workflowPanelDrag = { + clientX: event.clientX, + clientY: event.clientY, + startX: workflowPanelFrame.x, + startY: workflowPanelFrame.y, + }; + event.preventDefault(); + } + + function moveWorkflowPanelDrag(event) { + if (!workflowPanelDrag) return; + workflowPanelFrame.x = workflowPanelDrag.startX + event.clientX - workflowPanelDrag.clientX; + workflowPanelFrame.y = workflowPanelDrag.startY + event.clientY - workflowPanelDrag.clientY; + applyWorkflowPanelFrame(); + } + + function stopWorkflowPanelDrag() { + workflowPanelDrag = null; + } + + function startWorkflowPanelResize(event) { + if (event.button !== 0) return; + workflowPanelResizeDrag = { + clientX: event.clientX, + clientY: event.clientY, + startWidth: workflowPanelFrame.width, + startHeight: workflowPanelFrame.height, + }; + event.preventDefault(); + event.stopPropagation(); + } + + function moveWorkflowPanelResize(event) { + if (!workflowPanelResizeDrag) return; + workflowPanelFrame.width = workflowPanelResizeDrag.startWidth + event.clientX - workflowPanelResizeDrag.clientX; + workflowPanelFrame.height = workflowPanelResizeDrag.startHeight + event.clientY - workflowPanelResizeDrag.clientY; + applyWorkflowPanelFrame(); + } + + function stopWorkflowPanelResize() { + workflowPanelResizeDrag = null; + } + + function startDetailResize(event) { + if (event.button !== 0) return; + const currentWidth = parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--detail-width')) || 430; + detailResizeDrag = { + clientX: event.clientX, + startWidth: currentWidth, + }; + document.getElementById('detailResizer').classList.add('active'); + event.preventDefault(); + } + + function moveDetailResize(event) { + if (!detailResizeDrag) return; + const width = Math.min(760, Math.max(320, detailResizeDrag.startWidth - (event.clientX - detailResizeDrag.clientX))); + document.documentElement.style.setProperty('--detail-width', `${width}px`); + } + + function stopDetailResize() { + detailResizeDrag = null; + document.getElementById('detailResizer').classList.remove('active'); + } + + document.getElementById('refresh').addEventListener('click', load); + document.getElementById('playReplay').addEventListener('click', toggleReplay); + document.getElementById('timeSlider').addEventListener('input', e => { + stopReplay(false); + followLatest = false; + timelineIndex = Number(e.target.value); + render(); + }); + document.getElementById('latest').addEventListener('click', () => { + stopReplay(false); + followLatest = true; + timelineIndex = allEvents().length - 1; + render(); + }); + document.getElementById('eventHold').addEventListener('change', render); + document.getElementById('zoomIn').addEventListener('click', () => zoomGraph(0.86)); + document.getElementById('zoomOut').addEventListener('click', () => zoomGraph(1.16)); + document.getElementById('zoomReset').addEventListener('click', resetGraphView); + document.getElementById('graph').addEventListener('mousedown', startGraphPan); + document.getElementById('graph').addEventListener('wheel', event => { + if (!graphView) return; + event.preventDefault(); + zoomGraph(event.deltaY < 0 ? 0.94 : 1.06); + }, {passive: false}); + document.getElementById('embeddedWorkflowGraph').addEventListener('mousedown', startEmbeddedWorkflowPan); + document.getElementById('embeddedWorkflowGraph').addEventListener('wheel', event => { + if (!embeddedWorkflowView) return; + event.preventDefault(); + zoomEmbeddedWorkflow(event.deltaY < 0 ? 0.94 : 1.06); + }, {passive: false}); + document.getElementById('workflowZoomIn').addEventListener('click', () => zoomEmbeddedWorkflow(0.86)); + document.getElementById('workflowZoomOut').addEventListener('click', () => zoomEmbeddedWorkflow(1.16)); + document.getElementById('workflowZoomReset').addEventListener('click', resetEmbeddedWorkflowView); + document.getElementById('workflowPanelClose').addEventListener('click', () => { + workflowPanelOpen = false; + renderEmbeddedWorkflowPanel(); + }); + document.getElementById('workflowFloatingTab').addEventListener('click', () => { + workflowPanelOpen = true; + renderEmbeddedWorkflowPanel(); + }); + document.getElementById('workflowFloatingHead').addEventListener('mousedown', startWorkflowPanelDrag); + document.getElementById('workflowPanelResize').addEventListener('mousedown', startWorkflowPanelResize); + document.getElementById('detailResizer').addEventListener('mousedown', startDetailResize); + window.addEventListener('mousemove', moveGraphPan); + window.addEventListener('mousemove', moveEmbeddedWorkflowPan); + window.addEventListener('mousemove', moveWorkflowPanelDrag); + window.addEventListener('mousemove', moveWorkflowPanelResize); + window.addEventListener('mousemove', moveDetailResize); + window.addEventListener('mouseup', stopGraphPan); + window.addEventListener('mouseup', stopEmbeddedWorkflowPan); + window.addEventListener('mouseup', stopWorkflowPanelDrag); + window.addEventListener('mouseup', stopWorkflowPanelResize); + window.addEventListener('mouseup', stopDetailResize); + document.getElementById('detailPrimary').addEventListener('click', handleDetailPaneClick); + document.getElementById('detailSecondary').addEventListener('click', handleDetailPaneClick); + document.getElementById('detailTertiary').addEventListener('click', handleDetailPaneClick); + document.querySelectorAll('#graphMode button').forEach(button => { + button.addEventListener('click', () => { + graphMode = button.dataset.mode; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + }); + }); + document.getElementById('agentSelect').addEventListener('change', e => { + selectedAgent = e.target.value || null; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + }); + load(); + setInterval(load, 2000); diff --git a/src/chemgraph/academy/dashboard/static/index.html b/src/chemgraph/academy/dashboard/static/index.html new file mode 100644 index 00000000..f26c106f --- /dev/null +++ b/src/chemgraph/academy/dashboard/static/index.html @@ -0,0 +1,703 @@ + + + + + + ChemGraph Academy Dashboard + + + +
+

ChemGraph Academy Dashboard

+
+ + +
+
+
+
+
+
+
Run State
+
+
+
+
+
+
+
+
+
Agent Graph
+
+
+
+
+ + + + + +
+
+ + + +
+
+ Graph + + + +
+
+
+ + +
+
+
+
+ +
+
+
+
+
+
+
+
+
Selection
+ +
+
+
+

State

+
+

Evidence

+
+

History

+
+
+
+
+
+ + + + + diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py new file mode 100644 index 00000000..2af388c0 --- /dev/null +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +import asyncio +import json +import pathlib +import shutil +import time +from collections import Counter +from typing import Any + +from chemgraph.academy.observability.communication_proof import ( + build_communication_proof, +) +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import read_events +from chemgraph.academy.observability.run_files import append_jsonl +from chemgraph.academy.observability.run_files import write_json +from chemgraph.academy.observability.run_files import write_json_atomic +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.runtime.mpi import append_system_trace +from chemgraph.academy.core.lm import LLMSettings + + +def write_run_artifacts(run_dir: str | pathlib.Path) -> dict[str, Any]: + """Write placement, communication proof, and summary artifacts.""" + root = pathlib.Path(run_dir) + events = read_events(root / "events.jsonl") + placement = build_placement(events, root / "status.json") + proof = build_communication_proof(events, placement) + summary = summarize_events(events) + + write_json(root / "placement.json", placement) + write_json(root / "communication_proof.json", proof) + write_json(root / "summary.json", summary) + return { + "placement": placement, + "communication_proof": proof, + "summary": summary, + } + + +def build_placement( + events: list[CampaignEvent], + status_path: str | pathlib.Path | None = None, +) -> dict[str, Any]: + """Build agent placement proof from events and latest status.""" + agents: dict[str, dict[str, Any]] = {} + for event in events: + if event.event != "agent_started" or not event.agent_id: + continue + placement = event.payload.get("placement") + if isinstance(placement, dict): + agents[event.agent_id] = { + "agent_id": event.agent_id, + "role": event.role, + **placement, + } + + if status_path is not None: + path = pathlib.Path(status_path) + if path.exists(): + try: + status = json.loads(path.read_text(encoding="utf-8")) + states = status.get("agent_states", {}) + if isinstance(states, dict): + for agent_id, state in states.items(): + if not isinstance(state, dict): + continue + placement = state.get("placement") + if isinstance(placement, dict): + agents.setdefault( + agent_id, + { + "agent_id": agent_id, + "role": state.get("role"), + **placement, + }, + ) + except json.JSONDecodeError: + pass + + hostnames = sorted( + { + str(record.get("hostname")) + for record in agents.values() + if record.get("hostname") + }, + ) + return { + "agent_count": len(agents), + "hostnames": hostnames, + "distinct_hostname_count": len(hostnames), + "agents": dict(sorted(agents.items())), + } + + +def summarize_events(events: list[CampaignEvent]) -> dict[str, Any]: + """Return compact run summary from campaign events.""" + counts = Counter(event.event for event in events) + final_reports = _final_reports(events) + return { + "event_count": len(events), + "event_counts": dict(sorted(counts.items())), + "finish": _last_payload( + events, + {"campaign_finished", "workflow_finished", "run_finished"}, + ), + "agent_errors": _payloads_of(events, "agent_error"), + "message_count": counts.get("message_sent", 0), + "final_reports": final_reports, + "tool_results": _tool_result_summaries(events), + } + + +def _last_payload( + events: list[CampaignEvent], + kinds: set[str], +) -> dict[str, Any] | None: + payloads = [event.payload for event in events if event.event in kinds] + return payloads[-1] if payloads else None + + +def _payloads_of(events: list[CampaignEvent], kind: str) -> list[dict[str, Any]]: + return [ + { + "agent_id": event.agent_id, + "role": event.role, + **event.payload, + } + for event in events + if event.event == kind + ] + + +def _final_reports(events: list[CampaignEvent]) -> list[dict[str, Any]]: + reports = [] + for event in events: + payload = event.payload + if event.event == "belief_updated": + reports.append( + { + "agent_id": event.agent_id, + "summary": payload.get("summary") or payload.get("hypothesis"), + "confidence": payload.get("confidence"), + "supporting_message_ids": payload.get("supporting_message_ids", []), + "supporting_tool_result_ids": payload.get( + "supporting_tool_result_ids", + [], + ), + }, + ) + return reports[-10:] + + +def _tool_result_summaries(events: list[CampaignEvent]) -> list[dict[str, Any]]: + results = [] + for event in events: + if event.event != "tool_call_finished": + continue + payload = event.payload + results.append( + { + "timestamp": event.timestamp, + "agent_id": event.agent_id, + "tool_name": payload.get("tool_name"), + "tool_result_id": payload.get("tool_result_id"), + "status": payload.get("status"), + "content_preview": payload.get("content_preview"), + }, + ) + return results + + +def default_agent_state(spec: ChemGraphAgentSpec) -> dict[str, Any]: + return { + 'agent_name': spec.name, + 'role': spec.role, + 'status_updated_at': None, + 'round': 0, + 'finished': False, + 'last_error': None, + 'current_activity': None, + 'received_message_count': 0, + 'outbox_count': 0, + 'recent_received_messages': [], + 'recent_outbox': [], + 'tool_names': list(spec.tool_names), + 'tool_result_count': 0, + 'recent_tool_results': [], + 'belief': { + 'hypothesis': None, + 'confidence': 0.0, + 'supporting_message_ids': [], + 'supporting_tool_result_ids': [], + 'reason': None, + }, + 'belief_history': [], + } + + +def write_status_snapshot( + *, + run_dir: pathlib.Path, + campaign: ChemGraphCampaign, + agent_state: dict[str, Any], + placement: dict[str, Any], +) -> None: + state_dir = run_dir / 'agent_status' + state_dir.mkdir(parents=True, exist_ok=True) + payload = dict(agent_state) + payload['placement'] = placement + write_json_atomic(state_dir / f'{agent_state["agent_name"]}.json', payload) + + states_by_agent: dict[str, dict[str, Any]] = {} + for path in state_dir.glob('*.json'): + try: + item = json.loads(path.read_text(encoding='utf-8')) + except json.JSONDecodeError: + continue + if isinstance(item, dict) and isinstance(item.get('agent_name'), str): + states_by_agent[item['agent_name']] = item + + agents = [] + placements = {} + for spec in campaign.agents: + state = states_by_agent.get(spec.name) or default_agent_state(spec) + agents.append(state) + if isinstance(state.get('placement'), dict): + placements[spec.name] = state['placement'] + + distinct_hosts = sorted( + { + item.get('short_hostname') or item.get('hostname') + for item in placements.values() + if item.get('short_hostname') or item.get('hostname') + }, + ) + placement_doc = { + 'agents': placements, + 'distinct_hostnames': distinct_hosts, + 'distinct_hostname_count': len(distinct_hosts), + } + write_json_atomic(run_dir / 'placement.json', placement_doc) + + proof = build_communication_proof( + read_events(run_dir / "events.jsonl"), + placement_doc, + ) + status = { + 'timestamp': time.time(), + 'mode': 'mpi_daemon', + 'campaign_kind': 'chemgraph_agent_swarm', + 'campaign': campaign.run_id, + 'agents': sorted(agents, key=lambda item: item['agent_name']), + 'communication_proof': proof, + 'placement': placement_doc, + 'converged': bool(proof.get('passes', {}).get('final_report')), + } + write_json_atomic(run_dir / 'status.json', status) + append_jsonl(run_dir / 'status_history.jsonl', status) + write_json_atomic(run_dir / 'communication_proof.json', proof) + + +async def wait_for_agent_statuses_finished( + *, + run_dir: pathlib.Path, + campaign: ChemGraphCampaign, + timeout_s: float, +) -> bool: + deadline = time.monotonic() + timeout_s + state_dir = run_dir / 'agent_status' + expected = {spec.name for spec in campaign.agents} + while True: + finished = set() + for path in state_dir.glob('*.json'): + try: + item = json.loads(path.read_text(encoding='utf-8')) + except (OSError, json.JSONDecodeError): + continue + if item.get('finished') is True and item.get('agent_name') in expected: + finished.add(item['agent_name']) + if finished == expected: + return True + if time.monotonic() > deadline: + return False + await asyncio.sleep(0.5) + + +def clear_run_outputs(run_dir: pathlib.Path) -> None: + for name in ( + 'academy_registrations.json', + 'communication_proof.json', + 'launch_plan.json', + 'messages.jsonl', + 'events.jsonl', + 'placement.json', + 'status.json', + 'status_history.jsonl', + 'tool_results.jsonl', + ): + path = run_dir / name + if path.exists(): + path.unlink() + for dirname in ('agent_status', 'artifacts', 'shared'): + path = run_dir / dirname + if path.exists(): + shutil.rmtree(path) + + +def initialize_run_files( + *, + run_dir: pathlib.Path, + campaign: ChemGraphCampaign, + config: ChemGraphDaemonConfig, + llm_settings: LLMSettings, +) -> None: + run_dir.mkdir(parents=True, exist_ok=True) + clear_run_outputs(run_dir) + write_json( + run_dir / 'campaign_private.json', + { + 'run_id': campaign.run_id, + 'user_task': campaign.user_task, + 'initial_agent': campaign.initial_agent, + 'prompt_profile': str(campaign.prompt_profile), + 'resources': { + name: spec.model_dump(exclude_none=True) + for name, spec in campaign.resources.items() + }, + 'agents': [ + { + 'name': spec.name, + 'role': spec.role, + 'mission': spec.mission, + 'allowed_peers': list(spec.allowed_peers), + 'tool_names': list(spec.tool_names), + 'resources': list(spec.resources), + } + for spec in campaign.agents + ], + }, + ) + write_json( + run_dir / 'manifest.json', + { + 'run_dir': str(run_dir), + 'run_token': config.run_token, + 'mode': 'chemgraph_mpi_daemon', + 'agent_runtime': 'academy_runtime', + 'agent_count': config.agent_count, + 'max_decisions_per_agent': config.max_decisions, + 'campaign_config': ( + str(config.campaign_config) + if config.campaign_config is not None + else None + ), + 'prompt_profile': str(campaign.prompt_profile), + 'chemgraph_repo_root': str(config.chemgraph_repo_root), + 'communication_transport': 'academy_redis_actions', + 'redis_host': config.redis_host, + 'redis_port': config.redis_port, + 'redis_namespace': config.redis_namespace, + 'llm_model': llm_settings.model, + 'llm_base_url': llm_settings.base_url, + 'llm_provider': llm_settings.provider, + 'llm_user': llm_settings.user, + }, + ) + write_json( + run_dir / 'launch_plan.json', + { + 'agent_class': 'ChemGraphLogicalAgent', + 'exchange': { + 'backend': 'academy_redis', + 'host': config.redis_host, + 'port': config.redis_port, + }, + 'placement': { + 'launcher': 'mpiexec', + 'agent_count': config.agent_count, + }, + 'agents': [ + { + 'name': spec.name, + 'role': spec.role, + 'agent_class': 'ChemGraphLogicalAgent', + 'allowed_peers': list(spec.allowed_peers), + 'tool_names': list(spec.tool_names), + } + for spec in campaign.agents + ], + }, + ) + append_system_trace( + run_dir, + 'campaign_started', + { + 'mode': 'chemgraph_mpi_daemon', + 'agent_count': config.agent_count, + 'campaign': campaign.run_id, + }, + ) + append_system_trace( + run_dir, + 'campaign_planned', + { + 'agents': [spec.name for spec in campaign.agents], + 'roles': {spec.name: spec.role for spec in campaign.agents}, + 'tool_names': { + spec.name: list(spec.tool_names) + for spec in campaign.agents + }, + }, + ) diff --git a/src/chemgraph/academy/rate_limiter.py b/src/chemgraph/academy/rate_limiter.py deleted file mode 100644 index 9c521f55..00000000 --- a/src/chemgraph/academy/rate_limiter.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Token-bucket rate limiter for LLM API calls. - -Academy is LLM-agnostic, so rate limiting must be handled at the -ChemGraph layer. This module provides a shared :class:`RateLimiter` -that agents ``await`` before each LLM call to stay within per-provider -API quotas. -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from dataclasses import dataclass, field - -logger = logging.getLogger(__name__) - - -@dataclass -class _ProviderBucket: - """Token bucket state for a single LLM provider.""" - - rpm: float - tokens: float = 0.0 - last_refill: float = field(default_factory=time.monotonic) - lock: asyncio.Lock = field(default_factory=asyncio.Lock) - - def __post_init__(self) -> None: - # Start with a full bucket. - self.tokens = self.rpm - - -class RateLimiter: - """Async token-bucket rate limiter keyed by LLM provider. - - Parameters - ---------- - default_rpm : int - Default requests-per-minute for providers not explicitly - configured (default ``60``). - provider_rpm : dict[str, int] or None - Per-provider overrides. Keys are provider prefixes or model - names (e.g. ``"openai"``, ``"anthropic"``, ``"gpt-4o"``). - - Usage - ----- - :: - - limiter = RateLimiter(default_rpm=60, provider_rpm={"openai": 500}) - await limiter.acquire("gpt-4o") # blocks if bucket empty - """ - - # Map model-name prefixes to canonical provider keys so that - # ``acquire("gpt-4o")`` matches a rule set for ``"openai"``. - _PREFIX_MAP: dict[str, str] = { - "gpt-": "openai", - "o1": "openai", - "o3": "openai", - "o4": "openai", - "argo:": "argo", - "claude-": "anthropic", - "gemini-": "google", - "groq:": "groq", - "llama": "alcf", - } - - def __init__( - self, - default_rpm: int = 60, - provider_rpm: dict[str, int] | None = None, - ) -> None: - self._default_rpm = default_rpm - self._provider_rpm: dict[str, int] = provider_rpm or {} - self._buckets: dict[str, _ProviderBucket] = {} - - def _resolve_provider(self, model_name: str) -> str: - """Map a model name to a canonical provider key.""" - # Direct match first. - if model_name in self._provider_rpm: - return model_name - - # Prefix match. - lower = model_name.lower() - for prefix, provider in self._PREFIX_MAP.items(): - if lower.startswith(prefix): - return provider - - return model_name - - def _get_bucket(self, provider: str) -> _ProviderBucket: - """Get or create the bucket for *provider*.""" - if provider not in self._buckets: - rpm = self._provider_rpm.get(provider, self._default_rpm) - self._buckets[provider] = _ProviderBucket(rpm=rpm) - return self._buckets[provider] - - async def acquire(self, model_name: str) -> None: - """Wait until a request token is available for *model_name*. - - Refills the token bucket based on elapsed time, then consumes - one token. If the bucket is empty, sleeps until a token - becomes available. - """ - provider = self._resolve_provider(model_name) - bucket = self._get_bucket(provider) - - async with bucket.lock: - now = time.monotonic() - elapsed = now - bucket.last_refill - # Refill at rpm / 60 tokens per second. - refill = elapsed * (bucket.rpm / 60.0) - bucket.tokens = min(bucket.rpm, bucket.tokens + refill) - bucket.last_refill = now - - if bucket.tokens >= 1.0: - bucket.tokens -= 1.0 - return - - # Need to wait for a token. - deficit = 1.0 - bucket.tokens - wait_seconds = deficit / (bucket.rpm / 60.0) - logger.debug( - "Rate limit: waiting %.1fs for provider %s (rpm=%d)", - wait_seconds, - provider, - bucket.rpm, - ) - - # Sleep outside the lock so other providers aren't blocked. - await asyncio.sleep(wait_seconds) - - # Consume after waking. - async with bucket.lock: - bucket.tokens = 0.0 - bucket.last_refill = time.monotonic() diff --git a/src/chemgraph/academy/runtime/__init__.py b/src/chemgraph/academy/runtime/__init__.py new file mode 100644 index 00000000..ccc9bab8 --- /dev/null +++ b/src/chemgraph/academy/runtime/__init__.py @@ -0,0 +1 @@ +"""Runtime launch and MPI support for ChemGraph Academy campaigns.""" diff --git a/src/chemgraph/academy/runtime/compute_launcher.py b/src/chemgraph/academy/runtime/compute_launcher.py new file mode 100644 index 00000000..959ccb46 --- /dev/null +++ b/src/chemgraph/academy/runtime/compute_launcher.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import argparse +import dataclasses +import json +import os +import shutil +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Any + +from chemgraph.academy.examples import campaign_launch_defaults +from chemgraph.academy.examples import resolve_builtin_campaign +from chemgraph.academy.examples import resolve_builtin_lm_config_template +from chemgraph.academy.runtime.profiles import list_builtin_system_profiles +from chemgraph.academy.runtime.profiles import load_system_profile +from chemgraph.academy.runtime.profiles.system import SystemProfile + + +OPERATOR_METADATA_FILE = "operator_metadata.json" + + +@dataclasses.dataclass(frozen=True) +class AllocationPlan: + """Resolved parameters needed to launch one MPI-backed campaign.""" + + run_dir: Path + run_token: str + agent_count: int + agents_per_node: int + campaign_config: Path + lm_config: Path + max_decisions: int + poll_timeout_s: float + idle_timeout_s: float + startup_timeout_s: float + completion_timeout_s: float + status_interval_s: float + redis_host: str + redis_port: int + redis_bind: str + redis_protected_mode: str + redis_namespace: str + start_redis: bool + mpiexec: str + chemgraph_repo_root: Path + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run a built-in ChemGraph Academy campaign inside the current " + "HPC compute allocation." + ), + ) + parser.add_argument( + "--system", + required=True, + help=( + "Built-in system profile or profile JSON path. Built-ins: " + + ", ".join(list_builtin_system_profiles()) + ), + ) + parser.add_argument("--run-id", required=True) + parser.add_argument("--campaign", required=True) + parser.add_argument("--run-dir") + parser.add_argument("--lm-base-url") + parser.add_argument("--relay-host") + parser.add_argument("--lm-model") + parser.add_argument("--lm-user") + parser.add_argument("--max-tokens", type=int) + parser.add_argument("--agent-count", type=int) + parser.add_argument("--agents-per-node", type=int) + parser.add_argument("--max-decisions", type=int) + parser.add_argument("--redis-port", type=int) + parser.add_argument("--no-start-redis", action="store_true") + return parser.parse_args(argv) + + +def _prepend_path(name: str, entries: list[str]) -> None: + existing = os.environ.get(name, "") + values = [entry for entry in entries if entry] + if existing: + values.append(existing) + os.environ[name] = os.pathsep.join(values) + + +def _prepare_environment(profile: SystemProfile) -> None: + for name in profile.unset_env: + os.environ.pop(name, None) + _prepend_path("PATH", profile.path_entries) + _prepend_path("PYTHONPATH", profile.pythonpath_entries) + for name, value in profile.env.items(): + os.environ.setdefault(name, value) + os.environ["no_proxy"] = profile.no_proxy + os.environ["NO_PROXY"] = profile.no_proxy + + +def _load_operator_metadata(run_dir: Path) -> dict[str, Any]: + path = run_dir / OPERATOR_METADATA_FILE + if not path.exists(): + return {} + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise RuntimeError(f"{path} must contain a JSON object") + return data + + +def _relay_host_from_profile(profile: SystemProfile) -> str: + path = Path(profile.relay_host_file) + if not path.exists(): + raise RuntimeError( + "Could not determine UAN relay host. Start the Mac operator " + f"console first, or pass --lm-base-url. Missing: {path}", + ) + host = path.read_text(encoding="utf-8").strip() + if not host: + raise RuntimeError(f"Relay host file is empty: {path}") + return host + + +def _resolve_lm_base_url( + *, + args: argparse.Namespace, + profile: SystemProfile, + metadata: dict[str, Any], +) -> str: + if args.lm_base_url: + return args.lm_base_url + value = metadata.get("lm_base_url") + if isinstance(value, str) and value.strip(): + return value.strip() + relay_host = args.relay_host or metadata.get("relay_host") + if not isinstance(relay_host, str) or not relay_host.strip(): + relay_host = _relay_host_from_profile(profile) + return f"http://{relay_host.strip()}:{profile.relay_port}/argoapi/v1" + + +def _write_lm_config( + *, + run_dir: Path, + template_name: str, + base_url: str, + lm_model: str | None, + lm_user: str | None, + max_tokens: int | None, +) -> Path: + template_path = resolve_builtin_lm_config_template(template_name) + data = json.loads(template_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise RuntimeError(f"LM template must contain a JSON object: {template_path}") + data["base_url"] = base_url + if lm_model: + data["model"] = lm_model + if lm_user: + data["user"] = lm_user + if max_tokens is not None: + data["max_tokens"] = max_tokens + + path = run_dir / "lm_config.json" + path.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8") + return path + + +def _write_compute_launch_metadata( + *, + run_dir: Path, + args: argparse.Namespace, + profile: SystemProfile, + lm_config: Path, + lm_base_url: str, + agent_count: int, + agents_per_node: int, + max_decisions: int, + redis_port: int, +) -> None: + payload = { + "system": profile.name, + "run_id": args.run_id, + "campaign": args.campaign, + "run_dir": str(run_dir), + "lm_base_url": lm_base_url, + "lm_config": str(lm_config), + "agent_count": agent_count, + "agents_per_node": agents_per_node, + "max_decisions": max_decisions, + "redis_host": socket.getfqdn(), + "redis_port": redis_port, + "repo_root": profile.repo_root, + } + (run_dir / "compute_launch.json").write_text( + json.dumps(payload, indent=2) + "\n", + encoding="utf-8", + ) + + +def _export_workflow_lm_environment(lm_config: Path) -> None: + data = json.loads(lm_config.read_text(encoding="utf-8")) + values = { + "CHEMGRAPH_WORKFLOW_BASE_URL": data.get("base_url"), + "CHEMGRAPH_WORKFLOW_MODEL": data.get("model"), + "CHEMGRAPH_WORKFLOW_API_KEY": data.get("api_key"), + "CHEMGRAPH_WORKFLOW_ARGO_USER": data.get("user"), + "ARGO_USER": data.get("user"), + } + for name, value in values.items(): + if isinstance(value, str) and value: + os.environ.setdefault(name, value) + + +def _run_token() -> str: + return f"{int(time.time())}-{os.getpid()}" + + +def prepare_compute_launch(args: argparse.Namespace) -> AllocationPlan: + """Resolve a system profile and operator metadata into an allocation plan.""" + profile = load_system_profile(args.system) + _prepare_environment(profile) + + defaults = campaign_launch_defaults(args.campaign) + run_dir = Path(args.run_dir or Path(profile.run_root) / args.run_id).resolve() + run_dir.mkdir(parents=True, exist_ok=True) + metadata = _load_operator_metadata(run_dir) + metadata_campaign = metadata.get("campaign") + if metadata_campaign and metadata_campaign != args.campaign: + raise RuntimeError( + f"Run metadata campaign {metadata_campaign!r} does not match " + f"--campaign {args.campaign!r}", + ) + + lm_base_url = _resolve_lm_base_url( + args=args, + profile=profile, + metadata=metadata, + ) + lm_config = _write_lm_config( + run_dir=run_dir, + template_name=defaults.lm_config_template, + base_url=lm_base_url, + lm_model=args.lm_model, + lm_user=args.lm_user, + max_tokens=args.max_tokens, + ) + _export_workflow_lm_environment(lm_config) + agent_count = args.agent_count or defaults.agent_count + agents_per_node = args.agents_per_node or defaults.agents_per_node + max_decisions = args.max_decisions or defaults.max_decisions + redis_port = args.redis_port or profile.redis_port + + _write_compute_launch_metadata( + run_dir=run_dir, + args=args, + profile=profile, + lm_config=lm_config, + lm_base_url=lm_base_url, + agent_count=agent_count, + agents_per_node=agents_per_node, + max_decisions=max_decisions, + redis_port=redis_port, + ) + + campaign_config = resolve_builtin_campaign(args.campaign) + if not campaign_config.exists(): + campaign_config = Path(args.campaign).resolve() + + return AllocationPlan( + run_dir=run_dir, + run_token=_run_token(), + agent_count=agent_count, + agents_per_node=agents_per_node, + campaign_config=campaign_config, + lm_config=lm_config, + max_decisions=max_decisions, + poll_timeout_s=2.0, + idle_timeout_s=600.0, + startup_timeout_s=120.0, + completion_timeout_s=60.0, + status_interval_s=5.0, + redis_host=socket.getfqdn(), + redis_port=redis_port, + redis_bind=profile.redis_bind, + redis_protected_mode=profile.redis_protected_mode, + redis_namespace=f"academy-chemgraph-swarm:{args.run_id}", + start_redis=not args.no_start_redis, + mpiexec=profile.mpiexec, + chemgraph_repo_root=Path(profile.repo_root).resolve(), + ) + + +def wait_redis(host: str, port: int, run_dir: Path) -> None: + import redis + + deadline = time.time() + 30 + while True: + try: + redis.Redis(host=host, port=port).ping() + return + except Exception: + if time.time() > deadline: + log = run_dir / "redis.log" + if log.exists(): + print(log.read_text(errors="replace")[-4000:], file=sys.stderr) + raise + time.sleep(1) + + +def run_allocation(plan: AllocationPlan) -> int: + """Start Redis if requested and run per-rank daemons under mpiexec.""" + plan.run_dir.mkdir(parents=True, exist_ok=True) + redis_proc: subprocess.Popen[bytes] | None = None + if plan.start_redis: + redis_server = shutil.which("redis-server") + if redis_server is None: + raise RuntimeError("redis-server is required unless --no-start-redis is set") + redis_log = (plan.run_dir / "redis.log").open("ab") + redis_proc = subprocess.Popen( + [ + redis_server, + "--bind", + plan.redis_bind, + "--port", + str(plan.redis_port), + "--protected-mode", + plan.redis_protected_mode, + "--save", + "", + "--appendonly", + "no", + "--daemonize", + "no", + ], + stdout=redis_log, + stderr=subprocess.STDOUT, + ) + (plan.run_dir / "redis.pid").write_text( + f"{redis_proc.pid}\n", + encoding="utf-8", + ) + try: + wait_redis(plan.redis_host, plan.redis_port, plan.run_dir) + daemon_args = [ + "--run-dir", str(plan.run_dir), + "--run-token", plan.run_token, + "--agent-count", str(plan.agent_count), + "--campaign-config", str(plan.campaign_config), + "--lm-config", str(plan.lm_config), + "--max-decisions", str(plan.max_decisions), + "--poll-timeout-s", str(plan.poll_timeout_s), + "--idle-timeout-s", str(plan.idle_timeout_s), + "--startup-timeout-s", str(plan.startup_timeout_s), + "--completion-timeout-s", str(plan.completion_timeout_s), + "--status-interval-s", str(plan.status_interval_s), + "--redis-host", plan.redis_host, + "--redis-port", str(plan.redis_port), + "--redis-namespace", plan.redis_namespace, + "--chemgraph-repo-root", str(plan.chemgraph_repo_root), + ] + cmd = [ + plan.mpiexec, + "-n", str(plan.agent_count), + "--ppn", str(plan.agents_per_node), + sys.executable, "-m", "chemgraph.cli.main", "academy", "mpi-daemon", "--", + *daemon_args, + ] + (plan.run_dir / "launch_command.txt").write_text( + " ".join(cmd) + "\n", + encoding="utf-8", + ) + return subprocess.call(cmd) + finally: + if redis_proc is not None: + redis_proc.terminate() + try: + redis_proc.wait(timeout=10) + except subprocess.TimeoutExpired: + redis_proc.kill() + redis_proc.wait() + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + plan = prepare_compute_launch(args) + print(f"ChemGraph Academy run: {args.run_id}") + print(f" system: {load_system_profile(args.system).name}") + print(f" campaign: {args.campaign}") + print(f" run dir: {plan.run_dir}") + print(f" LM config: {plan.lm_config}") + print(f" agents: {plan.agent_count}, agents_per_node: {plan.agents_per_node}") + return run_allocation(plan) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py new file mode 100644 index 00000000..5c78697c --- /dev/null +++ b/src/chemgraph/academy/runtime/daemon.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import argparse +import asyncio +import pathlib + +from academy.exchange.redis import RedisExchangeFactory +from academy.handle import Handle +from academy.runtime import Runtime +from academy.runtime import RuntimeConfig + +from chemgraph.academy.core.peer_protocol import build_message +from chemgraph.academy.runtime.registration import load_academy_registrations +from chemgraph.academy.runtime.registration import wait_academy_registrations +from chemgraph.academy.runtime.registration import write_academy_registrations +from chemgraph.academy.observability.run_artifacts import initialize_run_files +from chemgraph.academy.observability.run_artifacts import ( + wait_for_agent_statuses_finished, +) +from chemgraph.academy.observability.run_artifacts import write_status_snapshot +from chemgraph.academy.core.campaign import campaign_bootstrap_text +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import ExecutionSpec +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import namespace_for_run +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.campaign import selected_agent +from chemgraph.academy.core.campaign import validate_campaign +from chemgraph.academy.examples import resolve_builtin_campaign +from chemgraph.academy.runtime.mpi import append_system_trace +from chemgraph.academy.runtime.mpi import local_rank_from_env +from chemgraph.academy.runtime.mpi import placement_payload +from chemgraph.academy.runtime.mpi import rank_from_env +from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from chemgraph.academy.core.lm import load_lm_config +from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.core.fastmcp import ( + build_campaign_fastmcp_tool_invoker, +) + + +async def run_daemon(config: ChemGraphDaemonConfig) -> int: + config.run_dir.mkdir(parents=True, exist_ok=True) + llm_settings = load_lm_config(config.lm_config) + campaign = resolve_campaign_resources( + load_campaign(config.campaign_config), + config.run_dir, + ) + prompt_profile = load_prompt_profile(campaign.prompt_profile) + validate_campaign(campaign, config.agent_count) + agent_spec = selected_agent(campaign, config.rank) + placement = placement_payload(config, agent_spec.name) + + academy_factory = RedisExchangeFactory( + hostname=config.redis_host, + port=config.redis_port, + ) + if config.rank == 0: + initialize_run_files( + run_dir=config.run_dir, + campaign=campaign, + config=config, + llm_settings=llm_settings, + ) + registrar = await academy_factory.create_user_client( + name=f'{config.run_dir.name}-registrar', + start_listener=False, + ) + try: + registered = await registrar.register_agents( + [ + (ChemGraphLogicalAgent, spec.name) + for spec in campaign.agents + ], + ) + finally: + await registrar.close() + registrations = dict( + zip( + (spec.name for spec in campaign.agents), + registered, + strict=True, + ), + ) + write_academy_registrations( + run_dir=config.run_dir, + run_token=config.run_token, + registrations=registrations, + ) + else: + registrations = await wait_academy_registrations( + config.run_dir, + run_token=config.run_token, + timeout_s=config.startup_timeout_s, + ) + + if config.rank == 0: + registrations = load_academy_registrations( + config.run_dir, + run_token=config.run_token, + ) + registration = registrations[agent_spec.name] + peer_agent_ids = { + peer: registrations[peer].agent_id + for peer in agent_spec.allowed_peers + if peer in registrations + } + + tool_invoker = await build_campaign_fastmcp_tool_invoker( + specs=list(agent_spec.tools), + execution=ExecutionSpec(backend='local', system='local'), + run_dir=config.run_dir, + agent_name=agent_spec.name, + ) + agent = ChemGraphLogicalAgent( + agent_spec, + campaign=campaign, + llm_settings=llm_settings, + prompt_profile=prompt_profile, + run_dir=config.run_dir, + max_decisions=config.max_decisions, + tool_invoker=tool_invoker, + peer_agent_ids=peer_agent_ids, + placement=placement, + poll_timeout_s=config.poll_timeout_s, + idle_timeout_s=config.idle_timeout_s, + status_interval_s=config.status_interval_s, + ) + runtime_config = RuntimeConfig( + terminate_on_success=False, + terminate_on_error=False, + ) + runtime = Runtime( + agent, + exchange_factory=academy_factory, + registration=registration, + config=runtime_config, + ) + async with runtime: + await agent.write_runtime_status() + + if config.rank == 0: + bootstrap = build_message( + sender='campaign', + recipient=campaign.initial_agent, + content=campaign_bootstrap_text(campaign), + kind='message', + tldr='Campaign bootstrap', + reason='Initial campaign task dispatch.', + confidence=1.0, + ) + initial_handle: Handle[Any] = Handle( + registrations[campaign.initial_agent].agent_id, + ) + await initial_handle.action( + 'receive_message', + bootstrap, + ) + append_system_trace( + config.run_dir, + 'bootstrap_message_dispatched', + { + 'agent': campaign.initial_agent, + 'message_id': bootstrap['message_id'], + 'via': 'academy_action', + }, + ) + + await runtime.wait_shutdown() + + if config.rank == 0: + all_done = await wait_for_agent_statuses_finished( + run_dir=config.run_dir, + campaign=campaign, + timeout_s=config.completion_timeout_s, + ) + append_system_trace( + config.run_dir, + 'campaign_finished', + {'all_agents_done': all_done}, + ) + write_status_snapshot( + run_dir=config.run_dir, + campaign=campaign, + agent_state=await agent.report_state(), + placement=placement, + ) + return 0 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description='Run one persistent ChemGraph-style agent daemon.', + ) + parser.add_argument('--run-dir', required=True) + parser.add_argument('--run-token', required=True) + parser.add_argument('--agent-count', type=int, default=5) + parser.add_argument('--campaign-config', required=True) + parser.add_argument('--lm-config', required=True) + parser.add_argument('--max-decisions', type=int, default=6) + parser.add_argument('--poll-timeout-s', type=float, default=2) + parser.add_argument('--idle-timeout-s', type=float, default=600) + parser.add_argument('--startup-timeout-s', type=float, default=120) + parser.add_argument('--completion-timeout-s', type=float, default=60) + parser.add_argument('--status-interval-s', type=float, default=5) + parser.add_argument('--redis-host', default='127.0.0.1') + parser.add_argument('--redis-port', type=int, required=True) + parser.add_argument('--redis-namespace') + parser.add_argument('--rank', type=int) + parser.add_argument('--local-rank', type=int) + parser.add_argument('--no-clean-redis', action='store_true') + parser.add_argument('--chemgraph-repo-root') + return parser.parse_args() + + +def config_from_args(args: argparse.Namespace) -> ChemGraphDaemonConfig: + run_dir = pathlib.Path(args.run_dir).resolve() + resolved_campaign = resolve_builtin_campaign(args.campaign_config) + campaign_config = ( + resolved_campaign.resolve() + if resolved_campaign.exists() + else pathlib.Path(args.campaign_config).resolve() + ) + return ChemGraphDaemonConfig( + run_dir=run_dir, + run_token=args.run_token, + agent_count=args.agent_count, + campaign_config=campaign_config, + lm_config=pathlib.Path(args.lm_config).resolve(), + max_decisions=args.max_decisions, + poll_timeout_s=args.poll_timeout_s, + idle_timeout_s=args.idle_timeout_s, + startup_timeout_s=args.startup_timeout_s, + completion_timeout_s=args.completion_timeout_s, + status_interval_s=args.status_interval_s, + redis_host=args.redis_host, + redis_port=args.redis_port, + redis_namespace=args.redis_namespace or namespace_for_run(run_dir), + clean_redis=not args.no_clean_redis, + rank=args.rank if args.rank is not None else rank_from_env(), + local_rank=( + args.local_rank + if args.local_rank is not None + else local_rank_from_env() + ), + chemgraph_repo_root=( + pathlib.Path(args.chemgraph_repo_root).resolve() + if args.chemgraph_repo_root + else pathlib.Path.cwd().resolve() + ), + ) + + +def main() -> int: + return asyncio.run(run_daemon(config_from_args(parse_args()))) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/mpi.py b/src/chemgraph/academy/runtime/mpi.py new file mode 100644 index 00000000..dfe88cd8 --- /dev/null +++ b/src/chemgraph/academy/runtime/mpi.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import os +import pathlib +import socket +import sys +from collections.abc import Mapping +from typing import Any + +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.run_files import write_json_atomic + +MPI_RANK_ENV = ( + 'PMI_RANK', + 'PMIX_RANK', + 'OMPI_COMM_WORLD_RANK', + 'PALS_RANK', + 'SLURM_PROCID', +) + +MPI_LOCAL_RANK_ENV = ( + 'MPI_LOCALRANKID', + 'PMI_LOCAL_RANK', + 'PMIX_LOCAL_RANK', + 'OMPI_COMM_WORLD_LOCAL_RANK', + 'PALS_LOCAL_RANK', + 'SLURM_LOCALID', +) + + +def rank_from_env(env: Mapping[str, str] | None = None) -> int: + env = os.environ if env is None else env + for name in MPI_RANK_ENV: + value = env.get(name) + if value is not None: + return int(value) + raise RuntimeError( + 'Could not determine MPI rank from environment. Expected one of ' + f'{", ".join(MPI_RANK_ENV)}. Run this through mpiexec.', + ) + + +def local_rank_from_env(env: Mapping[str, str] | None = None) -> int | None: + env = os.environ if env is None else env + for name in MPI_LOCAL_RANK_ENV: + value = env.get(name) + if value is not None: + return int(value) + return None + + +def append_system_trace( + run_dir: pathlib.Path, + event: str, + payload: dict[str, Any], +) -> None: + EventLog(run_dir / 'events.jsonl').emit( + event, # type: ignore[arg-type] + run_id=run_dir.name, + agent_id='system', + payload=payload, + ) + + +def placement_payload(config: Any, agent_name: str) -> dict[str, Any]: + host = socket.gethostname() + pbs_keys = ( + 'PBS_JOBID', + 'PBS_NODEFILE', + 'PBS_O_WORKDIR', + 'PBS_NCPUS', + 'PBS_NUM_NODES', + 'PBS_TASKNUM', + ) + mpi_keys = (*MPI_RANK_ENV, *MPI_LOCAL_RANK_ENV) + env = { + key: os.environ[key] + for key in (*pbs_keys, *mpi_keys) + if key in os.environ + } + nodefile = os.environ.get('PBS_NODEFILE') + nodes: list[str] = [] + if nodefile and pathlib.Path(nodefile).exists(): + nodes = [ + line.strip() + for line in pathlib.Path(nodefile).read_text().splitlines() + if line.strip() + ] + return { + 'agent_name': agent_name, + 'hostname': host, + 'short_hostname': host.split('.')[0], + 'pid': os.getpid(), + 'cwd': os.getcwd(), + 'python_executable': sys.executable, + 'rank': config.rank, + 'local_rank': config.local_rank, + 'redis_host': config.redis_host, + 'redis_port': config.redis_port, + 'redis_namespace': config.redis_namespace, + 'env': env, + 'pbs_nodefile_nodes': nodes, + } diff --git a/src/chemgraph/academy/runtime/operator_console.py b/src/chemgraph/academy/runtime/operator_console.py new file mode 100644 index 00000000..48b46c94 --- /dev/null +++ b/src/chemgraph/academy/runtime/operator_console.py @@ -0,0 +1,759 @@ +from __future__ import annotations + +import argparse +import json +import os +import shlex +import shutil +import subprocess +import sys +import threading +import time +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + +from chemgraph.academy.examples import campaign_launch_defaults +from chemgraph.academy.runtime.profiles import list_builtin_system_profiles +from chemgraph.academy.runtime.profiles import load_system_profile +from chemgraph.academy.runtime.profiles.system import SystemProfile + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Start the local operator console for a ChemGraph Academy run. " + "This prepares remote run metadata, starts the local dashboard, " + "and optionally starts the temporary Mac-to-UAN Argo relay." + ), + ) + parser.add_argument("run_id") + parser.add_argument( + "--system", + default="aurora", + help=( + "Built-in system profile or profile JSON path. Built-ins: " + + ", ".join(list_builtin_system_profiles()) + ), + ) + parser.add_argument("--campaign", default="mace-ensemble-screening-20") + parser.add_argument( + "--lm-connect", + choices=("mac-argo-relay", "direct"), + default="mac-argo-relay", + help=( + "How the compute job should reach the LM endpoint. " + "mac-argo-relay starts the current SSH reverse tunnel and UAN " + "relay. direct writes --lm-base-url to run metadata without " + "starting relay infrastructure." + ), + ) + parser.add_argument( + "--lm-base-url", + help="Required for --lm-connect direct. Overrides generated relay URL.", + ) + parser.add_argument("--operator-host", help="SSH target for the login/UAN host.") + parser.add_argument("--ssh-control-path") + parser.add_argument("--keep-ssh-master", action="store_true") + parser.add_argument("--local-argo-host", default="127.0.0.1") + parser.add_argument("--local-argo-port", type=int, default=18085) + parser.add_argument("--reverse-port", type=int, default=18185) + parser.add_argument("--relay-port", type=int) + parser.add_argument("--relay-python") + parser.add_argument("--rsync-interval-s", type=float, default=2.0) + parser.add_argument( + "--local-mirror-root", + default=str(Path.home() / "projects/chemgraph-academy/remote-runs"), + ) + parser.add_argument("--local-run-dir") + parser.add_argument("--dashboard-host", default="127.0.0.1") + parser.add_argument("--dashboard-port", type=int, default=8765) + parser.add_argument( + "--local", + action="store_true", + help="Only serve an already mirrored local run. No SSH, relay, or rsync.", + ) + parser.add_argument( + "--no-dashboard", + action="store_true", + help="Prepare operator metadata and return without serving dashboard.", + ) + parser.add_argument( + "--overwrite-run", + action="store_true", + help=( + "Delete the remote run directory and local mirror before starting. " + "This does not stop an already-running compute job." + ), + ) + return parser.parse_args() + + +def _log(message: str) -> None: + print(message, flush=True) + + +def _http_ok(url: str, *, timeout_s: float = 5.0) -> bool: + try: + with urllib.request.urlopen(url, timeout=timeout_s) as response: + return 200 <= int(response.status) < 300 + except (OSError, urllib.error.URLError, urllib.error.HTTPError): + return False + + +def _run(command: list[str], *, input_text: str | None = None) -> subprocess.CompletedProcess[str]: + return subprocess.run( + command, + input=input_text, + text=True, + check=True, + ) + + +def _ssh_options(control_path: str, *, batch_mode: bool = True) -> list[str]: + opts = [ + "-o", + f"ControlPath={control_path}", + "-o", + "ControlMaster=auto", + "-o", + "ControlPersist=yes", + "-o", + "ServerAliveInterval=30", + "-o", + "ServerAliveCountMax=4", + ] + if batch_mode: + opts[:0] = ["-o", "BatchMode=yes"] + return opts + + +def _start_ssh_master(*, host: str, control_path: str) -> bool: + Path(control_path).expanduser().parent.mkdir(parents=True, exist_ok=True) + check = subprocess.run( + ["ssh", "-o", f"ControlPath={control_path}", "-O", "check", host], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + check=False, + ) + if check.returncode == 0: + return False + + _log(f"Starting SSH ControlMaster for {host}...") + _run( + [ + "ssh", + "-M", + "-N", + "-f", + "-o", + "ControlMaster=yes", + "-o", + f"ControlPath={control_path}", + "-o", + "ControlPersist=yes", + "-o", + "ServerAliveInterval=30", + "-o", + "ServerAliveCountMax=4", + host, + ], + ) + return True + + +def _stop_ssh_master(*, host: str, control_path: str) -> None: + subprocess.run( + ["ssh", "-o", f"ControlPath={control_path}", "-O", "exit", host], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + check=False, + ) + + +def _wrapper_text(profile: SystemProfile) -> str: + path_prefix = ":".join([profile.redis_bin_dir, f"{profile.remote_root}/bin"]) + pythonpath = ":".join(profile.pythonpath_entries) + return f"""#!/bin/bash +set -euo pipefail + +log() {{ + printf '[chemgraph-academy-run] %s\\n' "$*" >&2 +}} + +export PATH="{path_prefix}:${{PATH}}" +export PYTHONPATH="{pythonpath}:${{PYTHONPATH:-}}" + +PYTHON_BIN="${{CHEMGRAPH_ACADEMY_PYTHON:-python}}" +if ! command -v "${{PYTHON_BIN}}" >/dev/null 2>&1; then + log "Python command not found: ${{PYTHON_BIN}}" + log "Load your site module and activate the ChemGraph/Academy environment first." + log "Profile Python, if you want to use it explicitly: {profile.venv_python}" + exit 1 +fi + +ACTIVE_PYTHON="$("${{PYTHON_BIN}}" -c 'import sys; print(sys.executable)')" +log "using active Python: ${{ACTIVE_PYTHON}}" +log "not loading modules or activating a venv inside this wrapper" + +if ! "${{PYTHON_BIN}}" -c 'import chemgraph.academy.runtime.compute_launcher' >/dev/null 2>&1; then + log "active Python cannot import chemgraph.academy.runtime.compute_launcher" + log "Load the proper site module and venv before running this command." + log "Profile Python, if you want to use it explicitly: {profile.venv_python}" + exit 1 +fi + +log "starting ChemGraph Academy compute launcher" +exec "${{PYTHON_BIN}}" -m chemgraph.academy.runtime.compute_launcher "$@" +""" + + +def _install_compute_wrapper( + *, + profile: SystemProfile, + host: str, + ssh_opts: list[str], +) -> str: + wrapper_bin_dir = f"{profile.remote_root}/bin" + wrapper_path = f"{wrapper_bin_dir}/chemgraph-academy-run" + _log(f"Installing compute wrapper at {wrapper_path}...") + remote_command = ( + f"mkdir -p {shlex.quote(wrapper_bin_dir)} && " + f"cat > {shlex.quote(wrapper_path)} && " + f"chmod +x {shlex.quote(wrapper_path)}" + ) + _run( + ["ssh", *ssh_opts, host, remote_command], + input_text=_wrapper_text(profile), + ) + return wrapper_path + + +def _relay_script_text() -> str: + return r""" +set -euo pipefail + +REMOTE_ROOT="$1" +RELAY_SCRIPT="$2" +RELAY_HOST_FILE="$3" +RELAY_PID_FILE="$4" +RELAY_LOG_FILE="$5" +RELAY_PORT="$6" +REVERSE_PORT="$7" +RELAY_PYTHON="$8" + +cd "${REMOTE_ROOT}" +UAN_HOST="$(hostname -f)" +printf '%s\n' "${UAN_HOST}" > "${RELAY_HOST_FILE}" + +if [ -f "${RELAY_PID_FILE}" ]; then + OLD_PID="$(cat "${RELAY_PID_FILE}" 2>/dev/null || true)" + case "${OLD_PID}" in + ''|*[!0-9]*) ;; + *) kill "${OLD_PID}" 2>/dev/null || true ;; + esac +fi + +"${RELAY_PYTHON}" "${RELAY_SCRIPT}" \ + --listen-host 0.0.0.0 \ + --listen-port "${RELAY_PORT}" \ + --target-host 127.0.0.1 \ + --target-port "${REVERSE_PORT}" \ + > "${RELAY_LOG_FILE}" 2>&1 & +RELAY_PID="$!" +printf '%s\n' "${RELAY_PID}" > "${RELAY_PID_FILE}" + +cleanup_remote() { + kill "${RELAY_PID}" 2>/dev/null || true +} +trap cleanup_remote EXIT + +deadline=$((SECONDS + 45)) +while ! curl -fsS "http://${UAN_HOST}:${RELAY_PORT}/v1/models" >/dev/null; do + if ! kill -0 "${RELAY_PID}" 2>/dev/null; then + echo "UAN relay process exited before readiness. Last relay log lines:" >&2 + tail -n 80 "${RELAY_LOG_FILE}" >&2 || true + exit 1 + fi + if [ "${SECONDS}" -gt "${deadline}" ]; then + echo "UAN relay did not become ready. Last relay log lines:" >&2 + tail -n 80 "${RELAY_LOG_FILE}" >&2 || true + exit 1 + fi + sleep 1 +done + +echo "UAN_RELAY_HOST=${UAN_HOST}" +echo "UAN relay ready at http://${UAN_HOST}:${RELAY_PORT}/argoapi/v1" + +while true; do + sleep 3600 +done +""" + + +def _start_mac_argo_relay( + *, + profile: SystemProfile, + host: str, + ssh_opts: list[str], + local_argo_host: str, + local_argo_port: int, + reverse_port: int, + relay_port: int, + relay_python: str, + local_log_path: Path, +) -> subprocess.Popen[str]: + relay_script = f"{profile.academy_repo_root}/examples/09-polaris-lm-swarm/uan_http_relay.py" + relay_pid_file = f"{profile.remote_root}/uan-relay-{relay_port}.pid" + relay_log_file = f"{profile.remote_root}/uan-relay-{relay_port}.log" + local_log_path.parent.mkdir(parents=True, exist_ok=True) + log_file = local_log_path.open("w", encoding="utf-8") + + _log(f"Starting {profile.name} UAN relay through {host}...") + command = [ + "ssh", + *ssh_opts, + "-R", + f"127.0.0.1:{reverse_port}:{local_argo_host}:{local_argo_port}", + host, + "bash", + "-s", + "--", + profile.remote_root, + relay_script, + profile.relay_host_file, + relay_pid_file, + relay_log_file, + str(relay_port), + str(reverse_port), + relay_python, + ] + process = subprocess.Popen( + command, + stdin=subprocess.PIPE, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + assert process.stdin is not None + process.stdin.write(_relay_script_text()) + process.stdin.close() + return process + + +def _remote_relay_ready( + *, + host: str, + ssh_opts: list[str], + relay_host_file: str, + relay_port: int, +) -> bool: + command = ( + f"host=$(cat {shlex.quote(relay_host_file)} 2>/dev/null || true); " + f'test -n "$host" && ' + f'curl -fsS "http://${{host}}:{relay_port}/v1/models" >/dev/null' + ) + result = subprocess.run( + ["ssh", *ssh_opts, host, command], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + check=False, + ) + return result.returncode == 0 + + +def _read_remote_file( + *, + host: str, + ssh_opts: list[str], + path: str, +) -> str: + result = subprocess.run( + ["ssh", *ssh_opts, host, "cat", path], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return result.stdout.strip() + + +def _wait_for_relay( + *, + profile: SystemProfile, + host: str, + ssh_opts: list[str], + relay_port: int, + relay_process: subprocess.Popen[str], + local_log_path: Path, +) -> str: + _log("Waiting for relay readiness...") + deadline = time.time() + 60 + while time.time() < deadline: + if _remote_relay_ready( + host=host, + ssh_opts=ssh_opts, + relay_host_file=profile.relay_host_file, + relay_port=relay_port, + ): + relay_host = _read_remote_file( + host=host, + ssh_opts=ssh_opts, + path=profile.relay_host_file, + ) + _log(f"{profile.name} relay host: {relay_host}") + return relay_host + if relay_process.poll() is not None: + detail = local_log_path.read_text(encoding="utf-8", errors="replace") + raise RuntimeError( + "Relay SSH session exited before readiness. Local relay log:\n" + + detail, + ) + time.sleep(1) + detail = local_log_path.read_text(encoding="utf-8", errors="replace") + raise RuntimeError("Relay readiness timed out. Local relay log:\n" + detail) + + +def _write_operator_metadata( + *, + profile: SystemProfile, + host: str, + ssh_opts: list[str], + run_id: str, + campaign: str, + lm_connect: str, + lm_base_url: str, + relay_host: str | None, + relay_port: int | None, +) -> None: + remote_run_dir = f"{profile.run_root}/{run_id}" + payload: dict[str, Any] = { + "created_at": time.time(), + "created_by": "chemgraph-academy-console", + "run_id": run_id, + "system": profile.name, + "campaign": campaign, + "remote_run_dir": remote_run_dir, + "operator_host": host, + "lm_connect": lm_connect, + "lm_base_url": lm_base_url, + "workspace_root": profile.remote_root, + "academy_repo_root": profile.academy_repo_root, + "chemgraph_repo_root": profile.repo_root, + } + if relay_host: + payload["relay_host"] = relay_host + if relay_port is not None: + payload["relay_port"] = relay_port + + metadata = json.dumps(payload, indent=2) + "\n" + remote_path = f"{remote_run_dir}/operator_metadata.json" + remote_command = ( + f"mkdir -p {shlex.quote(remote_run_dir)} && " + f"cat > {shlex.quote(remote_path)}" + ) + _log(f"Writing run metadata: {host}:{remote_run_dir}/operator_metadata.json") + _run( + ["ssh", *ssh_opts, host, remote_command], + input_text=metadata, + ) + + +def _run_id_allows_delete(run_id: str) -> bool: + return bool(run_id) and "/" not in run_id and run_id not in {".", ".."} + + +def _delete_existing_run( + *, + profile: SystemProfile, + host: str, + ssh_opts: list[str], + run_id: str, + local_run_dir: Path, +) -> None: + if not _run_id_allows_delete(run_id): + raise RuntimeError(f"Refusing to overwrite unsafe run id: {run_id!r}") + + remote_run_dir = f"{profile.run_root}/{run_id}" + _log("Deleting existing run artifacts because --overwrite-run was set:") + _log(f" remote: {host}:{remote_run_dir}") + _log(f" local: {local_run_dir}") + + remote_command = ( + "set -euo pipefail; " + f"run_root={shlex.quote(profile.run_root)}; " + f"run_id={shlex.quote(run_id)}; " + 'case "$run_id" in ""|.|..|*/*) echo "unsafe run id" >&2; exit 2;; esac; ' + 'run_dir="$run_root/$run_id"; ' + 'trash_root="$run_root/.deleted-runs"; ' + 'if [ -e "$run_dir" ]; then ' + 'mkdir -p "$trash_root"; ' + 'trash_dir="$trash_root/${run_id}.$(date +%Y%m%d%H%M%S).$$"; ' + 'mv -- "$run_dir" "$trash_dir"; ' + 'for delay in 0 1 2 5 10; do ' + 'sleep "$delay"; ' + 'if rm -rf -- "$trash_dir" 2>/dev/null; then break; fi; ' + 'done; ' + 'fi; ' + 'mkdir -p "$run_dir"' + ) + _run(["ssh", *ssh_opts, host, remote_command]) + if local_run_dir.exists(): + shutil.rmtree(local_run_dir) + + +def _start_rsync_loop( + *, + host: str, + control_path: str, + remote_run_dir: str, + local_run_dir: Path, + interval_s: float, + stop_event: threading.Event, +) -> threading.Thread: + local_run_dir.mkdir(parents=True, exist_ok=True) + log_path = local_run_dir / "rsync.log" + + def loop() -> None: + ssh_command = ( + "ssh " + "-o BatchMode=yes " + "-o ControlMaster=auto " + f"-o ControlPath={shlex.quote(control_path)} " + "-o ControlPersist=yes" + ) + while not stop_event.is_set(): + with log_path.open("a", encoding="utf-8") as log: + subprocess.run( + [ + "rsync", + "-az", + "--delete", + "-e", + ssh_command, + f"{host}:{remote_run_dir}/", + f"{local_run_dir}/", + ], + stdout=log, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + stop_event.wait(interval_s) + + thread = threading.Thread(target=loop, name="chemgraph-academy-rsync", daemon=True) + thread.start() + return thread + + +def _run_dashboard(*, local_run_dir: Path, host: str, port: int) -> int: + from chemgraph.academy import dashboard + + old_argv = sys.argv + try: + sys.argv = [ + "chemgraph-academy-console dashboard", + "--run-dir", + str(local_run_dir), + "--host", + host, + "--port", + str(port), + ] + return dashboard.main() + finally: + sys.argv = old_argv + + +def _print_compute_command( + *, + profile: SystemProfile, + wrapper_path: str, + run_id: str, + campaign: str, +) -> None: + _log("") + _log("Operator console is ready.") + _log("") + _log(f"On the {profile.name} compute node, use:") + if profile.name == "polaris": + _log(" module use /soft/modulefiles") + _log(" module load conda") + _log(" conda activate base") + _log(f" source {profile.remote_root}/venvs/academy-swarm/bin/activate") + else: + _log(" module load frameworks") + _log(f" source {profile.remote_root}/venvs/academy-swarm/bin/activate") + _log(f" export PATH={profile.remote_root}/bin:$PATH") + _log(" chemgraph-academy-run \\") + _log(f" --system {profile.name} \\") + _log(f" --run-id {run_id} \\") + _log(f" --campaign {campaign}") + _log("") + _log("If PATH is not configured, use:") + _log(f" {wrapper_path} \\") + _log(f" --system {profile.name} \\") + _log(f" --run-id {run_id} \\") + _log(f" --campaign {campaign}") + + +def _validate_campaign_name(campaign: str) -> None: + campaign_launch_defaults(campaign) + + +def main() -> int: + args = parse_args() + profile = load_system_profile(args.system) + _validate_campaign_name(args.campaign) + + local_run_dir = Path( + args.local_run_dir or Path(args.local_mirror_root) / args.run_id, + ).expanduser() + local_run_dir.mkdir(parents=True, exist_ok=True) + + if args.local and args.overwrite_run: + raise RuntimeError("--overwrite-run cannot be used with --local") + + if args.local: + if args.no_dashboard: + _log(f"Local run directory: {local_run_dir}") + return 0 + return _run_dashboard( + local_run_dir=local_run_dir, + host=args.dashboard_host, + port=args.dashboard_port, + ) + + operator_host = args.operator_host or profile.operator_host + control_path = ( + args.ssh_control_path + or str(Path.home() / f".ssh/{profile.name}-dashboard-%r@%h:%p") + ) + relay_port = args.relay_port or profile.relay_port + relay_python = args.relay_python or profile.venv_python + local_relay_log = Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log") + remote_run_dir = f"{profile.run_root}/{args.run_id}" + + relay_process: subprocess.Popen[str] | None = None + stop_rsync = threading.Event() + started_ssh_master = False + + try: + if args.lm_connect == "mac-argo-relay": + health_url = f"http://{args.local_argo_host}:{args.local_argo_port}/v1/models" + if not _http_ok(health_url): + raise RuntimeError( + "Local argo-shim is not reachable: " + f"{health_url}\n" + "Start it before using --lm-connect mac-argo-relay.", + ) + elif not args.lm_base_url: + raise RuntimeError("--lm-connect direct requires --lm-base-url") + + started_ssh_master = _start_ssh_master( + host=operator_host, + control_path=control_path, + ) + ssh_opts = _ssh_options(control_path) + if args.overwrite_run: + _delete_existing_run( + profile=profile, + host=operator_host, + ssh_opts=ssh_opts, + run_id=args.run_id, + local_run_dir=local_run_dir, + ) + wrapper_path = _install_compute_wrapper( + profile=profile, + host=operator_host, + ssh_opts=ssh_opts, + ) + + relay_host: str | None = None + if args.lm_connect == "mac-argo-relay": + relay_process = _start_mac_argo_relay( + profile=profile, + host=operator_host, + ssh_opts=ssh_opts, + local_argo_host=args.local_argo_host, + local_argo_port=args.local_argo_port, + reverse_port=args.reverse_port, + relay_port=relay_port, + relay_python=relay_python, + local_log_path=local_relay_log, + ) + relay_host = _wait_for_relay( + profile=profile, + host=operator_host, + ssh_opts=ssh_opts, + relay_port=relay_port, + relay_process=relay_process, + local_log_path=local_relay_log, + ) + lm_base_url = f"http://{relay_host}:{relay_port}/argoapi/v1" + else: + lm_base_url = str(args.lm_base_url) + + _log(f"Compute-node LM URL: {lm_base_url}") + _write_operator_metadata( + profile=profile, + host=operator_host, + ssh_opts=ssh_opts, + run_id=args.run_id, + campaign=args.campaign, + lm_connect=args.lm_connect, + lm_base_url=lm_base_url, + relay_host=relay_host, + relay_port=relay_port if relay_host else None, + ) + + _log("Starting rsync mirror:") + _log(f" {operator_host}:{remote_run_dir}/") + _log(f" {local_run_dir}/") + _start_rsync_loop( + host=operator_host, + control_path=control_path, + remote_run_dir=remote_run_dir, + local_run_dir=local_run_dir, + interval_s=args.rsync_interval_s, + stop_event=stop_rsync, + ) + + _print_compute_command( + profile=profile, + wrapper_path=wrapper_path, + run_id=args.run_id, + campaign=args.campaign, + ) + + if args.no_dashboard: + return 0 + + _log("") + _log(f"Starting dashboard at http://{args.dashboard_host}:{args.dashboard_port}") + _log("Ctrl-C stops the local dashboard, rsync loop, and relay tunnel.") + return _run_dashboard( + local_run_dir=local_run_dir, + host=args.dashboard_host, + port=args.dashboard_port, + ) + finally: + stop_rsync.set() + if relay_process is not None and relay_process.poll() is None: + relay_process.terminate() + try: + relay_process.wait(timeout=5) + except subprocess.TimeoutExpired: + relay_process.kill() + keep = args.keep_ssh_master or os.environ.get("CHEMGRAPH_ACADEMY_KEEP_SSH_MASTER") == "1" + if started_ssh_master and not keep: + _stop_ssh_master(host=operator_host, control_path=control_path) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/profiles/__init__.py b/src/chemgraph/academy/runtime/profiles/__init__.py new file mode 100644 index 00000000..2ead8a21 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/__init__.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from importlib import resources +from pathlib import Path + + +BUILTIN_SYSTEM_PROFILES = { + "aurora": "aurora.template.json", + "polaris": "polaris.template.json", +} + + +def resolve_builtin_system_profile(path_or_name: str | Path) -> Path: + value = str(path_or_name) + path = Path(value) + if path.exists(): + return path.resolve() + relative = BUILTIN_SYSTEM_PROFILES.get(value) + if relative is None: + return path + return Path(str(resources.files(__package__).joinpath(relative))) + + +def list_builtin_system_profiles() -> list[str]: + return sorted(BUILTIN_SYSTEM_PROFILES) + + +from chemgraph.academy.runtime.profiles.system import SystemProfile # noqa: E402 +from chemgraph.academy.runtime.profiles.system import load_system_profile # noqa: E402 + + +__all__ = [ + "BUILTIN_SYSTEM_PROFILES", + "SystemProfile", + "list_builtin_system_profiles", + "load_system_profile", + "resolve_builtin_system_profile", +] diff --git a/src/chemgraph/academy/runtime/profiles/aurora.template.json b/src/chemgraph/academy/runtime/profiles/aurora.template.json new file mode 100644 index 00000000..c8469792 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/aurora.template.json @@ -0,0 +1,38 @@ +{ + "name": "aurora", + "operator_host": "${ALCF_USER}@aurora.alcf.anl.gov", + "remote_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}", + "academy_repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/academy", + "repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", + "run_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/runs", + "relay_host_file": "/flare/${ALCF_PROJECT}/${ALCF_USER}/uan-relay-18186.host", + "relay_port": 18186, + "venv_python": "/flare/${ALCF_PROJECT}/${ALCF_USER}/venvs/academy-swarm/bin/python", + "redis_bin_dir": "/flare/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "redis_port": 6392, + "redis_bind": "0.0.0.0", + "redis_protected_mode": "no", + "mpiexec": "mpiexec", + "pythonpath_entries": [ + "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src", + "/flare/${ALCF_PROJECT}/${ALCF_USER}/academy" + ], + "path_entries": [ + "/flare/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "/flare/${ALCF_PROJECT}/${ALCF_USER}/bin" + ], + "env": { + "NUMEXPR_MAX_THREADS": "256", + "NUMEXPR_NUM_THREADS": "64", + "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_ACADEMY_PY": "0.0.0+aurora" + }, + "unset_env": [ + "http_proxy", + "HTTP_PROXY", + "https_proxy", + "HTTPS_PROXY", + "all_proxy", + "ALL_PROXY" + ], + "no_proxy": "127.0.0.1,localhost,.alcf.anl.gov,*.alcf.anl.gov" +} diff --git a/src/chemgraph/academy/runtime/profiles/polaris.template.json b/src/chemgraph/academy/runtime/profiles/polaris.template.json new file mode 100644 index 00000000..c7f54afb --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/polaris.template.json @@ -0,0 +1,38 @@ +{ + "name": "polaris", + "operator_host": "${ALCF_USER}@polaris.alcf.anl.gov", + "remote_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}", + "academy_repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy", + "repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", + "run_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/runs", + "relay_host_file": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy/uan-relay-18186.host", + "relay_port": 18186, + "venv_python": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/venvs/academy-swarm/bin/python", + "redis_bin_dir": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "redis_port": 6392, + "redis_bind": "0.0.0.0", + "redis_protected_mode": "no", + "mpiexec": "mpiexec", + "pythonpath_entries": [ + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src", + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy" + ], + "path_entries": [ + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/bin" + ], + "env": { + "NUMEXPR_MAX_THREADS": "256", + "NUMEXPR_NUM_THREADS": "64", + "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_ACADEMY_PY": "0.0.0+polaris" + }, + "unset_env": [ + "http_proxy", + "HTTP_PROXY", + "https_proxy", + "HTTPS_PROXY", + "all_proxy", + "ALL_PROXY" + ], + "no_proxy": "127.0.0.1,localhost,.alcf.anl.gov,*.alcf.anl.gov" +} diff --git a/src/chemgraph/academy/runtime/profiles/system.py b/src/chemgraph/academy/runtime/profiles/system.py new file mode 100644 index 00000000..a67f14c1 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/system.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import json +import os +import re +from pathlib import Path + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from chemgraph.academy.runtime.profiles import resolve_builtin_system_profile + + +class SystemProfile(BaseModel): + """Site/runtime paths for launching ChemGraph Academy on an HPC system.""" + + model_config = ConfigDict(extra="forbid") + + name: str + operator_host: str + remote_root: str + academy_repo_root: str + repo_root: str + run_root: str + relay_host_file: str + relay_port: int + venv_python: str + redis_bin_dir: str + redis_port: int + redis_bind: str + redis_protected_mode: str + mpiexec: str + pythonpath_entries: list[str] + path_entries: list[str] = Field(default_factory=list) + env: dict[str, str] = Field(default_factory=dict) + unset_env: list[str] = Field(default_factory=list) + no_proxy: str + + +def load_system_profile(path_or_name: str | Path) -> SystemProfile: + profile_path = resolve_builtin_system_profile(path_or_name) + text = os.path.expandvars(profile_path.read_text(encoding="utf-8")) + unresolved = sorted(set(re.findall(r"\$\{([^}]+)\}", text))) + if unresolved: + raise ValueError( + f"System profile {profile_path} contains unresolved environment " + f"variables: {', '.join(unresolved)}", + ) + data = json.loads(text) + return SystemProfile.model_validate(data) diff --git a/src/chemgraph/academy/runtime/registration.py b/src/chemgraph/academy/runtime/registration.py new file mode 100644 index 00000000..c56db752 --- /dev/null +++ b/src/chemgraph/academy/runtime/registration.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import asyncio +import json +import pathlib +import time +from collections.abc import Mapping +from typing import Any + +from academy.exchange.redis import RedisAgentRegistration +from academy.identifier import AgentId + +from chemgraph.academy.observability.run_files import write_json_atomic + + +def academy_registration_path(run_dir: pathlib.Path) -> pathlib.Path: + return run_dir / 'academy_registrations.json' + + +def registration_payload( + *, + run_token: str, + registrations: Mapping[str, RedisAgentRegistration[Any]], +) -> dict[str, Any]: + return { + 'run_token': run_token, + 'exchange_type': 'redis', + 'agents': { + name: registration.agent_id.model_dump(mode='json') + for name, registration in registrations.items() + }, + } + + +def write_academy_registrations( + *, + run_dir: pathlib.Path, + run_token: str, + registrations: Mapping[str, RedisAgentRegistration[Any]], +) -> None: + write_json_atomic( + academy_registration_path(run_dir), + registration_payload(run_token=run_token, registrations=registrations), + ) + + +def load_academy_registrations( + run_dir: pathlib.Path, + *, + run_token: str, +) -> dict[str, RedisAgentRegistration[Any]]: + path = academy_registration_path(run_dir) + data = json.loads(path.read_text(encoding='utf-8')) + if data.get('run_token') != run_token: + raise RuntimeError( + f'Academy registration file {path} belongs to a different run', + ) + agents = data.get('agents') + if not isinstance(agents, dict): + raise RuntimeError(f'Academy registration file is malformed: {path}') + return { + name: RedisAgentRegistration( + agent_id=AgentId[Any].model_validate(agent_id), + ) + for name, agent_id in agents.items() + } + + +async def wait_academy_registrations( + run_dir: pathlib.Path, + *, + run_token: str, + timeout_s: float, +) -> dict[str, RedisAgentRegistration[Any]]: + path = academy_registration_path(run_dir) + deadline = time.monotonic() + timeout_s + while True: + if path.exists(): + return load_academy_registrations( + run_dir, + run_token=run_token, + ) + if time.monotonic() > deadline: + raise TimeoutError( + f'Timed out waiting for Academy registrations at {path}', + ) + await asyncio.sleep(0.25) diff --git a/src/chemgraph/academy/screening.py b/src/chemgraph/academy/screening.py deleted file mode 100644 index 09891642..00000000 --- a/src/chemgraph/academy/screening.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Screening agent for batch molecule processing. - -Wraps :class:`ChemGraphAgent` with a ``@loop`` that iterates over an -assigned list of molecules and publishes results via the Academy -exchange. -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import os -import time -from typing import Any, Optional - -from academy.agent import Agent, action, loop - -from chemgraph.academy.agent import ChemGraphAgent - -logger = logging.getLogger(__name__) - - -class ScreeningAgent(ChemGraphAgent): - """Agent that screens a batch of molecules using a ChemGraph workflow. - - Parameters - ---------- - molecules : list[str] - SMILES strings to screen. - query_template : str - Query template with ``{smiles}`` placeholder, e.g. - ``"Optimize the geometry of {smiles} and compute its energy."``. - results_dir : str or None - Directory to write per-molecule JSON result files for - downstream aggregation. If ``None``, results are only - returned via the exchange. - model_name, workflow_type, log_dir, rate_limiter, **chemgraph_kwargs - Forwarded to :class:`ChemGraphAgent`. - """ - - def __init__( - self, - molecules: list[str], - query_template: str, - results_dir: Optional[str] = None, - model_name: str = "gpt-4o-mini", - workflow_type: str = "single_agent", - log_dir: Optional[str] = None, - rate_limiter: Any = None, - **chemgraph_kwargs: Any, - ) -> None: - super().__init__( - model_name=model_name, - workflow_type=workflow_type, - log_dir=log_dir, - rate_limiter=rate_limiter, - **chemgraph_kwargs, - ) - self._molecules = molecules - self._query_template = query_template - self._results_dir = results_dir - self.results: list[dict[str, Any]] = [] - self.completed: int = 0 - self.failed: int = 0 - - async def agent_on_startup(self) -> None: - await super().agent_on_startup() - if self._results_dir: - os.makedirs(self._results_dir, exist_ok=True) - logger.info( - "ScreeningAgent %s: %d molecules to process", - self._agent_uuid, - len(self._molecules), - ) - - @action - async def get_progress(self) -> dict[str, Any]: - """Return screening progress.""" - return { - "agent_uuid": self._agent_uuid, - "total": len(self._molecules), - "completed": self.completed, - "failed": self.failed, - } - - @loop - async def screening_loop(self, shutdown: asyncio.Event) -> None: - """Iterate over assigned molecules and run queries.""" - for smiles in self._molecules: - if shutdown.is_set(): - logger.info( - "ScreeningAgent %s: shutdown requested, stopping", - self._agent_uuid, - ) - break - - query = self._query_template.format(smiles=smiles) - t0 = time.monotonic() - try: - result = await self.run_query(query) - elapsed = time.monotonic() - t0 - record = { - "smiles": smiles, - "status": "success", - "result": result, - "elapsed_seconds": round(elapsed, 2), - "agent_uuid": self._agent_uuid, - } - self.completed += 1 - except Exception as exc: - elapsed = time.monotonic() - t0 - logger.exception( - "ScreeningAgent %s: failed on %s", - self._agent_uuid, - smiles, - ) - record = { - "smiles": smiles, - "status": "error", - "error": str(exc), - "elapsed_seconds": round(elapsed, 2), - "agent_uuid": self._agent_uuid, - } - self.failed += 1 - - self.results.append(record) - - # Write individual result file for aggregation. - if self._results_dir: - safe_name = smiles.replace("/", "_").replace("\\", "_")[:50] - path = os.path.join( - self._results_dir, - f"{self._agent_uuid}_{safe_name}.json", - ) - with open(path, "w") as f: - json.dump(record, f, default=str) - - logger.info( - "ScreeningAgent %s: finished (%d ok, %d failed)", - self._agent_uuid, - self.completed, - self.failed, - ) - # Signal that this agent is done. - self.agent_shutdown() - - @action - async def get_results(self) -> list[dict[str, Any]]: - """Return all collected results so far.""" - return self.results diff --git a/tests/test_academy_compute_launcher.py b/tests/test_academy_compute_launcher.py new file mode 100644 index 00000000..10652443 --- /dev/null +++ b/tests/test_academy_compute_launcher.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from pathlib import Path + +from chemgraph.academy.runtime import compute_launcher +from chemgraph.academy.runtime.compute_launcher import AllocationPlan + + +def _plan(tmp_path: Path) -> AllocationPlan: + lm_config = tmp_path / "lm.json" + campaign = tmp_path / "campaign.json" + lm_config.write_text("{}\n", encoding="utf-8") + campaign.write_text("{}\n", encoding="utf-8") + return AllocationPlan( + run_dir=tmp_path, + run_token="token-1", + agent_count=3, + agents_per_node=1, + campaign_config=campaign, + lm_config=lm_config, + max_decisions=7, + poll_timeout_s=2.0, + idle_timeout_s=600.0, + startup_timeout_s=120.0, + completion_timeout_s=60.0, + status_interval_s=5.0, + redis_host="redis-host", + redis_port=6392, + redis_bind="0.0.0.0", + redis_protected_mode="no", + redis_namespace="ns", + start_redis=False, + mpiexec="mpiexec", + chemgraph_repo_root=tmp_path / "ChemGraph", + ) + + +def test_run_allocation_builds_single_mpiexec_command(tmp_path, monkeypatch) -> None: + calls: list[list[str]] = [] + monkeypatch.setattr(compute_launcher, "wait_redis", lambda *args, **kwargs: None) + monkeypatch.setattr( + compute_launcher.subprocess, + "call", + lambda cmd: calls.append(cmd) or 0, + ) + + assert compute_launcher.run_allocation(_plan(tmp_path)) == 0 + + assert len(calls) == 1 + cmd = calls[0] + assert cmd[:4] == ["mpiexec", "-n", "3", "--ppn"] + assert "chemgraph.cli.main" in cmd + assert "mpi-daemon" in cmd + assert "--campaign-config" in cmd + assert "--lm-config" in cmd + assert "--chemgraph-repo-root" in cmd + assert (tmp_path / "launch_command.txt").exists() diff --git a/tests/test_academy_dashboard.py b/tests/test_academy_dashboard.py new file mode 100644 index 00000000..61d9bf04 --- /dev/null +++ b/tests/test_academy_dashboard.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import json + +import chemgraph.academy.dashboard as dashboard +from chemgraph.academy.observability.event_log import EventLog + + +def test_dashboard_reads_canonical_events_jsonl(tmp_path) -> None: + run_dir = tmp_path / "daemon-run" + run_dir.mkdir() + (run_dir / "status.json").write_text( + json.dumps({"mode": "mpi_daemon", "timestamp": 10.0, "agents": []}) + + "\n", + encoding="utf-8", + ) + log = EventLog(run_dir / "events.jsonl") + log.emit( + "agent_started", + agent_id="agent-00", + role="scheduler observer", + payload={ + "role": "scheduler observer", + "placement": {"hostname": "x1", "short_hostname": "x1"}, + "hostname": "x1", + "short_hostname": "x1", + }, + ) + log.emit( + "agent_decision", + agent_id="agent-00", + role="scheduler observer", + payload={ + "round": 1, + "tool_names": ["send_message"], + "actions": [{"action": "send_message"}], + }, + ) + + events = dashboard.events_payload(run_dir)["events"] + + assert events[0]["event"] == "agent_started" + assert events[0]["payload"]["placement"]["hostname"] == "x1" + assert events[1]["event"] == "agent_decision" + assert events[1]["payload"]["actions"] == [{"action": "send_message"}] + + +def test_status_payload_builds_summary_and_proof_from_events(tmp_path) -> None: + run_dir = tmp_path / "daemon-run" + run_dir.mkdir() + (run_dir / "status.json").write_text( + json.dumps({"mode": "mpi_daemon", "agents": []}) + "\n", + encoding="utf-8", + ) + log = EventLog(run_dir / "events.jsonl") + for agent_id, hostname in (("agent-00", "x0"), ("agent-01", "x1")): + log.emit( + "agent_started", + agent_id=agent_id, + role="observer", + payload={ + "role": "observer", + "placement": {"hostname": hostname, "short_hostname": hostname}, + "hostname": hostname, + "short_hostname": hostname, + }, + ) + log.emit( + "message_sent", + agent_id="agent-00", + role="observer", + payload={ + "message_id": "msg-1", + "timestamp": 2.0, + "sender": "agent-00", + "recipient": "agent-01", + "kind": "message", + "content": "share evidence", + "tldr": "evidence", + "artifact_refs": [], + "tool_result_ids": [], + }, + ) + log.emit( + "belief_updated", + agent_id="agent-01", + role="observer", + payload={ + "hypothesis": "used peer evidence", + "confidence": 0.8, + "supporting_message_ids": ["msg-1"], + "supporting_tool_result_ids": [], + }, + ) + + class Handler: + pass + + handler = Handler() + handler.run_dir = run_dir + payload = dashboard.status_payload(handler) + + assert set(payload) == { + "communication_proof", + "placement", + "run_dir", + "schema", + "status", + "summary", + "updated", + } + assert payload["summary"]["message_count"] == 1 + assert payload["communication_proof"]["passes"]["has_message"] is True + assert payload["communication_proof"]["passes"]["has_cross_node_message"] is True + assert payload["communication_proof"]["passes"]["has_belief_citing_message"] is True + + +def test_dashboard_ignores_legacy_trace_jsonl(tmp_path) -> None: + run_dir = tmp_path / "old-run" + run_dir.mkdir() + (run_dir / "trace.jsonl").write_text( + json.dumps( + { + "timestamp": 1.0, + "agent": "agent-00", + "event": "daemon_started", + "payload": {"hostname": "x0"}, + }, + ) + + "\n", + encoding="utf-8", + ) + + assert dashboard.events_payload(run_dir)["events"] == [] diff --git a/tests/test_academy_operator_console.py b/tests/test_academy_operator_console.py new file mode 100644 index 00000000..99e726e3 --- /dev/null +++ b/tests/test_academy_operator_console.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from chemgraph.academy.runtime import operator_console +from chemgraph.academy.runtime.profiles.system import SystemProfile + + +def _profile(tmp_path: Path) -> SystemProfile: + return SystemProfile( + name="test-system", + operator_host="user@example", + remote_root="/remote/root", + academy_repo_root="/remote/root/academy", + repo_root="/remote/root/ChemGraph", + run_root="/remote/root/runs", + relay_host_file="/remote/root/relay.host", + relay_port=18186, + venv_python="/remote/root/venv/bin/python", + redis_bin_dir="/remote/root/tools/redis/bin", + redis_port=6392, + redis_bind="0.0.0.0", + redis_protected_mode="no", + mpiexec="mpiexec", + pythonpath_entries=[str(tmp_path)], + no_proxy="127.0.0.1,localhost", + ) + + +def test_delete_existing_run_removes_remote_and_local(tmp_path, monkeypatch) -> None: + local_run = tmp_path / "mirror" / "run-001" + local_run.mkdir(parents=True) + (local_run / "status.json").write_text("{}\n", encoding="utf-8") + calls: list[list[str]] = [] + + monkeypatch.setattr( + operator_console, + "_run", + lambda command, **kwargs: calls.append(command), + ) + + operator_console._delete_existing_run( + profile=_profile(tmp_path), + host="user@example", + ssh_opts=["-o", "BatchMode=yes"], + run_id="run-001", + local_run_dir=local_run, + ) + + assert not local_run.exists() + assert calls + assert calls[0][:4] == ["ssh", "-o", "BatchMode=yes", "user@example"] + assert 'mv -- "$run_dir" "$trash_dir"' in calls[0][-1] + assert 'rm -rf -- "$trash_dir"' in calls[0][-1] + assert 'mkdir -p "$run_dir"' in calls[0][-1] + + +def test_delete_existing_run_rejects_unsafe_run_id(tmp_path) -> None: + with pytest.raises(RuntimeError, match="unsafe run id"): + operator_console._delete_existing_run( + profile=_profile(tmp_path), + host="user@example", + ssh_opts=[], + run_id="../bad", + local_run_dir=tmp_path / "mirror", + ) diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py new file mode 100644 index 00000000..6e5f6c60 --- /dev/null +++ b/tests/test_academy_reasoning_phase2.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import json +from pathlib import Path + +import pytest + +from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from chemgraph.academy.core.turn import ( + build_peer_status, + ChemGraphReasoningRoundEngine, +) +from chemgraph.academy.core.turn import ReasoningTurnResult +from chemgraph.academy.core.tools import ReasoningToolRuntimeState +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ResourceSpec +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.lm import LLMSettings +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.core.prompt import PromptStateLimits + + +def _agent_spec() -> ChemGraphAgentSpec: + return ChemGraphAgentSpec( + name="agent-a", + role="Worker", + mission="Use explicit tools only.", + allowed_peers=(), + tools=(), + ) + + +def _agent_spec_with_peer() -> ChemGraphAgentSpec: + return ChemGraphAgentSpec( + name="agent-a", + role="Worker", + mission="Use explicit tools only.", + allowed_peers=("agent-b",), + tools=(), + ) + + +def _campaign(spec: ChemGraphAgentSpec) -> ChemGraphCampaign: + return ChemGraphCampaign( + run_id="campaign-1", + user_task="Rank staged candidates.", + initial_agent=spec.name, + prompt_profile=Path("prompt_profiles/default.json"), + agents=(spec,), + ) + + +def _prompt_profile() -> PromptProfile: + return PromptProfile( + prompt_version="test", + prompt_style="json_state", + system_prompt="system prompt", + protocol_prompt="call finish_turn when idle", + langchain_recursion_limit=8, + state_limits=PromptStateLimits( + received_messages_last_n=1, + tool_results_last_n=1, + actions_last_n=2, + ), + ) + + +def _lm_settings() -> LLMSettings: + return LLMSettings( + base_url="http://127.0.0.1:18085/argoapi/v1", + model="GPT-5.4", + provider="openai_compatible_tools", + api_key="dummy", + user="test-user", + timeout_s=60, + temperature=0, + max_tokens=1024, + max_retries=1, + retry_delay_s=0, + ) + + +class _FakeReasoningEngine: + async def run_turn(self) -> ReasoningTurnResult: + return ReasoningTurnResult( + final_text="done", + state={"messages": []}, + tool_calls_completed=1, + action_tools_called=("finish_turn",), + science_tools_called=("science_tool",), + executed_tool_names=("science_tool", "finish_turn"), + requested_finish=True, + requested_self_wake=True, + workflow_span_id="workflow-1", + thread_id="agent-a-round-1", + ) + + +class _SlowPeerHandle: + def __init__(self) -> None: + self.delivered = asyncio.Event() + self.calls: list[tuple[str, dict]] = [] + + async def action(self, name: str, message: dict) -> None: + await asyncio.sleep(0.1) + self.calls.append((name, message)) + self.delivered.set() + + +@pytest.mark.asyncio +async def test_reasoning_adapter_finish_turn_updates_runtime_state(tmp_path) -> None: + spec = _agent_spec() + runtime_state = ReasoningToolRuntimeState() + traces: list[tuple[str, dict]] = [] + + tools = await build_chemgraph_reasoning_tools( + spec=spec, + run_dir=tmp_path, + tool_invoker=object(), # unused when spec.tools is empty + peer_names=(), + peer_handles={}, + outbox=[], + tool_results=[], + get_round_index=lambda: 1, + set_final_result=lambda result: None, + trace=lambda event, payload: traces.append((event, payload)), + runtime_state=runtime_state, + ) + + assert [tool.name for tool in tools] == [ + "send_message", + "ask_peer", + "submit_result", + "finish_turn", + ] + + finish_turn = next(tool for tool in tools if tool.name == "finish_turn") + result = await finish_turn.ainvoke({"reason": "nothing useful now"}) + + assert result == {"status": "finished", "reason": "nothing useful now"} + assert runtime_state.finished_turn is True + assert runtime_state.action_tool_names == ["finish_turn"] + assert runtime_state.executed_tool_names == ["finish_turn"] + assert traces == [ + ( + "turn_finished_without_external_action", + {"reason": "nothing useful now"}, + ) + ] + + +@pytest.mark.asyncio +async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: + spec = _agent_spec_with_peer() + runtime_state = ReasoningToolRuntimeState() + peer = _SlowPeerHandle() + traces: list[tuple[str, dict]] = [] + outbox: list[dict] = [] + + tools = await build_chemgraph_reasoning_tools( + spec=spec, + run_dir=tmp_path, + tool_invoker=object(), + peer_names=("agent-b",), + peer_handles={"agent-b": peer}, + outbox=outbox, + tool_results=[], + get_round_index=lambda: 1, + set_final_result=lambda result: None, + trace=lambda event, payload: traces.append((event, payload)), + runtime_state=runtime_state, + ) + send_message = next(tool for tool in tools if tool.name == "send_message") + + result = await asyncio.wait_for( + send_message.ainvoke( + { + "recipient": "agent-b", + "tldr": "short summary", + "content": "full message", + "artifact_refs": [], + "tool_result_ids": [], + "reason": "peer needs this evidence", + "confidence": 0.8, + }, + ), + timeout=0.05, + ) + + assert result["status"] == "sent" + assert result["delivery"] == "queued" + assert len(outbox) == 1 + assert [name for name, _ in traces] == ["message_sent"] + + await asyncio.wait_for(peer.delivered.wait(), timeout=1) + await asyncio.sleep(0) + + assert peer.calls[0][0] == "receive_message" + assert [name for name, _ in traces] == [ + "message_sent", + "message_delivered", + ] + + +@pytest.mark.asyncio +async def test_logical_agent_startup_initializes_chemgraph_reasoning_engine( + tmp_path, +) -> None: + spec = _agent_spec() + agent = ChemGraphLogicalAgent( + spec, + campaign=_campaign(spec), + llm_settings=_lm_settings(), + prompt_profile=_prompt_profile(), + run_dir=tmp_path, + max_decisions=5, + tool_invoker=object(), # unused when spec.tools is empty + ) + + await agent.agent_on_startup() + + assert isinstance(agent._reasoning_engine, ChemGraphReasoningRoundEngine) + + +@pytest.mark.asyncio +async def test_logical_agent_reasoning_round_uses_chemgraph_engine(tmp_path) -> None: + spec = _agent_spec() + agent = ChemGraphLogicalAgent( + spec, + campaign=_campaign(spec), + llm_settings=_lm_settings(), + prompt_profile=_prompt_profile(), + run_dir=tmp_path, + max_decisions=5, + tool_invoker=object(), + ) + agent.round_index = 1 + agent._reasoning_engine = _FakeReasoningEngine() + + self_wake = await agent._reasoning_round() + + assert self_wake is True + events = [ + json.loads(line)["event"] + for line in tmp_path.joinpath("events.jsonl").read_text().splitlines() + ] + assert events == [ + "round_started", + "agent_decision", + "round_finished", + "self_wake_scheduled", + ] + + +def test_reasoning_engine_builds_bounded_wakeup_state(tmp_path) -> None: + spec = _agent_spec() + received_message_history = [{"message_id": "old"}, {"message_id": "new"}] + outbox = [ + { + "message_id": "msg-old", + "recipient": "agent-b", + "tldr": "old message", + "timestamp": 1, + }, + { + "message_id": "msg-new", + "recipient": "agent-b", + "tldr": "new message", + "timestamp": 3, + }, + ] + tool_results = [{"tool_result_id": "old"}, {"tool_result_id": "new"}] + final_result = {"summary": "current belief"} + engine = ChemGraphReasoningRoundEngine( + campaign=_campaign(spec), + spec=spec, + llm_settings=_lm_settings(), + prompt_profile=_prompt_profile(), + run_dir=tmp_path, + max_decisions=5, + tools=[], + runtime_state=ReasoningToolRuntimeState(), + received_message_history=received_message_history, + outbox=outbox, + tool_results=tool_results, + get_final_result=lambda: final_result, + get_round_index=lambda: 2, + trace=lambda event, payload: None, + ) + + state = engine.build_wakeup_state(round_index=2) + + assert state["campaign"] == "campaign-1" + assert state["user_task"] == "Rank staged candidates." + assert state["agent_name"] == "agent-a" + assert state["available_chemgraph_tools"] == [] + assert state["peer_status"] == {} + assert state["received_messages"] == [{"message_id": "new"}] + assert state["local_chemgraph_tool_results"] == [{"tool_result_id": "new"}] + assert state["recent_actions"] == [ + { + "type": "send_message", + "recipient": "agent-b", + "tldr": "old message", + "message_id": "msg-old", + "timestamp": 1, + }, + { + "type": "send_message", + "recipient": "agent-b", + "tldr": "new message", + "message_id": "msg-new", + "timestamp": 3, + }, + ] + assert state["current_final_result"] == final_result + assert state["required_protocol"] == "call finish_turn when idle" + + +def test_build_peer_status_uses_inflight_tool_events(tmp_path) -> None: + state_dir = tmp_path / "agent_status" + state_dir.mkdir() + (state_dir / "agent-b.json").write_text( + json.dumps( + { + "agent_name": "agent-b", + "round": 3, + "finished": False, + "last_error": None, + "status_updated_at": 100.0, + "recent_outbox": [ + { + "message_id": "msg-ack", + "tldr": "Starting requested MACE energy run", + }, + ], + "belief": { + "hypothesis": None, + "confidence": 0.0, + }, + }, + ) + + "\n", + encoding="utf-8", + ) + events = [ + { + "timestamp": 101.0, + "event": "message_sent", + "agent_id": "agent-b", + "payload": { + "message_id": "msg-ack", + "tldr": "Starting requested MACE energy run", + }, + }, + { + "timestamp": 102.0, + "event": "tool_call_started", + "agent_id": "agent-b", + "payload": { + "tool_name": "run_mace_ensemble", + "tool_result_id": "tool-1", + "tool_call_id": "call-1", + }, + }, + ] + with (tmp_path / "events.jsonl").open("w", encoding="utf-8") as fp: + for event in events: + fp.write(json.dumps(event) + "\n") + + status = build_peer_status(run_dir=tmp_path, peer_names=("agent-b",)) + + assert status["agent-b"]["state"] == "busy" + assert status["agent-b"]["last_outbox_tldr"] == "Starting requested MACE energy run" + assert status["agent-b"]["current_activity"] == { + "type": "tool_call", + "tool_name": "run_mace_ensemble", + "tool_result_id": "tool-1", + "tool_call_id": "call-1", + "started_at": 102.0, + } + + +def test_campaign_resources_resolve_to_shared_run_artifacts(tmp_path) -> None: + spec = dataclasses.replace( + _agent_spec(), + resources=("candidate_dataset", "structure_output_directory"), + ) + campaign = ChemGraphCampaign( + run_id="campaign-1", + user_task="Rank staged candidates.", + initial_agent=spec.name, + prompt_profile=Path("prompt_profiles/default.json"), + agents=(spec,), + resources={ + "candidate_dataset": ResourceSpec( + kind="json", + path="/source/data/candidates.json", + scope="absolute", + expose_content=True, + ), + "structure_output_directory": ResourceSpec( + kind="directory", + path="academy_mace_structures", + scope="shared_run", + ), + "mace_output_result_file": ResourceSpec( + kind="file", + path="academy_mace_outputs/mace_results.json", + scope="shared_run", + ), + }, + ) + + resolved = resolve_campaign_resources(campaign, tmp_path / "run-1") + + assert campaign.resources["structure_output_directory"].path == ( + "academy_mace_structures" + ) + assert resolved.resources["candidate_dataset"].path == ( + "/source/data/candidates.json" + ) + assert resolved.resources["structure_output_directory"].path == str( + tmp_path / "run-1" / "shared" / "academy_mace_structures", + ) + assert resolved.resources["mace_output_result_file"].path == str( + tmp_path / "run-1" / "shared" / "academy_mace_outputs" / "mace_results.json", + ) diff --git a/tests/test_tool_adapter_validation.py b/tests/test_tool_adapter_validation.py new file mode 100644 index 00000000..09ae3ade --- /dev/null +++ b/tests/test_tool_adapter_validation.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import json +from typing import Any + +import pytest + +from chemgraph.academy.core.tools import ReasoningToolRuntimeState +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.campaign import ChemGraphAgentSpec + + +class _FakePeerHandle: + def __init__(self) -> None: + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def action(self, name: str, payload: dict[str, Any]) -> None: + self.calls.append((name, payload)) + + +def _agent_spec() -> ChemGraphAgentSpec: + return ChemGraphAgentSpec( + name="agent-a", + role="Worker", + mission="Use explicit tools only.", + allowed_peers=("agent-b",), + tools=(), + ) + + +async def _build_tools(tmp_path): + runtime_state = ReasoningToolRuntimeState() + traces: list[tuple[str, dict[str, Any]]] = [] + outbox: list[dict[str, Any]] = [] + peer_handle = _FakePeerHandle() + tools = await build_chemgraph_reasoning_tools( + spec=_agent_spec(), + run_dir=tmp_path, + tool_invoker=object(), # unused when spec.tools is empty + peer_names=("agent-b",), + peer_handles={"agent-b": peer_handle}, + outbox=outbox, + tool_results=[], + get_round_index=lambda: 1, + set_final_result=lambda result: None, + trace=lambda event, payload: traces.append((event, payload)), + runtime_state=runtime_state, + ) + return { + "tools": {tool.name: tool for tool in tools}, + "runtime_state": runtime_state, + "traces": traces, + "outbox": outbox, + "peer_handle": peer_handle, + } + + +@pytest.mark.asyncio +async def test_send_message_invalid_args_return_structured_tool_error(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "invalid confidence", + "content": "content", + "artifact_refs": [], + "tool_result_ids": [], + "reason": "exercise validation", + "confidence": 1.5, + } + ) + + assert result["status"] == "error" + assert result["error_type"] == "invalid_tool_arguments" + assert result["errors"][0]["field"] == "confidence" + assert env["runtime_state"].action_tool_names == ["send_message"] + assert env["outbox"] == [] + assert env["peer_handle"].calls == [] + assert env["traces"] == [ + ( + "tool_call_failed", + { + "tool_name": "send_message", + "status": "failed", + "error": "invalid_tool_arguments", + "error_type": "invalid_tool_arguments", + "errors": result["errors"], + }, + ) + ] + + +@pytest.mark.asyncio +async def test_send_message_disallowed_recipient_does_not_deliver(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "not-a-peer", + "tldr": "wrong peer", + "content": "content", + "artifact_refs": [], + "tool_result_ids": [], + "reason": "exercise validation", + "confidence": 0.8, + } + ) + + assert result == { + "status": "error", + "tool_name": "send_message", + "error": "disallowed_recipient", + "error_type": "disallowed_recipient", + "recipient": "not-a-peer", + "allowed_peers": ["agent-b"], + } + assert env["outbox"] == [] + assert env["peer_handle"].calls == [] + assert env["traces"][0][0] == "tool_call_failed" + assert env["traces"][0][1]["error_type"] == "disallowed_recipient" + + +@pytest.mark.asyncio +async def test_ask_peer_requires_tldr(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["ask_peer"].ainvoke( + { + "recipient": "agent-b", + "tldr": "", + "question": "What happened?", + "reason": "need a peer check", + } + ) + + assert result["status"] == "error" + assert result["error_type"] == "invalid_tool_arguments" + assert result["errors"][0]["field"] == "tldr" + assert env["outbox"] == [] + assert env["peer_handle"].calls == [] + + +@pytest.mark.asyncio +async def test_valid_send_message_still_delivers(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "candidate ready", + "content": "Candidate C1 has a usable artifact.", + "artifact_refs": ["artifacts/c1.xyz"], + "tool_result_ids": ["tool-1"], + "reason": "peer needs the result", + "confidence": 0.9, + } + ) + + assert result["status"] == "sent" + assert result["recipient"] == "agent-b" + assert len(env["outbox"]) == 1 + assert env["peer_handle"].calls[0][0] == "receive_message" + assert env["peer_handle"].calls[0][1]["message_id"] == result["message_id"] + assert [event for event, _ in env["traces"]] == [ + "message_sent", + "message_delivered", + ] + assert { + json.loads(line)["message_id"] + for line in tmp_path.joinpath("messages.jsonl").read_text().splitlines() + } == {result["message_id"]} From 778e94c018cbfcc34c505e820aaae1d5914bddfb Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 14:12:33 -0500 Subject: [PATCH 050/119] feat(academy): add MACE screening campaign example --- pyproject.toml | 8 + src/chemgraph/academy/examples/__init__.py | 76 ++++++++++ .../campaign.jsonc | 117 +++++++++++++++ .../data/mace_screening_20_smiles.json | 106 +++++++++++++ .../lm_config.template.json | 12 ++ .../prompt_profiles/default.json | 12 ++ tests/test_academy_campaign.py | 142 ++++++++++++++++++ 7 files changed, 473 insertions(+) create mode 100644 src/chemgraph/academy/examples/__init__.py create mode 100644 src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc create mode 100644 src/chemgraph/academy/examples/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json create mode 100644 src/chemgraph/academy/examples/example-002-mace-ensemble-screening/lm_config.template.json create mode 100644 src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json create mode 100644 tests/test_academy_campaign.py diff --git a/pyproject.toml b/pyproject.toml index c98d967d..d2c3dd37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,14 @@ where = ["src/"] [tool.setuptools.package-data] "chemgraph.eval" = ["data/*.json"] +"chemgraph.academy.examples" = [ + "example-*/*.json", + "example-*/*.jsonc", + "example-*/data/*.json", + "example-*/prompt_profiles/*.json", +] +"chemgraph.academy.runtime.profiles" = ["*.json"] +"chemgraph.academy.dashboard" = ["static/*"] "ui" = ["assets/*.png"] [tool.ruff] diff --git a/src/chemgraph/academy/examples/__init__.py b/src/chemgraph/academy/examples/__init__.py new file mode 100644 index 00000000..3f376834 --- /dev/null +++ b/src/chemgraph/academy/examples/__init__.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import dataclasses +from importlib import resources +from pathlib import Path + + +EXAMPLE_002 = 'example-002-mace-ensemble-screening' + +BUILTIN_CAMPAIGNS = { + 'mace-ensemble-screening-20': f'{EXAMPLE_002}/campaign.jsonc', +} + +BUILTIN_LM_CONFIG_TEMPLATES = { + 'argo-gpt54-template': f'{EXAMPLE_002}/lm_config.template.json', + 'argo-gpt54-mace-template': f'{EXAMPLE_002}/lm_config.template.json', +} + + +@dataclasses.dataclass(frozen=True) +class CampaignLaunchDefaults: + """Runtime defaults for a built-in ChemGraph Academy campaign.""" + + lm_config_template: str + agent_count: int + agents_per_node: int + max_decisions: int + + +BUILTIN_CAMPAIGN_LAUNCH_DEFAULTS = { + 'mace-ensemble-screening-20': CampaignLaunchDefaults( + lm_config_template='argo-gpt54-mace-template', + agent_count=5, + agents_per_node=1, + max_decisions=24, + ), +} + + +def _resolve_builtin( + path_or_name: str | Path, + builtins: dict[str, str], +) -> Path: + value = str(path_or_name) + path = Path(value) + if path.exists(): + return path.resolve() + relative = builtins.get(value) + if relative is None: + return path + return Path(str(resources.files(__package__).joinpath(relative))) + + +def resolve_builtin_campaign(path_or_name: str | Path) -> Path: + return _resolve_builtin(path_or_name, BUILTIN_CAMPAIGNS) + + +def resolve_builtin_lm_config_template(path_or_name: str | Path) -> Path: + return _resolve_builtin(path_or_name, BUILTIN_LM_CONFIG_TEMPLATES) + + +def list_builtin_campaigns() -> list[str]: + return sorted(BUILTIN_CAMPAIGNS) + + +def list_builtin_lm_config_templates() -> list[str]: + return sorted(BUILTIN_LM_CONFIG_TEMPLATES) + + +def campaign_launch_defaults(campaign: str) -> CampaignLaunchDefaults: + try: + return BUILTIN_CAMPAIGN_LAUNCH_DEFAULTS[campaign] + except KeyError as exc: + raise KeyError( + f'No built-in launch defaults for campaign {campaign!r}', + ) from exc diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc new file mode 100644 index 00000000..6cac1d79 --- /dev/null +++ b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc @@ -0,0 +1,117 @@ +{ + // Campaign files support JSONC-style comments. + "run_id": "mace-ensemble-screening-20", + "user_task": "Given 20 staged SMILES candidates, generate 3D XYZ structures, run a MACE ensemble energy screen over generated structures, and rank candidates by calculation readiness and available MACE evidence.", + "prompt_profile": "prompt_profiles/default.json", + "initial_agent": "coordinator-agent", + "resources": { + // Resource fields: + // kind: "json" | "file" | "directory" + // scope: "campaign_file" | "shared_run" | "absolute" | "external" + // campaign_file: relative paths resolve next to this campaign file. + // shared_run: relative paths resolve under /shared/. + // absolute: path must already be absolute. + // external: runtime leaves path/uri unchanged. + // expose_content: only meaningful for kind="json"; true includes parsed JSON in the bootstrap task. + "candidate_dataset": { + "kind": "json", + "path": "data/mace_screening_20_smiles.json", + "scope": "campaign_file", + "description": "The full input candidate list. Coordinator-agent may inspect it and delegate records by peer message.", + "expose_content": true + }, + "structure_output_directory": { + "kind": "directory", + "path": "academy_mace_structures", + "scope": "shared_run", + "description": "Shared run directory where generated XYZ coordinate files should be written." + }, + "mace_output_result_file": { + "kind": "file", + "path": "academy_mace_outputs/mace_results.json", + "scope": "shared_run", + "description": "Shared run file requested for the MACE ensemble result summary." + }, + "mace_model_file": { + "kind": "file", + "path": "models/mace-mpa-0-medium.model", + "scope": "campaign_file", + "description": "Local MACE model file shipped with this campaign example." + } + }, + "agents": [ + { + "name": "coordinator-agent", + "role": "MACEReadinessCoordinatorAgent", + "mission": "Coordinate the campaign from the bootstrap task. Send odd-numbered MOL candidates to structure-agent-a and even-numbered MOL candidates to structure-agent-b, including candidate_id, label, SMILES, and output_file. After structure evidence returns, ask mace-agent to run one MACE energy calculation on CPU using the provided structure directory, output result path, and local model file resource, then ask assessment-agent for readiness/ranking evidence before submitting the final result.", + "allowed_peers": [ + "structure-agent-a", + "structure-agent-b", + "mace-agent", + "assessment-agent" + ], + "tools": [], + "resources": [ + "candidate_dataset", + "structure_output_directory", + "mace_output_result_file", + "mace_model_file" + ] + }, + { + "name": "structure-agent-a", + "role": "MolecularStructureWorkerAgent", + "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", + "allowed_peers": ["coordinator-agent"], + "tools": ["smiles_to_coordinate_file"], + "resources": [] + }, + { + "name": "structure-agent-b", + "role": "MolecularStructureWorkerAgent", + "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", + "allowed_peers": ["coordinator-agent"], + "tools": ["smiles_to_coordinate_file"], + "resources": [] + }, + { + "name": "mace-agent", + "role": "MACEEnsembleAgent", + "mission": "Run MACE only after a concrete request from coordinator-agent. Report started, completed, partial, or failed evidence back to coordinator-agent, including output paths and tool_result_ids; pending work is not a failure.", + "allowed_peers": ["coordinator-agent"], + "tools": ["run_mace_ensemble", "inspect_json"], + "resources": ["mace_model_file"] + }, + { + "name": "assessment-agent", + "role": "ScreeningAssessmentAgent", + "mission": "Assess evidence received from coordinator-agent. Summarize structure coverage, MACE coverage, failures, ranking readiness, and pending work without treating pending MACE work as failure.", + "allowed_peers": ["coordinator-agent"], + "tools": ["inspect_json"], + "resources": [] + } + ], + "tools": [ + // Tool fields: + // module: Python module containing the ChemGraph FastMCP object. + // tool: concrete tool name exposed by that module. + { + "name": "smiles_to_coordinate_file", + "module": "chemgraph.mcp.mcp_tools", + "tool": "smiles_to_coordinate_file", + "description": "Convert one SMILES string to a generated XYZ coordinate file." + }, + { + "name": "run_mace_ensemble", + "module": "chemgraph.mcp.mace_mcp_hpc", + "tool": "run_mace_ensemble", + "description": "Run MACE calculations over every generated structure in a directory." + }, + { + "name": "inspect_json", + "module": "chemgraph.mcp.hpc_misc_mcp", + "tool": "inspect_json", + "description": "Inspect a JSON file, output directory, or missing expected JSON path and return compact summaries." + } + ] +} diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json new file mode 100644 index 00000000..90bce655 --- /dev/null +++ b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json @@ -0,0 +1,106 @@ +{ + "dataset_id": "mace-screening-20-smiles-v1", + "description": "Twenty small-molecule SMILES for a ChemGraph-native Academy MACE ensemble screening demo.", + "candidates": [ + { + "candidate_id": "MOL-001", + "label": "water", + "smiles": "O" + }, + { + "candidate_id": "MOL-002", + "label": "methane", + "smiles": "C" + }, + { + "candidate_id": "MOL-003", + "label": "ammonia", + "smiles": "N" + }, + { + "candidate_id": "MOL-004", + "label": "carbon_dioxide", + "smiles": "O=C=O" + }, + { + "candidate_id": "MOL-005", + "label": "methanol", + "smiles": "CO" + }, + { + "candidate_id": "MOL-006", + "label": "ethanol", + "smiles": "CCO" + }, + { + "candidate_id": "MOL-007", + "label": "acetone", + "smiles": "CC(=O)C" + }, + { + "candidate_id": "MOL-008", + "label": "acetic_acid", + "smiles": "CC(=O)O" + }, + { + "candidate_id": "MOL-009", + "label": "benzene", + "smiles": "c1ccccc1" + }, + { + "candidate_id": "MOL-010", + "label": "toluene", + "smiles": "Cc1ccccc1" + }, + { + "candidate_id": "MOL-011", + "label": "phenol", + "smiles": "Oc1ccccc1" + }, + { + "candidate_id": "MOL-012", + "label": "aniline", + "smiles": "Nc1ccccc1" + }, + { + "candidate_id": "MOL-013", + "label": "pyridine", + "smiles": "n1ccccc1" + }, + { + "candidate_id": "MOL-014", + "label": "furan", + "smiles": "c1ccoc1" + }, + { + "candidate_id": "MOL-015", + "label": "formaldehyde", + "smiles": "C=O" + }, + { + "candidate_id": "MOL-016", + "label": "formic_acid", + "smiles": "C(=O)O" + }, + { + "candidate_id": "MOL-017", + "label": "glycine", + "smiles": "NCC(=O)O" + }, + { + "candidate_id": "MOL-018", + "label": "alanine", + "smiles": "CC(N)C(=O)O" + }, + { + "candidate_id": "MOL-019", + "label": "urea", + "smiles": "NC(=O)N" + }, + { + "candidate_id": "MOL-020", + "label": "acetonitrile", + "smiles": "CC#N" + } + ] +} diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/lm_config.template.json b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/lm_config.template.json new file mode 100644 index 00000000..26fe66ed --- /dev/null +++ b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/lm_config.template.json @@ -0,0 +1,12 @@ +{ + "provider": "openai_compatible_tools", + "base_url": "http://:18186/argoapi/v1", + "model": "GPT-5.4", + "api_key": "dummy", + "user": "", + "timeout_s": 180, + "temperature": 0.1, + "max_tokens": 8192, + "max_retries": 3, + "retry_delay_s": 2 +} diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json new file mode 100644 index 00000000..bbeb4f44 --- /dev/null +++ b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json @@ -0,0 +1,12 @@ +{ + "prompt_version": "chemgraph-mace-ensemble-agent-v1", + "prompt_style": "json_state", + "system_prompt": "You are a persistent ChemGraph-style LM agent hosted inside an Academy daemon on HPC. You communicate with peers only through send_message or ask_peer. You may call only the ChemGraph MCP tools listed in available_chemgraph_tools. Treat peer messages as evidence only when they include message_id, candidate IDs, artifact paths, or tool_result_ids. Do not claim access to another agent's private state unless it appears in a received message.", + "protocol_prompt": "Return one or more tool calls. If no action is useful, call finish_turn. Never fabricate ChemGraph tool outputs, energies, coordinate paths, or MACE results. Only cite tool_result_ids that appear in local_chemgraph_tool_results or received_messages. After a local ChemGraph tool finishes, you will be woken for another decision round with that result visible in local_chemgraph_tool_results; use that follow-up round to interpret, communicate, or rank the new evidence. Inspect peer_status before asking a peer for status. If peer_status shows the peer is busy on the requested tool or recently acknowledged the request, do not ask again; call finish_turn or proceed with other useful work. Every send_message or ask_peer call must include tldr: one short line summarizing the communication for the dashboard. Keep each string argument concise. For final ranking, summarize aggregate counts and exceptions in summary, and put detailed evidence in artifact_refs, tool_result_ids, and supporting_message_ids.", + "langchain_recursion_limit": 64, + "state_limits": { + "received_messages_last_n": 28, + "tool_results_last_n": 18, + "actions_last_n": 18 + } +} diff --git a/tests/test_academy_campaign.py b/tests/test_academy_campaign.py new file mode 100644 index 00000000..f129320f --- /dev/null +++ b/tests/test_academy_campaign.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import json + +import pytest + +from chemgraph.academy.core.campaign import campaign_bootstrap_text +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import validate_campaign + + +def test_builtin_mace_campaign_uses_star_coordinator_without_routing_policy() -> None: + campaign = load_campaign("mace-ensemble-screening-20") + + validate_campaign(campaign, len(campaign.agents)) + + assert campaign.initial_agent == "coordinator-agent" + assert [agent.name for agent in campaign.agents] == [ + "coordinator-agent", + "structure-agent-a", + "structure-agent-b", + "mace-agent", + "assessment-agent", + ] + peers = {agent.name: set(agent.allowed_peers) for agent in campaign.agents} + assert peers["coordinator-agent"] == { + "structure-agent-a", + "structure-agent-b", + "mace-agent", + "assessment-agent", + } + assert peers["structure-agent-a"] == {"coordinator-agent"} + assert peers["structure-agent-b"] == {"coordinator-agent"} + assert peers["mace-agent"] == {"coordinator-agent"} + assert peers["assessment-agent"] == {"coordinator-agent"} + + bootstrap = json.loads(campaign_bootstrap_text(campaign)) + assert "parameters" not in bootstrap + assert "routing_policy" not in bootstrap + + +def test_removed_structured_orchestration_fields_are_rejected(tmp_path) -> None: + campaign_path = tmp_path / "campaign.json" + campaign_path.write_text( + json.dumps( + { + "run_id": "stale", + "user_task": "test", + "prompt_profile": "prompt.json", + "parameters": {"old": "field"}, + "routing_policy": {"type": "old"}, + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "tools": [], + }, + ], + "tools": [], + }, + ), + encoding="utf-8", + ) + + with pytest.raises(RuntimeError, match="removed structured orchestration"): + load_campaign(campaign_path) + + +def test_campaign_loader_accepts_jsonc_comments(tmp_path) -> None: + campaign_path = tmp_path / "campaign.json" + campaign_path.write_text( + """ + { + // User-facing campaign files may include comments. + "run_id": "commented", + "user_task": "test", + "prompt_profile": "prompt.json", + "resources": { + /* Resource options are documented in the built-in examples. */ + "input": { + "kind": "json", + "path": "input.json", + "scope": "campaign_file", + "expose_content": false + } + }, + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "tools": [], + "resources": ["input"] + } + ], + "tools": [] + } + """, + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + + assert campaign.run_id == "commented" + assert campaign.resources["input"].kind == "json" + + +def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: + campaign_path = tmp_path / "campaign.json" + campaign_path.write_text( + json.dumps( + { + "run_id": "bad-resource", + "user_task": "test", + "prompt_profile": "prompt.json", + "resources": { + "input": { + "kind": "blob", + "path": "input.json", + "scope": "somewhere", + }, + }, + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "tools": [], + }, + ], + "tools": [], + }, + ), + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="resource kind must be one of"): + load_campaign(campaign_path) From 0b8197b1c344a555d04bf2d0f8fbe9d10210651e Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 14:15:58 -0500 Subject: [PATCH 051/119] feat(cli): wire dashboard and Academy commands --- pyproject.toml | 4 + src/chemgraph/cli/main.py | 161 +++++++++++++++++ .../observability/local_dashboard_run.py | 170 ++++++++++++++++++ 3 files changed, 335 insertions(+) create mode 100644 src/chemgraph/observability/local_dashboard_run.py diff --git a/pyproject.toml b/pyproject.toml index d2c3dd37..efb519a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,10 @@ rag = [ [project.scripts] chemgraph = "chemgraph.cli:main" chemgraph-eval = "chemgraph.eval.cli:main" +chemgraph-academy-run = "chemgraph.academy.runtime.compute_launcher:main" +chemgraph-academy-console = "chemgraph.academy.runtime.operator_console:main" +chemgraph-dashboard = "chemgraph.academy.dashboard:main" +chemgraph-dashboard-run = "chemgraph.observability.local_dashboard_run:main" [tool.setuptools.packages.find] where = ["src/"] diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index 86a6184d..ce254e06 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -232,6 +232,83 @@ def create_argument_parser() -> argparse.ArgumentParser: # ---- "models" subcommand --------------------------------------------- subparsers.add_parser("models", help="List all available LLM models.") + # ---- "dashboard" subcommands ---------------------------------------- + dashboard_parser = subparsers.add_parser( + "dashboard", + help="Serve the ChemGraph dashboard for a run directory.", + ) + dashboard_parser.add_argument( + "dashboard_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.dashboard.", + ) + + dashboard_run_parser = subparsers.add_parser( + "dashboard-run", + help="Run a local ChemGraph workflow and write dashboard artifacts.", + ) + dashboard_run_parser.add_argument( + "dashboard_run_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.observability.local_dashboard_run.", + ) + + # ---- "academy" subcommand ------------------------------------------- + academy_parser = subparsers.add_parser( + "academy", + help="Run and inspect Academy-backed ChemGraph agent campaigns.", + ) + academy_sub = academy_parser.add_subparsers(dest="academy_command") + + daemon_parser = academy_sub.add_parser( + "mpi-daemon", + help="Run one ChemGraph Academy agent daemon inside mpiexec.", + ) + daemon_parser.add_argument( + "daemon_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.runtime.daemon.", + ) + + dashboard_parser = academy_sub.add_parser( + "dashboard", + help="Serve the ChemGraph Academy dashboard for a run directory.", + ) + dashboard_parser.add_argument( + "dashboard_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.dashboard.", + ) + + compute_parser = academy_sub.add_parser( + "run-compute", + help="Run a profile-backed ChemGraph Academy campaign in this allocation.", + ) + compute_parser.add_argument( + "compute_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.runtime.compute_launcher.", + ) + + console_parser = academy_sub.add_parser( + "console", + help="Start the local operator console for a ChemGraph Academy run.", + ) + console_parser.add_argument( + "console_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.runtime.operator_console.", + ) + + academy_sub.add_parser( + "campaigns", + help="List built-in ChemGraph Academy campaign specs.", + ) + academy_sub.add_parser( + "logical-agent-configs", + help="List built-in ChemGraph Academy logical-agent prompt configs.", + ) + # ---- Legacy fallback args ------------------------------------------- # Also add run args to the top-level parser so that # `chemgraph -q "..."` keeps working without a subcommand. @@ -461,6 +538,75 @@ def _handle_run(args: argparse.Namespace) -> None: console.print("[dim]Thank you for using ChemGraph CLI![/dim]") +def _strip_remainder_separator(args: list[str]) -> list[str]: + """Remove an optional argparse remainder separator.""" + if args and args[0] == "--": + return args[1:] + return args + + +def _run_module_main(module_name: str, argv: list[str]) -> None: + """Run a module-level main() with forwarded command-line arguments.""" + import importlib + + module = importlib.import_module(module_name) + old_argv = sys.argv + try: + sys.argv = [f"chemgraph {module_name.rsplit('.', 1)[-1]}", *argv] + code = module.main() + finally: + sys.argv = old_argv + if isinstance(code, int) and code: + sys.exit(code) + + +def _handle_academy(args: argparse.Namespace) -> None: + """Handle Academy-backed ChemGraph campaign commands.""" + command = getattr(args, "academy_command", None) + if command == "mpi-daemon": + _run_module_main( + "chemgraph.academy.runtime.daemon", + _strip_remainder_separator(args.daemon_args), + ) + return + if command == "dashboard": + _run_module_main( + "chemgraph.academy.dashboard", + _strip_remainder_separator(args.dashboard_args), + ) + return + if command == "run-compute": + from chemgraph.academy.runtime.compute_launcher import main as compute_main + + code = compute_main(_strip_remainder_separator(args.compute_args)) + if code: + sys.exit(code) + return + if command == "console": + _run_module_main( + "chemgraph.academy.runtime.operator_console", + _strip_remainder_separator(args.console_args), + ) + return + if command == "campaigns": + from chemgraph.academy.examples import list_builtin_campaigns + + for name in list_builtin_campaigns(): + console.print(name) + return + if command == "logical-agent-configs": + from chemgraph.academy.examples import list_builtin_logical_agent_configs + + for name in list_builtin_logical_agent_configs(): + console.print(name) + return + console.print( + "Usage: chemgraph academy " + "{mpi-daemon,run-compute,console,dashboard,campaigns," + "logical-agent-configs}.", + ) + + # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- @@ -496,6 +642,21 @@ def main() -> None: elif args.command == "models": list_models() + elif args.command == "dashboard": + _run_module_main( + "chemgraph.academy.dashboard", + _strip_remainder_separator(args.dashboard_args), + ) + + elif args.command == "dashboard-run": + _run_module_main( + "chemgraph.observability.local_dashboard_run", + _strip_remainder_separator(args.dashboard_run_args), + ) + + elif args.command == "academy": + _handle_academy(args) + elif args.command == "run": _handle_run(args) diff --git a/src/chemgraph/observability/local_dashboard_run.py b/src/chemgraph/observability/local_dashboard_run.py new file mode 100644 index 00000000..eb0f4ef0 --- /dev/null +++ b/src/chemgraph/observability/local_dashboard_run.py @@ -0,0 +1,170 @@ +"""Run a traditional ChemGraph workflow and write dashboard artifacts.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import shutil +import threading +import traceback +from pathlib import Path + +from chemgraph.academy.core.lm import load_lm_config +from chemgraph.observability.workflow_runner import run_observed_chemgraph_workflow + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run a local traditional ChemGraph workflow and emit event artifacts " + "that can be visualized by the ChemGraph dashboard." + ), + ) + parser.add_argument("--run-dir", required=True) + parser.add_argument("--query", required=True) + parser.add_argument("--workflow-type", default="single_agent") + parser.add_argument("--return-option", choices=["last_message", "state"], default="state") + parser.add_argument("--recursion-limit", type=int, default=50) + parser.add_argument("--lm-config") + parser.add_argument("--model-name") + parser.add_argument("--base-url") + parser.add_argument("--api-key") + parser.add_argument("--argo-user") + parser.add_argument("--serve", action="store_true") + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8765) + parser.add_argument( + "--overwrite", + action="store_true", + help="Replace an existing local dashboard run directory.", + ) + parser.add_argument( + "--json-output", + action="store_true", + help="Print the full workflow result JSON to stdout.", + ) + return parser.parse_args() + + +def _prepare_run_dir(path: Path, *, overwrite: bool) -> None: + existing_artifacts = [ + path / "events.jsonl", + path / "status.json", + path / "manifest.json", + path / "result.json", + path / "chemgraph_workflows", + ] + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return + if overwrite: + _clear_run_dir(path) + elif any(item.exists() for item in existing_artifacts): + raise RuntimeError( + f"Run directory already contains dashboard artifacts: {path}\n" + "Use a new --run-dir, run chemgraph-dashboard to view the " + "existing run, or pass --overwrite to replace it.", + ) + path.mkdir(parents=True, exist_ok=True) + + +def _clear_run_dir(path: Path) -> None: + for item in path.iterdir(): + if item.is_dir() and not item.is_symlink(): + shutil.rmtree(item) + else: + item.unlink() + + +async def _run(args: argparse.Namespace) -> dict: + model_name = args.model_name + base_url = args.base_url + api_key = args.api_key + argo_user = args.argo_user + if args.lm_config: + settings = load_lm_config(args.lm_config) + model_name = model_name or settings.model + base_url = base_url or settings.base_url + api_key = api_key or settings.api_key + argo_user = argo_user or settings.user + + return await run_observed_chemgraph_workflow( + query=args.query, + run_dir=Path(args.run_dir), + workflow_type=args.workflow_type, + model_name=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + return_option=args.return_option, + recursion_limit=args.recursion_limit, + write_run_files=True, + ) + + +def _print_result_summary(*, result: dict, run_dir: Path, json_output: bool) -> None: + result_path = run_dir / "result.json" + print( + "ChemGraph workflow completed.\n" + f" status: {result.get('status')}\n" + f" workflow: {result.get('workflow_type')}\n" + f" span: {result.get('span_id')}\n" + f" result: {result_path}", + flush=True, + ) + if json_output: + print(json.dumps(result, indent=2, default=str), flush=True) + + +def _run_and_report(args: argparse.Namespace, *, run_dir: Path) -> None: + try: + result = asyncio.run(_run(args)) + except Exception: # noqa: BLE001 - surface background workflow failures + print("ChemGraph workflow failed. See status.json/events.jsonl if present.", flush=True) + traceback.print_exc() + return + _print_result_summary( + result=result, + run_dir=run_dir, + json_output=args.json_output, + ) + + +def main() -> int: + args = parse_args() + run_dir = Path(args.run_dir).resolve() + _prepare_run_dir(run_dir, overwrite=args.overwrite) + args.run_dir = str(run_dir) + if args.serve: + from chemgraph.academy.dashboard import serve_dashboard + + thread = threading.Thread( + target=_run_and_report, + kwargs={"args": args, "run_dir": run_dir}, + name="chemgraph-dashboard-workflow", + daemon=True, + ) + thread.start() + return serve_dashboard( + run_dir=run_dir, + host=args.host, + port=args.port, + ) + + result = asyncio.run(_run(args)) + _print_result_summary( + result=result, + run_dir=run_dir, + json_output=args.json_output, + ) + print( + "Dashboard command: " + f"chemgraph-dashboard --run-dir {run_dir}", + flush=True, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 96241d79b01c12e98cdd6fdf95eaf1d769996c9c Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 14:19:26 -0500 Subject: [PATCH 052/119] feat(models): support OpenAI-compatible Argo user metadata --- src/chemgraph/agent/llm_agent.py | 34 ++++++++- src/chemgraph/models/loader.py | 5 +- src/chemgraph/models/openai.py | 93 +++++++++++++++++++----- tests/test_openai_model_normalization.py | 61 ++++++++++++++++ 4 files changed, 173 insertions(+), 20 deletions(-) create mode 100644 tests/test_openai_model_normalization.py diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 5f9f7759..27b2117f 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -86,6 +86,34 @@ def serialize_state(state): return str(state) +def _custom_openai_compatible_kwargs( + *, + model_name: str, + temperature: float, + base_url: str, + api_key: str, + max_tokens: int, + top_p: float, + frequency_penalty: float, + presence_penalty: float, + argo_user: str | None, +) -> dict: + kwargs = { + "model": model_name, + "temperature": temperature, + "base_url": base_url, + "api_key": api_key, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + } + user = argo_user or os.getenv("ARGO_USER") + if base_url and "argoapi" in base_url and user: + kwargs["model_kwargs"] = {"user": user} + return kwargs + + class ChemGraph: """A graph-based workflow for LLM-powered computational chemistry tasks. @@ -260,8 +288,8 @@ def __init__( ) from langchain_openai import ChatOpenAI - llm = ChatOpenAI( - model=model_name, + llm_kwargs = _custom_openai_compatible_kwargs( + model_name=model_name, temperature=temperature, base_url=vllm_base_url, api_key=vllm_api_key, @@ -269,7 +297,9 @@ def __init__( top_p=top_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + argo_user=argo_user, ) + llm = ChatOpenAI(**llm_kwargs) logger.info( f"Successfully initialized ChatOpenAI for model '{model_name}' at {vllm_base_url}" ) diff --git a/src/chemgraph/models/loader.py b/src/chemgraph/models/loader.py index 07583777..2e2968d2 100644 --- a/src/chemgraph/models/loader.py +++ b/src/chemgraph/models/loader.py @@ -63,6 +63,8 @@ def load_chat_model( "temperature": temperature, "base_url": base_url, } + if api_key is not None: + kwargs["api_key"] = api_key if argo_user is not None: kwargs["argo_user"] = argo_user return load_openai_model(**kwargs) @@ -87,5 +89,6 @@ def load_chat_model( else: raise ValueError( f"Model '{model_name}' not found in any supported model list. " - f"Use a model from: OpenAI, Anthropic, Gemini, groq:, argo:, ALCF, or Ollama." + "Use a model from: OpenAI, Anthropic, Gemini, groq:, " + "argo:, ALCF, or Ollama." ) diff --git a/src/chemgraph/models/openai.py b/src/chemgraph/models/openai.py index f904da67..e6a4a35a 100644 --- a/src/chemgraph/models/openai.py +++ b/src/chemgraph/models/openai.py @@ -2,7 +2,10 @@ import os from getpass import getpass +from urllib.parse import urlparse + from langchain_openai import ChatOpenAI + from chemgraph.models.supported_models import ( ARGO_DEFAULT_BASE_URL, supported_openai_models, @@ -60,29 +63,46 @@ } +ARGO_LOCAL_OPENAI_MODEL_MAP = { + # argo-shim advertises GPT-5.4 with this casing. Lowercase gpt-5.4 is + # rejected by the upstream Argo API behind the shim. + "argo:gpt-5.4": "GPT-5.4", +} + + def _normalize_argo_model(model_name: str, base_url: str) -> str: """Normalize an ``argo:``-prefixed model name for the target endpoint. - * Argo API (base_url contains ``argoapi``): map to internal wire - names via ``ARGO_MODEL_MAP`` (e.g. ``argo:gpt-4o`` -> ``gpt4o``). - * Other endpoints (ArgoProxy, custom): strip the ``argo:`` prefix - and send the remainder as-is (e.g. ``argo:gpt-4o`` -> ``gpt-4o``). + * Hosted Argo API endpoints use internal wire names via + ``ARGO_MODEL_MAP`` (e.g. ``argo:gpt-4o`` -> ``gpt4o``). + * Argo shim, ArgoProxy, and custom OpenAI-compatible endpoints strip the + ``argo:`` prefix and keep the OpenAI-style name. """ if not model_name.startswith("argo:"): return model_name - if base_url and "argoapi" in base_url: - # Argo API endpoint -- use the wire-name map - normalized = ARGO_MODEL_MAP.get(model_name) - if normalized: - logger.info("Normalized Argo model '%s' -> '%s'", model_name, normalized) - return normalized - # Fallback: strip prefix and remove punctuation - fallback = model_name.removeprefix("argo:").replace("-", "").replace(".", "") + model_format = os.getenv("CHEMGRAPH_ARGO_MODEL_FORMAT", "").lower() + if model_format == "shim": + return _normalize_argo_local_openai_model(model_name) + if model_format in {"openai", "openai-compatible"}: + stripped = model_name.removeprefix("argo:") + logger.info("Stripped argo: prefix '%s' -> '%s'", model_name, stripped) + return stripped + if model_format in {"wire", "argo"}: + return _normalize_argo_wire_model(model_name) + + if _is_local_http_endpoint(base_url): + stripped = _normalize_argo_local_openai_model(model_name) logger.info( - "Normalized Argo model '%s' -> '%s' (fallback)", model_name, fallback + "Using OpenAI-style Argo model for local endpoint '%s': '%s' -> '%s'", + base_url, + model_name, + stripped, ) - return fallback + return stripped + + if base_url and "argoapi" in base_url: + return _normalize_argo_wire_model(model_name) else: # Non-Argo-API endpoint -- strip prefix only stripped = model_name.removeprefix("argo:") @@ -90,6 +110,41 @@ def _normalize_argo_model(model_name: str, base_url: str) -> str: return stripped +def _normalize_argo_local_openai_model(model_name: str) -> str: + """Return the model name expected by local OpenAI-compatible Argo shims.""" + return ARGO_LOCAL_OPENAI_MODEL_MAP.get( + model_name, + model_name.removeprefix("argo:"), + ) + + +def _normalize_argo_wire_model(model_name: str) -> str: + """Return the hosted-Argo wire model for an ``argo:`` model name.""" + normalized = ARGO_MODEL_MAP.get(model_name) + if normalized: + logger.info("Normalized Argo model '%s' -> '%s'", model_name, normalized) + return normalized + + fallback = model_name.removeprefix("argo:").replace("-", "").replace(".", "") + logger.info( + "Normalized Argo model '%s' -> '%s' (fallback)", model_name, fallback + ) + return fallback + + +def _is_local_http_endpoint(base_url: str | None) -> bool: + """Return True for local HTTP endpoints such as ``argo-shim``.""" + if not base_url: + return False + parsed = urlparse(base_url) + return parsed.scheme == "http" and parsed.hostname in { + "localhost", + "127.0.0.1", + "::1", + "0.0.0.0", + } + + def load_openai_model( model_name: str, temperature: float, @@ -161,9 +216,13 @@ def load_openai_model( api_key = getpass("OpenAI API key: ") os.environ["OPENAI_API_KEY"] = api_key - if model_name not in supported_openai_models and model_name not in supported_argo_models: + if ( + model_name not in supported_openai_models + and model_name not in supported_argo_models + ): raise ValueError( - f"Unsupported model '{model_name}'. Supported models are: {supported_openai_models}." + f"Unsupported model '{model_name}'. " + f"Supported models are: {supported_openai_models}." ) is_argo_endpoint = bool(base_url and "argoapi" in base_url) @@ -202,7 +261,7 @@ def load_openai_model( api_key=api_key, max_tokens=6000, ) - # No guarantee that api_key is valid, authentication happens only during invocation + # Authentication happens only during invocation. logger.info(f"Requested model: {model_name}") logger.info("OpenAI model loaded successfully") return llm diff --git a/tests/test_openai_model_normalization.py b/tests/test_openai_model_normalization.py new file mode 100644 index 00000000..0ff55624 --- /dev/null +++ b/tests/test_openai_model_normalization.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from chemgraph.models.openai import _normalize_argo_model + + +def test_local_argo_shim_keeps_openai_style_model_name(monkeypatch): + monkeypatch.delenv("CHEMGRAPH_ARGO_MODEL_FORMAT", raising=False) + + assert ( + _normalize_argo_model( + "argo:gpt-4o-mini", + "http://127.0.0.1:18085/argoapi/v1", + ) + == "gpt-4o-mini" + ) + + +def test_local_argo_shim_uses_advertised_gpt54_model_name(monkeypatch): + monkeypatch.delenv("CHEMGRAPH_ARGO_MODEL_FORMAT", raising=False) + + assert ( + _normalize_argo_model( + "argo:gpt-5.4", + "http://127.0.0.1:18085/argoapi/v1", + ) + == "GPT-5.4" + ) + + +def test_hosted_argo_endpoint_uses_wire_model_name(monkeypatch): + monkeypatch.delenv("CHEMGRAPH_ARGO_MODEL_FORMAT", raising=False) + + assert ( + _normalize_argo_model( + "argo:gpt-4o-mini", + "https://apps.inside.anl.gov/argoapi/v1", + ) + == "gpt4omini" + ) + + +def test_argo_model_format_env_override(monkeypatch): + monkeypatch.setenv("CHEMGRAPH_ARGO_MODEL_FORMAT", "openai") + assert ( + _normalize_argo_model( + "argo:gpt-4o-mini", + "https://apps.inside.anl.gov/argoapi/v1", + ) + == "gpt-4o-mini" + ) + + +def test_argo_model_format_shim_override_uses_local_alias(monkeypatch): + monkeypatch.setenv("CHEMGRAPH_ARGO_MODEL_FORMAT", "shim") + assert ( + _normalize_argo_model( + "argo:gpt-5.4", + "https://apps.inside.anl.gov/argoapi/v1", + ) + == "GPT-5.4" + ) From f4992f6f1e65643c2469fb0aad0f396df833627d Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 14:21:54 -0500 Subject: [PATCH 053/119] chore(academy): add redis optional dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index efb519a0..70cdc074 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ globus_compute = [ ] academy = [ "academy-py", + "redis", ] xanes = [ "mp-api; python_version >= '3.11'", From d8a7868efb14f1b3aa4c87491cb9d7eaa1be227e Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 15:07:33 -0500 Subject: [PATCH 054/119] refactor(academy): rename operator console dashboard launcher --- pyproject.toml | 2 +- .../academy/runtime/compute_launcher.py | 14 +++--- ...rator_console.py => dashboard_launcher.py} | 43 ++++++++++--------- .../runtime/profiles/aurora.template.json | 2 +- .../runtime/profiles/polaris.template.json | 2 +- .../academy/runtime/profiles/system.py | 2 +- src/chemgraph/cli/main.py | 32 ++++---------- ....py => test_academy_dashboard_launcher.py} | 10 ++--- 8 files changed, 46 insertions(+), 61 deletions(-) rename src/chemgraph/academy/runtime/{operator_console.py => dashboard_launcher.py} (95%) rename tests/{test_academy_operator_console.py => test_academy_dashboard_launcher.py} (90%) diff --git a/pyproject.toml b/pyproject.toml index 70cdc074..09a30600 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ rag = [ chemgraph = "chemgraph.cli:main" chemgraph-eval = "chemgraph.eval.cli:main" chemgraph-academy-run = "chemgraph.academy.runtime.compute_launcher:main" -chemgraph-academy-console = "chemgraph.academy.runtime.operator_console:main" +chemgraph-academy-dashboard = "chemgraph.academy.runtime.dashboard_launcher:main" chemgraph-dashboard = "chemgraph.academy.dashboard:main" chemgraph-dashboard-run = "chemgraph.observability.local_dashboard_run:main" diff --git a/src/chemgraph/academy/runtime/compute_launcher.py b/src/chemgraph/academy/runtime/compute_launcher.py index 959ccb46..d437331e 100644 --- a/src/chemgraph/academy/runtime/compute_launcher.py +++ b/src/chemgraph/academy/runtime/compute_launcher.py @@ -20,7 +20,7 @@ from chemgraph.academy.runtime.profiles.system import SystemProfile -OPERATOR_METADATA_FILE = "operator_metadata.json" +DASHBOARD_METADATA_FILE = "dashboard_metadata.json" @dataclasses.dataclass(frozen=True) @@ -99,8 +99,8 @@ def _prepare_environment(profile: SystemProfile) -> None: os.environ["NO_PROXY"] = profile.no_proxy -def _load_operator_metadata(run_dir: Path) -> dict[str, Any]: - path = run_dir / OPERATOR_METADATA_FILE +def _load_dashboard_metadata(run_dir: Path) -> dict[str, Any]: + path = run_dir / DASHBOARD_METADATA_FILE if not path.exists(): return {} data = json.loads(path.read_text(encoding="utf-8")) @@ -113,8 +113,8 @@ def _relay_host_from_profile(profile: SystemProfile) -> str: path = Path(profile.relay_host_file) if not path.exists(): raise RuntimeError( - "Could not determine UAN relay host. Start the Mac operator " - f"console first, or pass --lm-base-url. Missing: {path}", + "Could not determine UAN relay host. Start the Mac dashboard " + f"first, or pass --lm-base-url. Missing: {path}", ) host = path.read_text(encoding="utf-8").strip() if not host: @@ -216,14 +216,14 @@ def _run_token() -> str: def prepare_compute_launch(args: argparse.Namespace) -> AllocationPlan: - """Resolve a system profile and operator metadata into an allocation plan.""" + """Resolve a system profile and dashboard metadata into an allocation plan.""" profile = load_system_profile(args.system) _prepare_environment(profile) defaults = campaign_launch_defaults(args.campaign) run_dir = Path(args.run_dir or Path(profile.run_root) / args.run_id).resolve() run_dir.mkdir(parents=True, exist_ok=True) - metadata = _load_operator_metadata(run_dir) + metadata = _load_dashboard_metadata(run_dir) metadata_campaign = metadata.get("campaign") if metadata_campaign and metadata_campaign != args.campaign: raise RuntimeError( diff --git a/src/chemgraph/academy/runtime/operator_console.py b/src/chemgraph/academy/runtime/dashboard_launcher.py similarity index 95% rename from src/chemgraph/academy/runtime/operator_console.py rename to src/chemgraph/academy/runtime/dashboard_launcher.py index 48b46c94..a9626d28 100644 --- a/src/chemgraph/academy/runtime/operator_console.py +++ b/src/chemgraph/academy/runtime/dashboard_launcher.py @@ -22,8 +22,9 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( + prog="chemgraph academy dashboard", description=( - "Start the local operator console for a ChemGraph Academy run. " + "Start the local dashboard for a ChemGraph Academy run. " "This prepares remote run metadata, starts the local dashboard, " "and optionally starts the temporary Mac-to-UAN Argo relay." ), @@ -53,7 +54,7 @@ def parse_args() -> argparse.Namespace: "--lm-base-url", help="Required for --lm-connect direct. Overrides generated relay URL.", ) - parser.add_argument("--operator-host", help="SSH target for the login/UAN host.") + parser.add_argument("--remote-host", help="SSH target for the login/UAN host.") parser.add_argument("--ssh-control-path") parser.add_argument("--keep-ssh-master", action="store_true") parser.add_argument("--local-argo-host", default="127.0.0.1") @@ -77,7 +78,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--no-dashboard", action="store_true", - help="Prepare operator metadata and return without serving dashboard.", + help="Prepare dashboard metadata and return without serving dashboard.", ) parser.add_argument( "--overwrite-run", @@ -419,7 +420,7 @@ def _wait_for_relay( raise RuntimeError("Relay readiness timed out. Local relay log:\n" + detail) -def _write_operator_metadata( +def _write_dashboard_metadata( *, profile: SystemProfile, host: str, @@ -434,12 +435,12 @@ def _write_operator_metadata( remote_run_dir = f"{profile.run_root}/{run_id}" payload: dict[str, Any] = { "created_at": time.time(), - "created_by": "chemgraph-academy-console", + "created_by": "chemgraph-academy-dashboard", "run_id": run_id, "system": profile.name, "campaign": campaign, "remote_run_dir": remote_run_dir, - "operator_host": host, + "remote_host": host, "lm_connect": lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, @@ -452,12 +453,12 @@ def _write_operator_metadata( payload["relay_port"] = relay_port metadata = json.dumps(payload, indent=2) + "\n" - remote_path = f"{remote_run_dir}/operator_metadata.json" + remote_path = f"{remote_run_dir}/dashboard_metadata.json" remote_command = ( f"mkdir -p {shlex.quote(remote_run_dir)} && " f"cat > {shlex.quote(remote_path)}" ) - _log(f"Writing run metadata: {host}:{remote_run_dir}/operator_metadata.json") + _log(f"Writing run metadata: {host}:{remote_run_dir}/dashboard_metadata.json") _run( ["ssh", *ssh_opts, host, remote_command], input_text=metadata, @@ -557,7 +558,7 @@ def _run_dashboard(*, local_run_dir: Path, host: str, port: int) -> int: old_argv = sys.argv try: sys.argv = [ - "chemgraph-academy-console dashboard", + "chemgraph-academy-dashboard serve", "--run-dir", str(local_run_dir), "--host", @@ -578,7 +579,7 @@ def _print_compute_command( campaign: str, ) -> None: _log("") - _log("Operator console is ready.") + _log("Dashboard launcher is ready.") _log("") _log(f"On the {profile.name} compute node, use:") if profile.name == "polaris": @@ -629,7 +630,7 @@ def main() -> int: port=args.dashboard_port, ) - operator_host = args.operator_host or profile.operator_host + remote_host = args.remote_host or profile.remote_host control_path = ( args.ssh_control_path or str(Path.home() / f".ssh/{profile.name}-dashboard-%r@%h:%p") @@ -656,21 +657,21 @@ def main() -> int: raise RuntimeError("--lm-connect direct requires --lm-base-url") started_ssh_master = _start_ssh_master( - host=operator_host, + host=remote_host, control_path=control_path, ) ssh_opts = _ssh_options(control_path) if args.overwrite_run: _delete_existing_run( profile=profile, - host=operator_host, + host=remote_host, ssh_opts=ssh_opts, run_id=args.run_id, local_run_dir=local_run_dir, ) wrapper_path = _install_compute_wrapper( profile=profile, - host=operator_host, + host=remote_host, ssh_opts=ssh_opts, ) @@ -678,7 +679,7 @@ def main() -> int: if args.lm_connect == "mac-argo-relay": relay_process = _start_mac_argo_relay( profile=profile, - host=operator_host, + host=remote_host, ssh_opts=ssh_opts, local_argo_host=args.local_argo_host, local_argo_port=args.local_argo_port, @@ -689,7 +690,7 @@ def main() -> int: ) relay_host = _wait_for_relay( profile=profile, - host=operator_host, + host=remote_host, ssh_opts=ssh_opts, relay_port=relay_port, relay_process=relay_process, @@ -700,9 +701,9 @@ def main() -> int: lm_base_url = str(args.lm_base_url) _log(f"Compute-node LM URL: {lm_base_url}") - _write_operator_metadata( + _write_dashboard_metadata( profile=profile, - host=operator_host, + host=remote_host, ssh_opts=ssh_opts, run_id=args.run_id, campaign=args.campaign, @@ -713,10 +714,10 @@ def main() -> int: ) _log("Starting rsync mirror:") - _log(f" {operator_host}:{remote_run_dir}/") + _log(f" {remote_host}:{remote_run_dir}/") _log(f" {local_run_dir}/") _start_rsync_loop( - host=operator_host, + host=remote_host, control_path=control_path, remote_run_dir=remote_run_dir, local_run_dir=local_run_dir, @@ -752,7 +753,7 @@ def main() -> int: relay_process.kill() keep = args.keep_ssh_master or os.environ.get("CHEMGRAPH_ACADEMY_KEEP_SSH_MASTER") == "1" if started_ssh_master and not keep: - _stop_ssh_master(host=operator_host, control_path=control_path) + _stop_ssh_master(host=remote_host, control_path=control_path) if __name__ == "__main__": diff --git a/src/chemgraph/academy/runtime/profiles/aurora.template.json b/src/chemgraph/academy/runtime/profiles/aurora.template.json index c8469792..db59c939 100644 --- a/src/chemgraph/academy/runtime/profiles/aurora.template.json +++ b/src/chemgraph/academy/runtime/profiles/aurora.template.json @@ -1,6 +1,6 @@ { "name": "aurora", - "operator_host": "${ALCF_USER}@aurora.alcf.anl.gov", + "remote_host": "${ALCF_USER}@aurora.alcf.anl.gov", "remote_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}", "academy_repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/academy", "repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", diff --git a/src/chemgraph/academy/runtime/profiles/polaris.template.json b/src/chemgraph/academy/runtime/profiles/polaris.template.json index c7f54afb..0737fefc 100644 --- a/src/chemgraph/academy/runtime/profiles/polaris.template.json +++ b/src/chemgraph/academy/runtime/profiles/polaris.template.json @@ -1,6 +1,6 @@ { "name": "polaris", - "operator_host": "${ALCF_USER}@polaris.alcf.anl.gov", + "remote_host": "${ALCF_USER}@polaris.alcf.anl.gov", "remote_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}", "academy_repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy", "repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", diff --git a/src/chemgraph/academy/runtime/profiles/system.py b/src/chemgraph/academy/runtime/profiles/system.py index a67f14c1..fcddb3c8 100644 --- a/src/chemgraph/academy/runtime/profiles/system.py +++ b/src/chemgraph/academy/runtime/profiles/system.py @@ -18,7 +18,7 @@ class SystemProfile(BaseModel): model_config = ConfigDict(extra="forbid") name: str - operator_host: str + remote_host: str remote_root: str academy_repo_root: str repo_root: str diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index ce254e06..2ead5086 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -270,16 +270,6 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Arguments forwarded to chemgraph.academy.runtime.daemon.", ) - dashboard_parser = academy_sub.add_parser( - "dashboard", - help="Serve the ChemGraph Academy dashboard for a run directory.", - ) - dashboard_parser.add_argument( - "dashboard_args", - nargs=argparse.REMAINDER, - help="Arguments forwarded to chemgraph.academy.dashboard.", - ) - compute_parser = academy_sub.add_parser( "run-compute", help="Run a profile-backed ChemGraph Academy campaign in this allocation.", @@ -290,14 +280,14 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Arguments forwarded to chemgraph.academy.runtime.compute_launcher.", ) - console_parser = academy_sub.add_parser( - "console", - help="Start the local operator console for a ChemGraph Academy run.", + dashboard_parser = academy_sub.add_parser( + "dashboard", + help="Start the local dashboard launcher for a ChemGraph Academy run.", ) - console_parser.add_argument( - "console_args", + dashboard_parser.add_argument( + "dashboard_args", nargs=argparse.REMAINDER, - help="Arguments forwarded to chemgraph.academy.runtime.operator_console.", + help="Arguments forwarded to chemgraph.academy.runtime.dashboard_launcher.", ) academy_sub.add_parser( @@ -571,7 +561,7 @@ def _handle_academy(args: argparse.Namespace) -> None: return if command == "dashboard": _run_module_main( - "chemgraph.academy.dashboard", + "chemgraph.academy.runtime.dashboard_launcher", _strip_remainder_separator(args.dashboard_args), ) return @@ -582,12 +572,6 @@ def _handle_academy(args: argparse.Namespace) -> None: if code: sys.exit(code) return - if command == "console": - _run_module_main( - "chemgraph.academy.runtime.operator_console", - _strip_remainder_separator(args.console_args), - ) - return if command == "campaigns": from chemgraph.academy.examples import list_builtin_campaigns @@ -602,7 +586,7 @@ def _handle_academy(args: argparse.Namespace) -> None: return console.print( "Usage: chemgraph academy " - "{mpi-daemon,run-compute,console,dashboard,campaigns," + "{mpi-daemon,run-compute,dashboard,campaigns," "logical-agent-configs}.", ) diff --git a/tests/test_academy_operator_console.py b/tests/test_academy_dashboard_launcher.py similarity index 90% rename from tests/test_academy_operator_console.py rename to tests/test_academy_dashboard_launcher.py index 99e726e3..c4fcd857 100644 --- a/tests/test_academy_operator_console.py +++ b/tests/test_academy_dashboard_launcher.py @@ -4,14 +4,14 @@ import pytest -from chemgraph.academy.runtime import operator_console +from chemgraph.academy.runtime import dashboard_launcher from chemgraph.academy.runtime.profiles.system import SystemProfile def _profile(tmp_path: Path) -> SystemProfile: return SystemProfile( name="test-system", - operator_host="user@example", + remote_host="user@example", remote_root="/remote/root", academy_repo_root="/remote/root/academy", repo_root="/remote/root/ChemGraph", @@ -36,12 +36,12 @@ def test_delete_existing_run_removes_remote_and_local(tmp_path, monkeypatch) -> calls: list[list[str]] = [] monkeypatch.setattr( - operator_console, + dashboard_launcher, "_run", lambda command, **kwargs: calls.append(command), ) - operator_console._delete_existing_run( + dashboard_launcher._delete_existing_run( profile=_profile(tmp_path), host="user@example", ssh_opts=["-o", "BatchMode=yes"], @@ -59,7 +59,7 @@ def test_delete_existing_run_removes_remote_and_local(tmp_path, monkeypatch) -> def test_delete_existing_run_rejects_unsafe_run_id(tmp_path) -> None: with pytest.raises(RuntimeError, match="unsafe run id"): - operator_console._delete_existing_run( + dashboard_launcher._delete_existing_run( profile=_profile(tmp_path), host="user@example", ssh_opts=[], From 527acb9db2e28eea18606733d42efab64be653cc Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 9 Jun 2026 16:21:18 -0500 Subject: [PATCH 055/119] refactor(academy): merge peer request action into send_message --- src/chemgraph/academy/core/peer_protocol.py | 2 + src/chemgraph/academy/core/tools.py | 67 ++++--------------- src/chemgraph/academy/core/turn.py | 5 +- src/chemgraph/academy/dashboard/static/app.js | 2 +- .../prompt_profiles/default.json | 4 +- tests/test_academy_reasoning_phase2.py | 4 +- tests/test_tool_adapter_validation.py | 34 +++++++++- 7 files changed, 54 insertions(+), 64 deletions(-) diff --git a/src/chemgraph/academy/core/peer_protocol.py b/src/chemgraph/academy/core/peer_protocol.py index 503efd80..67182203 100644 --- a/src/chemgraph/academy/core/peer_protocol.py +++ b/src/chemgraph/academy/core/peer_protocol.py @@ -29,6 +29,7 @@ def build_message( tldr: str | None = None, artifact_refs: list[str] | None = None, tool_result_ids: list[str] | None = None, + reply_requested: bool = False, reason: str | None = None, confidence: float | None = None, ) -> dict[str, Any]: @@ -40,6 +41,7 @@ def build_message( 'recipient': recipient, 'kind': kind, 'content': content, + 'reply_requested': reply_requested, 'artifact_refs': artifact_refs or [], 'tool_result_ids': tool_result_ids or [], } diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py index 38a60e38..f22d3f4c 100644 --- a/src/chemgraph/academy/core/tools.py +++ b/src/chemgraph/academy/core/tools.py @@ -96,6 +96,13 @@ class SendMessageArgs(BaseModel): default_factory=list, description="JSON array of ChemGraph tool_result_id strings cited by this message.", ) + reply_requested: bool = Field( + default=False, + description=( + "Set true when this message asks the peer to reply or take a " + "specific follow-up action; false for one-way updates." + ), + ) reason: str = Field( min_length=1, max_length=600, @@ -108,21 +115,6 @@ class SendMessageArgs(BaseModel): ) -class AskPeerArgs(BaseModel): - """Arguments for asking a peer a question.""" - - model_config = ConfigDict(extra="forbid") - - recipient: str = Field(min_length=1) - tldr: str = Field( - min_length=1, - max_length=160, - description="One-line user-visible summary for dashboard edge labels.", - ) - question: str = Field(min_length=1, max_length=900) - reason: str = Field(min_length=1, max_length=600) - - class SubmitResultArgs(BaseModel): """Arguments for submitting a logical agent's current result.""" @@ -230,14 +222,15 @@ async def _send_message_impl( content: str, artifact_refs: list[str], tool_result_ids: list[str], + reply_requested: bool, reason: str, confidence: float, - kind: str, ) -> dict[str, Any]: if recipient not in peer_names: raise ValueError( f"{spec.name} tried to message disallowed peer {recipient}", ) + kind = "question" if reply_requested else "message" message = build_message( sender=spec.name, recipient=recipient, @@ -247,6 +240,7 @@ async def _send_message_impl( tldr=tldr, artifact_refs=artifact_refs, tool_result_ids=tool_result_ids, + reply_requested=reply_requested, reason=reason, confidence=confidence, ) @@ -325,33 +319,9 @@ async def send_message(**kwargs: Any) -> dict[str, Any]: content=args.content, artifact_refs=args.artifact_refs, tool_result_ids=args.tool_result_ids, + reply_requested=args.reply_requested, reason=args.reason, confidence=args.confidence, - kind="message", - ) - - async def ask_peer(**kwargs: Any) -> dict[str, Any]: - runtime_state.record_action("ask_peer") - try: - args = AskPeerArgs.model_validate(kwargs) - except ValidationError as exc: - return _invalid_args_response("ask_peer", exc, trace) - if args.recipient not in peer_names: - return _disallowed_recipient_response( - "ask_peer", - args.recipient, - peer_names, - trace, - ) - return await _send_message_impl( - recipient=args.recipient, - tldr=args.tldr, - content=args.question, - artifact_refs=[], - tool_result_ids=[], - reason=args.reason, - confidence=0.0, - kind="question", ) async def submit_result(**kwargs: Any) -> dict[str, Any]: @@ -395,25 +365,14 @@ async def finish_turn(**kwargs: Any) -> dict[str, Any]: "Send tool-backed evidence, reasoning, or a request to one " "allowed peer. Always provide recipient, tldr, content, " "artifact_refs as an array of strings or [], tool_result_ids " - "as an array of strings or [], a non-empty reason, and numeric " + "as an array of strings or [], reply_requested as true when " + "the peer should respond, a non-empty reason, and numeric " "confidence from 0 to 1." ), args_schema=SendMessageArgs, handle_validation_error=_validation_error_handler("send_message"), metadata={"chemgraph_academy_tool_kind": "action_tool"}, ), - StructuredTool.from_function( - coroutine=ask_peer, - name="ask_peer", - description=( - "Ask an allowed peer for missing information or a tool result " - "needed for the molecule workflow. Always provide recipient, " - "tldr, question, and reason." - ), - args_schema=AskPeerArgs, - handle_validation_error=_validation_error_handler("ask_peer"), - metadata={"chemgraph_academy_tool_kind": "action_tool"}, - ), StructuredTool.from_function( coroutine=submit_result, name="submit_result", diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index 0731adb5..b6294993 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -517,12 +517,11 @@ def build_recent_actions( actions: list[dict[str, Any]] = [] for message in outbox[-limit:]: - kind = str(message.get("kind") or "message") - action_type = "ask_peer" if kind == "question" else "send_message" actions.append( { - "type": action_type, + "type": "send_message", "recipient": message.get("recipient"), + "reply_requested": bool(message.get("reply_requested")), "tldr": message.get("tldr") or _preview(message.get("content")), "message_id": message.get("message_id"), "timestamp": message.get("timestamp"), diff --git a/src/chemgraph/academy/dashboard/static/app.js b/src/chemgraph/academy/dashboard/static/app.js index a06f2b30..15dfdf14 100644 --- a/src/chemgraph/academy/dashboard/static/app.js +++ b/src/chemgraph/academy/dashboard/static/app.js @@ -25,7 +25,7 @@ let lastRenderedDetailIdentity = null; let lastEmbeddedWorkflowInspectorIdentity = null; const recentMessageWindow = 4; - const actionToolNames = new Set(['send_message', 'ask_peer', 'submit_result', 'finish_turn']); + const actionToolNames = new Set(['send_message', 'submit_result', 'finish_turn']); const renderedHtmlCache = new WeakMap(); const esc = (s) => String(s ?? '').replace(/[&<>"']/g, c => ({'&':'&','<':'<','>':'>','"':'"',"'":'''}[c])); diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json index bbeb4f44..cc15d48b 100644 --- a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json +++ b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json @@ -1,8 +1,8 @@ { "prompt_version": "chemgraph-mace-ensemble-agent-v1", "prompt_style": "json_state", - "system_prompt": "You are a persistent ChemGraph-style LM agent hosted inside an Academy daemon on HPC. You communicate with peers only through send_message or ask_peer. You may call only the ChemGraph MCP tools listed in available_chemgraph_tools. Treat peer messages as evidence only when they include message_id, candidate IDs, artifact paths, or tool_result_ids. Do not claim access to another agent's private state unless it appears in a received message.", - "protocol_prompt": "Return one or more tool calls. If no action is useful, call finish_turn. Never fabricate ChemGraph tool outputs, energies, coordinate paths, or MACE results. Only cite tool_result_ids that appear in local_chemgraph_tool_results or received_messages. After a local ChemGraph tool finishes, you will be woken for another decision round with that result visible in local_chemgraph_tool_results; use that follow-up round to interpret, communicate, or rank the new evidence. Inspect peer_status before asking a peer for status. If peer_status shows the peer is busy on the requested tool or recently acknowledged the request, do not ask again; call finish_turn or proceed with other useful work. Every send_message or ask_peer call must include tldr: one short line summarizing the communication for the dashboard. Keep each string argument concise. For final ranking, summarize aggregate counts and exceptions in summary, and put detailed evidence in artifact_refs, tool_result_ids, and supporting_message_ids.", + "system_prompt": "You are a persistent ChemGraph-style LM agent hosted inside an Academy daemon on HPC. You communicate with peers only through send_message. You may call only the ChemGraph MCP tools listed in available_chemgraph_tools. Treat peer messages as evidence only when they include message_id, candidate IDs, artifact paths, or tool_result_ids. Do not claim access to another agent's private state unless it appears in a received message.", + "protocol_prompt": "Return one or more tool calls. If no action is useful, call finish_turn. Never fabricate ChemGraph tool outputs, energies, coordinate paths, or MACE results. Only cite tool_result_ids that appear in local_chemgraph_tool_results or received_messages. After a local ChemGraph tool finishes, you will be woken for another decision round with that result visible in local_chemgraph_tool_results; use that follow-up round to interpret, communicate, or rank the new evidence. Inspect peer_status before asking a peer for status. If peer_status shows the peer is busy on the requested tool or recently acknowledged the request, do not ask again; call finish_turn or proceed with other useful work. Every send_message call must include tldr: one short line summarizing the communication for the dashboard. Set reply_requested=true when the peer should answer or take follow-up action; otherwise set reply_requested=false. Keep each string argument concise. For final ranking, summarize aggregate counts and exceptions in summary, and put detailed evidence in artifact_refs, tool_result_ids, and supporting_message_ids.", "langchain_recursion_limit": 64, "state_limits": { "received_messages_last_n": 28, diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index 6e5f6c60..0825f66e 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -133,7 +133,6 @@ async def test_reasoning_adapter_finish_turn_updates_runtime_state(tmp_path) -> assert [tool.name for tool in tools] == [ "send_message", - "ask_peer", "submit_result", "finish_turn", ] @@ -184,6 +183,7 @@ async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: "content": "full message", "artifact_refs": [], "tool_result_ids": [], + "reply_requested": False, "reason": "peer needs this evidence", "confidence": 0.8, }, @@ -305,6 +305,7 @@ def test_reasoning_engine_builds_bounded_wakeup_state(tmp_path) -> None: { "type": "send_message", "recipient": "agent-b", + "reply_requested": False, "tldr": "old message", "message_id": "msg-old", "timestamp": 1, @@ -312,6 +313,7 @@ def test_reasoning_engine_builds_bounded_wakeup_state(tmp_path) -> None: { "type": "send_message", "recipient": "agent-b", + "reply_requested": False, "tldr": "new message", "message_id": "msg-new", "timestamp": 3, diff --git a/tests/test_tool_adapter_validation.py b/tests/test_tool_adapter_validation.py index 09ae3ade..f58a2b17 100644 --- a/tests/test_tool_adapter_validation.py +++ b/tests/test_tool_adapter_validation.py @@ -122,15 +122,19 @@ async def test_send_message_disallowed_recipient_does_not_deliver(tmp_path) -> N @pytest.mark.asyncio -async def test_ask_peer_requires_tldr(tmp_path) -> None: +async def test_send_message_request_requires_tldr(tmp_path) -> None: env = await _build_tools(tmp_path) - result = await env["tools"]["ask_peer"].ainvoke( + result = await env["tools"]["send_message"].ainvoke( { "recipient": "agent-b", "tldr": "", - "question": "What happened?", + "content": "What happened?", + "artifact_refs": [], + "tool_result_ids": [], + "reply_requested": True, "reason": "need a peer check", + "confidence": 0.5, } ) @@ -141,6 +145,28 @@ async def test_ask_peer_requires_tldr(tmp_path) -> None: assert env["peer_handle"].calls == [] +@pytest.mark.asyncio +async def test_send_message_reply_requested_marks_question(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "need status", + "content": "Please send current status.", + "artifact_refs": [], + "tool_result_ids": [], + "reply_requested": True, + "reason": "the report needs the peer status", + "confidence": 0.7, + } + ) + + assert result["status"] == "sent" + assert env["outbox"][0]["reply_requested"] is True + assert env["outbox"][0]["kind"] == "question" + + @pytest.mark.asyncio async def test_valid_send_message_still_delivers(tmp_path) -> None: env = await _build_tools(tmp_path) @@ -152,6 +178,7 @@ async def test_valid_send_message_still_delivers(tmp_path) -> None: "content": "Candidate C1 has a usable artifact.", "artifact_refs": ["artifacts/c1.xyz"], "tool_result_ids": ["tool-1"], + "reply_requested": False, "reason": "peer needs the result", "confidence": 0.9, } @@ -160,6 +187,7 @@ async def test_valid_send_message_still_delivers(tmp_path) -> None: assert result["status"] == "sent" assert result["recipient"] == "agent-b" assert len(env["outbox"]) == 1 + assert env["outbox"][0]["reply_requested"] is False assert env["peer_handle"].calls[0][0] == "receive_message" assert env["peer_handle"].calls[0][1]["message_id"] == result["message_id"] assert [event for event, _ in env["traces"]] == [ From 457351edeb6d79ed20cd9bbc9faf25a646c3bbcd Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 10:41:13 -0500 Subject: [PATCH 056/119] refactor(mcp): move in-process FastMCP adapter out of Academy --- src/chemgraph/academy/core/agent.py | 6 +-- src/chemgraph/academy/core/tools.py | 10 ++-- src/chemgraph/academy/core/turn.py | 6 +-- src/chemgraph/academy/runtime/daemon.py | 6 +-- .../core/fastmcp.py => mcp/fastmcp_client.py} | 52 +++++++++++++------ 5 files changed, 50 insertions(+), 30 deletions(-) rename src/chemgraph/{academy/core/fastmcp.py => mcp/fastmcp_client.py} (90%) diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py index d4b55da7..2b005d34 100644 --- a/src/chemgraph/academy/core/agent.py +++ b/src/chemgraph/academy/core/agent.py @@ -13,8 +13,8 @@ from academy.handle import Handle from academy.identifier import AgentId -from chemgraph.academy.core.fastmcp import ( - CampaignFastMCPToolInvoker, +from chemgraph.mcp.fastmcp_client import ( + FastMCPToolInvoker, ) from chemgraph.academy.core.peer_protocol import validate_message from chemgraph.academy.observability.event_log import EventLog @@ -39,7 +39,7 @@ def __init__( prompt_profile: PromptProfile, run_dir: Path, max_decisions: int, - tool_invoker: CampaignFastMCPToolInvoker, + tool_invoker: FastMCPToolInvoker, peer_agent_ids: Mapping[str, AgentId[Any]] | None = None, placement: dict[str, Any] | None = None, poll_timeout_s: float = 2.0, diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py index f22d3f4c..4603bc79 100644 --- a/src/chemgraph/academy/core/tools.py +++ b/src/chemgraph/academy/core/tools.py @@ -22,10 +22,10 @@ from pydantic import ValidationError from chemgraph.academy.core.campaign import ChemGraphAgentSpec -from chemgraph.academy.core.fastmcp import ToolInvocation -from chemgraph.academy.core.fastmcp import fastmcp_tool_schemas -from chemgraph.academy.core.fastmcp import ( - CampaignFastMCPToolInvoker, +from chemgraph.mcp.fastmcp_client import ToolInvocation +from chemgraph.mcp.fastmcp_client import fastmcp_tool_schemas +from chemgraph.mcp.fastmcp_client import ( + FastMCPToolInvoker, ) from chemgraph.academy.core.peer_protocol import build_message from chemgraph.academy.observability.run_files import append_jsonl @@ -203,7 +203,7 @@ async def build_chemgraph_reasoning_tools( *, spec: ChemGraphAgentSpec, run_dir: pathlib.Path, - tool_invoker: CampaignFastMCPToolInvoker, + tool_invoker: FastMCPToolInvoker, peer_names: tuple[str, ...], peer_handles: Mapping[str, Handle[Any]], outbox: list[dict[str, Any]], diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index b6294993..28452970 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -13,8 +13,8 @@ from academy.handle import Handle from langchain_core.tools import BaseTool -from chemgraph.academy.core.fastmcp import ( - CampaignFastMCPToolInvoker, +from chemgraph.mcp.fastmcp_client import ( + FastMCPToolInvoker, ) from chemgraph.academy.core.tools import ( ReasoningToolRuntimeState, @@ -98,7 +98,7 @@ async def create( prompt_profile: PromptProfile, run_dir: Path, max_decisions: int, - tool_invoker: CampaignFastMCPToolInvoker, + tool_invoker: FastMCPToolInvoker, peer_names: tuple[str, ...], peer_handles: Mapping[str, Handle[Any]], received_message_history: list[dict[str, Any]], diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index 5c78697c..f95e4806 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -34,8 +34,8 @@ from chemgraph.academy.core.agent import ChemGraphLogicalAgent from chemgraph.academy.core.lm import load_lm_config from chemgraph.academy.core.prompt import load_prompt_profile -from chemgraph.academy.core.fastmcp import ( - build_campaign_fastmcp_tool_invoker, +from chemgraph.mcp.fastmcp_client import ( + build_fastmcp_tool_invoker, ) @@ -106,7 +106,7 @@ async def run_daemon(config: ChemGraphDaemonConfig) -> int: if peer in registrations } - tool_invoker = await build_campaign_fastmcp_tool_invoker( + tool_invoker = await build_fastmcp_tool_invoker( specs=list(agent_spec.tools), execution=ExecutionSpec(backend='local', system='local'), run_dir=config.run_dir, diff --git a/src/chemgraph/academy/core/fastmcp.py b/src/chemgraph/mcp/fastmcp_client.py similarity index 90% rename from src/chemgraph/academy/core/fastmcp.py rename to src/chemgraph/mcp/fastmcp_client.py index b7c148b7..0d3b98e5 100644 --- a/src/chemgraph/academy/core/fastmcp.py +++ b/src/chemgraph/mcp/fastmcp_client.py @@ -1,20 +1,20 @@ -"""Campaign-scoped in-process FastMCP tool loading and invocation.""" +"""In-process client adapter for FastMCP tool modules.""" from __future__ import annotations import importlib import json import uuid +from collections.abc import Mapping from collections.abc import Sequence from pathlib import Path from typing import Any +from typing import Protocol from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from chemgraph.academy.core.campaign import ExecutionSpec -from chemgraph.academy.core.campaign import ToolSpec class ToolInvocation(BaseModel): """A normalized record of one agent-requested FastMCP tool call.""" @@ -40,6 +40,24 @@ class ToolResult(BaseModel): correlation_id: str +class FastMCPToolSpec(Protocol): + """Structural interface for config-declared FastMCP tools.""" + + name: str + module: str + tool: str + description: str | None + + +class FastMCPExecutionSpec(Protocol): + """Structural interface for backend configuration used by CGFastMCP.""" + + backend: str | None + system: str | None + config_path: str | None + options: Mapping[str, Any] + + def load_fastmcp_tool_module( module_name: str, *, @@ -63,7 +81,9 @@ def load_fastmcp_tool_module( return server -async def fastmcp_tool_schemas(specs: list[ToolSpec]) -> list[dict[str, Any]]: +async def fastmcp_tool_schemas( + specs: Sequence[FastMCPToolSpec], +) -> list[dict[str, Any]]: """Build OpenAI tool schemas from declared FastMCP ToolSpecs.""" schemas: list[dict[str, Any]] = [] module_cache: dict[str, Any] = {} @@ -108,7 +128,7 @@ def _fastmcp_tool_payload(tool: Any) -> dict[str, Any]: def _openai_tool_schema( - spec: ToolSpec, + spec: FastMCPToolSpec, tool_payload: dict[str, Any], ) -> dict[str, Any]: parameters = _sanitize_input_schema( @@ -195,14 +215,14 @@ def _first_json_text_result(values: list[Any]) -> Any | None: return None -class CampaignFastMCPToolInvoker: - """Invoke campaign-allowed tools through in-process FastMCP modules.""" +class FastMCPToolInvoker: + """Invoke allowed tools through in-process FastMCP modules.""" def __init__( self, *, - specs: list[ToolSpec], - execution: ExecutionSpec, + specs: Sequence[FastMCPToolSpec], + execution: FastMCPExecutionSpec, run_dir: str | Path, ) -> None: self.specs = {spec.name: spec for spec in specs} @@ -298,7 +318,7 @@ def _configure_fastmcp_backend( mcp: Any, *, module_name: str, - execution: ExecutionSpec, + execution: FastMCPExecutionSpec, run_dir: str | Path, ) -> None: """Configure a CGFastMCP backend without initialising compute resources.""" @@ -323,15 +343,15 @@ def _configure_fastmcp_backend( ) -async def build_campaign_fastmcp_tool_invoker( +async def build_fastmcp_tool_invoker( *, - specs: list[ToolSpec], - execution: ExecutionSpec, + specs: Sequence[FastMCPToolSpec], + execution: FastMCPExecutionSpec, run_dir: str | Path, agent_name: str, -) -> CampaignFastMCPToolInvoker: - """Build and verify one agent's campaign-scoped FastMCP tool invoker.""" - invoker = CampaignFastMCPToolInvoker( +) -> FastMCPToolInvoker: + """Build and verify one agent's in-process FastMCP tool invoker.""" + invoker = FastMCPToolInvoker( specs=list(specs), execution=execution, run_dir=run_dir, From 186d013f7342632937001e4afbd254bf016b4af0 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 10:57:14 -0500 Subject: [PATCH 057/119] refactor(academy): remove unused run artifacts --- src/chemgraph/academy/__init__.py | 2 - src/chemgraph/academy/core/__init__.py | 2 - src/chemgraph/academy/core/agent.py | 13 -- src/chemgraph/academy/core/campaign.py | 16 +-- src/chemgraph/academy/core/tools.py | 10 +- src/chemgraph/academy/dashboard/server.py | 1 - src/chemgraph/academy/dashboard/static/app.js | 5 +- .../academy/observability/__init__.py | 2 - .../observability/communication_proof.py | 117 ---------------- .../academy/observability/payloads.py | 130 ------------------ .../academy/observability/run_artifacts.py | 75 +--------- .../academy/runtime/compute_launcher.py | 44 ------ src/chemgraph/academy/runtime/daemon.py | 16 +-- src/chemgraph/mcp/fastmcp_client.py | 17 ++- tests/test_academy_dashboard.py | 15 +- tests/test_academy_payloads.py | 69 ---------- 16 files changed, 39 insertions(+), 495 deletions(-) delete mode 100644 src/chemgraph/academy/observability/communication_proof.py delete mode 100644 src/chemgraph/academy/observability/payloads.py delete mode 100644 tests/test_academy_payloads.py diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py index fc3dece9..a5ad2313 100644 --- a/src/chemgraph/academy/__init__.py +++ b/src/chemgraph/academy/__init__.py @@ -12,7 +12,6 @@ from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig -from chemgraph.academy.core.campaign import ExecutionSpec from chemgraph.academy.core.campaign import ResourceSpec from chemgraph.academy.core.campaign import ToolSpec from chemgraph.academy.core.campaign import load_campaign @@ -29,7 +28,6 @@ "ChemGraphCampaign", "ChemGraphDaemonConfig", "EventLog", - "ExecutionSpec", "PromptProfile", "ResourceSpec", "ChemGraphLogicalAgent", diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py index 6dd05c9f..06b8f5e7 100644 --- a/src/chemgraph/academy/core/__init__.py +++ b/src/chemgraph/academy/core/__init__.py @@ -4,7 +4,6 @@ from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig -from chemgraph.academy.core.campaign import ExecutionSpec from chemgraph.academy.core.campaign import ResourceSpec from chemgraph.academy.core.campaign import ToolSpec from chemgraph.academy.core.campaign import load_campaign @@ -22,7 +21,6 @@ "ChemGraphDaemonConfig", "ChemGraphLogicalAgent", "ChemGraphReasoningRoundEngine", - "ExecutionSpec", "LLMSettings", "PromptProfile", "ReasoningTurnResult", diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py index 2b005d34..2dacb033 100644 --- a/src/chemgraph/academy/core/agent.py +++ b/src/chemgraph/academy/core/agent.py @@ -19,7 +19,6 @@ from chemgraph.academy.core.peer_protocol import validate_message from chemgraph.academy.observability.event_log import EventLog from chemgraph.academy.observability.run_artifacts import write_status_snapshot -from chemgraph.academy.core.turn import build_recent_actions from chemgraph.academy.core.turn import ChemGraphReasoningRoundEngine from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign @@ -208,18 +207,7 @@ async def report_state(self) -> dict[str, Any]: 'finished': self.finished, 'last_error': self.last_error, 'current_activity': None, - 'received_message_count': len(self.received_message_history), - 'outbox_count': len(self.outbox), - 'recent_received_messages': self.received_message_history[-10:], 'recent_outbox': self.outbox[-10:], - 'tool_names': list(self.spec.tool_names), - 'tool_result_count': len(self.tool_results), - 'recent_tool_results': self.tool_results[-8:], - 'recent_actions': build_recent_actions( - outbox=self.outbox, - tool_results=self.tool_results, - limit=12, - ), 'belief': self.final_result or { 'hypothesis': None, 'confidence': 0.0, @@ -227,7 +215,6 @@ async def report_state(self) -> dict[str, Any]: 'supporting_tool_result_ids': [], 'reason': None, }, - 'belief_history': [self.final_result] if self.final_result else [], } async def _reasoning_round(self) -> bool: diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py index 7eae1cd3..baaa570b 100644 --- a/src/chemgraph/academy/core/campaign.py +++ b/src/chemgraph/academy/core/campaign.py @@ -7,7 +7,7 @@ from typing import Any from chemgraph.academy.examples import resolve_builtin_campaign -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, field_validator _REMOVED_CAMPAIGN_FIELDS = frozenset( @@ -30,7 +30,7 @@ class ToolSpec(BaseModel): - """Campaign-declared in-process FastMCP tool available to agents.""" + """Campaign-declared external tool available to agents.""" model_config = ConfigDict(extra='forbid') @@ -48,17 +48,6 @@ def _non_empty(cls, value: str) -> str: return value -class ExecutionSpec(BaseModel): - """Execution defaults used when configuring ChemGraph FastMCP backends.""" - - model_config = ConfigDict(extra='forbid') - - backend: str = 'local' - system: str = 'local' - config_path: str | None = None - options: dict[str, Any] = Field(default_factory=dict) - - class ResourceSpec(BaseModel): """Campaign-declared resource or artifact handle. @@ -145,7 +134,6 @@ class ChemGraphDaemonConfig: redis_host: str redis_port: int redis_namespace: str - clean_redis: bool rank: int local_rank: int | None chemgraph_repo_root: pathlib.Path diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py index 4603bc79..a933598d 100644 --- a/src/chemgraph/academy/core/tools.py +++ b/src/chemgraph/academy/core/tools.py @@ -1,4 +1,4 @@ -"""Adapt Academy actions and campaign FastMCP tools for ChemGraph turns.""" +"""Build Academy action tools and attach configured science tools.""" from __future__ import annotations @@ -40,21 +40,14 @@ class ReasoningToolRuntimeState: """Mutable per-turn state updated by ChemGraph reasoning tools.""" science_tool_completed: bool = False - submitted_result: bool = False finished_turn: bool = False executed_tool_names: list[str] = field(default_factory=list) action_tool_names: list[str] = field(default_factory=list) science_tool_names: list[str] = field(default_factory=list) background_tasks: set[asyncio.Task[Any]] = field(default_factory=set) - @property - def tool_completed(self) -> bool: - """Backward-compatible name for a completed science tool call.""" - return self.science_tool_completed - def reset(self) -> None: self.science_tool_completed = False - self.submitted_result = False self.finished_turn = False self.executed_tool_names.clear() self.action_tool_names.clear() @@ -330,7 +323,6 @@ async def submit_result(**kwargs: Any) -> dict[str, Any]: args = SubmitResultArgs.model_validate(kwargs) except ValidationError as exc: return _invalid_args_response("submit_result", exc, trace) - runtime_state.submitted_result = True result = { "timestamp": time.time(), "round": get_round_index(), diff --git a/src/chemgraph/academy/dashboard/server.py b/src/chemgraph/academy/dashboard/server.py index 141a0f90..3c50741f 100644 --- a/src/chemgraph/academy/dashboard/server.py +++ b/src/chemgraph/academy/dashboard/server.py @@ -102,7 +102,6 @@ def status_payload(handler: DashboardHandler) -> dict[str, Any]: "schema": schema, "status": status, "placement": artifacts["placement"], - "communication_proof": artifacts["communication_proof"], "summary": artifacts["summary"], } diff --git a/src/chemgraph/academy/dashboard/static/app.js b/src/chemgraph/academy/dashboard/static/app.js index 15dfdf14..0237d2e8 100644 --- a/src/chemgraph/academy/dashboard/static/app.js +++ b/src/chemgraph/academy/dashboard/static/app.js @@ -385,9 +385,8 @@ document.getElementById('metrics').innerHTML = values.map(([k,v]) => `
${esc(k)}
${esc(v)}
`).join(''); - const proof = snapshot.communication_proof || {}; - document.getElementById('proof').innerHTML = proof.passes - ? Object.entries(proof.passes).map(([k,v]) => `${esc(k)}=${v}`).join('') + document.getElementById('proof').innerHTML = crossNodeMessages.length + ? `cross-node messages=${crossNodeMessages.length}` : ''; } diff --git a/src/chemgraph/academy/observability/__init__.py b/src/chemgraph/academy/observability/__init__.py index 752e52d4..f6031201 100644 --- a/src/chemgraph/academy/observability/__init__.py +++ b/src/chemgraph/academy/observability/__init__.py @@ -3,11 +3,9 @@ from chemgraph.academy.observability.event_log import CampaignEvent from chemgraph.academy.observability.event_log import EventLog from chemgraph.academy.observability.event_log import read_events -from chemgraph.academy.observability.payloads import typed_payload __all__ = [ 'CampaignEvent', 'EventLog', 'read_events', - 'typed_payload', ] diff --git a/src/chemgraph/academy/observability/communication_proof.py b/src/chemgraph/academy/observability/communication_proof.py deleted file mode 100644 index 4f61b7bf..00000000 --- a/src/chemgraph/academy/observability/communication_proof.py +++ /dev/null @@ -1,117 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from chemgraph.academy.observability.event_log import CampaignEvent - - -def build_communication_proof( - events: list[CampaignEvent], - placement: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Build proof that communication could affect recipient behavior.""" - message_ids: dict[str, dict[str, Any]] = {} - sent_messages: list[dict[str, Any]] = [] - for event in events: - if event.event != "message_sent": - continue - payload = event.payload - message_id = payload.get("message_id") - if not isinstance(message_id, str): - continue - message = { - "message_id": message_id, - "sender": payload.get("sender"), - "recipient": payload.get("recipient"), - "content": payload.get("content"), - "evidence_refs": payload.get("evidence_refs", []), - "artifact_refs": payload.get("artifact_refs", []), - "tool_result_ids": payload.get("tool_result_ids", []), - "timestamp": payload.get("timestamp"), - } - message_ids[message_id] = message - sent_messages.append(message) - - agents = (placement or {}).get("agents", {}) - cross_node_messages = [] - if isinstance(agents, dict): - for message in sent_messages: - sender = agents.get(message.get("sender"), {}) - recipient = agents.get(message.get("recipient"), {}) - sender_host = sender.get("short_hostname") or sender.get("hostname") - recipient_host = recipient.get("short_hostname") or recipient.get("hostname") - if sender_host and recipient_host and sender_host != recipient_host: - cross_node_messages.append( - { - **message, - "sender_hostname": sender_host, - "recipient_hostname": recipient_host, - }, - ) - - cited_beliefs = [] - cited_message_ids: set[str] = set() - for event in events: - if event.event != "belief_updated": - continue - refs = event.payload.get("supporting_message_ids", []) - if not isinstance(refs, list): - continue - cited = [ref for ref in refs if isinstance(ref, str) and ref in message_ids] - if not cited: - continue - cited_message_ids.update(cited) - cited_beliefs.append( - { - "agent_id": event.agent_id, - "role": event.role, - "hypothesis": event.payload.get("hypothesis"), - "confidence": event.payload.get("confidence"), - "supporting_message_ids": cited, - }, - ) - - cited_tool_refs = [] - final_report_count = 0 - for event in events: - if event.event != "belief_updated": - continue - final_report_count += 1 - refs = ( - event.payload.get("supporting_tool_result_ids") - or event.payload.get("supporting_artifact_ids") - or [] - ) - if not isinstance(refs, list): - continue - calls = [ - ref - for ref in refs - if isinstance(ref, str) - and (ref.startswith("call-") or ref.startswith("tool-")) - ] - if calls: - cited_tool_refs.append( - { - "agent_id": event.agent_id, - "hypothesis": event.payload.get("hypothesis"), - "supporting_artifact_ids": calls, - }, - ) - - return { - "message_count": len(sent_messages), - "received_message_ids_cited_in_beliefs": sorted(cited_message_ids), - "belief_changes_citing_messages": len(cited_beliefs), - "belief_change_examples": cited_beliefs[:10], - "cross_node_message_count": len(cross_node_messages), - "cross_node_message_examples": cross_node_messages[:10], - "tool_refs_cited_in_beliefs": cited_tool_refs[:10], - "final_report_count": final_report_count, - "passes": { - "has_message": bool(sent_messages), - "has_belief_citing_message": bool(cited_beliefs), - "has_cross_node_message": bool(cross_node_messages), - "final_report": final_report_count > 0, - }, - } diff --git a/src/chemgraph/academy/observability/payloads.py b/src/chemgraph/academy/observability/payloads.py deleted file mode 100644 index 2844eec8..00000000 --- a/src/chemgraph/academy/observability/payloads.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from pydantic import BaseModel, ConfigDict, Field - -from chemgraph.academy.observability.event_log import CampaignEvent - - -class MessageSentPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - message_id: str - sender: str - recipient: str - content: str - kind: str | None = None - tldr: str | None = None - artifact_refs: list[str] = Field(default_factory=list) - tool_result_ids: list[str] = Field(default_factory=list) - reason: str | None = None - confidence: float | None = None - round: int | None = None - timestamp: float | None = None - - -class MessageReceivedPayload(MessageSentPayload): - pass - - -class ToolCallStartedPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - tool_result_id: str | None = None - tool_call_id: str | None = None - tool_name: str - arguments: dict[str, Any] = Field(default_factory=dict) - - -class ToolCallFinishedPayload(ToolCallStartedPayload): - status: str - result: Any = None - timestamp: float | None = None - agent_name: str | None = None - - -class ToolCallFailedPayload(ToolCallStartedPayload): - status: str - error: str - - -class WorkflowStartedPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - workflow_type: str - workflow_node: str | None = None - model_name: str | None = None - query: str | None = None - log_dir: str | None = None - round: int | None = None - thread_id: str | None = None - tool_names: list[str] = Field(default_factory=list) - span_id: str | None = None - parent_span_id: str | None = None - - -class WorkflowFinishedPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - workflow_type: str - status: str - error: str | None = None - log_dir: str | None = None - round: int | None = None - thread_id: str | None = None - span_id: str | None = None - parent_span_id: str | None = None - - -class LLMDecisionPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - round: int | None = None - tool_names: list[str] = Field(default_factory=list) - action_tools_called: list[str] = Field(default_factory=list) - science_tools_called: list[str] = Field(default_factory=list) - workflow_span_id: str | None = None - thread_id: str | None = None - - -class AgentStartedPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - role: str | None = None - tool_names: list[str] = Field(default_factory=list) - allowed_peers: list[str] = Field(default_factory=list) - placement: dict[str, Any] | None = None - hostname: str | None = None - short_hostname: str | None = None - - -class BeliefUpdatedPayload(BaseModel): - model_config = ConfigDict(extra="allow") - - hypothesis: str | None = None - summary: str | None = None - confidence: float | None = None - supporting_message_ids: list[str] = Field(default_factory=list) - supporting_tool_result_ids: list[str] = Field(default_factory=list) - reason: str | None = None - - -PAYLOAD_MODELS: dict[str, type[BaseModel]] = { - "message_sent": MessageSentPayload, - "message_received": MessageReceivedPayload, - "tool_call_started": ToolCallStartedPayload, - "tool_call_finished": ToolCallFinishedPayload, - "tool_call_failed": ToolCallFailedPayload, - "workflow_started": WorkflowStartedPayload, - "workflow_finished": WorkflowFinishedPayload, - "llm_decision": LLMDecisionPayload, - "llm_tool_calls": LLMDecisionPayload, - "agent_started": AgentStartedPayload, - "belief_updated": BeliefUpdatedPayload, -} - - -def typed_payload(event: CampaignEvent) -> BaseModel | None: - model = PAYLOAD_MODELS.get(event.event) - return model.model_validate(event.payload) if model else None diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py index 2af388c0..750e8b13 100644 --- a/src/chemgraph/academy/observability/run_artifacts.py +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -8,9 +8,6 @@ from collections import Counter from typing import Any -from chemgraph.academy.observability.communication_proof import ( - build_communication_proof, -) from chemgraph.academy.observability.event_log import CampaignEvent from chemgraph.academy.observability.event_log import read_events from chemgraph.academy.observability.run_files import append_jsonl @@ -24,19 +21,16 @@ def write_run_artifacts(run_dir: str | pathlib.Path) -> dict[str, Any]: - """Write placement, communication proof, and summary artifacts.""" + """Write placement and summary artifacts.""" root = pathlib.Path(run_dir) events = read_events(root / "events.jsonl") placement = build_placement(events, root / "status.json") - proof = build_communication_proof(events, placement) summary = summarize_events(events) write_json(root / "placement.json", placement) - write_json(root / "communication_proof.json", proof) write_json(root / "summary.json", summary) return { "placement": placement, - "communication_proof": proof, "summary": summary, } @@ -182,13 +176,7 @@ def default_agent_state(spec: ChemGraphAgentSpec) -> dict[str, Any]: 'finished': False, 'last_error': None, 'current_activity': None, - 'received_message_count': 0, - 'outbox_count': 0, - 'recent_received_messages': [], 'recent_outbox': [], - 'tool_names': list(spec.tool_names), - 'tool_result_count': 0, - 'recent_tool_results': [], 'belief': { 'hypothesis': None, 'confidence': 0.0, @@ -196,7 +184,6 @@ def default_agent_state(spec: ChemGraphAgentSpec) -> dict[str, Any]: 'supporting_tool_result_ids': [], 'reason': None, }, - 'belief_history': [], } @@ -244,9 +231,8 @@ def write_status_snapshot( } write_json_atomic(run_dir / 'placement.json', placement_doc) - proof = build_communication_proof( - read_events(run_dir / "events.jsonl"), - placement_doc, + converged = bool(agents) and all( + bool(item.get('finished')) for item in agents ) status = { 'timestamp': time.time(), @@ -254,13 +240,11 @@ def write_status_snapshot( 'campaign_kind': 'chemgraph_agent_swarm', 'campaign': campaign.run_id, 'agents': sorted(agents, key=lambda item: item['agent_name']), - 'communication_proof': proof, 'placement': placement_doc, - 'converged': bool(proof.get('passes', {}).get('final_report')), + 'converged': converged, } write_json_atomic(run_dir / 'status.json', status) append_jsonl(run_dir / 'status_history.jsonl', status) - write_json_atomic(run_dir / 'communication_proof.json', proof) async def wait_for_agent_statuses_finished( @@ -291,7 +275,9 @@ async def wait_for_agent_statuses_finished( def clear_run_outputs(run_dir: pathlib.Path) -> None: for name in ( 'academy_registrations.json', + 'campaign_private.json', 'communication_proof.json', + 'compute_launch.json', 'launch_plan.json', 'messages.jsonl', 'events.jsonl', @@ -318,30 +304,6 @@ def initialize_run_files( ) -> None: run_dir.mkdir(parents=True, exist_ok=True) clear_run_outputs(run_dir) - write_json( - run_dir / 'campaign_private.json', - { - 'run_id': campaign.run_id, - 'user_task': campaign.user_task, - 'initial_agent': campaign.initial_agent, - 'prompt_profile': str(campaign.prompt_profile), - 'resources': { - name: spec.model_dump(exclude_none=True) - for name, spec in campaign.resources.items() - }, - 'agents': [ - { - 'name': spec.name, - 'role': spec.role, - 'mission': spec.mission, - 'allowed_peers': list(spec.allowed_peers), - 'tool_names': list(spec.tool_names), - 'resources': list(spec.resources), - } - for spec in campaign.agents - ], - }, - ) write_json( run_dir / 'manifest.json', { @@ -368,31 +330,6 @@ def initialize_run_files( 'llm_user': llm_settings.user, }, ) - write_json( - run_dir / 'launch_plan.json', - { - 'agent_class': 'ChemGraphLogicalAgent', - 'exchange': { - 'backend': 'academy_redis', - 'host': config.redis_host, - 'port': config.redis_port, - }, - 'placement': { - 'launcher': 'mpiexec', - 'agent_count': config.agent_count, - }, - 'agents': [ - { - 'name': spec.name, - 'role': spec.role, - 'agent_class': 'ChemGraphLogicalAgent', - 'allowed_peers': list(spec.allowed_peers), - 'tool_names': list(spec.tool_names), - } - for spec in campaign.agents - ], - }, - ) append_system_trace( run_dir, 'campaign_started', diff --git a/src/chemgraph/academy/runtime/compute_launcher.py b/src/chemgraph/academy/runtime/compute_launcher.py index d437331e..655a9a6a 100644 --- a/src/chemgraph/academy/runtime/compute_launcher.py +++ b/src/chemgraph/academy/runtime/compute_launcher.py @@ -165,38 +165,6 @@ def _write_lm_config( return path -def _write_compute_launch_metadata( - *, - run_dir: Path, - args: argparse.Namespace, - profile: SystemProfile, - lm_config: Path, - lm_base_url: str, - agent_count: int, - agents_per_node: int, - max_decisions: int, - redis_port: int, -) -> None: - payload = { - "system": profile.name, - "run_id": args.run_id, - "campaign": args.campaign, - "run_dir": str(run_dir), - "lm_base_url": lm_base_url, - "lm_config": str(lm_config), - "agent_count": agent_count, - "agents_per_node": agents_per_node, - "max_decisions": max_decisions, - "redis_host": socket.getfqdn(), - "redis_port": redis_port, - "repo_root": profile.repo_root, - } - (run_dir / "compute_launch.json").write_text( - json.dumps(payload, indent=2) + "\n", - encoding="utf-8", - ) - - def _export_workflow_lm_environment(lm_config: Path) -> None: data = json.loads(lm_config.read_text(encoding="utf-8")) values = { @@ -250,18 +218,6 @@ def prepare_compute_launch(args: argparse.Namespace) -> AllocationPlan: max_decisions = args.max_decisions or defaults.max_decisions redis_port = args.redis_port or profile.redis_port - _write_compute_launch_metadata( - run_dir=run_dir, - args=args, - profile=profile, - lm_config=lm_config, - lm_base_url=lm_base_url, - agent_count=agent_count, - agents_per_node=agents_per_node, - max_decisions=max_decisions, - redis_port=redis_port, - ) - campaign_config = resolve_builtin_campaign(args.campaign) if not campaign_config.exists(): campaign_config = Path(args.campaign).resolve() diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index f95e4806..cf7735db 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -20,7 +20,6 @@ from chemgraph.academy.observability.run_artifacts import write_status_snapshot from chemgraph.academy.core.campaign import campaign_bootstrap_text from chemgraph.academy.core.campaign import ChemGraphDaemonConfig -from chemgraph.academy.core.campaign import ExecutionSpec from chemgraph.academy.core.campaign import load_campaign from chemgraph.academy.core.campaign import namespace_for_run from chemgraph.academy.core.campaign import resolve_campaign_resources @@ -35,6 +34,7 @@ from chemgraph.academy.core.lm import load_lm_config from chemgraph.academy.core.prompt import load_prompt_profile from chemgraph.mcp.fastmcp_client import ( + FastMCPExecutionConfig, build_fastmcp_tool_invoker, ) @@ -108,7 +108,7 @@ async def run_daemon(config: ChemGraphDaemonConfig) -> int: tool_invoker = await build_fastmcp_tool_invoker( specs=list(agent_spec.tools), - execution=ExecutionSpec(backend='local', system='local'), + execution=FastMCPExecutionConfig(backend='local', system='local'), run_dir=config.run_dir, agent_name=agent_spec.name, ) @@ -206,9 +206,6 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--redis-host', default='127.0.0.1') parser.add_argument('--redis-port', type=int, required=True) parser.add_argument('--redis-namespace') - parser.add_argument('--rank', type=int) - parser.add_argument('--local-rank', type=int) - parser.add_argument('--no-clean-redis', action='store_true') parser.add_argument('--chemgraph-repo-root') return parser.parse_args() @@ -236,13 +233,8 @@ def config_from_args(args: argparse.Namespace) -> ChemGraphDaemonConfig: redis_host=args.redis_host, redis_port=args.redis_port, redis_namespace=args.redis_namespace or namespace_for_run(run_dir), - clean_redis=not args.no_clean_redis, - rank=args.rank if args.rank is not None else rank_from_env(), - local_rank=( - args.local_rank - if args.local_rank is not None - else local_rank_from_env() - ), + rank=rank_from_env(), + local_rank=local_rank_from_env(), chemgraph_repo_root=( pathlib.Path(args.chemgraph_repo_root).resolve() if args.chemgraph_repo_root diff --git a/src/chemgraph/mcp/fastmcp_client.py b/src/chemgraph/mcp/fastmcp_client.py index 0d3b98e5..18090e56 100644 --- a/src/chemgraph/mcp/fastmcp_client.py +++ b/src/chemgraph/mcp/fastmcp_client.py @@ -29,7 +29,7 @@ class ToolInvocation(BaseModel): class ToolResult(BaseModel): - """Normalized result from a campaign FastMCP tool call.""" + """Normalized result from a FastMCP tool call.""" model_config = ConfigDict(extra="allow") @@ -58,6 +58,17 @@ class FastMCPExecutionSpec(Protocol): options: Mapping[str, Any] +class FastMCPExecutionConfig(BaseModel): + """Concrete backend configuration for in-process FastMCP clients.""" + + model_config = ConfigDict(extra="forbid") + + backend: str | None = "local" + system: str | None = "local" + config_path: str | None = None + options: dict[str, Any] = Field(default_factory=dict) + + def load_fastmcp_tool_module( module_name: str, *, @@ -251,7 +262,7 @@ async def invoke(self, invocation: ToolInvocation) -> ToolResult: spec = self.specs.get(invocation.tool_name) if spec is None: raise KeyError( - f"unknown campaign FastMCP tool: {invocation.tool_name}", + f"unknown FastMCP tool: {invocation.tool_name}", ) try: @@ -350,7 +361,7 @@ async def build_fastmcp_tool_invoker( run_dir: str | Path, agent_name: str, ) -> FastMCPToolInvoker: - """Build and verify one agent's in-process FastMCP tool invoker.""" + """Build and verify one in-process FastMCP tool invoker.""" invoker = FastMCPToolInvoker( specs=list(specs), execution=execution, diff --git a/tests/test_academy_dashboard.py b/tests/test_academy_dashboard.py index 61d9bf04..6f62f37b 100644 --- a/tests/test_academy_dashboard.py +++ b/tests/test_academy_dashboard.py @@ -45,7 +45,7 @@ def test_dashboard_reads_canonical_events_jsonl(tmp_path) -> None: assert events[1]["payload"]["actions"] == [{"action": "send_message"}] -def test_status_payload_builds_summary_and_proof_from_events(tmp_path) -> None: +def test_status_payload_builds_summary_from_events(tmp_path) -> None: run_dir = tmp_path / "daemon-run" run_dir.mkdir() (run_dir / "status.json").write_text( @@ -101,7 +101,6 @@ class Handler: payload = dashboard.status_payload(handler) assert set(payload) == { - "communication_proof", "placement", "run_dir", "schema", @@ -110,9 +109,15 @@ class Handler: "updated", } assert payload["summary"]["message_count"] == 1 - assert payload["communication_proof"]["passes"]["has_message"] is True - assert payload["communication_proof"]["passes"]["has_cross_node_message"] is True - assert payload["communication_proof"]["passes"]["has_belief_citing_message"] is True + assert payload["summary"]["final_reports"] == [ + { + "agent_id": "agent-01", + "confidence": 0.8, + "summary": "used peer evidence", + "supporting_message_ids": ["msg-1"], + "supporting_tool_result_ids": [], + }, + ] def test_dashboard_ignores_legacy_trace_jsonl(tmp_path) -> None: diff --git a/tests/test_academy_payloads.py b/tests/test_academy_payloads.py deleted file mode 100644 index 36e2c057..00000000 --- a/tests/test_academy_payloads.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from chemgraph.academy.observability.event_log import CampaignEvent -from chemgraph.academy.observability.payloads import PAYLOAD_MODELS -from chemgraph.academy.observability.payloads import typed_payload - - -def _payload_for(event: str) -> dict: - payloads = { - "message_sent": { - "message_id": "msg-1", - "sender": "agent-a", - "recipient": "agent-b", - "kind": "message", - "content": "content", - }, - "message_received": { - "message_id": "msg-1", - "sender": "agent-a", - "recipient": "agent-b", - "kind": "message", - "content": "content", - }, - "tool_call_started": { - "tool_result_id": "tool-1", - "tool_name": "tool", - "arguments": {}, - }, - "tool_call_finished": { - "tool_result_id": "tool-1", - "tool_name": "tool", - "arguments": {}, - "status": "ok", - }, - "tool_call_failed": { - "tool_result_id": "tool-1", - "tool_name": "tool", - "arguments": {}, - "status": "failed", - "error": "boom", - }, - "workflow_started": {"workflow_type": "single_agent"}, - "workflow_finished": {"workflow_type": "single_agent", "status": "completed"}, - "llm_decision": {}, - "llm_tool_calls": {}, - "agent_started": {}, - "belief_updated": {}, - } - return payloads[event] - - -def test_payload_models_round_trip() -> None: - for event, model in PAYLOAD_MODELS.items(): - payload = _payload_for(event) - parsed = model.model_validate(payload) - reparsed = model.model_validate(parsed.model_dump()) - assert reparsed.model_dump() == parsed.model_dump() - - -def test_typed_payload_selects_model() -> None: - event = CampaignEvent( - event="message_sent", - payload=_payload_for("message_sent"), - ) - - payload = typed_payload(event) - - assert payload is not None - assert payload.model_dump()["message_id"] == "msg-1" From a431bf0013a1ad1a320413fadcf358e14480c5a5 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 11:10:45 -0500 Subject: [PATCH 058/119] refactor(academy): extract dashboard launcher templates --- pyproject.toml | 1 + .../academy/runtime/dashboard_launcher.py | 844 +++--------------- .../academy/runtime/templates/__init__.py | 1 + .../runtime/templates/compute_wrapper.sh.tmpl | 31 + .../academy/runtime/templates/rsync_loop.sh | 19 + .../academy/runtime/templates/start_relay.sh | 59 ++ tests/test_academy_dashboard_launcher.py | 109 ++- 7 files changed, 331 insertions(+), 733 deletions(-) create mode 100644 src/chemgraph/academy/runtime/templates/__init__.py create mode 100644 src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl create mode 100644 src/chemgraph/academy/runtime/templates/rsync_loop.sh create mode 100644 src/chemgraph/academy/runtime/templates/start_relay.sh diff --git a/pyproject.toml b/pyproject.toml index 25557545..b5540e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ where = ["src/"] "example-*/prompt_profiles/*.json", ] "chemgraph.academy.runtime.profiles" = ["*.json"] +"chemgraph.academy.runtime.templates" = ["*"] "chemgraph.academy.dashboard" = ["static/*"] "ui" = ["assets/*.png"] diff --git a/src/chemgraph/academy/runtime/dashboard_launcher.py b/src/chemgraph/academy/runtime/dashboard_launcher.py index a9626d28..f9735871 100644 --- a/src/chemgraph/academy/runtime/dashboard_launcher.py +++ b/src/chemgraph/academy/runtime/dashboard_launcher.py @@ -2,759 +2,193 @@ import argparse import json -import os -import shlex -import shutil -import subprocess -import sys -import threading +import os, shlex, shutil, signal, subprocess, threading import time import urllib.error import urllib.request +from importlib.resources import files from pathlib import Path -from typing import Any +from chemgraph.academy.dashboard import serve_dashboard from chemgraph.academy.examples import campaign_launch_defaults from chemgraph.academy.runtime.profiles import list_builtin_system_profiles from chemgraph.academy.runtime.profiles import load_system_profile from chemgraph.academy.runtime.profiles.system import SystemProfile - def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - prog="chemgraph academy dashboard", - description=( - "Start the local dashboard for a ChemGraph Academy run. " - "This prepares remote run metadata, starts the local dashboard, " - "and optionally starts the temporary Mac-to-UAN Argo relay." - ), - ) - parser.add_argument("run_id") - parser.add_argument( - "--system", - default="aurora", - help=( - "Built-in system profile or profile JSON path. Built-ins: " - + ", ".join(list_builtin_system_profiles()) - ), - ) - parser.add_argument("--campaign", default="mace-ensemble-screening-20") - parser.add_argument( - "--lm-connect", - choices=("mac-argo-relay", "direct"), - default="mac-argo-relay", - help=( - "How the compute job should reach the LM endpoint. " - "mac-argo-relay starts the current SSH reverse tunnel and UAN " - "relay. direct writes --lm-base-url to run metadata without " - "starting relay infrastructure." - ), - ) - parser.add_argument( - "--lm-base-url", - help="Required for --lm-connect direct. Overrides generated relay URL.", - ) - parser.add_argument("--remote-host", help="SSH target for the login/UAN host.") - parser.add_argument("--ssh-control-path") - parser.add_argument("--keep-ssh-master", action="store_true") - parser.add_argument("--local-argo-host", default="127.0.0.1") - parser.add_argument("--local-argo-port", type=int, default=18085) - parser.add_argument("--reverse-port", type=int, default=18185) - parser.add_argument("--relay-port", type=int) - parser.add_argument("--relay-python") - parser.add_argument("--rsync-interval-s", type=float, default=2.0) - parser.add_argument( - "--local-mirror-root", - default=str(Path.home() / "projects/chemgraph-academy/remote-runs"), - ) - parser.add_argument("--local-run-dir") - parser.add_argument("--dashboard-host", default="127.0.0.1") - parser.add_argument("--dashboard-port", type=int, default=8765) - parser.add_argument( - "--local", - action="store_true", - help="Only serve an already mirrored local run. No SSH, relay, or rsync.", - ) - parser.add_argument( - "--no-dashboard", - action="store_true", - help="Prepare dashboard metadata and return without serving dashboard.", - ) - parser.add_argument( - "--overwrite-run", - action="store_true", - help=( - "Delete the remote run directory and local mirror before starting. " - "This does not stop an already-running compute job." - ), - ) - return parser.parse_args() - - -def _log(message: str) -> None: - print(message, flush=True) - - -def _http_ok(url: str, *, timeout_s: float = 5.0) -> bool: - try: - with urllib.request.urlopen(url, timeout=timeout_s) as response: - return 200 <= int(response.status) < 300 - except (OSError, urllib.error.URLError, urllib.error.HTTPError): - return False - - -def _run(command: list[str], *, input_text: str | None = None) -> subprocess.CompletedProcess[str]: - return subprocess.run( - command, - input=input_text, - text=True, - check=True, - ) - - -def _ssh_options(control_path: str, *, batch_mode: bool = True) -> list[str]: - opts = [ - "-o", - f"ControlPath={control_path}", - "-o", - "ControlMaster=auto", - "-o", - "ControlPersist=yes", - "-o", - "ServerAliveInterval=30", - "-o", - "ServerAliveCountMax=4", - ] + p = argparse.ArgumentParser(prog="chemgraph academy dashboard") + a = p.add_argument + a("run_id") + a("--system", default="aurora", help="Built-ins: " + ", ".join(list_builtin_system_profiles())) + a("--campaign", default="mace-ensemble-screening-20") + a("--lm-connect", choices=("mac-argo-relay", "direct"), default="mac-argo-relay") + a("--lm-base-url") + a("--remote-host") + a("--ssh-control-path") + a("--keep-ssh-master", action="store_true") + a("--local-argo-host", default="127.0.0.1") + a("--local-argo-port", type=int, default=18085) + a("--reverse-port", type=int, default=18185) + a("--relay-port", type=int) + a("--relay-python") + a("--rsync-interval-s", type=float, default=2.0) + a("--local-mirror-root", default=str(Path.home() / "projects/chemgraph-academy/remote-runs")) + a("--local-run-dir") + a("--dashboard-host", default="127.0.0.1") + a("--dashboard-port", type=int, default=8765) + a("--local", action="store_true", help="Only serve an already mirrored local run.") + a("--no-dashboard", action="store_true") + a("--overwrite-run", action="store_true") + return p.parse_args() + +def template(name: str) -> str: + return files("chemgraph.academy.runtime.templates").joinpath(name).read_text() + +def ssh(host: str, command: str | list[str] | None, *, control_path: str, input_text: str | None = None, check: bool = True, capture: bool = False, batch_mode: bool = True, extra: list[str] | None = None) -> subprocess.CompletedProcess[str]: + cmd = ["ssh"] if batch_mode: - opts[:0] = ["-o", "BatchMode=yes"] - return opts - - -def _start_ssh_master(*, host: str, control_path: str) -> bool: - Path(control_path).expanduser().parent.mkdir(parents=True, exist_ok=True) - check = subprocess.run( - ["ssh", "-o", f"ControlPath={control_path}", "-O", "check", host], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - text=True, - check=False, - ) - if check.returncode == 0: - return False - - _log(f"Starting SSH ControlMaster for {host}...") - _run( - [ - "ssh", - "-M", - "-N", - "-f", - "-o", - "ControlMaster=yes", - "-o", - f"ControlPath={control_path}", - "-o", - "ControlPersist=yes", - "-o", - "ServerAliveInterval=30", - "-o", - "ServerAliveCountMax=4", - host, - ], - ) - return True - - -def _stop_ssh_master(*, host: str, control_path: str) -> None: - subprocess.run( - ["ssh", "-o", f"ControlPath={control_path}", "-O", "exit", host], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - text=True, - check=False, - ) - - -def _wrapper_text(profile: SystemProfile) -> str: - path_prefix = ":".join([profile.redis_bin_dir, f"{profile.remote_root}/bin"]) - pythonpath = ":".join(profile.pythonpath_entries) - return f"""#!/bin/bash -set -euo pipefail - -log() {{ - printf '[chemgraph-academy-run] %s\\n' "$*" >&2 -}} - -export PATH="{path_prefix}:${{PATH}}" -export PYTHONPATH="{pythonpath}:${{PYTHONPATH:-}}" - -PYTHON_BIN="${{CHEMGRAPH_ACADEMY_PYTHON:-python}}" -if ! command -v "${{PYTHON_BIN}}" >/dev/null 2>&1; then - log "Python command not found: ${{PYTHON_BIN}}" - log "Load your site module and activate the ChemGraph/Academy environment first." - log "Profile Python, if you want to use it explicitly: {profile.venv_python}" - exit 1 -fi - -ACTIVE_PYTHON="$("${{PYTHON_BIN}}" -c 'import sys; print(sys.executable)')" -log "using active Python: ${{ACTIVE_PYTHON}}" -log "not loading modules or activating a venv inside this wrapper" - -if ! "${{PYTHON_BIN}}" -c 'import chemgraph.academy.runtime.compute_launcher' >/dev/null 2>&1; then - log "active Python cannot import chemgraph.academy.runtime.compute_launcher" - log "Load the proper site module and venv before running this command." - log "Profile Python, if you want to use it explicitly: {profile.venv_python}" - exit 1 -fi - -log "starting ChemGraph Academy compute launcher" -exec "${{PYTHON_BIN}}" -m chemgraph.academy.runtime.compute_launcher "$@" -""" - - -def _install_compute_wrapper( - *, - profile: SystemProfile, - host: str, - ssh_opts: list[str], -) -> str: - wrapper_bin_dir = f"{profile.remote_root}/bin" - wrapper_path = f"{wrapper_bin_dir}/chemgraph-academy-run" - _log(f"Installing compute wrapper at {wrapper_path}...") - remote_command = ( - f"mkdir -p {shlex.quote(wrapper_bin_dir)} && " - f"cat > {shlex.quote(wrapper_path)} && " - f"chmod +x {shlex.quote(wrapper_path)}" - ) - _run( - ["ssh", *ssh_opts, host, remote_command], - input_text=_wrapper_text(profile), - ) - return wrapper_path - - -def _relay_script_text() -> str: - return r""" -set -euo pipefail - -REMOTE_ROOT="$1" -RELAY_SCRIPT="$2" -RELAY_HOST_FILE="$3" -RELAY_PID_FILE="$4" -RELAY_LOG_FILE="$5" -RELAY_PORT="$6" -REVERSE_PORT="$7" -RELAY_PYTHON="$8" - -cd "${REMOTE_ROOT}" -UAN_HOST="$(hostname -f)" -printf '%s\n' "${UAN_HOST}" > "${RELAY_HOST_FILE}" - -if [ -f "${RELAY_PID_FILE}" ]; then - OLD_PID="$(cat "${RELAY_PID_FILE}" 2>/dev/null || true)" - case "${OLD_PID}" in - ''|*[!0-9]*) ;; - *) kill "${OLD_PID}" 2>/dev/null || true ;; - esac -fi - -"${RELAY_PYTHON}" "${RELAY_SCRIPT}" \ - --listen-host 0.0.0.0 \ - --listen-port "${RELAY_PORT}" \ - --target-host 127.0.0.1 \ - --target-port "${REVERSE_PORT}" \ - > "${RELAY_LOG_FILE}" 2>&1 & -RELAY_PID="$!" -printf '%s\n' "${RELAY_PID}" > "${RELAY_PID_FILE}" - -cleanup_remote() { - kill "${RELAY_PID}" 2>/dev/null || true -} -trap cleanup_remote EXIT - -deadline=$((SECONDS + 45)) -while ! curl -fsS "http://${UAN_HOST}:${RELAY_PORT}/v1/models" >/dev/null; do - if ! kill -0 "${RELAY_PID}" 2>/dev/null; then - echo "UAN relay process exited before readiness. Last relay log lines:" >&2 - tail -n 80 "${RELAY_LOG_FILE}" >&2 || true - exit 1 - fi - if [ "${SECONDS}" -gt "${deadline}" ]; then - echo "UAN relay did not become ready. Last relay log lines:" >&2 - tail -n 80 "${RELAY_LOG_FILE}" >&2 || true - exit 1 - fi - sleep 1 -done - -echo "UAN_RELAY_HOST=${UAN_HOST}" -echo "UAN relay ready at http://${UAN_HOST}:${RELAY_PORT}/argoapi/v1" - -while true; do - sleep 3600 -done -""" - - -def _start_mac_argo_relay( - *, - profile: SystemProfile, - host: str, - ssh_opts: list[str], - local_argo_host: str, - local_argo_port: int, - reverse_port: int, - relay_port: int, - relay_python: str, - local_log_path: Path, -) -> subprocess.Popen[str]: + cmd += ["-o", "BatchMode=yes"] + cmd += ["-o", f"ControlPath={control_path}", "-o", "ControlMaster=auto", "-o", "ControlPersist=yes", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=4"] + cmd += extra or [] + cmd.append(host) + cmd += command if isinstance(command, list) else ([command] if command else []) + return subprocess.run(cmd, input=input_text, text=True, check=check, stdout=subprocess.PIPE if capture else None, stderr=subprocess.PIPE if capture else None) + +def wrapper(profile: SystemProfile) -> str: + return ( + template("compute_wrapper.sh.tmpl") + .replace("%{path_prefix}%", ":".join([profile.redis_bin_dir, f"{profile.remote_root}/bin"])) + .replace("%{pythonpath}%", ":".join(profile.pythonpath_entries)) + .replace("%{venv_python}%", profile.venv_python) + ) + +def start_relay(profile: SystemProfile, host: str, control_path: str, args: argparse.Namespace, relay_port: int, relay_python: str, log_path: Path) -> subprocess.Popen[str]: relay_script = f"{profile.academy_repo_root}/examples/09-polaris-lm-swarm/uan_http_relay.py" - relay_pid_file = f"{profile.remote_root}/uan-relay-{relay_port}.pid" - relay_log_file = f"{profile.remote_root}/uan-relay-{relay_port}.log" - local_log_path.parent.mkdir(parents=True, exist_ok=True) - log_file = local_log_path.open("w", encoding="utf-8") - - _log(f"Starting {profile.name} UAN relay through {host}...") - command = [ - "ssh", - *ssh_opts, - "-R", - f"127.0.0.1:{reverse_port}:{local_argo_host}:{local_argo_port}", - host, - "bash", - "-s", - "--", - profile.remote_root, - relay_script, - profile.relay_host_file, - relay_pid_file, - relay_log_file, - str(relay_port), - str(reverse_port), - relay_python, - ] - process = subprocess.Popen( - command, - stdin=subprocess.PIPE, - stdout=log_file, - stderr=subprocess.STDOUT, - text=True, - ) + relay_args = ["bash", "-s", "--", profile.remote_root, relay_script, profile.relay_host_file, f"{profile.remote_root}/uan-relay-{relay_port}.pid", f"{profile.remote_root}/uan-relay-{relay_port}.log", str(relay_port), str(args.reverse_port), relay_python] + log_path.parent.mkdir(parents=True, exist_ok=True) + cmd = ["ssh", "-o", "BatchMode=yes", "-o", f"ControlPath={control_path}", "-o", "ControlMaster=auto", "-o", "ControlPersist=yes", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=4", "-R", f"127.0.0.1:{args.reverse_port}:{args.local_argo_host}:{args.local_argo_port}", host, *relay_args] + process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=log_path.open("w", encoding="utf-8"), stderr=subprocess.STDOUT, text=True) assert process.stdin is not None - process.stdin.write(_relay_script_text()) + process.stdin.write(template("start_relay.sh")) process.stdin.close() return process - -def _remote_relay_ready( - *, - host: str, - ssh_opts: list[str], - relay_host_file: str, - relay_port: int, -) -> bool: - command = ( - f"host=$(cat {shlex.quote(relay_host_file)} 2>/dev/null || true); " - f'test -n "$host" && ' - f'curl -fsS "http://${{host}}:{relay_port}/v1/models" >/dev/null' - ) - result = subprocess.run( - ["ssh", *ssh_opts, host, command], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - text=True, - check=False, - ) - return result.returncode == 0 - - -def _read_remote_file( - *, - host: str, - ssh_opts: list[str], - path: str, -) -> str: - result = subprocess.run( - ["ssh", *ssh_opts, host, "cat", path], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=True, - ) - return result.stdout.strip() - - -def _wait_for_relay( - *, - profile: SystemProfile, - host: str, - ssh_opts: list[str], - relay_port: int, - relay_process: subprocess.Popen[str], - local_log_path: Path, -) -> str: - _log("Waiting for relay readiness...") - deadline = time.time() + 60 - while time.time() < deadline: - if _remote_relay_ready( - host=host, - ssh_opts=ssh_opts, - relay_host_file=profile.relay_host_file, - relay_port=relay_port, - ): - relay_host = _read_remote_file( - host=host, - ssh_opts=ssh_opts, - path=profile.relay_host_file, - ) - _log(f"{profile.name} relay host: {relay_host}") +def wait_relay(profile: SystemProfile, host: str, control_path: str, relay_port: int, process: subprocess.Popen[str], log_path: Path) -> str: + print("Waiting for relay readiness...", flush=True) + check = f"host=$(cat {shlex.quote(profile.relay_host_file)} 2>/dev/null || true); test -n \"$host\" && curl -fsS \"http://${{host}}:{relay_port}/v1/models\" >/dev/null" + for _ in range(60): + if ssh(host, check, control_path=control_path, check=False).returncode == 0: + relay_host = ssh(host, ["cat", profile.relay_host_file], control_path=control_path, capture=True).stdout.strip() + print(f"{profile.name} relay host: {relay_host}", flush=True) return relay_host - if relay_process.poll() is not None: - detail = local_log_path.read_text(encoding="utf-8", errors="replace") - raise RuntimeError( - "Relay SSH session exited before readiness. Local relay log:\n" - + detail, - ) + if process.poll() is not None: + raise RuntimeError("Relay SSH session exited before readiness. Local relay log:\n" + log_path.read_text(encoding="utf-8", errors="replace")) time.sleep(1) - detail = local_log_path.read_text(encoding="utf-8", errors="replace") - raise RuntimeError("Relay readiness timed out. Local relay log:\n" + detail) - - -def _write_dashboard_metadata( - *, - profile: SystemProfile, - host: str, - ssh_opts: list[str], - run_id: str, - campaign: str, - lm_connect: str, - lm_base_url: str, - relay_host: str | None, - relay_port: int | None, -) -> None: - remote_run_dir = f"{profile.run_root}/{run_id}" - payload: dict[str, Any] = { - "created_at": time.time(), - "created_by": "chemgraph-academy-dashboard", - "run_id": run_id, - "system": profile.name, - "campaign": campaign, - "remote_run_dir": remote_run_dir, - "remote_host": host, - "lm_connect": lm_connect, - "lm_base_url": lm_base_url, - "workspace_root": profile.remote_root, - "academy_repo_root": profile.academy_repo_root, - "chemgraph_repo_root": profile.repo_root, - } - if relay_host: - payload["relay_host"] = relay_host - if relay_port is not None: - payload["relay_port"] = relay_port - - metadata = json.dumps(payload, indent=2) + "\n" - remote_path = f"{remote_run_dir}/dashboard_metadata.json" - remote_command = ( - f"mkdir -p {shlex.quote(remote_run_dir)} && " - f"cat > {shlex.quote(remote_path)}" - ) - _log(f"Writing run metadata: {host}:{remote_run_dir}/dashboard_metadata.json") - _run( - ["ssh", *ssh_opts, host, remote_command], - input_text=metadata, - ) - - -def _run_id_allows_delete(run_id: str) -> bool: - return bool(run_id) and "/" not in run_id and run_id not in {".", ".."} - + raise RuntimeError("Relay readiness timed out. Local relay log:\n" + log_path.read_text(encoding="utf-8", errors="replace")) -def _delete_existing_run( - *, - profile: SystemProfile, - host: str, - ssh_opts: list[str], - run_id: str, - local_run_dir: Path, -) -> None: - if not _run_id_allows_delete(run_id): - raise RuntimeError(f"Refusing to overwrite unsafe run id: {run_id!r}") - - remote_run_dir = f"{profile.run_root}/{run_id}" - _log("Deleting existing run artifacts because --overwrite-run was set:") - _log(f" remote: {host}:{remote_run_dir}") - _log(f" local: {local_run_dir}") - - remote_command = ( - "set -euo pipefail; " - f"run_root={shlex.quote(profile.run_root)}; " - f"run_id={shlex.quote(run_id)}; " - 'case "$run_id" in ""|.|..|*/*) echo "unsafe run id" >&2; exit 2;; esac; ' - 'run_dir="$run_root/$run_id"; ' - 'trash_root="$run_root/.deleted-runs"; ' - 'if [ -e "$run_dir" ]; then ' - 'mkdir -p "$trash_root"; ' - 'trash_dir="$trash_root/${run_id}.$(date +%Y%m%d%H%M%S).$$"; ' - 'mv -- "$run_dir" "$trash_dir"; ' - 'for delay in 0 1 2 5 10; do ' - 'sleep "$delay"; ' - 'if rm -rf -- "$trash_dir" 2>/dev/null; then break; fi; ' - 'done; ' - 'fi; ' - 'mkdir -p "$run_dir"' - ) - _run(["ssh", *ssh_opts, host, remote_command]) - if local_run_dir.exists(): - shutil.rmtree(local_run_dir) - - -def _start_rsync_loop( - *, - host: str, - control_path: str, - remote_run_dir: str, - local_run_dir: Path, - interval_s: float, - stop_event: threading.Event, -) -> threading.Thread: +def start_rsync(host: str, control_path: str, remote_run_dir: str, local_run_dir: Path, interval_s: float, stop: threading.Event) -> None: local_run_dir.mkdir(parents=True, exist_ok=True) - log_path = local_run_dir / "rsync.log" + rsync_args = [host, control_path, remote_run_dir, str(local_run_dir), str(interval_s), str(local_run_dir / "rsync.log")] def loop() -> None: - ssh_command = ( - "ssh " - "-o BatchMode=yes " - "-o ControlMaster=auto " - f"-o ControlPath={shlex.quote(control_path)} " - "-o ControlPersist=yes" - ) - while not stop_event.is_set(): - with log_path.open("a", encoding="utf-8") as log: - subprocess.run( - [ - "rsync", - "-az", - "--delete", - "-e", - ssh_command, - f"{host}:{remote_run_dir}/", - f"{local_run_dir}/", - ], - stdout=log, - stderr=subprocess.STDOUT, - text=True, - check=False, - ) - stop_event.wait(interval_s) - - thread = threading.Thread(target=loop, name="chemgraph-academy-rsync", daemon=True) - thread.start() - return thread - - -def _run_dashboard(*, local_run_dir: Path, host: str, port: int) -> int: - from chemgraph.academy import dashboard - - old_argv = sys.argv - try: - sys.argv = [ - "chemgraph-academy-dashboard serve", - "--run-dir", - str(local_run_dir), - "--host", - host, - "--port", - str(port), - ] - return dashboard.main() - finally: - sys.argv = old_argv - - -def _print_compute_command( - *, - profile: SystemProfile, - wrapper_path: str, - run_id: str, - campaign: str, -) -> None: - _log("") - _log("Dashboard launcher is ready.") - _log("") - _log(f"On the {profile.name} compute node, use:") - if profile.name == "polaris": - _log(" module use /soft/modulefiles") - _log(" module load conda") - _log(" conda activate base") - _log(f" source {profile.remote_root}/venvs/academy-swarm/bin/activate") - else: - _log(" module load frameworks") - _log(f" source {profile.remote_root}/venvs/academy-swarm/bin/activate") - _log(f" export PATH={profile.remote_root}/bin:$PATH") - _log(" chemgraph-academy-run \\") - _log(f" --system {profile.name} \\") - _log(f" --run-id {run_id} \\") - _log(f" --campaign {campaign}") - _log("") - _log("If PATH is not configured, use:") - _log(f" {wrapper_path} \\") - _log(f" --system {profile.name} \\") - _log(f" --run-id {run_id} \\") - _log(f" --campaign {campaign}") - + process = subprocess.Popen(["bash", "-s", "--", *rsync_args], stdin=subprocess.PIPE, text=True, start_new_session=True) + assert process.stdin is not None + process.stdin.write(template("rsync_loop.sh")) + process.stdin.close() + stop.wait() + if process.poll() is None: + os.killpg(process.pid, signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + os.killpg(process.pid, signal.SIGKILL) -def _validate_campaign_name(campaign: str) -> None: - campaign_launch_defaults(campaign) + threading.Thread(target=loop, name="chemgraph-academy-rsync", daemon=True).start() +def compute_lines(profile: SystemProfile, wrapper_path: str, run_id: str, campaign: str) -> list[str]: + lines = [" module use /soft/modulefiles", " module load conda", " conda activate base"] if profile.name == "polaris" else [" module load frameworks"] + return lines + [f" source {profile.remote_root}/venvs/academy-swarm/bin/activate", f" export PATH={profile.remote_root}/bin:$PATH", " chemgraph-academy-run \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}", "", "If PATH is not configured, use:", f" {wrapper_path} \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}"] def main() -> int: args = parse_args() profile = load_system_profile(args.system) - _validate_campaign_name(args.campaign) - - local_run_dir = Path( - args.local_run_dir or Path(args.local_mirror_root) / args.run_id, - ).expanduser() + campaign_launch_defaults(args.campaign) + local_run_dir = Path(args.local_run_dir or Path(args.local_mirror_root) / args.run_id).expanduser() local_run_dir.mkdir(parents=True, exist_ok=True) - - if args.local and args.overwrite_run: - raise RuntimeError("--overwrite-run cannot be used with --local") - if args.local: - if args.no_dashboard: - _log(f"Local run directory: {local_run_dir}") - return 0 - return _run_dashboard( - local_run_dir=local_run_dir, - host=args.dashboard_host, - port=args.dashboard_port, - ) + if args.overwrite_run: + raise RuntimeError("--overwrite-run cannot be used with --local") + return 0 if args.no_dashboard else serve_dashboard(run_dir=local_run_dir, host=args.dashboard_host, port=args.dashboard_port) + if args.lm_connect == "direct" and not args.lm_base_url: + raise RuntimeError("--lm-connect direct requires --lm-base-url") + if args.lm_connect == "mac-argo-relay": + try: + with urllib.request.urlopen(f"http://{args.local_argo_host}:{args.local_argo_port}/v1/models", timeout=5) as response: + if int(response.status) >= 300: + raise OSError + except (OSError, urllib.error.URLError, urllib.error.HTTPError) as exc: + raise RuntimeError("Local argo-shim is not reachable. Start it before using --lm-connect mac-argo-relay.") from exc remote_host = args.remote_host or profile.remote_host - control_path = ( - args.ssh_control_path - or str(Path.home() / f".ssh/{profile.name}-dashboard-%r@%h:%p") - ) + control_path = args.ssh_control_path or str(Path.home() / f".ssh/{profile.name}-dashboard-%r@%h:%p") relay_port = args.relay_port or profile.relay_port - relay_python = args.relay_python or profile.venv_python - local_relay_log = Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log") remote_run_dir = f"{profile.run_root}/{args.run_id}" - relay_process: subprocess.Popen[str] | None = None - stop_rsync = threading.Event() - started_ssh_master = False - + stop = threading.Event() + started_master = False try: - if args.lm_connect == "mac-argo-relay": - health_url = f"http://{args.local_argo_host}:{args.local_argo_port}/v1/models" - if not _http_ok(health_url): - raise RuntimeError( - "Local argo-shim is not reachable: " - f"{health_url}\n" - "Start it before using --lm-connect mac-argo-relay.", - ) - elif not args.lm_base_url: - raise RuntimeError("--lm-connect direct requires --lm-base-url") - - started_ssh_master = _start_ssh_master( - host=remote_host, - control_path=control_path, - ) - ssh_opts = _ssh_options(control_path) + Path(control_path).expanduser().parent.mkdir(parents=True, exist_ok=True) + if ssh(remote_host, None, control_path=control_path, extra=["-O", "check"], check=False, batch_mode=False).returncode != 0: + print(f"Starting SSH ControlMaster for {remote_host}...", flush=True) + ssh(remote_host, None, control_path=control_path, extra=["-M", "-N", "-f", "-o", "ControlMaster=yes"], batch_mode=False) + started_master = True if args.overwrite_run: - _delete_existing_run( - profile=profile, - host=remote_host, - ssh_opts=ssh_opts, - run_id=args.run_id, - local_run_dir=local_run_dir, - ) - wrapper_path = _install_compute_wrapper( - profile=profile, - host=remote_host, - ssh_opts=ssh_opts, - ) - - relay_host: str | None = None + if not args.run_id or "/" in args.run_id or args.run_id in {".", ".."}: + raise RuntimeError(f"Refusing to overwrite unsafe run id: {args.run_id!r}") + print("Deleting existing run artifacts because --overwrite-run was set:", flush=True) + print(f" remote: {remote_host}:{remote_run_dir}", flush=True) + print(f" local: {local_run_dir}", flush=True) + delete = f"set -euo pipefail; run_root={shlex.quote(profile.run_root)}; run_id={shlex.quote(args.run_id)}; case \"$run_id\" in \"\"|.|..|*/*) echo \"unsafe run id\" >&2; exit 2;; esac; run_dir=\"$run_root/$run_id\"; trash_root=\"$run_root/.deleted-runs\"; if [ -e \"$run_dir\" ]; then mkdir -p \"$trash_root\"; trash_dir=\"$trash_root/${{run_id}}.$(date +%Y%m%d%H%M%S).$$\"; mv -- \"$run_dir\" \"$trash_dir\"; for delay in 0 1 2 5 10; do sleep \"$delay\"; if rm -rf -- \"$trash_dir\" 2>/dev/null; then break; fi; done; fi; mkdir -p \"$run_dir\"" + ssh(remote_host, delete, control_path=control_path) + if local_run_dir.exists(): + shutil.rmtree(local_run_dir) + wrapper_path = f"{profile.remote_root}/bin/chemgraph-academy-run" + print(f"Installing compute wrapper at {wrapper_path}...", flush=True) + ssh(remote_host, f"mkdir -p {shlex.quote(profile.remote_root + '/bin')} && cat > {shlex.quote(wrapper_path)} && chmod +x {shlex.quote(wrapper_path)}", control_path=control_path, input_text=wrapper(profile)) + relay_host = None if args.lm_connect == "mac-argo-relay": - relay_process = _start_mac_argo_relay( - profile=profile, - host=remote_host, - ssh_opts=ssh_opts, - local_argo_host=args.local_argo_host, - local_argo_port=args.local_argo_port, - reverse_port=args.reverse_port, - relay_port=relay_port, - relay_python=relay_python, - local_log_path=local_relay_log, - ) - relay_host = _wait_for_relay( - profile=profile, - host=remote_host, - ssh_opts=ssh_opts, - relay_port=relay_port, - relay_process=relay_process, - local_log_path=local_relay_log, - ) - lm_base_url = f"http://{relay_host}:{relay_port}/argoapi/v1" - else: - lm_base_url = str(args.lm_base_url) - - _log(f"Compute-node LM URL: {lm_base_url}") - _write_dashboard_metadata( - profile=profile, - host=remote_host, - ssh_opts=ssh_opts, - run_id=args.run_id, - campaign=args.campaign, - lm_connect=args.lm_connect, - lm_base_url=lm_base_url, - relay_host=relay_host, - relay_port=relay_port if relay_host else None, - ) - - _log("Starting rsync mirror:") - _log(f" {remote_host}:{remote_run_dir}/") - _log(f" {local_run_dir}/") - _start_rsync_loop( - host=remote_host, - control_path=control_path, - remote_run_dir=remote_run_dir, - local_run_dir=local_run_dir, - interval_s=args.rsync_interval_s, - stop_event=stop_rsync, - ) - - _print_compute_command( - profile=profile, - wrapper_path=wrapper_path, - run_id=args.run_id, - campaign=args.campaign, - ) - + print(f"Starting {profile.name} UAN relay through {remote_host}...", flush=True) + relay_process = start_relay(profile, remote_host, control_path, args, relay_port, args.relay_python or profile.venv_python, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log")) + relay_host = wait_relay(profile, remote_host, control_path, relay_port, relay_process, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log")) + lm_base_url = f"http://{relay_host}:{relay_port}/argoapi/v1" if relay_host else str(args.lm_base_url) + print(f"Compute-node LM URL: {lm_base_url}", flush=True) + metadata = {"created_at": time.time(), "created_by": "chemgraph-academy-dashboard", "run_id": args.run_id, "system": profile.name, "campaign": args.campaign, "remote_run_dir": remote_run_dir, "remote_host": remote_host, "lm_connect": args.lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, "academy_repo_root": profile.academy_repo_root, "chemgraph_repo_root": profile.repo_root} + if relay_host: + metadata.update({"relay_host": relay_host, "relay_port": relay_port}) + print(f"Writing run metadata: {remote_host}:{remote_run_dir}/dashboard_metadata.json", flush=True) + ssh(remote_host, f"mkdir -p {shlex.quote(remote_run_dir)} && cat > {shlex.quote(remote_run_dir + '/dashboard_metadata.json')}", control_path=control_path, input_text=json.dumps(metadata, indent=2) + "\n") + print("Starting rsync mirror:", flush=True) + print(f" {remote_host}:{remote_run_dir}/", flush=True) + print(f" {local_run_dir}/", flush=True) + start_rsync(remote_host, control_path, remote_run_dir, local_run_dir, args.rsync_interval_s, stop) + print("\nDashboard launcher is ready.\n", flush=True) + print(f"On the {profile.name} compute node, use:", flush=True) + print("\n".join(compute_lines(profile, wrapper_path, args.run_id, args.campaign)), flush=True) if args.no_dashboard: return 0 - - _log("") - _log(f"Starting dashboard at http://{args.dashboard_host}:{args.dashboard_port}") - _log("Ctrl-C stops the local dashboard, rsync loop, and relay tunnel.") - return _run_dashboard( - local_run_dir=local_run_dir, - host=args.dashboard_host, - port=args.dashboard_port, - ) + print(f"\nStarting dashboard at http://{args.dashboard_host}:{args.dashboard_port}", flush=True) + print("Ctrl-C stops the local dashboard, rsync loop, and relay tunnel.", flush=True) + return serve_dashboard(run_dir=local_run_dir, host=args.dashboard_host, port=args.dashboard_port) finally: - stop_rsync.set() + stop.set() if relay_process is not None and relay_process.poll() is None: relay_process.terminate() try: relay_process.wait(timeout=5) except subprocess.TimeoutExpired: relay_process.kill() - keep = args.keep_ssh_master or os.environ.get("CHEMGRAPH_ACADEMY_KEEP_SSH_MASTER") == "1" - if started_ssh_master and not keep: - _stop_ssh_master(host=remote_host, control_path=control_path) - + if started_master and not args.keep_ssh_master: + ssh(remote_host, None, control_path=control_path, extra=["-O", "exit"], check=False, batch_mode=False) if __name__ == "__main__": raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/templates/__init__.py b/src/chemgraph/academy/runtime/templates/__init__.py new file mode 100644 index 00000000..143a959e --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/__init__.py @@ -0,0 +1 @@ +"""Runtime shell templates shipped with ChemGraph Academy.""" diff --git a/src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl b/src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl new file mode 100644 index 00000000..f168159b --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl @@ -0,0 +1,31 @@ +#!/bin/bash +set -euo pipefail + +log() { + printf '[chemgraph-academy-run] %s\n' "$*" >&2 +} + +export PATH="%{path_prefix}%:${PATH}" +export PYTHONPATH="%{pythonpath}%:${PYTHONPATH:-}" + +PYTHON_BIN="${CHEMGRAPH_ACADEMY_PYTHON:-python}" +if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then + log "Python command not found: ${PYTHON_BIN}" + log "Load your site module and activate the ChemGraph/Academy environment first." + log "Profile Python, if you want to use it explicitly: %{venv_python}%" + exit 1 +fi + +ACTIVE_PYTHON="$("${PYTHON_BIN}" -c 'import sys; print(sys.executable)')" +log "using active Python: ${ACTIVE_PYTHON}" +log "not loading modules or activating a venv inside this wrapper" + +if ! "${PYTHON_BIN}" -c 'import chemgraph.academy.runtime.compute_launcher' >/dev/null 2>&1; then + log "active Python cannot import chemgraph.academy.runtime.compute_launcher" + log "Load the proper site module and venv before running this command." + log "Profile Python, if you want to use it explicitly: %{venv_python}%" + exit 1 +fi + +log "starting ChemGraph Academy compute launcher" +exec "${PYTHON_BIN}" -m chemgraph.academy.runtime.compute_launcher "$@" diff --git a/src/chemgraph/academy/runtime/templates/rsync_loop.sh b/src/chemgraph/academy/runtime/templates/rsync_loop.sh new file mode 100644 index 00000000..26663692 --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/rsync_loop.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -euo pipefail + +HOST="$1" +CONTROL_PATH="$2" +REMOTE_RUN_DIR="$3" +LOCAL_RUN_DIR="$4" +INTERVAL_S="$5" +LOG_PATH="$6" + +mkdir -p "${LOCAL_RUN_DIR}" +while true; do + rsync -az --delete \ + -e "ssh -o BatchMode=yes -o ControlMaster=auto -o ControlPath=${CONTROL_PATH} -o ControlPersist=yes" \ + "${HOST}:${REMOTE_RUN_DIR}/" \ + "${LOCAL_RUN_DIR}/" \ + >> "${LOG_PATH}" 2>&1 || true + sleep "${INTERVAL_S}" +done diff --git a/src/chemgraph/academy/runtime/templates/start_relay.sh b/src/chemgraph/academy/runtime/templates/start_relay.sh new file mode 100644 index 00000000..1bb9e5fd --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/start_relay.sh @@ -0,0 +1,59 @@ +#!/bin/bash +set -euo pipefail + +REMOTE_ROOT="$1" +RELAY_SCRIPT="$2" +RELAY_HOST_FILE="$3" +RELAY_PID_FILE="$4" +RELAY_LOG_FILE="$5" +RELAY_PORT="$6" +REVERSE_PORT="$7" +RELAY_PYTHON="$8" + +cd "${REMOTE_ROOT}" +UAN_HOST="$(hostname -f)" +printf '%s\n' "${UAN_HOST}" > "${RELAY_HOST_FILE}" + +if [ -f "${RELAY_PID_FILE}" ]; then + OLD_PID="$(cat "${RELAY_PID_FILE}" 2>/dev/null || true)" + case "${OLD_PID}" in + ''|*[!0-9]*) ;; + *) kill "${OLD_PID}" 2>/dev/null || true ;; + esac +fi + +"${RELAY_PYTHON}" "${RELAY_SCRIPT}" \ + --listen-host 0.0.0.0 \ + --listen-port "${RELAY_PORT}" \ + --target-host 127.0.0.1 \ + --target-port "${REVERSE_PORT}" \ + > "${RELAY_LOG_FILE}" 2>&1 & +RELAY_PID="$!" +printf '%s\n' "${RELAY_PID}" > "${RELAY_PID_FILE}" + +cleanup_remote() { + kill "${RELAY_PID}" 2>/dev/null || true +} +trap cleanup_remote EXIT + +deadline=$((SECONDS + 45)) +while ! curl -fsS "http://${UAN_HOST}:${RELAY_PORT}/v1/models" >/dev/null; do + if ! kill -0 "${RELAY_PID}" 2>/dev/null; then + echo "UAN relay process exited before readiness. Last relay log lines:" >&2 + tail -n 80 "${RELAY_LOG_FILE}" >&2 || true + exit 1 + fi + if [ "${SECONDS}" -gt "${deadline}" ]; then + echo "UAN relay did not become ready. Last relay log lines:" >&2 + tail -n 80 "${RELAY_LOG_FILE}" >&2 || true + exit 1 + fi + sleep 1 +done + +echo "UAN_RELAY_HOST=${UAN_HOST}" +echo "UAN relay ready at http://${UAN_HOST}:${RELAY_PORT}/argoapi/v1" + +while true; do + sleep 3600 +done diff --git a/tests/test_academy_dashboard_launcher.py b/tests/test_academy_dashboard_launcher.py index c4fcd857..587d2534 100644 --- a/tests/test_academy_dashboard_launcher.py +++ b/tests/test_academy_dashboard_launcher.py @@ -1,5 +1,8 @@ from __future__ import annotations +import argparse +import json +import subprocess from pathlib import Path import pytest @@ -24,45 +27,95 @@ def _profile(tmp_path: Path) -> SystemProfile: redis_bind="0.0.0.0", redis_protected_mode="no", mpiexec="mpiexec", - pythonpath_entries=[str(tmp_path)], + pythonpath_entries=[str(tmp_path), "/remote/root/ChemGraph/src"], no_proxy="127.0.0.1,localhost", ) -def test_delete_existing_run_removes_remote_and_local(tmp_path, monkeypatch) -> None: +def _args(tmp_path: Path, **overrides) -> argparse.Namespace: + values = { + "run_id": "run-001", + "system": "test-system", + "campaign": "mace-ensemble-screening-20", + "lm_connect": "direct", + "lm_base_url": "http://lm.example/v1", + "remote_host": None, + "ssh_control_path": str(tmp_path / "ssh-control"), + "keep_ssh_master": False, + "local_argo_host": "127.0.0.1", + "local_argo_port": 18085, + "reverse_port": 18185, + "relay_port": None, + "relay_python": None, + "rsync_interval_s": 2.0, + "local_mirror_root": str(tmp_path / "mirror"), + "local_run_dir": None, + "dashboard_host": "127.0.0.1", + "dashboard_port": 8765, + "local": False, + "no_dashboard": True, + "overwrite_run": True, + } + values.update(overrides) + return argparse.Namespace(**values) + + +def test_compute_wrapper_template_renders_profile_values(tmp_path) -> None: + text = dashboard_launcher.wrapper(_profile(tmp_path)) + + assert "%{" not in text + assert '/remote/root/tools/redis/bin:/remote/root/bin:${PATH}' in text + assert f'{tmp_path}:/remote/root/ChemGraph/src:${{PYTHONPATH:-}}' in text + assert "/remote/root/venv/bin/python" in text + + +def test_dashboard_launcher_overwrite_writes_remote_state(tmp_path, monkeypatch) -> None: local_run = tmp_path / "mirror" / "run-001" local_run.mkdir(parents=True) (local_run / "status.json").write_text("{}\n", encoding="utf-8") - calls: list[list[str]] = [] + calls: list[dict] = [] - monkeypatch.setattr( - dashboard_launcher, - "_run", - lambda command, **kwargs: calls.append(command), - ) + def fake_ssh(host, command, **kwargs): + calls.append({"host": host, "command": command, **kwargs}) + return subprocess.CompletedProcess(["ssh"], 0, stdout="") - dashboard_launcher._delete_existing_run( - profile=_profile(tmp_path), - host="user@example", - ssh_opts=["-o", "BatchMode=yes"], - run_id="run-001", - local_run_dir=local_run, - ) + monkeypatch.setattr(dashboard_launcher, "parse_args", lambda: _args(tmp_path)) + monkeypatch.setattr(dashboard_launcher, "load_system_profile", lambda _: _profile(tmp_path)) + monkeypatch.setattr(dashboard_launcher, "campaign_launch_defaults", lambda _: object()) + monkeypatch.setattr(dashboard_launcher, "ssh", fake_ssh) + monkeypatch.setattr(dashboard_launcher, "start_rsync", lambda *args, **kwargs: None) + assert dashboard_launcher.main() == 0 assert not local_run.exists() - assert calls - assert calls[0][:4] == ["ssh", "-o", "BatchMode=yes", "user@example"] - assert 'mv -- "$run_dir" "$trash_dir"' in calls[0][-1] - assert 'rm -rf -- "$trash_dir"' in calls[0][-1] - assert 'mkdir -p "$run_dir"' in calls[0][-1] + delete_command = calls[1]["command"] + assert 'mv -- "$run_dir" "$trash_dir"' in delete_command + assert 'rm -rf -- "$trash_dir"' in delete_command + assert 'mkdir -p "$run_dir"' in delete_command + + wrapper_call = calls[2] + assert wrapper_call["command"].endswith("chmod +x /remote/root/bin/chemgraph-academy-run") + assert "chemgraph.academy.runtime.compute_launcher" in wrapper_call["input_text"] + + metadata = json.loads(calls[3]["input_text"]) + assert metadata["run_id"] == "run-001" + assert metadata["lm_base_url"] == "http://lm.example/v1" + assert metadata["remote_run_dir"] == "/remote/root/runs/run-001" + + +def test_dashboard_launcher_rejects_unsafe_overwrite_run_id(tmp_path, monkeypatch) -> None: + monkeypatch.setattr( + dashboard_launcher, + "parse_args", + lambda: _args(tmp_path, run_id="../bad"), + ) + monkeypatch.setattr(dashboard_launcher, "load_system_profile", lambda _: _profile(tmp_path)) + monkeypatch.setattr(dashboard_launcher, "campaign_launch_defaults", lambda _: object()) + monkeypatch.setattr( + dashboard_launcher, + "ssh", + lambda *args, **kwargs: subprocess.CompletedProcess(["ssh"], 0, stdout=""), + ) -def test_delete_existing_run_rejects_unsafe_run_id(tmp_path) -> None: with pytest.raises(RuntimeError, match="unsafe run id"): - dashboard_launcher._delete_existing_run( - profile=_profile(tmp_path), - host="user@example", - ssh_opts=[], - run_id="../bad", - local_run_dir=tmp_path / "mirror", - ) + dashboard_launcher.main() From b71c441dcd7d5e8261709e2dce884365b7fd34bf Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 11:31:56 -0500 Subject: [PATCH 059/119] refactor(academy): run wakeups through chemgraph turn primitive --- pyproject.toml | 1 - src/chemgraph/academy/core/__init__.py | 4 +- src/chemgraph/academy/core/agent.py | 54 +- src/chemgraph/academy/core/tools.py | 108 +-- src/chemgraph/academy/core/turn.py | 618 +++-------------- .../academy/observability/event_log.py | 3 + src/chemgraph/agent/llm_agent.py | 652 +++++++++++------- src/chemgraph/cli/main.py | 16 - src/chemgraph/mcp/fastmcp_client.py | 17 +- src/chemgraph/observability/__init__.py | 17 - src/chemgraph/observability/events.py | 119 ---- .../observability/langgraph_stream.py | 346 ---------- .../observability/local_dashboard_run.py | 170 ----- .../observability/workflow_runner.py | 397 ----------- tests/test_academy_payloads.py | 27 + tests/test_academy_reasoning_phase2.py | 282 +++----- tests/test_tool_adapter_validation.py | 5 - 17 files changed, 672 insertions(+), 2164 deletions(-) delete mode 100644 src/chemgraph/observability/__init__.py delete mode 100644 src/chemgraph/observability/events.py delete mode 100644 src/chemgraph/observability/langgraph_stream.py delete mode 100644 src/chemgraph/observability/local_dashboard_run.py delete mode 100644 src/chemgraph/observability/workflow_runner.py create mode 100644 tests/test_academy_payloads.py diff --git a/pyproject.toml b/pyproject.toml index b5540e22..e667eab2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,6 @@ chemgraph-eval = "chemgraph.eval.cli:main" chemgraph-academy-run = "chemgraph.academy.runtime.compute_launcher:main" chemgraph-academy-dashboard = "chemgraph.academy.runtime.dashboard_launcher:main" chemgraph-dashboard = "chemgraph.academy.dashboard:main" -chemgraph-dashboard-run = "chemgraph.observability.local_dashboard_run:main" [tool.setuptools.packages.find] where = ["src/"] diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py index 06b8f5e7..5a6cc6b0 100644 --- a/src/chemgraph/academy/core/__init__.py +++ b/src/chemgraph/academy/core/__init__.py @@ -12,15 +12,14 @@ from chemgraph.academy.core.lm import load_lm_config from chemgraph.academy.core.prompt import PromptProfile from chemgraph.academy.core.prompt import load_prompt_profile -from chemgraph.academy.core.turn import ChemGraphReasoningRoundEngine from chemgraph.academy.core.turn import ReasoningTurnResult +from chemgraph.academy.core.turn import run_academy_turn __all__ = [ "ChemGraphAgentSpec", "ChemGraphCampaign", "ChemGraphDaemonConfig", "ChemGraphLogicalAgent", - "ChemGraphReasoningRoundEngine", "LLMSettings", "PromptProfile", "ReasoningTurnResult", @@ -30,4 +29,5 @@ "load_lm_config", "load_prompt_profile", "resolve_campaign_resources", + "run_academy_turn", ] diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py index 2dacb033..5463b41d 100644 --- a/src/chemgraph/academy/core/agent.py +++ b/src/chemgraph/academy/core/agent.py @@ -19,7 +19,8 @@ from chemgraph.academy.core.peer_protocol import validate_message from chemgraph.academy.observability.event_log import EventLog from chemgraph.academy.observability.run_artifacts import write_status_snapshot -from chemgraph.academy.core.turn import ChemGraphReasoningRoundEngine +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.turn import run_academy_turn from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.lm import LLMSettings @@ -69,7 +70,6 @@ def __init__( self.finished = False self.last_error: str | None = None self._wake_event: asyncio.Event | None = None - self._reasoning_engine: ChemGraphReasoningRoundEngine | None = None async def agent_on_startup(self) -> None: self._wake_event = asyncio.Event() @@ -78,24 +78,6 @@ async def agent_on_startup(self) -> None: for name, agent_id in self.peer_agent_ids.items() if name in self.peer_names } - self._reasoning_engine = await ChemGraphReasoningRoundEngine.create( - campaign=self.campaign, - spec=self.spec, - llm_settings=self.llm_settings, - prompt_profile=self.prompt_profile, - run_dir=self.run_dir, - max_decisions=self.max_decisions, - tool_invoker=self.tool_invoker, - peer_names=self.peer_names, - peer_handles=self.peer_handles, - received_message_history=self.received_message_history, - outbox=self.outbox, - tool_results=self.tool_results, - get_final_result=lambda: self.final_result, - get_round_index=lambda: self.round_index, - set_final_result=self._set_final_result, - trace=self._trace, - ) self._trace( 'agent_started', { @@ -218,10 +200,35 @@ async def report_state(self) -> dict[str, Any]: } async def _reasoning_round(self) -> bool: - if self._reasoning_engine is None: - raise RuntimeError('agent startup did not initialize reasoning engine') self._trace('round_started', {'round': self.round_index}) - result = await self._reasoning_engine.run_turn() + tools = await build_chemgraph_reasoning_tools( + spec=self.spec, + run_dir=self.run_dir, + tool_invoker=self.tool_invoker, + peer_names=self.peer_names, + peer_handles=self.peer_handles, + outbox=self.outbox, + tool_results=self.tool_results, + get_round_index=lambda: self.round_index, + set_final_result=self._set_final_result, + trace=self._trace, + ) + result = await run_academy_turn( + campaign=self.campaign, + spec=self.spec, + llm_settings=self.llm_settings, + prompt_profile=self.prompt_profile, + run_dir=self.run_dir, + max_decisions=self.max_decisions, + tools=tools, + received_message_history=self.received_message_history, + outbox=self.outbox, + tool_results=self.tool_results, + get_final_result=lambda: self.final_result, + get_round_index=lambda: self.round_index, + trace=self._trace, + peer_names=self.peer_names, + ) self._trace( 'agent_decision', { @@ -232,7 +239,6 @@ async def _reasoning_round(self) -> bool: 'tool_names': list(result.executed_tool_names), 'action_tools_called': list(result.action_tools_called), 'science_tools_called': list(result.science_tools_called), - 'workflow_span_id': result.workflow_span_id, 'thread_id': result.thread_id, 'engine': 'chemgraph_single_agent', 'actions': [ diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py index a933598d..a03f5ec5 100644 --- a/src/chemgraph/academy/core/tools.py +++ b/src/chemgraph/academy/core/tools.py @@ -7,110 +7,43 @@ import time import uuid import asyncio -from collections.abc import Callable -from collections.abc import Mapping -from dataclasses import dataclass -from dataclasses import field +from collections.abc import Callable, Mapping from typing import Any from academy.handle import Handle -from langchain_core.tools import BaseTool -from langchain_core.tools import StructuredTool -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import Field -from pydantic import ValidationError +from langchain_core.tools import BaseTool, StructuredTool +from pydantic import BaseModel, ConfigDict, Field, ValidationError from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.mcp.fastmcp_client import ToolInvocation from chemgraph.mcp.fastmcp_client import fastmcp_tool_schemas -from chemgraph.mcp.fastmcp_client import ( - FastMCPToolInvoker, -) +from chemgraph.mcp.fastmcp_client import FastMCPToolInvoker from chemgraph.academy.core.peer_protocol import build_message from chemgraph.academy.observability.run_files import append_jsonl TraceFn = Callable[[str, dict[str, Any]], None] SetFinalResultFn = Callable[[dict[str, Any]], None] - - -@dataclass -class ReasoningToolRuntimeState: - """Mutable per-turn state updated by ChemGraph reasoning tools.""" - - science_tool_completed: bool = False - finished_turn: bool = False - executed_tool_names: list[str] = field(default_factory=list) - action_tool_names: list[str] = field(default_factory=list) - science_tool_names: list[str] = field(default_factory=list) - background_tasks: set[asyncio.Task[Any]] = field(default_factory=set) - - def reset(self) -> None: - self.science_tool_completed = False - self.finished_turn = False - self.executed_tool_names.clear() - self.action_tool_names.clear() - self.science_tool_names.clear() - - def record_action(self, name: str) -> None: - self.executed_tool_names.append(name) - self.action_tool_names.append(name) - - def record_science(self, name: str) -> None: - self.executed_tool_names.append(name) - self.science_tool_names.append(name) +_BACKGROUND_DELIVERIES: set[asyncio.Task[Any]] = set() class SendMessageArgs(BaseModel): - """Arguments for the LM-visible peer-message action.""" - model_config = ConfigDict(extra="forbid") - recipient: str = Field( - min_length=1, - description="Allowed peer agent name that should receive this message.", - ) - tldr: str = Field( - min_length=1, - max_length=160, - description="One-line user-visible summary for dashboard edge labels.", - ) - content: str = Field( - min_length=1, - max_length=1800, - description="Full peer message content with concise evidence summaries.", - ) - artifact_refs: list[str] = Field( - default_factory=list, - description="JSON array of artifact path strings cited by this message.", - ) - tool_result_ids: list[str] = Field( - default_factory=list, - description="JSON array of ChemGraph tool_result_id strings cited by this message.", - ) + recipient: str = Field(min_length=1, description="Allowed peer agent name.") + tldr: str = Field(min_length=1, max_length=160, description="One-line dashboard edge label.") + content: str = Field(min_length=1, max_length=1800, description="Full peer message content.") + artifact_refs: list[str] = Field(default_factory=list, description="Artifact path strings.") + tool_result_ids: list[str] = Field(default_factory=list, description="ChemGraph tool_result_id strings.") reply_requested: bool = Field( default=False, - description=( - "Set true when this message asks the peer to reply or take a " - "specific follow-up action; false for one-way updates." - ), - ) - reason: str = Field( - min_length=1, - max_length=600, - description="Non-empty sentence explaining why this peer needs the message now.", - ) - confidence: float = Field( - ge=0, - le=1, - description="Numeric confidence from 0 to 1.", + description="True when this asks the peer to reply or act.", ) + reason: str = Field(min_length=1, max_length=600, description="Why this peer needs the message now.") + confidence: float = Field(ge=0, le=1, description="Numeric confidence from 0 to 1.") class SubmitResultArgs(BaseModel): - """Arguments for submitting a logical agent's current result.""" - model_config = ConfigDict(extra="forbid") summary: str = Field(min_length=1, max_length=1200) @@ -122,15 +55,12 @@ class SubmitResultArgs(BaseModel): class FinishTurnArgs(BaseModel): - """Arguments for ending the current logical-agent turn.""" - model_config = ConfigDict(extra="forbid") reason: str = Field(min_length=1, max_length=600) def _stable_validation_errors(exc: ValidationError) -> list[dict[str, str]]: - """Project Pydantic validation errors to a stable LM-facing shape.""" return [ { "field": ".".join(str(part) for part in error.get("loc", ())), @@ -204,7 +134,6 @@ async def build_chemgraph_reasoning_tools( get_round_index: Callable[[], int], set_final_result: SetFinalResultFn, trace: TraceFn, - runtime_state: ReasoningToolRuntimeState, ) -> list[BaseTool]: """Build explicit tools for one ChemGraph-backed reasoning turn.""" @@ -250,8 +179,8 @@ async def _send_message_impl( trace=trace, ), ) - runtime_state.background_tasks.add(task) - task.add_done_callback(runtime_state.background_tasks.discard) + _BACKGROUND_DELIVERIES.add(task) + task.add_done_callback(_BACKGROUND_DELIVERIES.discard) return { "status": "sent", "delivery": "queued", @@ -288,13 +217,11 @@ async def _deliver_message( def _validation_error_handler(tool_name: str) -> Callable[[ValidationError], dict[str, Any]]: def handle(exc: ValidationError) -> dict[str, Any]: - runtime_state.record_action(tool_name) return _invalid_args_response(tool_name, exc, trace) return handle async def send_message(**kwargs: Any) -> dict[str, Any]: - runtime_state.record_action("send_message") try: args = SendMessageArgs.model_validate(kwargs) except ValidationError as exc: @@ -318,7 +245,6 @@ async def send_message(**kwargs: Any) -> dict[str, Any]: ) async def submit_result(**kwargs: Any) -> dict[str, Any]: - runtime_state.record_action("submit_result") try: args = SubmitResultArgs.model_validate(kwargs) except ValidationError as exc: @@ -340,12 +266,10 @@ async def submit_result(**kwargs: Any) -> dict[str, Any]: return {"status": "submitted", "confidence": result["confidence"]} async def finish_turn(**kwargs: Any) -> dict[str, Any]: - runtime_state.record_action("finish_turn") try: args = FinishTurnArgs.model_validate(kwargs) except ValidationError as exc: return _invalid_args_response("finish_turn", exc, trace) - runtime_state.finished_turn = True trace("turn_finished_without_external_action", {"reason": args.reason}) return {"status": "finished", "reason": args.reason} @@ -405,7 +329,6 @@ async def run_fastmcp_tool( __tool_name: str = tool_spec.name, **kwargs: Any, ) -> dict[str, Any]: - runtime_state.record_science(__tool_name) if __tool_name not in spec.tool_names: raise RuntimeError( f"{spec.name} cannot call unavailable tool {__tool_name}", @@ -437,7 +360,6 @@ async def run_fastmcp_tool( trace("tool_call_failed", failure) raise RuntimeError(f"{__tool_name} failed: {failure['error']}") - runtime_state.science_tool_completed = True record = { **started, "timestamp": time.time(), diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index 28452970..be23084d 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -1,568 +1,146 @@ -"""Run one Academy logical-agent wakeup through ChemGraph LangGraph.""" +"""Run one Academy logical-agent wakeup through ChemGraph.""" from __future__ import annotations - import json import time from collections.abc import Callable -from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path from typing import Any - -from academy.handle import Handle from langchain_core.tools import BaseTool - -from chemgraph.mcp.fastmcp_client import ( - FastMCPToolInvoker, -) -from chemgraph.academy.core.tools import ( - ReasoningToolRuntimeState, -) -from chemgraph.academy.core.tools import ( - build_chemgraph_reasoning_tools, -) -from chemgraph.academy.core.campaign import ChemGraphAgentSpec -from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphAgentSpec, ChemGraphCampaign from chemgraph.academy.core.campaign import visible_resources_payload from chemgraph.academy.core.lm import LLMSettings from chemgraph.academy.core.prompt import PromptProfile from chemgraph.academy.observability.run_files import read_json_file -from chemgraph.academy.observability.run_files import read_jsonl +from chemgraph.agent.llm_agent import run_turn TraceFn = Callable[[str, dict[str, Any]], None] -SetFinalResultFn = Callable[[dict[str, Any]], None] - +ACTION_TOOL_NAMES = frozenset({"send_message", "ask_peer", "submit_result", "finish_turn"}) +TERMINAL_TOOL_NAMES = ("finish_turn", "submit_result") @dataclass(frozen=True) class ReasoningTurnResult: - """Summary of one ChemGraph-managed logical-agent reasoning turn.""" - final_text: str - state: dict[str, Any] - tool_calls_completed: int + executed_tool_names: tuple[str, ...] action_tools_called: tuple[str, ...] science_tools_called: tuple[str, ...] - executed_tool_names: tuple[str, ...] requested_finish: bool requested_self_wake: bool - workflow_span_id: str thread_id: str - -class ChemGraphReasoningRoundEngine: - """Use ChemGraph single_agent as the per-wakeup reasoning loop.""" - - def __init__( - self, - *, - campaign: ChemGraphCampaign, - spec: ChemGraphAgentSpec, - llm_settings: LLMSettings, - prompt_profile: PromptProfile, - run_dir: Path, - max_decisions: int, - tools: list[BaseTool], - runtime_state: ReasoningToolRuntimeState, - received_message_history: list[dict[str, Any]], - outbox: list[dict[str, Any]], - tool_results: list[dict[str, Any]], - get_final_result: Callable[[], dict[str, Any] | None], - get_round_index: Callable[[], int], - trace: TraceFn, - peer_names: tuple[str, ...] = (), - ) -> None: - self.campaign = campaign - self.spec = spec - self.llm_settings = llm_settings - self.prompt_profile = prompt_profile - self.run_dir = run_dir - self.max_decisions = max_decisions - self.tools = list(tools) - self.runtime_state = runtime_state - self.received_message_history = received_message_history - self.outbox = outbox - self.tool_results = tool_results - self.peer_names = tuple(peer_names) - self.get_final_result = get_final_result - self.get_round_index = get_round_index - self.trace = trace - - @classmethod - async def create( - cls, - *, - campaign: ChemGraphCampaign, - spec: ChemGraphAgentSpec, - llm_settings: LLMSettings, - prompt_profile: PromptProfile, - run_dir: Path, - max_decisions: int, - tool_invoker: FastMCPToolInvoker, - peer_names: tuple[str, ...], - peer_handles: Mapping[str, Handle[Any]], - received_message_history: list[dict[str, Any]], - outbox: list[dict[str, Any]], - tool_results: list[dict[str, Any]], - get_final_result: Callable[[], dict[str, Any] | None], - get_round_index: Callable[[], int], - set_final_result: SetFinalResultFn, - trace: TraceFn, - ) -> "ChemGraphReasoningRoundEngine": - runtime_state = ReasoningToolRuntimeState() - tools = await build_chemgraph_reasoning_tools( - spec=spec, - run_dir=run_dir, - tool_invoker=tool_invoker, - peer_names=peer_names, - peer_handles=peer_handles, - outbox=outbox, - tool_results=tool_results, - get_round_index=get_round_index, - set_final_result=set_final_result, - trace=trace, - runtime_state=runtime_state, - ) - return cls( - campaign=campaign, - spec=spec, - llm_settings=llm_settings, - prompt_profile=prompt_profile, - run_dir=run_dir, - max_decisions=max_decisions, - tools=tools, - runtime_state=runtime_state, - received_message_history=received_message_history, - outbox=outbox, - tool_results=tool_results, - peer_names=peer_names, - get_final_result=get_final_result, - get_round_index=get_round_index, - trace=trace, - ) - - async def run_turn(self) -> ReasoningTurnResult: - """Run one turn-local ChemGraph workflow for the current wakeup.""" - from chemgraph.agent.llm_agent import ChemGraph - from chemgraph.observability.events import WorkflowEventContext - from chemgraph.observability.events import emit_workflow_event - from chemgraph.observability.events import new_span_id - from chemgraph.observability.events import workflow_event_context - - round_index = self.get_round_index() - thread_id = f"{self.spec.name}-round-{round_index}" - workflow_span_id = new_span_id(f"chemgraph-turn-{self.spec.name}") - parent_span_id = f"academy-round-{self.spec.name}-{round_index}" - query = self.build_wakeup_query(round_index=round_index) - log_dir = ( - self.run_dir - / "chemgraph_turns" - / f"{self.spec.name}-round-{round_index:04d}" - ) - log_dir.mkdir(parents=True, exist_ok=True) - - self.runtime_state.reset() - self.trace( - "chemgraph_reasoning_turn_started", - { - "round": round_index, - "thread_id": thread_id, - "workflow_span_id": workflow_span_id, - "tool_names": [tool.name for tool in self.tools], - }, - ) - context = WorkflowEventContext( - run_id=self.run_dir.name, - run_dir=str(self.run_dir), - agent_id=self.spec.name, - role=self.spec.role, - parent_span_id=parent_span_id, - tool_name=None, - ) - - with workflow_event_context( - jsonl_path=self.run_dir / "events.jsonl", - context=context, - ): - emit_workflow_event( - "workflow_started", - { - "workflow_type": "single_agent", - "workflow_node": "ChemGraphReasoningRoundEngine", - "round": round_index, - "thread_id": thread_id, - "tool_names": [tool.name for tool in self.tools], - "log_dir": str(log_dir), - }, - span_id=workflow_span_id, - parent_span_id=parent_span_id, - ) - agent = ChemGraph( - model_name=self.llm_settings.model, - workflow_type="single_agent", - base_url=self.llm_settings.base_url, - api_key=self.llm_settings.api_key, - argo_user=self.llm_settings.user, - system_prompt=self.prompt_profile.system_prompt, - return_option="state", - recursion_limit=self.prompt_profile.langchain_recursion_limit, - tools=self.tools, - terminal_tool_names=("finish_turn", "submit_result"), - enable_memory=False, - log_dir=str(log_dir), - ) - try: - state = await agent.run( - query, - config={"configurable": {"thread_id": thread_id}}, - workflow_span_id=workflow_span_id, - ) - except Exception as exc: - emit_workflow_event( - "workflow_finished", - { - "workflow_type": "single_agent", - "workflow_node": "ChemGraphReasoningRoundEngine", - "round": round_index, - "thread_id": thread_id, - "status": "failed", - "error": repr(exc), - "log_dir": str(log_dir), - }, - span_id=workflow_span_id, - parent_span_id=parent_span_id, - ) - raise - else: - state = _ensure_state_dict(state) - emit_workflow_event( - "workflow_finished", - { - "workflow_type": "single_agent", - "workflow_node": "ChemGraphReasoningRoundEngine", - "round": round_index, - "thread_id": thread_id, - "status": "completed", - "log_dir": str(log_dir), - }, - span_id=workflow_span_id, - parent_span_id=parent_span_id, - ) - - if not self.runtime_state.executed_tool_names: - raise RuntimeError( - "ChemGraph reasoning turn returned without calling an " - "Academy action or science tool; logical agents must call " - "finish_turn when no external action is useful.", - ) - - result = ReasoningTurnResult( - final_text=_extract_final_text(state), - state=state, - tool_calls_completed=len(self.runtime_state.executed_tool_names), - action_tools_called=tuple(self.runtime_state.action_tool_names), - science_tools_called=tuple(self.runtime_state.science_tool_names), - executed_tool_names=tuple(self.runtime_state.executed_tool_names), - requested_finish=self.runtime_state.finished_turn, - requested_self_wake=self.runtime_state.science_tool_completed, - workflow_span_id=workflow_span_id, - thread_id=thread_id, - ) - self.trace( - "chemgraph_reasoning_turn_finished", - { - "round": round_index, - "thread_id": thread_id, - "workflow_span_id": workflow_span_id, - "action_tools_called": list(result.action_tools_called), - "science_tools_called": list(result.science_tools_called), - "requested_finish": result.requested_finish, - "requested_self_wake": result.requested_self_wake, - }, - ) - return result - - def build_wakeup_query(self, *, round_index: int) -> str: - """Build the user message for one ChemGraph turn.""" - state = self.build_wakeup_state(round_index=round_index) - return json.dumps(state, sort_keys=True) - - def build_wakeup_state(self, *, round_index: int) -> dict[str, Any]: - """Build the exact state visible to the logical agent this turn.""" - limits = self.prompt_profile.state_limits - return { - "campaign": self.campaign.run_id, - "user_task": self.campaign.user_task, - "agent_name": self.spec.name, - "role": self.spec.role, - "mission": self.spec.mission, - "round": round_index, - "max_decisions": self.max_decisions, - "resources": visible_resources_payload(self.campaign, self.spec), - "allowed_peers": list(self.spec.allowed_peers), - "peer_status": build_peer_status( - run_dir=self.run_dir, - peer_names=self.peer_names, - ), - "available_chemgraph_tools": list(self.spec.tool_names), - "received_messages": ( - self.received_message_history[ - -limits.received_messages_last_n : - ] - if limits.received_messages_last_n - else [] - ), - "local_chemgraph_tool_results": ( - self.tool_results[-limits.tool_results_last_n :] - if limits.tool_results_last_n - else [] - ), - "recent_actions": build_recent_actions( - outbox=self.outbox, - tool_results=self.tool_results, - limit=limits.actions_last_n, - ), - "current_final_result": self.get_final_result(), - "required_protocol": self.prompt_profile.protocol_prompt, - } - - -def build_peer_status( +async def run_academy_turn( *, + campaign: ChemGraphCampaign, + spec: ChemGraphAgentSpec, + llm_settings: LLMSettings, + prompt_profile: PromptProfile, run_dir: Path, - peer_names: tuple[str, ...], - event_scan_limit: int = 1000, -) -> dict[str, dict[str, Any]]: - """Return compact status snapshots for peers visible to this agent.""" - if not peer_names: - return {} - - now = time.time() - peers = set(peer_names) - status: dict[str, dict[str, Any]] = { - peer: _status_from_agent_file(run_dir, peer, now=now) - for peer in peer_names - } - - for event in read_jsonl(run_dir / "events.jsonl")[-event_scan_limit:]: - agent_id = event.get("agent_id") - if agent_id not in peers: - continue - kind = str(event.get("event") or "") - timestamp = _float_or_none(event.get("timestamp")) - payload = event.get("payload") - payload = payload if isinstance(payload, dict) else {} - peer_status = status[str(agent_id)] - - if kind == "round_started": - peer_status["state"] = "busy" - peer_status["current_activity"] = { - "type": "reasoning_round", - "round": payload.get("round"), - "started_at": timestamp, - } - _set_update_age(peer_status, timestamp, now=now) - elif kind == "tool_call_started": - peer_status["state"] = "busy" - peer_status["current_activity"] = { - "type": "tool_call", - "tool_name": payload.get("tool_name"), - "tool_result_id": payload.get("tool_result_id"), - "tool_call_id": payload.get("tool_call_id"), - "started_at": timestamp, - } - _set_update_age(peer_status, timestamp, now=now) - elif kind in {"tool_call_finished", "tool_call_failed"}: - peer_status["state"] = "busy" - peer_status["current_activity"] = { - "type": "reasoning_after_tool", - "last_tool": payload.get("tool_name"), - "tool_result_id": payload.get("tool_result_id"), - "status": payload.get("status"), - "updated_at": timestamp, - } - _set_update_age(peer_status, timestamp, now=now) - elif kind == "message_sent": - peer_status["last_outbox_tldr"] = ( - payload.get("tldr") or _preview(payload.get("content")) - ) - peer_status["last_outbox_message_id"] = payload.get("message_id") - _set_update_age(peer_status, timestamp, now=now) - elif kind == "belief_updated": - peer_status["last_belief"] = _compact_belief(payload) - _set_update_age(peer_status, timestamp, now=now) - elif kind in { - "round_finished", - "turn_finished_without_external_action", - "workflow_finished", - }: - if kind == "workflow_finished" and payload.get("status") == "failed": - peer_status["state"] = "error" - else: - peer_status["state"] = "idle" - peer_status["current_activity"] = None - _set_update_age(peer_status, timestamp, now=now) - elif kind == "agent_error": - peer_status["state"] = "error" - peer_status["last_error"] = payload.get("error") - peer_status["current_activity"] = None - _set_update_age(peer_status, timestamp, now=now) - elif kind == "daemon_stopped": - peer_status["state"] = "finished" - peer_status["finished"] = True - peer_status["current_activity"] = None - _set_update_age(peer_status, timestamp, now=now) - - return status - - -def _status_from_agent_file( - run_dir: Path, - peer_name: str, - *, - now: float, -) -> dict[str, Any]: - data = read_json_file( - run_dir / "agent_status" / f"{peer_name}.json", - default={}, + max_decisions: int, + tools: list[BaseTool], + received_message_history: list[dict[str, Any]], + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + get_final_result: Callable[[], dict[str, Any] | None], + get_round_index: Callable[[], int], + trace: TraceFn, + peer_names: tuple[str, ...] = (), +) -> ReasoningTurnResult: + round_index = get_round_index() + thread_id = f"{spec.name}-round-{round_index}" + trace("chemgraph_reasoning_turn_started", {"round": round_index, "thread_id": thread_id, "tool_names": [t.name for t in tools]}) + + def on_event(event: str, payload: dict) -> None: + trace(event, {"round": round_index, **payload}) + + result = await run_turn( + query=json.dumps(_state(campaign, spec, prompt_profile, run_dir, max_decisions, round_index, received_message_history, outbox, tool_results, get_final_result, peer_names), sort_keys=True), + tools=tools, + model_name=llm_settings.model, + base_url=llm_settings.base_url, + api_key=llm_settings.api_key, + argo_user=llm_settings.user, + system_prompt=prompt_profile.system_prompt, + recursion_limit=prompt_profile.langchain_recursion_limit, + thread_id=thread_id, + terminal_tool_names=TERMINAL_TOOL_NAMES, + on_event=on_event, + ) + if not result.executed_tool_names: + raise RuntimeError("ChemGraph reasoning turn returned without calling an Academy action or science tool; call finish_turn when no external action is useful.") + action_tools = tuple(n for n in result.executed_tool_names if n in ACTION_TOOL_NAMES) + science_tools = tuple(n for n in result.executed_tool_names if n not in ACTION_TOOL_NAMES) + out = ReasoningTurnResult( + final_text=result.final_text, + executed_tool_names=result.executed_tool_names, + action_tools_called=action_tools, + science_tools_called=science_tools, + requested_finish=result.terminal_tool in TERMINAL_TOOL_NAMES, + requested_self_wake=bool(science_tools), + thread_id=result.thread_id, ) - state = "unknown" - if data: - if data.get("last_error"): - state = "error" - elif data.get("finished") is True: - state = "finished" - else: - state = "idle" - timestamp = _float_or_none(data.get("status_updated_at")) + trace("chemgraph_reasoning_turn_finished", {"round": round_index, "thread_id": out.thread_id, "action_tools_called": list(action_tools), "science_tools_called": list(science_tools), "requested_finish": out.requested_finish, "requested_self_wake": out.requested_self_wake}) + return out + +def _state(campaign, spec, profile, run_dir, max_decisions, round_index, messages, outbox, results, get_final_result, peer_names) -> dict[str, Any]: + limits = profile.state_limits return { - "state": state, - "round": data.get("round"), - "finished": bool(data.get("finished")) if data else False, - "last_error": data.get("last_error"), - "current_activity": data.get("current_activity"), - "seconds_since_update": _age(timestamp, now=now), - "last_outbox_tldr": _last_outbox_tldr(data), - "last_outbox_message_id": _last_outbox_message_id(data), - "last_belief": _compact_belief(data.get("belief")), + "campaign": campaign.run_id, + "user_task": campaign.user_task, + "agent_name": spec.name, + "role": spec.role, + "mission": spec.mission, + "round": round_index, + "max_decisions": max_decisions, + "resources": visible_resources_payload(campaign, spec), + "allowed_peers": list(spec.allowed_peers), + "peer_status": build_peer_status(run_dir=run_dir, peer_names=peer_names), + "available_chemgraph_tools": list(spec.tool_names), + "received_messages": _tail(messages, limits.received_messages_last_n), + "local_chemgraph_tool_results": _tail(results, limits.tool_results_last_n), + "recent_actions": build_recent_actions(outbox=outbox, tool_results=results, limit=limits.actions_last_n), + "current_final_result": get_final_result(), + "required_protocol": profile.protocol_prompt, } +def build_peer_status(*, run_dir: Path, peer_names: tuple[str, ...]) -> dict[str, dict[str, Any]]: + return {peer: _status(run_dir, peer, now=time.time()) for peer in peer_names} -def _last_outbox_tldr(data: Mapping[str, Any]) -> str | None: - recent = data.get("recent_outbox") - if not isinstance(recent, list) or not recent: - return None - last = recent[-1] - if not isinstance(last, dict): - return None - return last.get("tldr") or _preview(last.get("content")) - +def build_recent_actions(*, outbox: list[dict[str, Any]], tool_results: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: + if limit <= 0: + return [] + actions = [{"type": "send_message", "recipient": m.get("recipient"), "reply_requested": bool(m.get("reply_requested")), "tldr": m.get("tldr") or _preview(m.get("content")), "message_id": m.get("message_id"), "timestamp": m.get("timestamp")} for m in outbox[-limit:]] + actions += [{"type": "tool_call", "tool_name": r.get("tool_name"), "tool_result_id": r.get("tool_result_id"), "status": r.get("status"), "timestamp": r.get("timestamp")} for r in tool_results[-limit:]] + return sorted(actions, key=lambda i: float(i.get("timestamp") or 0.0))[-limit:] -def _last_outbox_message_id(data: Mapping[str, Any]) -> str | None: - recent = data.get("recent_outbox") - if not isinstance(recent, list) or not recent: - return None - last = recent[-1] - if not isinstance(last, dict): - return None - value = last.get("message_id") - return str(value) if value else None +def _status(run_dir: Path, peer: str, *, now: float) -> dict[str, Any]: + data = read_json_file(run_dir / "agent_status" / f"{peer}.json", default={}) + timestamp = _float(data.get("status_updated_at")) + state = "unknown" if not data else "error" if data.get("last_error") else "finished" if data.get("finished") else "idle" + return {"state": state, "round": data.get("round"), "finished": bool(data.get("finished")) if data else False, "last_error": data.get("last_error"), "current_activity": data.get("current_activity"), "seconds_since_update": None if timestamp is None else max(0.0, round(now - timestamp, 3)), "last_outbox_tldr": _last_outbox(data), "last_belief": _belief(data.get("belief"))} -def _compact_belief(value: Any) -> dict[str, Any] | None: - if not isinstance(value, dict): - return None - summary = value.get("summary") or value.get("hypothesis") - if not summary: - return None - return { - "summary": _preview(summary, max_chars=220), - "confidence": value.get("confidence"), - } +def _last_outbox(data: dict[str, Any]) -> str | None: + recent = data.get("recent_outbox") + return (recent[-1].get("tldr") or _preview(recent[-1].get("content"))) if isinstance(recent, list) and recent and isinstance(recent[-1], dict) else None -def _set_update_age( - peer_status: dict[str, Any], - timestamp: float | None, - *, - now: float, -) -> None: - peer_status["seconds_since_update"] = _age(timestamp, now=now) +def _belief(value: Any) -> dict[str, Any] | None: + summary = value.get("summary") or value.get("hypothesis") if isinstance(value, dict) else None + return {"summary": _preview(summary, max_chars=220), "confidence": value.get("confidence")} if summary else None -def _age(timestamp: float | None, *, now: float) -> float | None: - if timestamp is None: - return None - return max(0.0, round(now - timestamp, 3)) +def _tail(items: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: + return items[-limit:] if limit else [] -def _float_or_none(value: Any) -> float | None: - if isinstance(value, bool) or value is None: - return None +def _float(value: Any) -> float | None: try: - return float(value) + return None if value is None or isinstance(value, bool) else float(value) except (TypeError, ValueError): return None -def build_recent_actions( - *, - outbox: list[dict[str, Any]], - tool_results: list[dict[str, Any]], - limit: int, -) -> list[dict[str, Any]]: - """Build a compact chronological action history for LM prompt state.""" - if limit <= 0: - return [] - - actions: list[dict[str, Any]] = [] - for message in outbox[-limit:]: - actions.append( - { - "type": "send_message", - "recipient": message.get("recipient"), - "reply_requested": bool(message.get("reply_requested")), - "tldr": message.get("tldr") or _preview(message.get("content")), - "message_id": message.get("message_id"), - "timestamp": message.get("timestamp"), - }, - ) - - for result in tool_results[-limit:]: - actions.append( - { - "type": "tool_call", - "tool_name": result.get("tool_name"), - "tool_result_id": result.get("tool_result_id"), - "status": result.get("status"), - "timestamp": result.get("timestamp"), - }, - ) - - actions.sort(key=lambda item: float(item.get("timestamp") or 0.0)) - return actions[-limit:] - - def _preview(value: Any, *, max_chars: int = 160) -> str: text = "" if value is None else str(value) - if len(text) <= max_chars: - return text - return text[: max_chars - 1] + "..." - - -def _ensure_state_dict(state: Any) -> dict[str, Any]: - if isinstance(state, dict): - return state - return {"value": state} - - -def _extract_final_text(state: Mapping[str, Any]) -> str: - messages = state.get("messages") - if not isinstance(messages, list) or not messages: - return "" - last = messages[-1] - if isinstance(last, dict): - content = last.get("content") - return "" if content is None else str(content) - content = getattr(last, "content", None) - return "" if content is None else str(content) + return text if len(text) <= max_chars else text[: max_chars - 1] + "..." diff --git a/src/chemgraph/academy/observability/event_log.py b/src/chemgraph/academy/observability/event_log.py index e42013b6..c42a41cc 100644 --- a/src/chemgraph/academy/observability/event_log.py +++ b/src/chemgraph/academy/observability/event_log.py @@ -55,6 +55,9 @@ "workflow_finished", "workflow_node_started", "workflow_node_finished", + "llm_call_started", + "llm_call_finished", + "llm_call_failed", "llm_decision", "workflow_output", ] diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 1537da73..dbe55245 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -2,7 +2,8 @@ import datetime import dataclasses import os -from typing import Callable, Collection, List, Optional +import time +from typing import Any, Callable, Collection, List, Optional import uuid from chemgraph.memory.store import SessionStore @@ -22,8 +23,6 @@ supported_gemini_models, ) -from chemgraph.observability.langgraph_stream import ChemGraphWorkflowCallback -from chemgraph.observability.langgraph_stream import emit_live_message_events from chemgraph.schemas.ase_input import ( get_available_calculator_names, get_calculator_selection_context, @@ -42,8 +41,8 @@ aggregator_prompt as default_aggregator_prompt, planner_prompt as default_planner_prompt, ) -from langgraph.types import Command from langgraph.errors import GraphInterrupt +from langchain_core.callbacks import BaseCallbackHandler from chemgraph.graphs.single_agent import construct_single_agent_graph @@ -211,6 +210,355 @@ def _custom_openai_compatible_kwargs( return kwargs +EventCallback = Callable[[str, dict], None] + + +@dataclasses.dataclass(frozen=True) +class TurnResult: + """Result of one bounded ChemGraph single-agent turn.""" + + final_text: str + state: dict[str, Any] + executed_tool_names: tuple[str, ...] + terminal_tool: str | None + thread_id: str + duration_s: float + + +class _TurnEventCallback(BaseCallbackHandler): + """Forward LangChain callback events to a small stable callback surface.""" + + def __init__(self, on_event: EventCallback, thread_id: str) -> None: + self._on_event = on_event + self._thread_id = thread_id + + def _emit(self, event: str, payload: dict[str, Any]) -> None: + try: + self._on_event(event, {"thread_id": self._thread_id, **payload}) + except Exception: # noqa: BLE001 - callbacks must not break the run. + logger.debug("turn event callback failed", exc_info=True) + + def on_chat_model_start(self, serialized, messages, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(messages[0]) if messages else 0, + }, + ) + + def on_llm_start(self, serialized, prompts, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(prompts or []), + }, + ) + + def on_llm_end(self, response, **kwargs) -> None: + payload: dict[str, Any] = {} + usage = getattr(response, "llm_output", None) + if isinstance(usage, dict): + payload["llm_output"] = usage + self._emit("llm_call_finished", payload) + + def on_llm_error(self, error, **kwargs) -> None: + self._emit("llm_call_failed", {"error": repr(error)}) + + def on_tool_start(self, serialized, input_str, **kwargs) -> None: + self._emit( + "tool_call_started", + { + "tool_name": _serialized_name(serialized), + "arguments": serialize_state(input_str), + }, + ) + + def on_tool_end(self, output, **kwargs) -> None: + payload: dict[str, Any] = {"result": serialize_state(output)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_finished", payload) + + def on_tool_error(self, error, **kwargs) -> None: + payload = {"error": repr(error)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_failed", payload) + + +def _serialized_name(serialized: Any) -> str | None: + if isinstance(serialized, dict): + return serialized.get("name") or serialized.get("id") + return None + + +def _message_tool_calls(message: Any) -> list[Any]: + if isinstance(message, dict): + calls = message.get("tool_calls") + else: + calls = getattr(message, "tool_calls", None) + return calls if isinstance(calls, list) else [] + + +def _tool_message_name(message: Any) -> str | None: + if isinstance(message, dict): + name = message.get("name") + role = message.get("role") or message.get("type") + if name and role in {"tool", "tool_message", "ToolMessage"}: + return str(name) + return str(name) if name and not _message_tool_calls(message) else None + name = getattr(message, "name", None) + message_type = getattr(message, "type", None) + if name and message_type == "tool": + return str(name) + return str(name) if name and not _message_tool_calls(message) else None + + +def _call_name(call: Any) -> str | None: + if isinstance(call, dict): + if call.get("name"): + return str(call["name"]) + function = call.get("function") + if isinstance(function, dict) and function.get("name"): + return str(function["name"]) + name = getattr(call, "name", None) + return str(name) if name else None + + +def _state_messages(state: Any) -> list[Any]: + if isinstance(state, dict): + messages = state.get("messages", []) + else: + messages = getattr(state, "messages", []) + return list(messages or []) + + +def _executed_tool_names(messages: list[Any]) -> tuple[str, ...]: + names: list[str] = [] + for message in messages: + name = _tool_message_name(message) + if name: + names.append(name) + if names: + return tuple(names) + for message in messages: + for call in _message_tool_calls(message): + if name := _call_name(call): + names.append(name) + return tuple(names) + + +def _terminal_tool_name( + executed_tool_names: tuple[str, ...], + terminal_tool_names: Collection[str], +) -> str | None: + terminal = set(terminal_tool_names) + for name in reversed(executed_tool_names): + if name in terminal: + return name + return None + + +def _message_text(message: Any) -> str: + content = message.get("content") if isinstance(message, dict) else getattr(message, "content", "") + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text") or item.get("content") or item)) + else: + parts.append(str(item)) + return "\n".join(parts) + return "" if content is None else str(content) + + +def _final_text(messages: list[Any]) -> str: + for message in reversed(messages): + message_type = ( + message.get("role") or message.get("type") + if isinstance(message, dict) + else getattr(message, "type", None) + ) + if message_type in {"ai", "assistant"}: + return _message_text(message) + return _message_text(messages[-1]) if messages else "" + + +def _load_turn_llm( + *, + model_name: str, + base_url: str | None, + api_key: str | None, + argo_user: str | None, +) -> Any: + temperature = 0.0 + if model_name in supported_openai_models or model_name in supported_argo_models: + kwargs = { + "model_name": model_name, + "temperature": temperature, + "base_url": base_url, + } + if argo_user is not None: + kwargs["argo_user"] = argo_user + return load_openai_model(**kwargs) + if model_name in supported_ollama_models: + return load_ollama_model(model_name=model_name, temperature=temperature) + if model_name in supported_alcf_models: + return load_alcf_model( + model_name=model_name, + base_url=base_url, + api_key=api_key, + ) + if model_name in supported_anthropic_models: + return load_anthropic_model( + model_name=model_name, + api_key=api_key, + temperature=temperature, + ) + if model_name in supported_gemini_models: + return load_gemini_model( + model_name=model_name, + api_key=api_key, + temperature=temperature, + ) + if model_name.startswith("groq:"): + return load_groq_model( + model_name=model_name, + api_key=api_key, + temperature=temperature, + ) + + endpoint = os.getenv("VLLM_BASE_URL", base_url or "") + key = os.getenv("OPENAI_API_KEY", api_key or "dummy_vllm_key") + if not endpoint: + raise ValueError(f"Unsupported model or missing base URL for: {model_name}") + from langchain_openai import ChatOpenAI + + return ChatOpenAI( + **_custom_openai_compatible_kwargs( + model_name=model_name, + temperature=temperature, + base_url=endpoint, + api_key=key, + max_tokens=4000, + top_p=1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + argo_user=argo_user, + ), + ) + + +async def run_turn( + *, + query: str, + tools: list[Any] | None = None, + model_name: str = "gpt-4o-mini", + base_url: str | None = None, + api_key: str | None = None, + argo_user: str | None = None, + system_prompt: str = single_agent_prompt, + formatter_prompt: str = default_formatter_prompt, + structured_output: bool = False, + generate_report: bool = False, + report_prompt: str = default_report_prompt, + recursion_limit: int = 50, + thread_id: str | None = None, + terminal_tool_names: Collection[str] = (), + human_supervised: bool = False, + on_event: EventCallback | None = None, +) -> TurnResult: + """Run one bounded single-agent ChemGraph LangGraph turn.""" + + started = time.time() + thread_id = thread_id or str(uuid.uuid4()) + callbacks = [_TurnEventCallback(on_event, thread_id)] if on_event else [] + event = on_event or (lambda _event, _payload: None) + event( + "workflow_started", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "tool_names": [getattr(tool, "name", str(tool)) for tool in tools or []], + }, + ) + llm = _load_turn_llm( + model_name=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + ) + workflow = construct_single_agent_graph( + llm, + system_prompt, + structured_output, + formatter_prompt, + generate_report, + report_prompt, + tools, + human_supervised=human_supervised, + terminal_tool_names=terminal_tool_names, + ) + config: dict[str, Any] = { + "configurable": {"thread_id": thread_id}, + "recursion_limit": recursion_limit, + } + if callbacks: + config["callbacks"] = callbacks + + last_state: Any = None + try: + async for state in workflow.astream( + {"messages": query}, + stream_mode="values", + config=config, + ): + last_state = state + except Exception as exc: + event( + "workflow_finished", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "status": "failed", + "error": repr(exc), + "duration_s": round(time.time() - started, 3), + }, + ) + raise + + if last_state is None: + raise RuntimeError("ChemGraph turn produced no states.") + + messages = _state_messages(last_state) + executed_tools = _executed_tool_names(messages) + terminal_tool = _terminal_tool_name(executed_tools, terminal_tool_names) + result = TurnResult( + final_text=_final_text(messages), + state=serialize_state(last_state), + executed_tool_names=executed_tools, + terminal_tool=terminal_tool, + thread_id=thread_id, + duration_s=round(time.time() - started, 3), + ) + event( + "workflow_finished", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "status": "completed", + "executed_tool_names": list(result.executed_tool_names), + "terminal_tool": terminal_tool, + "duration_s": result.duration_s, + }, + ) + return result + + class ChemGraph: """A graph-based workflow for LLM-powered computational chemistry tasks. @@ -471,6 +819,9 @@ def __init__( self.workflow_type = workflow_type self.model_name = model_name + self.base_url = base_url + self.api_key = api_key + self.argo_user = argo_user self.system_prompt = system_prompt self.formatter_prompt = formatter_prompt self.structured_output = structured_output @@ -954,7 +1305,6 @@ async def run( query: str, config=None, resume_from: Optional[str] = None, - workflow_span_id: Optional[str] = None, ): """ Async-only runner. Requires `self.workflow.astream(...)`. @@ -977,191 +1327,19 @@ async def run( Session ID to load context from. The previous conversation summary is prepended to the query. """ + if config is None: + config = {} + if not isinstance(config, dict): + raise TypeError(f"`config` must be a dictionary, got {type(config).__name__}") + if "thread_id" in config: + config.setdefault("configurable", {})["thread_id"] = str(config["thread_id"]) + config.setdefault("configurable", {}).setdefault("thread_id", "1") + config["recursion_limit"] = self.recursion_limit - def _validate_config(cfg): - """Normalize and validate the LangGraph run configuration. - - Parameters - ---------- - cfg : dict or None - User-provided configuration, optionally with top-level - ``thread_id``. - - Returns - ------- - dict - Config with ``configurable.thread_id`` and recursion limit set. - """ - if cfg is None: - cfg = {} - if not isinstance(cfg, dict): - raise TypeError( - f"`config` must be a dictionary, got {type(cfg).__name__}" - ) - - # Support top-level thread_id for convenience - if "thread_id" in cfg: - if "configurable" not in cfg: - cfg["configurable"] = {} - cfg["configurable"]["thread_id"] = str(cfg["thread_id"]) - - cfg.setdefault("configurable", {}).setdefault("thread_id", "1") - cfg["recursion_limit"] = self.recursion_limit - if workflow_span_id: - callbacks = list(cfg.get("callbacks") or []) - callbacks.append( - ChemGraphWorkflowCallback(workflow_span_id=workflow_span_id), - ) - cfg["callbacks"] = callbacks - return cfg - - def _save_state_and_select_return(last_state, cfg): - """Persist the final state and apply the configured return option. - - Parameters - ---------- - last_state : dict - Final streamed graph state. - cfg : dict - LangGraph run configuration used to retrieve/write state. - - Returns - ------- - Any - Final message or serialized state, depending on - ``self.return_option``. - """ - log_dir = self.log_dir - if not log_dir: - log_dir = "cg_logs" - - os.makedirs(log_dir, exist_ok=True) - log_path = None - self.write_state(config=cfg, file_path=log_path) - - if self.return_option == "last_message": - return last_state["messages"][-1] - elif self.return_option == "state": - return serialize_state(self.get_state(config=cfg)) - else: - raise ValueError( - f"Unsupported return_option: {self.return_option}. Use 'last_message' or 'state'." - ) - - async def _stream_until_interrupt(stream_input, cfg): - """Stream the workflow until completion or an interrupt. - - Parameters - ---------- - stream_input : dict or Command - Initial graph input or resume command to stream. - cfg : dict - LangGraph run configuration. - - Returns - ------- - tuple - ``(last_state, interrupt_value)`` where ``interrupt_value`` is - ``None`` when the graph completed normally. - - LangGraph's ``astream(stream_mode="values")`` does **not** - raise ``GraphInterrupt``. Instead the stream emits a state - containing an ``__interrupt__`` key and then ends. We - detect this in two ways: - - 1. Check for the ``__interrupt__`` key in streamed states. - 2. After the stream ends, inspect the checkpoint snapshot - for pending interrupt tasks. - """ - prev_msgs: list = [] - last_st = None - interrupt_val = None - try: - async for s in self.workflow.astream( - stream_input, stream_mode="values", config=cfg - ): - # Detect inline interrupt marker emitted by astream. - if "__interrupt__" in s: - int_data = s["__interrupt__"] - if isinstance(int_data, (list, tuple)) and int_data: - interrupt_val = int_data[0].value - elif hasattr(int_data, "value"): - interrupt_val = int_data.value - else: - interrupt_val = { - "question": "The workflow needs your input." - } - - if "messages" in s and s["messages"] != prev_msgs: - messages = s["messages"] - if workflow_span_id: - emit_live_message_events( - previous_messages=prev_msgs, - current_messages=messages, - workflow_span_id=workflow_span_id, - ) - new_messages = ( - messages[len(prev_msgs) :] - if len(messages) >= len(prev_msgs) - else messages[-1:] - ) - for new_message in new_messages: - try: - new_message.pretty_print() - except Exception: - pass - logger.info(new_message) - prev_msgs = list(messages) - last_st = s - except GraphInterrupt as gi: - # Fallback: some LangGraph versions may still raise. - interrupts = gi.args[0] if gi.args else [] - if interrupts: - interrupt_val = interrupts[0].value - else: - interrupt_val = { - "question": "The workflow needs your input." - } - - # Double-check the checkpoint for pending interrupts that - # the stream may not have surfaced explicitly. - if interrupt_val is None: - try: - snapshot = self.workflow.get_state(cfg) - if snapshot and snapshot.tasks: - for t in snapshot.tasks: - t_interrupts = getattr(t, "interrupts", None) - if t_interrupts: - interrupt_val = t_interrupts[0].value - break - except Exception: - pass - - if interrupt_val is not None: - logger.info("Graph interrupted: %s", interrupt_val) - # Refresh state from checkpoint for consistency. - try: - snapshot = self.workflow.get_state(cfg) - if snapshot: - last_st = snapshot.values - except Exception: - pass - - return last_st, interrupt_val - - logger.debug("run called with config=%s", config) - config = _validate_config(config) - logger.debug("validated config=%s", config) - - # Initialize logging directory before determining inputs or running workflow - # Check if CHEMGRAPH_LOG_DIR is already set if not os.environ.get("CHEMGRAPH_LOG_DIR"): os.environ["CHEMGRAPH_LOG_DIR"] = self.log_dir - # Ensure session exists in memory store self._ensure_session(query) - - # If resuming from a previous session, prepend context if resume_from and self.session_store: context = self.session_store.build_context_summary(resume_from) if context: @@ -1172,61 +1350,63 @@ async def _stream_until_interrupt(stream_input, cfg): ) logger.info(f"Injected context from session {resume_from}") - inputs = {"messages": query} + thread_id = str(config["configurable"]["thread_id"]) + if self.workflow_type == "single_agent": + result = await run_turn( + query=query, + tools=self.tools, + model_name=self.model_name, + base_url=self.base_url, + api_key=self.api_key, + argo_user=self.argo_user, + system_prompt=self.system_prompt, + formatter_prompt=self.formatter_prompt, + structured_output=self.structured_output, + generate_report=self.generate_report, + report_prompt=self.report_prompt, + recursion_limit=self.recursion_limit, + thread_id=thread_id, + terminal_tool_names=self.terminal_tool_names, + human_supervised=self.human_supervised, + ) + self._save_messages_to_store(result.state, query) + if self.return_option == "state": + return result.state + if self.return_option == "last_message": + return result.final_text + raise ValueError( + f"Unsupported return_option: {self.return_option}. " + "Use 'last_message' or 'state'." + ) try: - last_state, interrupt_value = await _stream_until_interrupt(inputs, config) - - # --- Human-in-the-loop resume loop --- - # When the graph pauses with an interrupt, ask the human and - # resume. This loop handles chains of multiple interrupts - # (e.g., the agent asks a follow-up question after receiving - # the first answer). - max_interrupts = 10 # safety guard against infinite interrupt loops - interrupt_count = 0 - while interrupt_value is not None: - interrupt_count += 1 - if interrupt_count > max_interrupts: - logger.error( - "Exceeded maximum number of human interrupts (%d); " - "aborting workflow.", - max_interrupts, - ) - raise RuntimeError( - f"Workflow exceeded maximum of {max_interrupts} " - f"human interrupts." - ) - - # Extract the question text from the interrupt value. - if isinstance(interrupt_value, dict): - question = interrupt_value.get( - "question", - interrupt_value.get("message", str(interrupt_value)), - ) - else: - question = str(interrupt_value) - - logger.info("Requesting human input: %s", question) - human_answer = await self._call_human_input_handler(question) - logger.info("Human responded: %s", human_answer) - - # Resume the graph from the checkpoint with the human's answer. - resume_cmd = Command(resume=human_answer) - last_state, interrupt_value = await _stream_until_interrupt( - resume_cmd, config - ) - + last_state = None + async for state in self.workflow.astream( + {"messages": query}, + stream_mode="values", + config=config, + ): + if "messages" in state: + for message in state["messages"][-1:]: + try: + message.pretty_print() + except Exception: + pass + logger.info(message) + last_state = state if last_state is None: - raise RuntimeError("Workflow produced no states.") - - # Save messages to persistent session store + raise RuntimeError("Workflow produced no states") self._save_messages_to_store(last_state, query) - - return _save_state_and_select_return(last_state, config) - - except HumanInputRequired: - # No human_input_handler configured — propagate so the - # caller (CLI / UI) can prompt the user and resume. + self.write_state(config=config, file_path=None) + if self.return_option == "state": + return serialize_state(self.get_state(config=config)) + if self.return_option == "last_message": + return last_state["messages"][-1] + raise ValueError( + f"Unsupported return_option: {self.return_option}. " + "Use 'last_message' or 'state'." + ) + except GraphInterrupt: raise except Exception as e: logger.error(f"Error running workflow {self.workflow_type}: {e}") diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index 429bd5c7..7e23b346 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -248,16 +248,6 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Arguments forwarded to chemgraph.academy.dashboard.", ) - dashboard_run_parser = subparsers.add_parser( - "dashboard-run", - help="Run a local ChemGraph workflow and write dashboard artifacts.", - ) - dashboard_run_parser.add_argument( - "dashboard_run_args", - nargs=argparse.REMAINDER, - help="Arguments forwarded to chemgraph.observability.local_dashboard_run.", - ) - # ---- "academy" subcommand ------------------------------------------- academy_parser = subparsers.add_parser( "academy", @@ -653,12 +643,6 @@ def main() -> None: _strip_remainder_separator(args.dashboard_args), ) - elif args.command == "dashboard-run": - _run_module_main( - "chemgraph.observability.local_dashboard_run", - _strip_remainder_separator(args.dashboard_run_args), - ) - elif args.command == "academy": _handle_academy(args) diff --git a/src/chemgraph/mcp/fastmcp_client.py b/src/chemgraph/mcp/fastmcp_client.py index 18090e56..6519b313 100644 --- a/src/chemgraph/mcp/fastmcp_client.py +++ b/src/chemgraph/mcp/fastmcp_client.py @@ -279,22 +279,7 @@ async def invoke(self, invocation: ToolInvocation) -> ToolResult: execution=self.execution, run_dir=self.run_dir, ) - from chemgraph.observability.events import WorkflowEventContext - from chemgraph.observability.events import workflow_event_context - - context = WorkflowEventContext( - run_id=self.run_dir.name, - run_dir=str(self.run_dir), - agent_id=invocation.agent_id, - role=invocation.role, - parent_span_id=invocation.correlation_id, - tool_name=invocation.tool_name, - ) - with workflow_event_context( - jsonl_path=self.run_dir / "events.jsonl", - context=context, - ): - result = await mcp.call_tool(spec.tool, invocation.arguments) + result = await mcp.call_tool(spec.tool, invocation.arguments) except Exception as exc: # noqa: BLE001 - preserve tool failure as data return ToolResult( tool_name=invocation.tool_name, diff --git a/src/chemgraph/observability/__init__.py b/src/chemgraph/observability/__init__.py deleted file mode 100644 index d6b910e5..00000000 --- a/src/chemgraph/observability/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Shared observability helpers for ChemGraph runtimes.""" - -from chemgraph.observability.events import WorkflowEventContext -from chemgraph.observability.events import WorkflowEventSink -from chemgraph.observability.events import current_workflow_event_context -from chemgraph.observability.events import emit_workflow_event -from chemgraph.observability.events import new_span_id -from chemgraph.observability.events import workflow_event_context - -__all__ = [ - "WorkflowEventContext", - "WorkflowEventSink", - "current_workflow_event_context", - "emit_workflow_event", - "new_span_id", - "workflow_event_context", -] diff --git a/src/chemgraph/observability/events.py b/src/chemgraph/observability/events.py deleted file mode 100644 index b24d601f..00000000 --- a/src/chemgraph/observability/events.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations - -import contextlib -import contextvars -import uuid -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Iterator - -from chemgraph.academy.observability.event_log import EventLog - - -def new_span_id(prefix: str) -> str: - return f"{prefix}-{uuid.uuid4()}" - - -@dataclass(frozen=True) -class WorkflowEventContext: - """Execution context for nested ChemGraph workflow events.""" - - run_id: str | None - run_dir: str | None - agent_id: str | None - role: str | None - parent_span_id: str | None - tool_name: str | None - runtime: str = "chemgraph-langgraph" - - -@dataclass(frozen=True) -class WorkflowEventSink: - """Write normalized workflow events to canonical Academy events.""" - - path: Path - context: WorkflowEventContext - - def emit( - self, - event: str, - payload: dict[str, Any] | None = None, - *, - span_id: str | None = None, - parent_span_id: str | None = None, - runtime: str | None = None, - agent_id: str | None = None, - role: str | None = None, - ) -> dict[str, Any]: - ctx = self.context - resolved_agent_id = agent_id or ctx.agent_id - resolved_role = role or ctx.role - body = { - **(payload or {}), - "span_id": span_id, - "parent_span_id": parent_span_id or ctx.parent_span_id, - "runtime": runtime or ctx.runtime, - "run_id": ctx.run_id, - "run_dir": ctx.run_dir, - "agent_id": resolved_agent_id, - "role": resolved_role, - "parent_tool_name": ctx.tool_name, - "nested": True, - } - record = EventLog(self.path).emit( - event, # type: ignore[arg-type] - run_id=ctx.run_id, - agent_id=resolved_agent_id or "system", - role=resolved_role, - correlation_id=span_id, - payload=body, - ) - return record.model_dump(mode="json") - - -_CURRENT_SINK: contextvars.ContextVar[WorkflowEventSink | None] = ( - contextvars.ContextVar("chemgraph_workflow_event_sink", default=None) -) -_CURRENT_CONTEXT: contextvars.ContextVar[WorkflowEventContext | None] = ( - contextvars.ContextVar("chemgraph_workflow_event_context", default=None) -) - - -def current_workflow_event_context() -> WorkflowEventContext | None: - return _CURRENT_CONTEXT.get() - - -def emit_workflow_event( - event: str, - payload: dict[str, Any] | None = None, - *, - span_id: str | None = None, - parent_span_id: str | None = None, - runtime: str | None = None, -) -> dict[str, Any] | None: - sink = _CURRENT_SINK.get() - if sink is None: - return None - return sink.emit( - event, - payload, - span_id=span_id, - parent_span_id=parent_span_id, - runtime=runtime, - ) - - -@contextlib.contextmanager -def workflow_event_context( - *, - jsonl_path: str | Path, - context: WorkflowEventContext, -) -> Iterator[WorkflowEventSink]: - sink = WorkflowEventSink(Path(jsonl_path), context=context) - sink_token = _CURRENT_SINK.set(sink) - context_token = _CURRENT_CONTEXT.set(context) - try: - yield sink - finally: - _CURRENT_CONTEXT.reset(context_token) - _CURRENT_SINK.reset(sink_token) diff --git a/src/chemgraph/observability/langgraph_stream.py b/src/chemgraph/observability/langgraph_stream.py deleted file mode 100644 index 1fe81923..00000000 --- a/src/chemgraph/observability/langgraph_stream.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Live LangGraph/LangChain event emission for ChemGraph workflows.""" - -from __future__ import annotations - -import json -import math -from typing import Any -from uuid import UUID - -from langchain_core.callbacks import BaseCallbackHandler - -from chemgraph.observability.events import emit_workflow_event -from chemgraph.observability.events import new_span_id - - -def _compact(value: Any, *, max_chars: int = 1000) -> Any: - try: - text = json.dumps(value, default=str, sort_keys=True) - except TypeError: - text = str(value) - if len(text) <= max_chars: - try: - return json.loads(text) - except json.JSONDecodeError: - return text - return { - "truncated": True, - "preview": text[:max_chars], - } - - -def _message_type(message: Any) -> str: - if isinstance(message, dict): - return str(message.get("type") or message.get("role") or "") - return str(getattr(message, "type", "") or getattr(message, "role", "")) - - -def _message_content(message: Any) -> Any: - if isinstance(message, dict): - return message.get("content") - return getattr(message, "content", None) - - -def _message_tool_calls(message: Any) -> list[dict[str, Any]]: - calls = ( - message.get("tool_calls") - if isinstance(message, dict) - else getattr(message, "tool_calls", None) - ) - if not isinstance(calls, list): - return [] - normalized = [] - for call in calls: - if isinstance(call, dict): - normalized.append( - { - "name": call.get("name"), - "id": call.get("id"), - "args": _compact(call.get("args") or {}, max_chars=2000), - }, - ) - else: - normalized.append({"name": str(call), "id": None, "args": {}}) - return normalized - - -def _message_usage_metadata(message: Any) -> dict[str, Any]: - usage = ( - message.get("usage_metadata") - if isinstance(message, dict) - else getattr(message, "usage_metadata", None) - ) - if isinstance(usage, dict) and usage: - return usage - response_metadata = ( - message.get("response_metadata") - if isinstance(message, dict) - else getattr(message, "response_metadata", None) - ) - if not isinstance(response_metadata, dict): - return {} - token_usage = response_metadata.get("token_usage") or response_metadata.get("usage") - return token_usage if isinstance(token_usage, dict) else {} - - -def _usage_int(usage: dict[str, Any], *keys: str) -> int | None: - for key in keys: - value = usage.get(key) - if isinstance(value, int): - return value - if isinstance(value, float) and value.is_integer(): - return int(value) - return None - - -def _text_for_token_estimate(value: Any) -> str: - try: - return json.dumps(value, default=str, sort_keys=True) - except TypeError: - return str(value) - - -def _json_safe(value: Any) -> Any: - try: - return json.loads(json.dumps(value, default=str)) - except TypeError: - return str(value) - - -def _serialize_message(message: Any) -> dict[str, Any]: - if isinstance(message, dict): - return _json_safe(message) - if hasattr(message, "model_dump"): - return _json_safe(message.model_dump(mode="json")) - return { - "type": _message_type(message), - "content": _json_safe(_message_content(message)), - "tool_calls": _message_tool_calls(message), - } - - -def _serialize_messages(messages: list[Any]) -> list[dict[str, Any]]: - return [_serialize_message(message) for message in messages] - - -def _estimate_tokens(text: str) -> int: - try: - import tiktoken # type: ignore[import-not-found] - - encoding = tiktoken.get_encoding("cl100k_base") - return len(encoding.encode(text)) - except Exception: - return max(1, math.ceil(len(text) / 4)) - - -def _llm_token_counts( - *, - previous_messages: list[Any], - message: Any, - tool_calls: list[dict[str, Any]], -) -> dict[str, Any]: - usage = _message_usage_metadata(message) - provider_input = _usage_int(usage, "input_tokens", "prompt_tokens") - provider_output = _usage_int(usage, "output_tokens", "completion_tokens") - provider_total = _usage_int(usage, "total_tokens") - if provider_input is not None or provider_output is not None or provider_total is not None: - input_tokens = provider_input - output_tokens = provider_output - if provider_total is None: - provider_total = (input_tokens or 0) + (output_tokens or 0) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": provider_total, - "source": "provider", - "raw_usage": _compact(usage, max_chars=1000), - } - - input_text = _text_for_token_estimate(previous_messages) - output_text = _text_for_token_estimate( - { - "content": _message_content(message), - "tool_calls": tool_calls, - }, - ) - input_tokens = _estimate_tokens(input_text) - output_tokens = _estimate_tokens(output_text) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - "source": "local_estimate", - "estimate_scope": "langgraph_state_messages", - } - - -def emit_live_message_events( - *, - previous_messages: list[Any], - current_messages: list[Any], - workflow_span_id: str, -) -> int: - """Emit live workflow events for newly streamed LangGraph messages.""" - if len(current_messages) <= len(previous_messages): - return 0 - count = 0 - for index, message in enumerate( - current_messages[len(previous_messages) :], - start=len(previous_messages), - ): - message_type = _message_type(message) - if message_type != "ai": - continue - tool_calls = _message_tool_calls(message) - token_counts = _llm_token_counts( - previous_messages=current_messages[:index], - message=message, - tool_calls=tool_calls, - ) - prompt_messages = _serialize_messages(current_messages[:index]) - if tool_calls: - emit_workflow_event( - "llm_decision", - { - "workflow_node": "ChemGraphAgent", - "message_index": index, - "tool_calls": tool_calls, - "token_counts": token_counts, - "prompt_messages": prompt_messages, - }, - span_id=new_span_id("chemgraph-llm-decision"), - parent_span_id=workflow_span_id, - ) - count += 1 - continue - content = _message_content(message) - if content: - emit_workflow_event( - "workflow_output", - { - "workflow_node": "ChemGraphAgent", - "message_index": index, - "content_preview": str(content)[:2000], - "token_counts": token_counts, - "prompt_messages": prompt_messages, - }, - span_id=new_span_id("chemgraph-output"), - parent_span_id=workflow_span_id, - ) - count += 1 - return count - - -def _tool_name(serialized: dict[str, Any] | None, kwargs: dict[str, Any]) -> str: - serialized = serialized or {} - value = ( - serialized.get("name") - or serialized.get("id") - or kwargs.get("name") - or kwargs.get("tool_name") - ) - if isinstance(value, list) and value: - value = value[-1] - return str(value or "tool") - - -def _run_id_text(run_id: UUID | str | None) -> str: - return str(run_id) if run_id is not None else new_span_id("tool-run") - - -class ChemGraphWorkflowCallback(BaseCallbackHandler): - """Emit live tool lifecycle events for a ChemGraph LangGraph run.""" - - def __init__(self, *, workflow_span_id: str) -> None: - self.workflow_span_id = workflow_span_id - self._tool_runs: dict[str, dict[str, Any]] = {} - - def on_tool_start( - self, - serialized: dict[str, Any], - input_str: str, - *, - run_id: UUID, - parent_run_id: UUID | None = None, - inputs: dict[str, Any] | None = None, - **kwargs: Any, - ) -> Any: - tool_run_id = _run_id_text(run_id) - tool_name = _tool_name(serialized, kwargs) - span_id = f"chemgraph-tool-call-{tool_run_id}" - self._tool_runs[tool_run_id] = { - "tool_name": tool_name, - "span_id": span_id, - } - emit_workflow_event( - "tool_call_started", - { - "workflow_node": "tools", - "tool_name": tool_name, - "tool_call_id": tool_run_id, - "parent_tool_run_id": _run_id_text(parent_run_id) - if parent_run_id - else None, - "input": _compact(inputs if inputs is not None else input_str), - }, - span_id=span_id, - parent_span_id=self.workflow_span_id, - ) - - def on_tool_end( - self, - output: Any, - *, - run_id: UUID, - parent_run_id: UUID | None = None, - **kwargs: Any, - ) -> Any: - tool_run_id = _run_id_text(run_id) - tool_run = self._tool_runs.get(tool_run_id, {}) - tool_name = str(tool_run.get("tool_name") or _tool_name(None, kwargs)) - span_id = str( - tool_run.get("span_id") or f"chemgraph-tool-call-{tool_run_id}", - ) - emit_workflow_event( - "tool_call_finished", - { - "workflow_node": "tools", - "tool_name": tool_name, - "tool_call_id": tool_run_id, - "parent_tool_run_id": _run_id_text(parent_run_id) - if parent_run_id - else None, - "content_preview": str(_compact(output))[:2000], - }, - span_id=span_id, - parent_span_id=self.workflow_span_id, - ) - - def on_tool_error( - self, - error: BaseException, - *, - run_id: UUID, - parent_run_id: UUID | None = None, - **kwargs: Any, - ) -> Any: - tool_run_id = _run_id_text(run_id) - tool_run = self._tool_runs.get(tool_run_id, {}) - tool_name = str(tool_run.get("tool_name") or _tool_name(None, kwargs)) - span_id = str( - tool_run.get("span_id") or f"chemgraph-tool-call-{tool_run_id}", - ) - emit_workflow_event( - "tool_call_failed", - { - "workflow_node": "tools", - "tool_name": tool_name, - "tool_call_id": tool_run_id, - "parent_tool_run_id": _run_id_text(parent_run_id) - if parent_run_id - else None, - "error": repr(error), - }, - span_id=span_id, - parent_span_id=self.workflow_span_id, - ) diff --git a/src/chemgraph/observability/local_dashboard_run.py b/src/chemgraph/observability/local_dashboard_run.py deleted file mode 100644 index eb0f4ef0..00000000 --- a/src/chemgraph/observability/local_dashboard_run.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Run a traditional ChemGraph workflow and write dashboard artifacts.""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import shutil -import threading -import traceback -from pathlib import Path - -from chemgraph.academy.core.lm import load_lm_config -from chemgraph.observability.workflow_runner import run_observed_chemgraph_workflow - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description=( - "Run a local traditional ChemGraph workflow and emit event artifacts " - "that can be visualized by the ChemGraph dashboard." - ), - ) - parser.add_argument("--run-dir", required=True) - parser.add_argument("--query", required=True) - parser.add_argument("--workflow-type", default="single_agent") - parser.add_argument("--return-option", choices=["last_message", "state"], default="state") - parser.add_argument("--recursion-limit", type=int, default=50) - parser.add_argument("--lm-config") - parser.add_argument("--model-name") - parser.add_argument("--base-url") - parser.add_argument("--api-key") - parser.add_argument("--argo-user") - parser.add_argument("--serve", action="store_true") - parser.add_argument("--host", default="127.0.0.1") - parser.add_argument("--port", type=int, default=8765) - parser.add_argument( - "--overwrite", - action="store_true", - help="Replace an existing local dashboard run directory.", - ) - parser.add_argument( - "--json-output", - action="store_true", - help="Print the full workflow result JSON to stdout.", - ) - return parser.parse_args() - - -def _prepare_run_dir(path: Path, *, overwrite: bool) -> None: - existing_artifacts = [ - path / "events.jsonl", - path / "status.json", - path / "manifest.json", - path / "result.json", - path / "chemgraph_workflows", - ] - if not path.exists(): - path.mkdir(parents=True, exist_ok=True) - return - if overwrite: - _clear_run_dir(path) - elif any(item.exists() for item in existing_artifacts): - raise RuntimeError( - f"Run directory already contains dashboard artifacts: {path}\n" - "Use a new --run-dir, run chemgraph-dashboard to view the " - "existing run, or pass --overwrite to replace it.", - ) - path.mkdir(parents=True, exist_ok=True) - - -def _clear_run_dir(path: Path) -> None: - for item in path.iterdir(): - if item.is_dir() and not item.is_symlink(): - shutil.rmtree(item) - else: - item.unlink() - - -async def _run(args: argparse.Namespace) -> dict: - model_name = args.model_name - base_url = args.base_url - api_key = args.api_key - argo_user = args.argo_user - if args.lm_config: - settings = load_lm_config(args.lm_config) - model_name = model_name or settings.model - base_url = base_url or settings.base_url - api_key = api_key or settings.api_key - argo_user = argo_user or settings.user - - return await run_observed_chemgraph_workflow( - query=args.query, - run_dir=Path(args.run_dir), - workflow_type=args.workflow_type, - model_name=model_name, - base_url=base_url, - api_key=api_key, - argo_user=argo_user, - return_option=args.return_option, - recursion_limit=args.recursion_limit, - write_run_files=True, - ) - - -def _print_result_summary(*, result: dict, run_dir: Path, json_output: bool) -> None: - result_path = run_dir / "result.json" - print( - "ChemGraph workflow completed.\n" - f" status: {result.get('status')}\n" - f" workflow: {result.get('workflow_type')}\n" - f" span: {result.get('span_id')}\n" - f" result: {result_path}", - flush=True, - ) - if json_output: - print(json.dumps(result, indent=2, default=str), flush=True) - - -def _run_and_report(args: argparse.Namespace, *, run_dir: Path) -> None: - try: - result = asyncio.run(_run(args)) - except Exception: # noqa: BLE001 - surface background workflow failures - print("ChemGraph workflow failed. See status.json/events.jsonl if present.", flush=True) - traceback.print_exc() - return - _print_result_summary( - result=result, - run_dir=run_dir, - json_output=args.json_output, - ) - - -def main() -> int: - args = parse_args() - run_dir = Path(args.run_dir).resolve() - _prepare_run_dir(run_dir, overwrite=args.overwrite) - args.run_dir = str(run_dir) - if args.serve: - from chemgraph.academy.dashboard import serve_dashboard - - thread = threading.Thread( - target=_run_and_report, - kwargs={"args": args, "run_dir": run_dir}, - name="chemgraph-dashboard-workflow", - daemon=True, - ) - thread.start() - return serve_dashboard( - run_dir=run_dir, - host=args.host, - port=args.port, - ) - - result = asyncio.run(_run(args)) - _print_result_summary( - result=result, - run_dir=run_dir, - json_output=args.json_output, - ) - print( - "Dashboard command: " - f"chemgraph-dashboard --run-dir {run_dir}", - flush=True, - ) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/chemgraph/observability/workflow_runner.py b/src/chemgraph/observability/workflow_runner.py deleted file mode 100644 index 198a5aca..00000000 --- a/src/chemgraph/observability/workflow_runner.py +++ /dev/null @@ -1,397 +0,0 @@ -"""Observed execution helpers for traditional ChemGraph workflows.""" - -from __future__ import annotations - -import json -import os -import time -from pathlib import Path -from typing import Any, Literal - -from chemgraph.agent.llm_agent import ChemGraph -from chemgraph.agent.llm_agent import serialize_state -from chemgraph.observability.events import WorkflowEventContext -from chemgraph.observability.events import current_workflow_event_context -from chemgraph.observability.events import emit_workflow_event -from chemgraph.observability.events import new_span_id -from chemgraph.observability.events import workflow_event_context - - -def _env_first(*names: str) -> str | None: - for name in names: - value = os.environ.get(name) - if value: - return value - return None - - -def normalize_model_name(model_name: str, base_url: str | None) -> str: - value = model_name.strip() - if base_url and "argoapi" in base_url and value.startswith("GPT-"): - return "argo:" + value.lower() - return value - - -def compact_value(value: Any, *, max_chars: int = 8000) -> Any: - try: - text = json.dumps(value, default=str, sort_keys=True) - except TypeError: - text = str(value) - if len(text) <= max_chars: - try: - return json.loads(text) - except json.JSONDecodeError: - return text - return { - "truncated": True, - "preview": text[:max_chars], - } - - -def _write_json(path: Path, payload: dict[str, Any]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, default=str) + "\n", encoding="utf-8") - - -def _write_status( - *, - run_dir: Path, - run_id: str, - workflow_span_id: str, - query: str, - workflow_type: str, - model_name: str, - base_url: str | None, - status: str, - started_at: float, - error: str | None = None, -) -> None: - now = time.time() - payload = { - "mode": "chemgraph_workflow", - "run_id": run_id, - "workflow_span_id": workflow_span_id, - "query": query, - "workflow_type": workflow_type, - "model_name": model_name, - "base_url": base_url, - "status": status, - "started": started_at, - "updated": now, - "finished": now if status in {"completed", "failed"} else None, - "events_path": str(run_dir / "events.jsonl"), - } - if error: - payload["error"] = error - _write_json(run_dir / "status.json", payload) - - -def _write_manifest( - *, - run_dir: Path, - run_id: str, - workflow_span_id: str, - query: str, - workflow_type: str, - model_name: str, - base_url: str | None, -) -> None: - _write_json( - run_dir / "manifest.json", - { - "mode": "chemgraph_workflow", - "run_id": run_id, - "workflow_span_id": workflow_span_id, - "query": query, - "workflow_type": workflow_type, - "model_name": model_name, - "base_url": base_url, - "events_path": str(run_dir / "events.jsonl"), - }, - ) - - -def _workflow_log_dir(run_dir: Path, workflow_span_id: str) -> str: - path = run_dir / "chemgraph_workflows" / workflow_span_id - path.mkdir(parents=True, exist_ok=True) - return str(path) - - -async def run_observed_chemgraph_workflow( - *, - query: str, - run_dir: str | Path | None = None, - run_id: str | None = None, - workflow_type: str = "single_agent", - model_name: str | None = None, - base_url: str | None = None, - api_key: str | None = None, - argo_user: str | None = None, - return_option: Literal["last_message", "state"] = "state", - recursion_limit: int = 50, - parent_span_id: str | None = None, - agent_id: str = "chemgraph-workflow", - role: str = "TraditionalChemGraphWorkflow", - write_run_files: bool = True, -) -> dict[str, Any]: - """Run a traditional ChemGraph workflow while emitting dashboard events.""" - current_context = current_workflow_event_context() - run_dir_value = run_dir - if run_dir_value is None and current_context and current_context.run_dir: - run_dir_value = current_context.run_dir - if run_dir_value is None: - run_dir_value = "runs/local-chemgraph-workflow" - effective_run_dir = Path(run_dir_value).resolve() - effective_run_dir.mkdir(parents=True, exist_ok=True) - - workflow_span_id = new_span_id("chemgraph-workflow") - effective_run_id = run_id or effective_run_dir.name - base_url = base_url or _env_first( - "CHEMGRAPH_WORKFLOW_BASE_URL", - "ACADEMY_LM_BASE_URL", - ) - model_name = normalize_model_name( - model_name - or _env_first("CHEMGRAPH_WORKFLOW_MODEL", "ACADEMY_LM_MODEL") - or "argo:gpt-5.4", - base_url, - ) - api_key = api_key or _env_first( - "CHEMGRAPH_WORKFLOW_API_KEY", - "ACADEMY_LM_API_KEY", - "OPENAI_API_KEY", - ) - argo_user = argo_user or _env_first( - "CHEMGRAPH_WORKFLOW_ARGO_USER", - "ACADEMY_LM_USER", - "ARGO_USER", - ) - - started_at = time.time() - if write_run_files: - _write_manifest( - run_dir=effective_run_dir, - run_id=effective_run_id, - workflow_span_id=workflow_span_id, - query=query, - workflow_type=workflow_type, - model_name=model_name, - base_url=base_url, - ) - _write_status( - run_dir=effective_run_dir, - run_id=effective_run_id, - workflow_span_id=workflow_span_id, - query=query, - workflow_type=workflow_type, - model_name=model_name, - base_url=base_url, - status="running", - started_at=started_at, - ) - - context_manager = ( - workflow_event_context( - jsonl_path=effective_run_dir / "events.jsonl", - context=WorkflowEventContext( - run_id=effective_run_id, - run_dir=str(effective_run_dir), - agent_id=agent_id, - role=role, - parent_span_id=parent_span_id, - tool_name=None, - ), - ) - if current_context is None - else None - ) - - if context_manager is None: - return await _run_observed_chemgraph_workflow_inner( - query=query, - run_dir=effective_run_dir, - run_id=effective_run_id, - workflow_span_id=workflow_span_id, - workflow_type=workflow_type, - model_name=model_name, - base_url=base_url, - api_key=api_key, - argo_user=argo_user, - return_option=return_option, - recursion_limit=recursion_limit, - write_run_files=write_run_files, - started_at=started_at, - ) - with context_manager: - return await _run_observed_chemgraph_workflow_inner( - query=query, - run_dir=effective_run_dir, - run_id=effective_run_id, - workflow_span_id=workflow_span_id, - workflow_type=workflow_type, - model_name=model_name, - base_url=base_url, - api_key=api_key, - argo_user=argo_user, - return_option=return_option, - recursion_limit=recursion_limit, - write_run_files=write_run_files, - started_at=started_at, - ) - - -async def _run_observed_chemgraph_workflow_inner( - *, - query: str, - run_dir: Path, - run_id: str, - workflow_span_id: str, - workflow_type: str, - model_name: str, - base_url: str | None, - api_key: str | None, - argo_user: str | None, - return_option: Literal["last_message", "state"], - recursion_limit: int, - write_run_files: bool, - started_at: float, -) -> dict[str, Any]: - log_dir = _workflow_log_dir(run_dir, workflow_span_id) - config = {"configurable": {"thread_id": workflow_span_id}} - - emit_workflow_event( - "run_started", - { - "workflow_type": workflow_type, - "model_name": model_name, - "query": query, - }, - span_id=workflow_span_id, - ) - emit_workflow_event( - "workflow_started", - { - "workflow_type": workflow_type, - "model_name": model_name, - "query": query, - "log_dir": log_dir, - }, - span_id=workflow_span_id, - ) - try: - emit_workflow_event( - "workflow_node_started", - {"workflow_node": "ChemGraph", "phase": "construct"}, - span_id=new_span_id("chemgraph-node"), - parent_span_id=workflow_span_id, - ) - agent = ChemGraph( - model_name=model_name, - workflow_type=workflow_type, - base_url=base_url, - api_key=api_key, - argo_user=argo_user, - return_option=return_option, - recursion_limit=recursion_limit, - log_dir=log_dir, - ) - emit_workflow_event( - "workflow_node_finished", - {"workflow_node": "ChemGraph", "phase": "construct"}, - span_id=new_span_id("chemgraph-node"), - parent_span_id=workflow_span_id, - ) - emit_workflow_event( - "workflow_node_started", - {"workflow_node": "LangGraph", "phase": "run"}, - span_id=new_span_id("chemgraph-node"), - parent_span_id=workflow_span_id, - ) - result = await agent.run( - query, - config=config, - workflow_span_id=workflow_span_id, - ) - state_payload = serialize_state(agent.get_state(config=config)) - emit_workflow_event( - "workflow_node_finished", - {"workflow_node": "LangGraph", "phase": "run"}, - span_id=new_span_id("chemgraph-node"), - parent_span_id=workflow_span_id, - ) - emit_workflow_event( - "workflow_finished", - { - "workflow_type": workflow_type, - "status": "completed", - "log_dir": log_dir, - }, - span_id=workflow_span_id, - ) - emit_workflow_event( - "run_finished", - { - "workflow_type": workflow_type, - "status": "completed", - }, - span_id=workflow_span_id, - ) - payload = { - "status": "completed", - "workflow_type": workflow_type, - "model_name": model_name, - "span_id": workflow_span_id, - "log_dir": log_dir, - "return_option": return_option, - "result": compact_value(serialize_state(result)), - "state": compact_value(state_payload), - } - if write_run_files: - _write_status( - run_dir=run_dir, - run_id=run_id, - workflow_span_id=workflow_span_id, - query=query, - workflow_type=workflow_type, - model_name=model_name, - base_url=base_url, - status="completed", - started_at=started_at, - ) - _write_json(run_dir / "result.json", payload) - return payload - except Exception as exc: - error = repr(exc) - emit_workflow_event( - "workflow_finished", - { - "workflow_type": workflow_type, - "status": "failed", - "error": error, - "log_dir": log_dir, - }, - span_id=workflow_span_id, - ) - emit_workflow_event( - "run_finished", - { - "workflow_type": workflow_type, - "status": "failed", - "error": error, - }, - span_id=workflow_span_id, - ) - if write_run_files: - _write_status( - run_dir=run_dir, - run_id=run_id, - workflow_span_id=workflow_span_id, - query=query, - workflow_type=workflow_type, - model_name=model_name, - base_url=base_url, - status="failed", - started_at=started_at, - error=error, - ) - raise diff --git a/tests/test_academy_payloads.py b/tests/test_academy_payloads.py new file mode 100644 index 00000000..8e9f82c8 --- /dev/null +++ b/tests/test_academy_payloads.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from chemgraph.academy.observability.event_log import EventLog, read_events + + +def test_event_log_preserves_payload_shape(tmp_path) -> None: + log = EventLog(tmp_path / "events.jsonl") + + log.emit( + "message_sent", + run_id="run-1", + agent_id="agent-a", + role="Worker", + payload={ + "message_id": "msg-1", + "recipient": "agent-b", + "tldr": "short", + }, + ) + + event = read_events(tmp_path / "events.jsonl")[0] + assert event.event == "message_sent" + assert event.payload == { + "message_id": "msg-1", + "recipient": "agent-b", + "tldr": "short", + } diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index 0825f66e..75d3fc31 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -4,24 +4,20 @@ import dataclasses import json from pathlib import Path +from typing import Any import pytest +from chemgraph.academy.core import agent as agent_module +from chemgraph.academy.core import turn as turn_module from chemgraph.academy.core.agent import ChemGraphLogicalAgent -from chemgraph.academy.core.turn import ( - build_peer_status, - ChemGraphReasoningRoundEngine, -) -from chemgraph.academy.core.turn import ReasoningTurnResult -from chemgraph.academy.core.tools import ReasoningToolRuntimeState -from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools -from chemgraph.academy.core.campaign import ChemGraphAgentSpec -from chemgraph.academy.core.campaign import ChemGraphCampaign -from chemgraph.academy.core.campaign import ResourceSpec -from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.campaign import ChemGraphAgentSpec, ChemGraphCampaign +from chemgraph.academy.core.campaign import ResourceSpec, resolve_campaign_resources from chemgraph.academy.core.lm import LLMSettings -from chemgraph.academy.core.prompt import PromptProfile -from chemgraph.academy.core.prompt import PromptStateLimits +from chemgraph.academy.core.prompt import PromptProfile, PromptStateLimits +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.turn import ReasoningTurnResult, build_peer_status +from chemgraph.agent.llm_agent import TurnResult def _agent_spec() -> ChemGraphAgentSpec: @@ -35,13 +31,7 @@ def _agent_spec() -> ChemGraphAgentSpec: def _agent_spec_with_peer() -> ChemGraphAgentSpec: - return ChemGraphAgentSpec( - name="agent-a", - role="Worker", - mission="Use explicit tools only.", - allowed_peers=("agent-b",), - tools=(), - ) + return dataclasses.replace(_agent_spec(), allowed_peers=("agent-b",)) def _campaign(spec: ChemGraphAgentSpec) -> ChemGraphCampaign: @@ -84,22 +74,6 @@ def _lm_settings() -> LLMSettings: ) -class _FakeReasoningEngine: - async def run_turn(self) -> ReasoningTurnResult: - return ReasoningTurnResult( - final_text="done", - state={"messages": []}, - tool_calls_completed=1, - action_tools_called=("finish_turn",), - science_tools_called=("science_tool",), - executed_tool_names=("science_tool", "finish_turn"), - requested_finish=True, - requested_self_wake=True, - workflow_span_id="workflow-1", - thread_id="agent-a-round-1", - ) - - class _SlowPeerHandle: def __init__(self) -> None: self.delivered = asyncio.Event() @@ -112,15 +86,12 @@ async def action(self, name: str, message: dict) -> None: @pytest.mark.asyncio -async def test_reasoning_adapter_finish_turn_updates_runtime_state(tmp_path) -> None: - spec = _agent_spec() - runtime_state = ReasoningToolRuntimeState() +async def test_reasoning_adapter_finish_turn_traces(tmp_path) -> None: traces: list[tuple[str, dict]] = [] - tools = await build_chemgraph_reasoning_tools( - spec=spec, + spec=_agent_spec(), run_dir=tmp_path, - tool_invoker=object(), # unused when spec.tools is empty + tool_invoker=object(), peer_names=(), peer_handles={}, outbox=[], @@ -128,22 +99,13 @@ async def test_reasoning_adapter_finish_turn_updates_runtime_state(tmp_path) -> get_round_index=lambda: 1, set_final_result=lambda result: None, trace=lambda event, payload: traces.append((event, payload)), - runtime_state=runtime_state, ) - assert [tool.name for tool in tools] == [ - "send_message", - "submit_result", - "finish_turn", - ] - - finish_turn = next(tool for tool in tools if tool.name == "finish_turn") - result = await finish_turn.ainvoke({"reason": "nothing useful now"}) + result = await next(t for t in tools if t.name == "finish_turn").ainvoke( + {"reason": "nothing useful now"}, + ) assert result == {"status": "finished", "reason": "nothing useful now"} - assert runtime_state.finished_turn is True - assert runtime_state.action_tool_names == ["finish_turn"] - assert runtime_state.executed_tool_names == ["finish_turn"] assert traces == [ ( "turn_finished_without_external_action", @@ -154,14 +116,11 @@ async def test_reasoning_adapter_finish_turn_updates_runtime_state(tmp_path) -> @pytest.mark.asyncio async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: - spec = _agent_spec_with_peer() - runtime_state = ReasoningToolRuntimeState() peer = _SlowPeerHandle() traces: list[tuple[str, dict]] = [] outbox: list[dict] = [] - tools = await build_chemgraph_reasoning_tools( - spec=spec, + spec=_agent_spec_with_peer(), run_dir=tmp_path, tool_invoker=object(), peer_names=("agent-b",), @@ -171,12 +130,10 @@ async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: get_round_index=lambda: 1, set_final_result=lambda result: None, trace=lambda event, payload: traces.append((event, payload)), - runtime_state=runtime_state, ) - send_message = next(tool for tool in tools if tool.name == "send_message") result = await asyncio.wait_for( - send_message.ainvoke( + next(t for t in tools if t.name == "send_message").ainvoke( { "recipient": "agent-b", "tldr": "short summary", @@ -191,43 +148,60 @@ async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: timeout=0.05, ) - assert result["status"] == "sent" assert result["delivery"] == "queued" assert len(outbox) == 1 assert [name for name, _ in traces] == ["message_sent"] - await asyncio.wait_for(peer.delivered.wait(), timeout=1) - await asyncio.sleep(0) - assert peer.calls[0][0] == "receive_message" - assert [name for name, _ in traces] == [ - "message_sent", - "message_delivered", - ] @pytest.mark.asyncio -async def test_logical_agent_startup_initializes_chemgraph_reasoning_engine( - tmp_path, -) -> None: - spec = _agent_spec() - agent = ChemGraphLogicalAgent( - spec, - campaign=_campaign(spec), +async def test_run_academy_turn_maps_action_and_science_tools(monkeypatch, tmp_path) -> None: + async def fake_run_turn(**kwargs: Any) -> TurnResult: + payload = json.loads(kwargs["query"]) + assert payload["received_messages"] == [{"message_id": "new"}] + assert payload["local_chemgraph_tool_results"] == [{"tool_result_id": "new"}] + kwargs["on_event"]("workflow_started", {"thread_id": kwargs["thread_id"]}) + return TurnResult( + final_text="done", + state={"messages": []}, + executed_tool_names=("science_tool", "finish_turn"), + terminal_tool="finish_turn", + thread_id=kwargs["thread_id"], + duration_s=0.1, + ) + + monkeypatch.setattr(turn_module, "run_turn", fake_run_turn) + traces: list[tuple[str, dict]] = [] + result = await turn_module.run_academy_turn( + campaign=_campaign(_agent_spec()), + spec=_agent_spec(), llm_settings=_lm_settings(), prompt_profile=_prompt_profile(), run_dir=tmp_path, max_decisions=5, - tool_invoker=object(), # unused when spec.tools is empty + tools=[], + received_message_history=[{"message_id": "old"}, {"message_id": "new"}], + outbox=[], + tool_results=[{"tool_result_id": "old"}, {"tool_result_id": "new"}], + get_final_result=lambda: {"summary": "current"}, + get_round_index=lambda: 2, + trace=lambda event, payload: traces.append((event, payload)), ) - await agent.agent_on_startup() - - assert isinstance(agent._reasoning_engine, ChemGraphReasoningRoundEngine) + assert result.action_tools_called == ("finish_turn",) + assert result.science_tools_called == ("science_tool",) + assert result.requested_finish is True + assert result.requested_self_wake is True + assert [event for event, _ in traces] == [ + "chemgraph_reasoning_turn_started", + "workflow_started", + "chemgraph_reasoning_turn_finished", + ] @pytest.mark.asyncio -async def test_logical_agent_reasoning_round_uses_chemgraph_engine(tmp_path) -> None: +async def test_logical_agent_reasoning_round_calls_turn_runner(monkeypatch, tmp_path) -> None: spec = _agent_spec() agent = ChemGraphLogicalAgent( spec, @@ -239,11 +213,27 @@ async def test_logical_agent_reasoning_round_uses_chemgraph_engine(tmp_path) -> tool_invoker=object(), ) agent.round_index = 1 - agent._reasoning_engine = _FakeReasoningEngine() - self_wake = await agent._reasoning_round() + async def fake_tools(**kwargs: Any) -> list: + assert kwargs["spec"] is spec + return [] - assert self_wake is True + async def fake_turn(**kwargs: Any) -> ReasoningTurnResult: + assert kwargs["spec"] is spec + return ReasoningTurnResult( + final_text="done", + executed_tool_names=("science_tool", "finish_turn"), + action_tools_called=("finish_turn",), + science_tools_called=("science_tool",), + requested_finish=True, + requested_self_wake=True, + thread_id="agent-a-round-1", + ) + + monkeypatch.setattr(agent_module, "build_chemgraph_reasoning_tools", fake_tools) + monkeypatch.setattr(agent_module, "run_academy_turn", fake_turn) + + assert await agent._reasoning_round() is True events = [ json.loads(line)["event"] for line in tmp_path.joinpath("events.jsonl").read_text().splitlines() @@ -256,135 +246,28 @@ async def test_logical_agent_reasoning_round_uses_chemgraph_engine(tmp_path) -> ] -def test_reasoning_engine_builds_bounded_wakeup_state(tmp_path) -> None: - spec = _agent_spec() - received_message_history = [{"message_id": "old"}, {"message_id": "new"}] - outbox = [ - { - "message_id": "msg-old", - "recipient": "agent-b", - "tldr": "old message", - "timestamp": 1, - }, - { - "message_id": "msg-new", - "recipient": "agent-b", - "tldr": "new message", - "timestamp": 3, - }, - ] - tool_results = [{"tool_result_id": "old"}, {"tool_result_id": "new"}] - final_result = {"summary": "current belief"} - engine = ChemGraphReasoningRoundEngine( - campaign=_campaign(spec), - spec=spec, - llm_settings=_lm_settings(), - prompt_profile=_prompt_profile(), - run_dir=tmp_path, - max_decisions=5, - tools=[], - runtime_state=ReasoningToolRuntimeState(), - received_message_history=received_message_history, - outbox=outbox, - tool_results=tool_results, - get_final_result=lambda: final_result, - get_round_index=lambda: 2, - trace=lambda event, payload: None, - ) - - state = engine.build_wakeup_state(round_index=2) - - assert state["campaign"] == "campaign-1" - assert state["user_task"] == "Rank staged candidates." - assert state["agent_name"] == "agent-a" - assert state["available_chemgraph_tools"] == [] - assert state["peer_status"] == {} - assert state["received_messages"] == [{"message_id": "new"}] - assert state["local_chemgraph_tool_results"] == [{"tool_result_id": "new"}] - assert state["recent_actions"] == [ - { - "type": "send_message", - "recipient": "agent-b", - "reply_requested": False, - "tldr": "old message", - "message_id": "msg-old", - "timestamp": 1, - }, - { - "type": "send_message", - "recipient": "agent-b", - "reply_requested": False, - "tldr": "new message", - "message_id": "msg-new", - "timestamp": 3, - }, - ] - assert state["current_final_result"] == final_result - assert state["required_protocol"] == "call finish_turn when idle" - - -def test_build_peer_status_uses_inflight_tool_events(tmp_path) -> None: +def test_build_peer_status_uses_agent_status_file(tmp_path) -> None: state_dir = tmp_path / "agent_status" state_dir.mkdir() (state_dir / "agent-b.json").write_text( json.dumps( { - "agent_name": "agent-b", "round": 3, "finished": False, "last_error": None, "status_updated_at": 100.0, - "recent_outbox": [ - { - "message_id": "msg-ack", - "tldr": "Starting requested MACE energy run", - }, - ], - "belief": { - "hypothesis": None, - "confidence": 0.0, - }, + "recent_outbox": [{"message_id": "msg-ack", "tldr": "MACE running"}], }, ) + "\n", encoding="utf-8", ) - events = [ - { - "timestamp": 101.0, - "event": "message_sent", - "agent_id": "agent-b", - "payload": { - "message_id": "msg-ack", - "tldr": "Starting requested MACE energy run", - }, - }, - { - "timestamp": 102.0, - "event": "tool_call_started", - "agent_id": "agent-b", - "payload": { - "tool_name": "run_mace_ensemble", - "tool_result_id": "tool-1", - "tool_call_id": "call-1", - }, - }, - ] - with (tmp_path / "events.jsonl").open("w", encoding="utf-8") as fp: - for event in events: - fp.write(json.dumps(event) + "\n") status = build_peer_status(run_dir=tmp_path, peer_names=("agent-b",)) - assert status["agent-b"]["state"] == "busy" - assert status["agent-b"]["last_outbox_tldr"] == "Starting requested MACE energy run" - assert status["agent-b"]["current_activity"] == { - "type": "tool_call", - "tool_name": "run_mace_ensemble", - "tool_result_id": "tool-1", - "tool_call_id": "call-1", - "started_at": 102.0, - } + assert status["agent-b"]["state"] == "idle" + assert status["agent-b"]["round"] == 3 + assert status["agent-b"]["last_outbox_tldr"] == "MACE running" def test_campaign_resources_resolve_to_shared_run_artifacts(tmp_path) -> None: @@ -420,12 +303,7 @@ def test_campaign_resources_resolve_to_shared_run_artifacts(tmp_path) -> None: resolved = resolve_campaign_resources(campaign, tmp_path / "run-1") - assert campaign.resources["structure_output_directory"].path == ( - "academy_mace_structures" - ) - assert resolved.resources["candidate_dataset"].path == ( - "/source/data/candidates.json" - ) + assert resolved.resources["candidate_dataset"].path == "/source/data/candidates.json" assert resolved.resources["structure_output_directory"].path == str( tmp_path / "run-1" / "shared" / "academy_mace_structures", ) diff --git a/tests/test_tool_adapter_validation.py b/tests/test_tool_adapter_validation.py index f58a2b17..d255ca90 100644 --- a/tests/test_tool_adapter_validation.py +++ b/tests/test_tool_adapter_validation.py @@ -5,7 +5,6 @@ import pytest -from chemgraph.academy.core.tools import ReasoningToolRuntimeState from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools from chemgraph.academy.core.campaign import ChemGraphAgentSpec @@ -29,7 +28,6 @@ def _agent_spec() -> ChemGraphAgentSpec: async def _build_tools(tmp_path): - runtime_state = ReasoningToolRuntimeState() traces: list[tuple[str, dict[str, Any]]] = [] outbox: list[dict[str, Any]] = [] peer_handle = _FakePeerHandle() @@ -44,11 +42,9 @@ async def _build_tools(tmp_path): get_round_index=lambda: 1, set_final_result=lambda result: None, trace=lambda event, payload: traces.append((event, payload)), - runtime_state=runtime_state, ) return { "tools": {tool.name: tool for tool in tools}, - "runtime_state": runtime_state, "traces": traces, "outbox": outbox, "peer_handle": peer_handle, @@ -74,7 +70,6 @@ async def test_send_message_invalid_args_return_structured_tool_error(tmp_path) assert result["status"] == "error" assert result["error_type"] == "invalid_tool_arguments" assert result["errors"][0]["field"] == "confidence" - assert env["runtime_state"].action_tool_names == ["send_message"] assert env["outbox"] == [] assert env["peer_handle"].calls == [] assert env["traces"] == [ From 77425b2bb7fd4d30c79ea72484c0ee45505d22c7 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 11:38:09 -0500 Subject: [PATCH 060/119] refactor(academy): trim agent status snapshots --- src/chemgraph/academy/core/agent.py | 9 --------- src/chemgraph/academy/core/turn.py | 12 +----------- src/chemgraph/academy/observability/run_artifacts.py | 9 --------- tests/test_academy_reasoning_phase2.py | 3 +-- 4 files changed, 2 insertions(+), 31 deletions(-) diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py index 5463b41d..94e03afd 100644 --- a/src/chemgraph/academy/core/agent.py +++ b/src/chemgraph/academy/core/agent.py @@ -188,15 +188,6 @@ async def report_state(self) -> dict[str, Any]: 'round': self.round_index, 'finished': self.finished, 'last_error': self.last_error, - 'current_activity': None, - 'recent_outbox': self.outbox[-10:], - 'belief': self.final_result or { - 'hypothesis': None, - 'confidence': 0.0, - 'supporting_message_ids': [], - 'supporting_tool_result_ids': [], - 'reason': None, - }, } async def _reasoning_round(self) -> bool: diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index be23084d..24d4baa2 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -117,17 +117,7 @@ def _status(run_dir: Path, peer: str, *, now: float) -> dict[str, Any]: data = read_json_file(run_dir / "agent_status" / f"{peer}.json", default={}) timestamp = _float(data.get("status_updated_at")) state = "unknown" if not data else "error" if data.get("last_error") else "finished" if data.get("finished") else "idle" - return {"state": state, "round": data.get("round"), "finished": bool(data.get("finished")) if data else False, "last_error": data.get("last_error"), "current_activity": data.get("current_activity"), "seconds_since_update": None if timestamp is None else max(0.0, round(now - timestamp, 3)), "last_outbox_tldr": _last_outbox(data), "last_belief": _belief(data.get("belief"))} - - -def _last_outbox(data: dict[str, Any]) -> str | None: - recent = data.get("recent_outbox") - return (recent[-1].get("tldr") or _preview(recent[-1].get("content"))) if isinstance(recent, list) and recent and isinstance(recent[-1], dict) else None - - -def _belief(value: Any) -> dict[str, Any] | None: - summary = value.get("summary") or value.get("hypothesis") if isinstance(value, dict) else None - return {"summary": _preview(summary, max_chars=220), "confidence": value.get("confidence")} if summary else None + return {"state": state, "round": data.get("round"), "finished": bool(data.get("finished")) if data else False, "last_error": data.get("last_error"), "seconds_since_update": None if timestamp is None else max(0.0, round(now - timestamp, 3))} def _tail(items: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py index 750e8b13..f109380b 100644 --- a/src/chemgraph/academy/observability/run_artifacts.py +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -175,15 +175,6 @@ def default_agent_state(spec: ChemGraphAgentSpec) -> dict[str, Any]: 'round': 0, 'finished': False, 'last_error': None, - 'current_activity': None, - 'recent_outbox': [], - 'belief': { - 'hypothesis': None, - 'confidence': 0.0, - 'supporting_message_ids': [], - 'supporting_tool_result_ids': [], - 'reason': None, - }, } diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index 75d3fc31..fa140c4a 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -256,7 +256,6 @@ def test_build_peer_status_uses_agent_status_file(tmp_path) -> None: "finished": False, "last_error": None, "status_updated_at": 100.0, - "recent_outbox": [{"message_id": "msg-ack", "tldr": "MACE running"}], }, ) + "\n", @@ -267,7 +266,7 @@ def test_build_peer_status_uses_agent_status_file(tmp_path) -> None: assert status["agent-b"]["state"] == "idle" assert status["agent-b"]["round"] == 3 - assert status["agent-b"]["last_outbox_tldr"] == "MACE running" + assert status["agent-b"]["last_error"] is None def test_campaign_resources_resolve_to_shared_run_artifacts(tmp_path) -> None: From 8772c7cbff844fd1adeaa51b82b1c060e09ac645 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 11:55:30 -0500 Subject: [PATCH 061/119] refactor(academy): remove stale cleanup targets --- src/chemgraph/academy/observability/run_artifacts.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py index f109380b..e884de3b 100644 --- a/src/chemgraph/academy/observability/run_artifacts.py +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -266,10 +266,6 @@ async def wait_for_agent_statuses_finished( def clear_run_outputs(run_dir: pathlib.Path) -> None: for name in ( 'academy_registrations.json', - 'campaign_private.json', - 'communication_proof.json', - 'compute_launch.json', - 'launch_plan.json', 'messages.jsonl', 'events.jsonl', 'placement.json', From 11f896df0d6c71acecce42b90d0c98f95c321c47 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 12:20:39 -0500 Subject: [PATCH 062/119] refactor(agent): route single-agent workflows through run_turn --- src/chemgraph/agent/llm_agent.py | 251 ++++++++++------ src/chemgraph/graphs/graspa_agent.py | 182 ------------ src/chemgraph/graphs/mock_agent.py | 102 ------- src/chemgraph/graphs/python_relp_agent.py | 224 --------------- src/chemgraph/graphs/rag_agent.py | 245 ---------------- .../graphs/single_agent_architector.py | 143 --------- src/chemgraph/graphs/single_agent_mcp.py | 116 -------- src/chemgraph/graphs/single_agent_xanes.py | 272 ------------------ tests/test_agent_session.py | 137 ++++----- tests/test_graph_constructors.py | 99 +------ tests/test_graphs.py | 208 +++++++++++--- tests/test_llm_agent.py | 16 +- 12 files changed, 426 insertions(+), 1569 deletions(-) delete mode 100644 src/chemgraph/graphs/graspa_agent.py delete mode 100644 src/chemgraph/graphs/mock_agent.py delete mode 100644 src/chemgraph/graphs/python_relp_agent.py delete mode 100644 src/chemgraph/graphs/rag_agent.py delete mode 100644 src/chemgraph/graphs/single_agent_architector.py delete mode 100644 src/chemgraph/graphs/single_agent_mcp.py delete mode 100644 src/chemgraph/graphs/single_agent_xanes.py diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index dbe55245..10a18902 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -42,30 +42,119 @@ planner_prompt as default_planner_prompt, ) from langgraph.errors import GraphInterrupt +from langchain_core.messages import AIMessage from langchain_core.callbacks import BaseCallbackHandler from chemgraph.graphs.single_agent import construct_single_agent_graph - - -from chemgraph.graphs.python_relp_agent import construct_relp_graph from chemgraph.graphs.multi_agent import construct_multi_agent_graph -from chemgraph.graphs.graspa_agent import construct_graspa_graph -from chemgraph.graphs.mock_agent import construct_mock_agent_graph -from chemgraph.graphs.single_agent_mcp import construct_single_agent_mcp_graph from chemgraph.graphs.graspa_mcp import construct_graspa_mcp_graph -from chemgraph.graphs.rag_agent import construct_rag_agent_graph -from chemgraph.graphs.single_agent_xanes import construct_single_agent_xanes_graph from chemgraph.prompt.rag_prompt import rag_agent_prompt from chemgraph.prompt.xanes_prompt import ( xanes_single_agent_prompt as default_xanes_single_agent_prompt, xanes_formatter_prompt as default_xanes_formatter_prompt, ) +from chemgraph.tools.ase_tools import ( + file_to_atomsdata, + run_ase, + save_atomsdata_to_file, +) +from chemgraph.tools.cheminformatics_tools import ( + molecule_name_to_smiles, + smiles_to_atomsdata, + smiles_to_coordinate_file, +) +from chemgraph.tools.generic_tools import calculator, repl_tool +from chemgraph.tools.graspa_tools import run_graspa +from chemgraph.tools.rag_tools import load_document, query_knowledge_base +from chemgraph.tools.xanes_tools import ( + fetch_xanes_data, + plot_xanes_data, + run_xanes, +) import logging logger = logging.getLogger(__name__) +SINGLE_AGENT_TURN_WORKFLOWS = { + "single_agent", + "python_relp", + "graspa", + "mock_agent", + "single_agent_mcp", + "rag_agent", + "single_agent_xanes", +} + +LEGACY_GRAPH_WORKFLOWS = {"multi_agent", "graspa_mcp"} + + +def _tool_name(tool: Any) -> str: + return str(getattr(tool, "name", getattr(tool, "__name__", repr(tool)))) + + +def _merge_tools(*groups: Collection[Any] | None) -> list[Any]: + """Merge tool groups by visible tool name while preserving order.""" + merged: list[Any] = [] + seen: set[str] = set() + for group in groups: + for tool in group or (): + name = _tool_name(tool) + if name not in seen: + merged.append(tool) + seen.add(name) + return merged + + +def _xanes_tools() -> list[Any]: + return [ + molecule_name_to_smiles, + smiles_to_coordinate_file, + run_ase, + run_xanes, + fetch_xanes_data, + plot_xanes_data, + ] + + +def _rag_tools() -> list[Any]: + return [ + load_document, + query_knowledge_base, + file_to_atomsdata, + smiles_to_coordinate_file, + run_ase, + molecule_name_to_smiles, + save_atomsdata_to_file, + calculator, + ] + + +def _mock_tools() -> list[Any]: + return [ + file_to_atomsdata, + smiles_to_atomsdata, + run_ase, + molecule_name_to_smiles, + save_atomsdata_to_file, + calculator, + ] + + +def _last_ai_message(state: dict[str, Any], fallback_text: str) -> AIMessage: + """Return the last AI message from a turn state, preserving objects when present.""" + messages = state.get("messages", []) if isinstance(state, dict) else [] + for message in reversed(messages): + if isinstance(message, AIMessage): + return message + if isinstance(message, dict): + message_type = message.get("type") or message.get("role") + if message_type in {"ai", "assistant"}: + return AIMessage(content=_message_text(message)) + return AIMessage(content=fallback_text) + + def _is_mock_object(value) -> bool: """Return True for unittest.mock objects without importing test-only APIs. @@ -817,6 +906,16 @@ def __init__( logger.error(f"Exception thrown when loading {model_name}: {str(e)}") raise e + supported_workflows = SINGLE_AGENT_TURN_WORKFLOWS | LEGACY_GRAPH_WORKFLOWS + if workflow_type not in supported_workflows: + raise ValueError( + f"Unsupported workflow type: {workflow_type}. " + f"Available types: {sorted(supported_workflows)}" + ) + + self._using_default_system_prompt = system_prompt == single_agent_prompt + self._using_default_formatter_prompt = formatter_prompt == default_formatter_prompt + self.workflow_type = workflow_type self.model_name = model_name self.base_url = base_url @@ -839,6 +938,7 @@ def __init__( self.human_input_handler = human_input_handler self.human_supervised = human_supervised self.terminal_tool_names = tuple(terminal_tool_names) + self._last_run_state: dict[str, Any] | None = None # When human supervision is disabled and the caller is using the # default system prompt, strip the ask_human instructions so the @@ -879,36 +979,14 @@ def append_calculator_context(prompt: str) -> str: self.support_structured_output = support_structured_output self.workflow_map = { - "single_agent": {"constructor": construct_single_agent_graph}, "multi_agent": {"constructor": construct_multi_agent_graph}, - "python_relp": {"constructor": construct_relp_graph}, - "graspa": {"constructor": construct_graspa_graph}, - "mock_agent": {"constructor": construct_mock_agent_graph}, - "single_agent_mcp": {"constructor": construct_single_agent_mcp_graph}, "graspa_mcp": {"constructor": construct_graspa_mcp_graph}, - "rag_agent": {"constructor": construct_rag_agent_graph}, - "single_agent_xanes": {"constructor": construct_single_agent_xanes_graph}, } - if workflow_type not in self.workflow_map: - raise ValueError( - f"Unsupported workflow type: {workflow_type}. Available types: {list(self.workflow_map.keys())}" - ) + self.tools = self._resolve_turn_tools(tools, data_tools) + self._resolve_turn_prompts() - if self.workflow_type == "single_agent": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm, - self.system_prompt, - self.structured_output, - self.formatter_prompt, - self.generate_report, - self.report_prompt, - self.tools, - max_retries=self.max_retries, - human_supervised=self.human_supervised, - terminal_tool_names=self.terminal_tool_names, - ) - elif self.workflow_type == "multi_agent": + if self.workflow_type == "multi_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( llm, planner_prompt=self.planner_prompt, @@ -918,55 +996,51 @@ def append_calculator_context(prompt: str) -> str: formatter_prompt=self.formatter_multi_prompt, max_retries=self.max_retries, ) - elif self.workflow_type == "python_relp": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm, - self.system_prompt, - ) - elif self.workflow_type == "graspa": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm, - self.system_prompt, - self.structured_output, - self.formatter_prompt, - ) - elif self.workflow_type == "mock_agent": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm=llm, - system_prompt=self.system_prompt, - ) - elif self.workflow_type == "single_agent_mcp": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm=llm, - system_prompt=self.system_prompt, - tools=self.tools, - ) elif self.workflow_type == "graspa_mcp": self.workflow = self.workflow_map[workflow_type]["constructor"]( llm=llm, executor_tools=self.tools, analysis_tools=self.data_tools, ) - elif self.workflow_type == "rag_agent": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm=llm, - system_prompt=self.system_prompt - if self.system_prompt != single_agent_prompt - else rag_agent_prompt, - tools=self.tools, - ) + else: + self.workflow = None + + def _resolve_turn_tools( + self, + tools: Collection[Any] | None, + data_tools: Collection[Any] | None, + ) -> list[Any] | None: + """Resolve the LangGraph tools for run_turn-backed workflows.""" + if self.workflow_type == "single_agent": + return list(tools) if tools is not None else None + if self.workflow_type == "python_relp": + return _merge_tools(tools, [repl_tool, calculator]) + if self.workflow_type == "graspa": + return _merge_tools(tools, [run_graspa]) + if self.workflow_type == "mock_agent": + return _merge_tools(tools, _mock_tools()) + if self.workflow_type == "single_agent_mcp": + resolved = _merge_tools(tools, data_tools) + if not resolved: + raise ValueError( + "No MCP tools loaded. Ensure MCP servers are configured and reachable." + ) + return resolved + if self.workflow_type == "rag_agent": + return _merge_tools(tools, _rag_tools()) + if self.workflow_type == "single_agent_xanes": + return _merge_tools(tools, _xanes_tools()) + return list(tools) if tools is not None else None + + def _resolve_turn_prompts(self) -> None: + """Apply workflow-specific prompt defaults before run_turn.""" + if self.workflow_type == "rag_agent" and self._using_default_system_prompt: + self.system_prompt = rag_agent_prompt elif self.workflow_type == "single_agent_xanes": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm, - system_prompt=self.system_prompt - if self.system_prompt != single_agent_prompt - else default_xanes_single_agent_prompt, - structured_output=self.structured_output, - formatter_prompt=self.formatter_prompt - if self.formatter_prompt != default_formatter_prompt - else default_xanes_formatter_prompt, - tools=self.tools, - ) + if self._using_default_system_prompt: + self.system_prompt = default_xanes_single_agent_prompt + if self._using_default_formatter_prompt: + self.formatter_prompt = default_xanes_formatter_prompt def visualize(self, method: str = "ascii"): """Visualize the LangGraph graph structure. @@ -991,6 +1065,11 @@ def visualize(self, method: str = "ascii"): Requires IPython and nest_asyncio to be installed. The visualization uses Mermaid diagrams with custom styling. """ + if self.workflow is None: + raise RuntimeError( + f"Workflow {self.workflow_type!r} is run-turn-backed and is built " + "inside ChemGraph.run(); it is not available for pre-run visualization." + ) import nest_asyncio from IPython.display import Image, display from langchain_core.runnables.graph import ( @@ -1034,6 +1113,12 @@ def get_state(self, config={"configurable": {"thread_id": "1"}}): list List of messages in the current state """ + if self.workflow is None: + if self._last_run_state is None: + raise RuntimeError( + f"Workflow {self.workflow_type!r} has not produced state yet." + ) + return self._last_run_state return self.workflow.get_state(config).values def write_state( @@ -1307,9 +1392,12 @@ async def run( resume_from: Optional[str] = None, ): """ - Async-only runner. Requires `self.workflow.astream(...)`. - Streams values, logs new messages, writes state, and returns according to - `self.return_option` ("last_message" or "state"). + Async runner for run-turn-backed and legacy graph-backed workflows. + + Run-turn-backed workflows delegate to :func:`run_turn`, while legacy + multi-node graph workflows stream through ``self.workflow.astream``. + The return value follows ``self.return_option`` ("last_message" or + "state"). When the graph pauses for human input (via ``interrupt()``), the ``human_input_handler`` callback is invoked to obtain the user's @@ -1351,7 +1439,7 @@ async def run( logger.info(f"Injected context from session {resume_from}") thread_id = str(config["configurable"]["thread_id"]) - if self.workflow_type == "single_agent": + if self.workflow_type in SINGLE_AGENT_TURN_WORKFLOWS: result = await run_turn( query=query, tools=self.tools, @@ -1369,11 +1457,13 @@ async def run( terminal_tool_names=self.terminal_tool_names, human_supervised=self.human_supervised, ) + self._last_run_state = result.state self._save_messages_to_store(result.state, query) + self.write_state(config=config, file_path=None) if self.return_option == "state": return result.state if self.return_option == "last_message": - return result.final_text + return _last_ai_message(result.state, result.final_text) raise ValueError( f"Unsupported return_option: {self.return_option}. " "Use 'last_message' or 'state'." @@ -1396,6 +1486,7 @@ async def run( last_state = state if last_state is None: raise RuntimeError("Workflow produced no states") + self._last_run_state = serialize_state(last_state) self._save_messages_to_store(last_state, query) self.write_state(config=config, file_path=None) if self.return_option == "state": diff --git a/src/chemgraph/graphs/graspa_agent.py b/src/chemgraph/graphs/graspa_agent.py deleted file mode 100644 index 578db3d1..00000000 --- a/src/chemgraph/graphs/graspa_agent.py +++ /dev/null @@ -1,182 +0,0 @@ - - -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode - -from chemgraph.tools.graspa_tools import run_graspa -from chemgraph.schemas.agent_response import ResponseFormatter -from chemgraph.prompt.single_agent_prompt import ( - single_agent_prompt, - formatter_prompt, -) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - - -def route_tools(state: State): - """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - - Returns - ------- - str - Either 'tools' or 'done' based on the state conditions - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "done" - - -def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - - # Load default tools if no tool is specified. - if tools is None: - tools = [run_graspa] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - - -def ResponseAgent(state: State, llm: ChatOpenAI, formatter_prompt: str): - """An LLM agent responsible for formatting final messag - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for formatting - formatter_prompt : str - The prompt to guide the LLM's formatting behavior - - Returns - ------- - dict - Updated state containing the formatted response - """ - messages = [ - {"role": "system", "content": formatter_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_structured_output = llm.with_structured_output(ResponseFormatter) - response = llm_structured_output.invoke(messages).model_dump_json() - return {"messages": [response]} - - -def construct_graspa_graph( - llm: ChatOpenAI, - system_prompt: str = single_agent_prompt, - structured_output: bool = False, - formatter_prompt: str = formatter_prompt, - tools: list = None, -): - """Construct a geometry optimization graph. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, by default single_agent_prompt - structured_output : bool, optional - Whether to use structured output, by default False - formatter_prompt : str, optional - The prompt to guide the LLM's formatting behavior, by default formatter_prompt - tool: list, optional - The list of tools for the agent, by default None - Returns - ------- - StateGraph - The constructed geometry optimization graph - """ - try: - logger.info("Constructing gRASPA graph") - checkpointer = MemorySaver() - if tools is None: - tools = [run_graspa] - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - if not structured_output: - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent( - state, llm, system_prompt=system_prompt, tools=tools - ), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_conditional_edges( - "ChemGraphAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "ChemGraphAgent") - graph_builder.add_edge(START, "ChemGraphAgent") - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("gRASPA graph construction completed") - return graph - else: - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent( - state, llm, system_prompt=system_prompt, tools=tools - ), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_node( - "ResponseAgent", - lambda state: ResponseAgent( - state, llm, formatter_prompt=formatter_prompt - ), - ) - graph_builder.add_conditional_edges( - "ChemGraphAgent", - route_tools, - {"tools": "tools", "done": "ResponseAgent"}, - ) - graph_builder.add_edge("tools", "ChemGraphAgent") - graph_builder.add_edge(START, "ChemGraphAgent") - graph_builder.add_edge("ResponseAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("gRASPA graph construction completed") - return graph - - except Exception as e: - logger.error(f"Error constructing graph: {str(e)}") - raise diff --git a/src/chemgraph/graphs/mock_agent.py b/src/chemgraph/graphs/mock_agent.py deleted file mode 100644 index d10441e7..00000000 --- a/src/chemgraph/graphs/mock_agent.py +++ /dev/null @@ -1,102 +0,0 @@ -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from chemgraph.tools.ase_tools import ( - run_ase, - save_atomsdata_to_file, - file_to_atomsdata, -) -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_atomsdata, -) -from chemgraph.tools.generic_tools import calculator -from chemgraph.prompt.single_agent_prompt import ( - single_agent_prompt, -) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - - -def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - - # Load default tools if no tool is specified. - if tools is None: - tools = [ - file_to_atomsdata, - smiles_to_atomsdata, - run_ase, - molecule_name_to_smiles, - save_atomsdata_to_file, - calculator, - ] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - -def construct_mock_agent_graph( - llm: ChatOpenAI, - system_prompt: str = single_agent_prompt, - tools: list = None, -): - """Construct a geometry optimization graph. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, by default single_agent_prompt - tools: list, optional - The list of tools for the main agent, by default None - Returns - ------- - StateGraph - The constructed single agent graph - """ - logger.info("Constructing mock agent graph") - checkpointer = MemorySaver() - if tools is None: - tools = [ - file_to_atomsdata, - smiles_to_atomsdata, - run_ase, - molecule_name_to_smiles, - save_atomsdata_to_file, - calculator, - ] - graph_builder = StateGraph(State) - - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent(state, llm, system_prompt=system_prompt, tools=tools), - ) - graph_builder.add_edge(START, "ChemGraphAgent") - graph_builder.add_edge("ChemGraphAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("Mock agent graph construction completed") - return graph diff --git a/src/chemgraph/graphs/python_relp_agent.py b/src/chemgraph/graphs/python_relp_agent.py deleted file mode 100644 index dd8edf98..00000000 --- a/src/chemgraph/graphs/python_relp_agent.py +++ /dev/null @@ -1,224 +0,0 @@ -from typing import Annotated -from typing_extensions import TypedDict - -from langgraph.graph import StateGraph, START, END -from langgraph.graph.message import add_messages -from langchain_core.messages import ToolMessage -import json -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from chemgraph.tools.generic_tools import repl_tool -from chemgraph.tools.generic_tools import calculator -from chemgraph.prompt.single_agent_prompt import single_agent_prompt -from chemgraph.utils.logging_config import setup_logger - -logger = setup_logger(__name__) - - -class State(TypedDict): - """Type definition for the state dictionary used in the graph. - - Attributes - ---------- - messages : list - List of messages in the conversation, annotated with add_messages - """ - - messages: Annotated[list, add_messages] - - -class BasicToolNode: - """A node that executes tools requested in the last AIMessage. - - This class processes tool calls from AI messages and executes the corresponding - tools, handling their results and any potential errors. - - Parameters - ---------- - tools : list - List of tool objects that can be called by the node - - Attributes - ---------- - tools_by_name : dict - Dictionary mapping tool names to their corresponding tool objects - """ - - def __init__(self, tools: list) -> None: - """Initialize the tool node. - - Parameters - ---------- - tools : list - Tool objects keyed by their ``name`` attribute. - """ - self.tools_by_name = {tool.name: tool for tool in tools} - - def __call__(self, inputs: State) -> State: - """Execute tools requested in the last message. - - Parameters - ---------- - inputs : State - The current state containing messages - - Returns - ------- - State - Updated state containing tool execution results - - Raises - ------ - ValueError - If no message is found in the input state - """ - if messages := inputs.get("messages", []): - message = messages[-1] - else: - raise ValueError("No message found in input") - - outputs = [] - for tool_call in message.tool_calls: - try: - tool_name = tool_call.get("name") - if not tool_name or tool_name not in self.tools_by_name: - raise ValueError(f"Invalid tool name: {tool_name}") - - tool_result = self.tools_by_name[tool_name].invoke(tool_call.get("args", {})) - - # Handle different types of tool results - result_content = ( - tool_result.dict() - if hasattr(tool_result, "dict") - else (tool_result if isinstance(tool_result, dict) else str(tool_result)) - ) - - outputs.append( - ToolMessage( - content=json.dumps(result_content), - name=tool_name, - tool_call_id=tool_call.get("id", ""), - ) - ) - - except Exception as e: - outputs.append( - ToolMessage( - content=json.dumps({"error": str(e)}), - name=tool_name if tool_name else "unknown_tool", - tool_call_id=tool_call.get("id", ""), - ) - ) - return {"messages": outputs} - - -def route_tools(state: State): - """Route to the 'tools' node if the last message has tool calls; otherwise, route to END. - - Parameters - ---------- - state : State - The current state containing messages - - Returns - ------- - str - Either 'tools' or END based on the presence of tool calls - - Raises - ------ - ValueError - If no messages are found in the input state - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return END - - -def CompChemAgent(state: State, llm: ChatOpenAI, system_prompt=single_agent_prompt, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages - llm : ChatOpenAI - The language model to use for processing - system_prompt : str, optional - The system prompt to guide the LLM's behavior, - by default single_agent_prompt - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - if tools is None: - tools = [] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - - -def construct_relp_graph(llm: ChatOpenAI, system_prompt=single_agent_prompt): - """Construct a graph for REPL-based Python execution workflow. - - This function creates a state graph that implements a workflow for executing - Python code through a REPL interface, using LLM agents and tools. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use in the workflow - system_prompt : str, optional - The system prompt to guide the LLM's behavior, - by default single_agent_prompt - - Returns - ------- - StateGraph - A compiled state graph implementing the REPL workflow - - Raises - ------ - Exception - If there is an error during graph construction - """ - try: - logger.info("Constructing geometry optimization graph") - checkpointer = MemorySaver() - tools = [ - repl_tool, - calculator, - ] - tool_node = BasicToolNode(tools=tools) - graph_builder = StateGraph(State) - graph_builder.add_node( - "CompChemAgent", - lambda state: CompChemAgent(state, llm, system_prompt=system_prompt, tools=tools), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_conditional_edges( - "CompChemAgent", - route_tools, - {"tools": "tools", END: END}, - ) - graph_builder.add_edge("tools", "CompChemAgent") - graph_builder.add_edge(START, "CompChemAgent") - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("Graph construction completed") - return graph - except Exception as e: - logger.error(f"Error constructing graph: {str(e)}") - raise diff --git a/src/chemgraph/graphs/rag_agent.py b/src/chemgraph/graphs/rag_agent.py deleted file mode 100644 index 91611166..00000000 --- a/src/chemgraph/graphs/rag_agent.py +++ /dev/null @@ -1,245 +0,0 @@ -"""LangGraph workflow for the RAG (Retrieval-Augmented Generation) agent. - -This graph combines document retrieval tools (load_document, -query_knowledge_base) with the standard chemistry tools so the agent -can answer questions grounded in user-provided text documents *and* -run molecular simulations when needed. - -Graph structure ---------------- - - START - | - v - RAGAgent <-------+ - | | - (route) | - / \\ | - v v | - tools done-->END | - | | - +----------------+ - -The agent loops through a ReAct cycle: it can call any combination of -RAG tools and chemistry tools, inspect the results, and decide whether -to call more tools or produce a final answer. -""" - -from langgraph.graph import StateGraph, START, END -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode - -from chemgraph.tools.rag_tools import load_document, query_knowledge_base -from chemgraph.tools.ase_tools import ( - run_ase, - save_atomsdata_to_file, - file_to_atomsdata, -) -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_coordinate_file, -) -from chemgraph.tools.generic_tools import calculator -from chemgraph.prompt.rag_prompt import rag_agent_prompt -from chemgraph.state.state import State -from chemgraph.utils.logging_config import setup_logger - -logger = setup_logger(__name__) - - -# --------------------------------------------------------------------------- -# Helpers (reuse the repeated-tool-call detection from single_agent) -# --------------------------------------------------------------------------- -def _tool_call_signature(tool_calls) -> tuple: - """Create a comparable signature for a list of tool calls. - - Parameters - ---------- - tool_calls : list - Tool-call dictionaries from an AI message. - - Returns - ------- - tuple - Deterministic signature of tool names and arguments. - """ - signature = [] - for call in tool_calls or []: - name = call.get("name") if isinstance(call, dict) else None - args = call.get("args", {}) if isinstance(call, dict) else {} - if isinstance(args, dict): - args_sig = tuple(sorted(args.items())) - else: - args_sig = str(args) - signature.append((name, args_sig)) - return tuple(signature) - - -def _is_repeated_tool_cycle(messages) -> bool: - """Detect if the most recent AI tool-call set repeats the previous one. - - Parameters - ---------- - messages : list - Message history to inspect. - - Returns - ------- - bool - ``True`` when the last two AI tool-call sets are identical. - """ - ai_with_calls = [ - m - for m in messages - if hasattr(m, "tool_calls") and getattr(m, "tool_calls", None) - ] - if len(ai_with_calls) < 2: - return False - last = _tool_call_signature(ai_with_calls[-1].tool_calls) - prev = _tool_call_signature(ai_with_calls[-2].tool_calls) - return bool(last) and last == prev - - -# --------------------------------------------------------------------------- -# Routing -# --------------------------------------------------------------------------- -def route_tools(state: State): - """Route to 'tools' if the last message has tool calls, else 'done'. - - Parameters - ---------- - state : State - Current graph state. - - Returns - ------- - str - ``"tools"`` or ``"done"``. - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state: {state}") - - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - if not isinstance(state, list) and _is_repeated_tool_cycle(messages): - return "done" - return "tools" - return "done" - - -# --------------------------------------------------------------------------- -# Agent node -# --------------------------------------------------------------------------- -def RAGAgent(state: State, llm, system_prompt: str, tools=None): - """LLM node that can retrieve from documents and run chemistry tools. - - Parameters - ---------- - state : State - Current graph state with messages. - llm : BaseChatModel - The bound language model. - system_prompt : str - System prompt guiding the agent's behaviour. - tools : list, optional - Tools available to the agent. Uses the default RAG + chemistry - tool set when ``None``. - - Returns - ------- - dict - Updated state with the LLM's response appended to messages. - """ - if tools is None: - tools = _default_tools() - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - - -# --------------------------------------------------------------------------- -# Default tool set -# --------------------------------------------------------------------------- -def _default_tools(): - """Return the combined RAG + chemistry tool list.""" - return [ - # RAG tools - load_document, - query_knowledge_base, - # Chemistry tools - file_to_atomsdata, - smiles_to_coordinate_file, - run_ase, - molecule_name_to_smiles, - save_atomsdata_to_file, - calculator, - ] - - -# --------------------------------------------------------------------------- -# Graph constructor -# --------------------------------------------------------------------------- -def construct_rag_agent_graph( - llm, - system_prompt: str = rag_agent_prompt, - tools: list = None, -): - """Construct a RAG agent graph with document retrieval and chemistry tools. - - Parameters - ---------- - llm : BaseChatModel - The language model to power the agent. - system_prompt : str, optional - System prompt for the RAG agent, by default ``rag_agent_prompt``. - tools : list, optional - Custom tool list. When ``None`` the default RAG + chemistry - tools are used. - - Returns - ------- - CompiledStateGraph - The compiled LangGraph workflow ready for execution. - """ - try: - logger.info("Constructing RAG agent graph") - checkpointer = MemorySaver() - - if tools is None: - tools = _default_tools() - - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - # Nodes - graph_builder.add_node( - "RAGAgent", - lambda state: RAGAgent( - state, llm, system_prompt=system_prompt, tools=tools - ), - ) - graph_builder.add_node("tools", tool_node) - - # Edges - graph_builder.add_edge(START, "RAGAgent") - graph_builder.add_conditional_edges( - "RAGAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "RAGAgent") - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("RAG agent graph construction completed") - return graph - - except Exception as e: - logger.error(f"Error constructing RAG agent graph: {e}") - raise diff --git a/src/chemgraph/graphs/single_agent_architector.py b/src/chemgraph/graphs/single_agent_architector.py deleted file mode 100644 index 9e61747d..00000000 --- a/src/chemgraph/graphs/single_agent_architector.py +++ /dev/null @@ -1,143 +0,0 @@ -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_coordinate_file, -) - -from chemgraph.tools.architector_tools import ( - visualize_molecule, - image_to_connection_points, - build_metal_complex -) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - -single_agent_prompt = "" - -def route_tools(state: State): - """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - - Returns - ------- - str - Either 'tools' or 'done' based on the state conditions - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "done" - - -def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - - # Load default tools if no tool is specified. - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - visualize_molecule, - image_to_connection_points, - build_metal_complex - ] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - -def construct_single_agent_architector_graph( - llm: ChatOpenAI, - system_prompt: str = "", - tools: list = None, -): - """Construct a geometry optimization graph. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, by default single_agent_prompt - structured_output : bool, optional - Whether to use structured output, by default False - formatter_prompt : str, optional - The prompt to guide the LLM's formatting behavior, by default formatter_prompt - generate_report: bool, optional - Whether to generate a report, by default False - report_prompt: str, optional - The prompt to guide the LLM's report generation behavior, by default report_prompt - tool: list, optional - The list of tools for the main agent, by default None - Returns - ------- - StateGraph - The constructed single agent graph - """ - try: - logger.info("Constructing single agent graph") - checkpointer = MemorySaver() - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - visualize_molecule, - image_to_connection_points, - build_metal_complex - ] - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent(state, llm, system_prompt=system_prompt, tools=tools), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_edge(START, "ChemGraphAgent") - graph_builder.add_conditional_edges( - "ChemGraphAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "ChemGraphAgent") - graph_builder.add_edge("ChemGraphAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("Graph construction completed") - return graph - except Exception as e: - logger.error(f"Error constructing graph: {str(e)}") - raise diff --git a/src/chemgraph/graphs/single_agent_mcp.py b/src/chemgraph/graphs/single_agent_mcp.py deleted file mode 100644 index f858a9c0..00000000 --- a/src/chemgraph/graphs/single_agent_mcp.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import List, Any - -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.prebuilt import ToolNode -from langgraph.checkpoint.memory import MemorySaver - -from chemgraph.prompt.single_agent_prompt import ( - single_agent_prompt, -) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - - -def route_tools(state: State) -> str: - """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - - Returns - ------- - str - Either 'tools' or 'done' based on the state conditions - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "done" - - -def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - - -def construct_single_agent_mcp_graph( - llm: ChatOpenAI, - system_prompt: str = single_agent_prompt, - tools: List[Any] = None, -): - """Construct a geometry optimization graph. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, by default single_agent_prompt - Returns - ------- - StateGraph - The constructed single agent graph - """ - if not tools: - raise ValueError( - "No MCP tools loaded. Ensure MCP servers are configured and reachable." - ) - logger.info("Constructing single agent MCP graph (sync)") - - checkpointer = MemorySaver() - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent( - state, llm, system_prompt=system_prompt, tools=tools - ), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_edge(START, "ChemGraphAgent") - - graph_builder.add_conditional_edges( - "ChemGraphAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "ChemGraphAgent") - graph_builder.add_edge("ChemGraphAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("Graph construction completed") - return graph diff --git a/src/chemgraph/graphs/single_agent_xanes.py b/src/chemgraph/graphs/single_agent_xanes.py deleted file mode 100644 index 1c3935d8..00000000 --- a/src/chemgraph/graphs/single_agent_xanes.py +++ /dev/null @@ -1,272 +0,0 @@ -import os - -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_coordinate_file, -) -from chemgraph.tools.ase_tools import run_ase -from chemgraph.tools.xanes_tools import ( - run_xanes, - fetch_xanes_data, - plot_xanes_data,) -from chemgraph.schemas.agent_response import ResponseFormatter -from chemgraph.prompt.xanes_prompt import ( - xanes_single_agent_prompt, - xanes_formatter_prompt, -) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - - -def _tool_call_signature(tool_calls) -> tuple: - """Create a comparable signature for a list of tool calls. - - Parameters - ---------- - tool_calls : list - Tool-call dictionaries from an AI message. - - Returns - ------- - tuple - Deterministic signature of tool names and arguments. - """ - signature = [] - for call in tool_calls or []: - name = call.get("name") if isinstance(call, dict) else None - args = call.get("args", {}) if isinstance(call, dict) else {} - if isinstance(args, dict): - args_sig = tuple(sorted(args.items())) - else: - args_sig = str(args) - signature.append((name, args_sig)) - return tuple(signature) - - -def _is_repeated_tool_cycle(messages) -> bool: - """Detect if the most recent AI tool-call set repeats the previous one. - - Parameters - ---------- - messages : list - Message history to inspect. - - Returns - ------- - bool - ``True`` when the last two AI tool-call sets are identical. - """ - ai_with_calls = [] - for message in messages: - if hasattr(message, "tool_calls") and getattr(message, "tool_calls", None): - ai_with_calls.append(message) - - if len(ai_with_calls) < 2: - return False - - last_calls = _tool_call_signature(ai_with_calls[-1].tool_calls) - prev_calls = _tool_call_signature(ai_with_calls[-2].tool_calls) - return bool(last_calls) and last_calls == prev_calls - - -def route_tools(state: State): - """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - - Returns - ------- - str - Either 'tools' or 'done' based on the state conditions - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - if not isinstance(state, list) and _is_repeated_tool_cycle(messages): - return "done" - return "tools" - return "done" - - -def XANESAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node for XANES workflows that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - run_ase, - run_xanes, - fetch_xanes_data, - plot_xanes_data, - ] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - - -def ResponseAgent(state: State, llm: ChatOpenAI, formatter_prompt: str): - """An LLM agent responsible for formatting final message. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for formatting - formatter_prompt : str - The prompt to guide the LLM's formatting behavior - - Returns - ------- - dict - Updated state containing the formatted response - """ - messages = [ - {"role": "system", "content": formatter_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_structured_output = llm.with_structured_output(ResponseFormatter) - response = llm_structured_output.invoke(messages).model_dump_json() - return {"messages": [response]} - - -def construct_single_agent_xanes_graph( - llm: ChatOpenAI, - system_prompt: str = xanes_single_agent_prompt, - structured_output: bool = False, - formatter_prompt: str = xanes_formatter_prompt, - tools: list = None, -): - """Construct a single-agent graph for XANES/FDMNES workflows. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, - by default xanes_single_agent_prompt - structured_output : bool, optional - Whether to use structured output, by default False - formatter_prompt : str, optional - The prompt to guide the LLM's formatting behavior, - by default xanes_formatter_prompt - tools : list, optional - The list of tools for the main agent, by default None - - Returns - ------- - StateGraph - The constructed single agent XANES graph - """ - try: - logger.info("Constructing single agent XANES graph") - - if not os.environ.get("MP_API_KEY"): - logger.warning( - "MP_API_KEY environment variable is not set. " - "The fetch_xanes_data tool will require an API key " - "to be passed explicitly." - ) - if not os.environ.get("FDMNES_EXE"): - logger.warning( - "FDMNES_EXE environment variable is not set. " - "The run_xanes tool will not work without the FDMNES executable." - ) - - checkpointer = MemorySaver() - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - run_ase, - run_xanes, - fetch_xanes_data, - plot_xanes_data, - ] - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - if not structured_output: - graph_builder.add_node( - "XANESAgent", - lambda state: XANESAgent( - state, llm, system_prompt=system_prompt, tools=tools - ), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_edge(START, "XANESAgent") - graph_builder.add_conditional_edges( - "XANESAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "XANESAgent") - graph_builder.add_edge("XANESAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("XANES graph construction completed") - return graph - else: - graph_builder.add_node( - "XANESAgent", - lambda state: XANESAgent( - state, llm, system_prompt=system_prompt, tools=tools - ), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_node( - "ResponseAgent", - lambda state: ResponseAgent( - state, llm, formatter_prompt=formatter_prompt - ), - ) - graph_builder.add_conditional_edges( - "XANESAgent", - route_tools, - {"tools": "tools", "done": "ResponseAgent"}, - ) - graph_builder.add_edge("tools", "XANESAgent") - graph_builder.add_edge(START, "XANESAgent") - graph_builder.add_edge("ResponseAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("XANES graph construction completed") - return graph - - except Exception as e: - logger.error(f"Error constructing XANES graph: {str(e)}") - raise diff --git a/tests/test_agent_session.py b/tests/test_agent_session.py index f646c33d..a9799c97 100644 --- a/tests/test_agent_session.py +++ b/tests/test_agent_session.py @@ -16,7 +16,7 @@ import pytest from unittest.mock import Mock, patch -from chemgraph.agent.llm_agent import ChemGraph, serialize_state +from chemgraph.agent.llm_agent import ChemGraph, TurnResult, serialize_state from chemgraph.memory.store import SessionStore @@ -46,14 +46,28 @@ def tmp_db(tmp_path): @pytest.fixture def mock_agent_patches(): - """Patch LLM loading and graph construction for fast agent creation.""" + """Patch LLM loading and run_turn for fast agent creation.""" with ( patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load, - patch("chemgraph.agent.llm_agent.construct_single_agent_graph") as mock_graph, + patch("chemgraph.agent.llm_agent.run_turn") as mock_run_turn, ): mock_load.return_value = Mock() - mock_graph.return_value = Mock() - yield mock_load, mock_graph + + async def default_run_turn(**kwargs): + ai_msg = Mock() + ai_msg.type = "ai" + ai_msg.content = "Test response" + return TurnResult( + final_text="Test response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) + + mock_run_turn.side_effect = default_run_turn + yield mock_load, mock_run_turn def _make_agent(clean_env, mock_agent_patches, tmp_db, **kwargs): @@ -62,6 +76,7 @@ def _make_agent(clean_env, mock_agent_patches, tmp_db, **kwargs): "model_name": "gpt-4o-mini", "enable_memory": True, "memory_db_path": tmp_db, + "log_dir": os.path.join(os.path.dirname(tmp_db), "logs"), } defaults.update(kwargs) agent = ChemGraph(**defaults) @@ -128,7 +143,7 @@ def test_uuid_set_when_log_dir_preset(self, mock_agent_patches, tmp_db): """uuid must be set even when CHEMGRAPH_LOG_DIR is already in env.""" os.environ["CHEMGRAPH_LOG_DIR"] = "/tmp/preset_log_dir" try: - agent = _make_agent(None, mock_agent_patches, tmp_db) + agent = _make_agent(None, mock_agent_patches, tmp_db, log_dir=None) assert agent.uuid is not None assert len(agent.uuid) == 8 assert agent.log_dir == "/tmp/preset_log_dir" @@ -350,8 +365,7 @@ def test_filename_includes_uuid( ): agent = _make_agent(clean_env, mock_agent_patches, tmp_db) - # Mock get_state to return something serializable - agent.workflow.get_state = Mock(return_value=Mock(values={"messages": []})) + agent._last_run_state = {"messages": []} log_dir = str(tmp_path / "test_logs") os.makedirs(log_dir, exist_ok=True) @@ -382,7 +396,7 @@ def test_no_overwrite_same_second( if "CHEMGRAPH_LOG_DIR" in os.environ: del os.environ["CHEMGRAPH_LOG_DIR"] a = _make_agent(clean_env, mock_agent_patches, tmp_db) - a.workflow.get_state = Mock(return_value=Mock(values={"messages": []})) + a._last_run_state = {"messages": []} a.log_dir = log_dir agents.append(a) @@ -402,23 +416,8 @@ def test_no_overwrite_same_second( class TestResumeFrom: def _make_streamable_agent(self, clean_env, mock_agent_patches, tmp_db): - """Create an agent with a mock async workflow.""" - agent = _make_agent(clean_env, mock_agent_patches, tmp_db) - - # Set up a mock astream that yields one state - ai_msg = Mock() - ai_msg.type = "ai" - ai_msg.content = "Test response" - ai_msg.pretty_print = Mock() - - final_state = {"messages": [ai_msg]} - - async def mock_astream(inputs, stream_mode, config): - yield final_state - - agent.workflow.astream = mock_astream - agent.workflow.get_state = Mock(return_value=Mock(values=final_state)) - return agent + """Create an agent whose run path is mocked through run_turn.""" + return _make_agent(clean_env, mock_agent_patches, tmp_db) @pytest.mark.asyncio async def test_resume_prepends_context(self, clean_env, mock_agent_patches, tmp_db): @@ -435,23 +434,24 @@ async def test_resume_prepends_context(self, clean_env, mock_agent_patches, tmp_ # Create second agent sharing the same DB agent2 = self._make_streamable_agent(clean_env, mock_agent_patches, tmp_db) - # Track what inputs are passed to astream + # Track what query is passed to run_turn. captured_inputs = [] - async def tracking_astream(inputs, stream_mode, config): - captured_inputs.append(inputs) + async def tracking_run_turn(**kwargs): + captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" ai_msg.content = "Follow-up response" - ai_msg.pretty_print = Mock() - yield {"messages": [ai_msg]} - - agent2.workflow.astream = tracking_astream - agent2.workflow.get_state = Mock( - return_value=Mock( - values={"messages": [Mock(type="ai", content="Follow-up")]} + return TurnResult( + final_text="Follow-up response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, ) - ) + + mock_agent_patches[1].side_effect = tracking_run_turn await agent2.run("Continue the analysis", resume_from=session_id) @@ -469,18 +469,21 @@ async def test_resume_from_nonexistent_session( captured_inputs = [] - async def tracking_astream(inputs, stream_mode, config): - captured_inputs.append(inputs) + async def tracking_run_turn(**kwargs): + captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" ai_msg.content = "Response" - ai_msg.pretty_print = Mock() - yield {"messages": [ai_msg]} + return TurnResult( + final_text="Response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) - agent.workflow.astream = tracking_astream - agent.workflow.get_state = Mock( - return_value=Mock(values={"messages": [Mock(type="ai", content="resp")]}) - ) + mock_agent_patches[1].side_effect = tracking_run_turn await agent.run("Hello", resume_from="nonexistent_id") @@ -495,21 +498,23 @@ async def test_resume_from_ignored_when_memory_disabled( ): agent = _make_agent(clean_env, mock_agent_patches, tmp_db, enable_memory=False) - ai_msg = Mock() - ai_msg.type = "ai" - ai_msg.content = "Response" - ai_msg.pretty_print = Mock() - captured_inputs = [] - async def tracking_astream(inputs, stream_mode, config): - captured_inputs.append(inputs) - yield {"messages": [ai_msg]} + async def tracking_run_turn(**kwargs): + captured_inputs.append({"messages": kwargs["query"]}) + ai_msg = Mock() + ai_msg.type = "ai" + ai_msg.content = "Response" + return TurnResult( + final_text="Response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) - agent.workflow.astream = tracking_astream - agent.workflow.get_state = Mock( - return_value=Mock(values={"messages": [ai_msg]}) - ) + mock_agent_patches[1].side_effect = tracking_run_turn await agent.run("Hello", resume_from="some_id") @@ -528,7 +533,6 @@ async def test_full_lifecycle(self, clean_env, mock_agent_patches, tmp_db): """init -> run -> messages saved -> load_previous_context -> resume""" agent = _make_agent(clean_env, mock_agent_patches, tmp_db) - # Set up mock workflow human_msg = Mock() human_msg.type = "human" human_msg.content = "Calculate energy of H2" @@ -540,11 +544,17 @@ async def test_full_lifecycle(self, clean_env, mock_agent_patches, tmp_db): final_state = {"messages": [human_msg, ai_msg]} - async def mock_astream(inputs, stream_mode, config): - yield final_state + async def mock_run_turn(**kwargs): + return TurnResult( + final_text=ai_msg.content, + state=final_state, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) - agent.workflow.astream = mock_astream - agent.workflow.get_state = Mock(return_value=Mock(values=final_state)) + mock_agent_patches[1].side_effect = mock_run_turn # Step 1: Run await agent.run("Calculate energy of H2") @@ -568,8 +578,7 @@ async def mock_astream(inputs, stream_mode, config): del os.environ["CHEMGRAPH_LOG_DIR"] agent2 = _make_agent(clean_env, mock_agent_patches, tmp_db) - agent2.workflow.astream = mock_astream - agent2.workflow.get_state = Mock(return_value=Mock(values=final_state)) + mock_agent_patches[1].side_effect = mock_run_turn await agent2.run("Now optimize H2", resume_from=agent.uuid) diff --git a/tests/test_graph_constructors.py b/tests/test_graph_constructors.py index 72e58b55..6efcaef3 100644 --- a/tests/test_graph_constructors.py +++ b/tests/test_graph_constructors.py @@ -1,94 +1,5 @@ -import pytest -from chemgraph.agent.llm_agent import ChemGraph - - -WORKFLOWS = [ - "single_agent", - "multi_agent", - "python_relp", - "graspa", - "mock_agent", - "single_agent_mcp", - "graspa_mcp", - "single_agent_xanes", -] - - -@pytest.mark.parametrize("workflow_type", WORKFLOWS) -def test_constructor_is_called(monkeypatch, workflow_type): - called = {} - - def fake_constructor(*args, **kwargs): - called["args"] = (args, kwargs) - return f"WORKFLOW-SENTINEL-{workflow_type}" - - # Patch the constructor name used by chemgraph.agent.llm_agent - constructor_attr = { - "single_agent": "construct_single_agent_graph", - "multi_agent": "construct_multi_agent_graph", - "python_relp": "construct_relp_graph", - "graspa": "construct_graspa_graph", - "mock_agent": "construct_mock_agent_graph", - "single_agent_mcp": "construct_single_agent_mcp_graph", - "graspa_mcp": "construct_graspa_mcp_graph", - "single_agent_xanes": "construct_single_agent_xanes_graph", - }[workflow_type] - - monkeypatch.setattr( - f"chemgraph.agent.llm_agent.{constructor_attr}", - fake_constructor, - ) - - # Ensure model loading is deterministic and doesn't call external APIs - monkeypatch.setattr( - "chemgraph.agent.llm_agent.load_openai_model", - lambda model_name, temperature, base_url=None: "FAKE_LLM", - ) - - # For MCP workflows some constructors expect tools; pass a non-empty list - kwargs = {} - if workflow_type in {"single_agent_mcp", "graspa_mcp"}: - kwargs["tools"] = ["DUMMY_TOOL"] - kwargs["data_tools"] = ["DUMMY_TOOL"] - - cg = ChemGraph( - model_name="gpt-4o-mini", - workflow_type=workflow_type, - enable_memory=False, - **kwargs, - ) - assert cg.workflow == f"WORKFLOW-SENTINEL-{workflow_type}" - args_tuple, kwargs_called = called["args"] - if args_tuple: - assert args_tuple[0] == "FAKE_LLM" - else: - assert kwargs_called.get("llm") == "FAKE_LLM" - - -def test_single_agent_initialization_injects_calculator_availability(monkeypatch): - called = {} - - def fake_constructor(*args, **kwargs): - called["args"] = (args, kwargs) - return "WORKFLOW-SENTINEL-single_agent" - - monkeypatch.setattr( - "chemgraph.agent.llm_agent.construct_single_agent_graph", - fake_constructor, - ) - monkeypatch.setattr( - "chemgraph.agent.llm_agent.load_openai_model", - lambda model_name, temperature, base_url=None: "FAKE_LLM", - ) - - cg = ChemGraph( - model_name="gpt-4o-mini", - workflow_type="single_agent", - enable_memory=False, - ) - - args_tuple, _ = called["args"] - system_prompt = args_tuple[1] - assert "Calculator availability detected during ChemGraph initialization" in system_prompt - assert cg.default_calculator in system_prompt - assert cg.default_calculator in cg.available_calculators +from tests.test_graphs import ( + test_legacy_graph_constructor_is_called, + test_run_turn_workflow_tool_and_prompt_wiring, + test_single_agent_initialization_injects_calculator_availability, +) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 86426c23..1bf66180 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -1,56 +1,178 @@ import pytest -from chemgraph.agent.llm_agent import ChemGraph +from langchain_core.messages import AIMessage -WORKFLOWS = [ - "single_agent", "multi_agent", "python_relp", "graspa", - "mock_agent", "single_agent_mcp", "graspa_mcp", -] +from chemgraph.agent import llm_agent +from chemgraph.agent.llm_agent import ChemGraph, TurnResult -@pytest.mark.parametrize("workflow_type", WORKFLOWS) -def test_constructor_is_called(monkeypatch, workflow_type): - called_data = {} - def fake_constructor(*args, **kwargs): - called_data["args"] = args - called_data["kwargs"] = kwargs +class _DummyTool: + def __init__(self, name): + self.name = name + + +def _tool_names(tools): + return [getattr(tool, "name", str(tool)) for tool in tools or []] + + +@pytest.mark.parametrize( + ("workflow_type", "constructor_attr", "kwargs"), + [ + ("multi_agent", "construct_multi_agent_graph", {}), + ( + "graspa_mcp", + "construct_graspa_mcp_graph", + {"tools": [_DummyTool("executor")], "data_tools": [_DummyTool("analysis")]}, + ), + ], +) +def test_legacy_graph_constructor_is_called( + monkeypatch, + tmp_path, + workflow_type, + constructor_attr, + kwargs, +): + called = {} + + def fake_constructor(*args, **constructor_kwargs): + called["args"] = args + called["kwargs"] = constructor_kwargs return f"WORKFLOW-SENTINEL-{workflow_type}" - mapping = { - "single_agent": "construct_single_agent_graph", - "multi_agent": "construct_multi_agent_graph", - "python_relp": "construct_relp_graph", - "graspa": "construct_graspa_graph", - "mock_agent": "construct_mock_agent_graph", - "single_agent_mcp": "construct_single_agent_mcp_graph", - "graspa_mcp": "construct_graspa_mcp_graph", - } - - constructor_attr = mapping[workflow_type] - - # Patch the graph constructor monkeypatch.setattr(f"chemgraph.agent.llm_agent.{constructor_attr}", fake_constructor) monkeypatch.setattr( "chemgraph.agent.llm_agent.load_openai_model", - lambda **kwargs: "FAKE_LLM", + lambda **_kwargs: "FAKE_LLM", ) - # Set up inputs - test_tools = ["DUMMY_TOOL"] - kwargs = {"tools": test_tools, "data_tools": test_tools} if "_mcp" in workflow_type else {} - - # Initialize - cg = ChemGraph(model_name="gpt-4o-mini", workflow_type=workflow_type, **kwargs) + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type=workflow_type, + enable_memory=False, + log_dir=str(tmp_path / "logs"), + **kwargs, + ) - # Assertions assert cg.workflow == f"WORKFLOW-SENTINEL-{workflow_type}" - - # Check if LLM was passed as the first positional arg or a keyword arg - args = called_data.get("args", []) - kwargs_called = called_data.get("kwargs", {}) - - llm_passed = (len(args) > 0 and args[0] == "FAKE_LLM") or (kwargs_called.get("llm") == "FAKE_LLM") - assert llm_passed, f"LLM not passed to {workflow_type} constructor" - - # Specific check for MCP tool passing - if workflow_type == "graspa_mcp": - assert kwargs_called.get("executor_tools") == test_tools \ No newline at end of file + args = called.get("args", ()) + constructor_kwargs = called.get("kwargs", {}) + assert (args and args[0] == "FAKE_LLM") or constructor_kwargs.get("llm") == "FAKE_LLM" + + +@pytest.mark.parametrize( + ("workflow_type", "kwargs", "expected_extra_tools", "expected_prompt"), + [ + ("single_agent", {"tools": [_DummyTool("custom")]}, [], None), + ("python_relp", {"tools": [_DummyTool("custom")]}, ["python_repl", "calculator"], None), + ("graspa", {"tools": [_DummyTool("custom")]}, ["run_graspa"], None), + ( + "mock_agent", + {"tools": [_DummyTool("custom")]}, + [ + "file_to_atomsdata", + "smiles_to_atomsdata", + "run_ase", + "molecule_name_to_smiles", + "save_atomsdata_to_file", + "calculator", + ], + None, + ), + ( + "single_agent_mcp", + {"tools": [_DummyTool("mcp_tool")], "data_tools": [_DummyTool("data_tool")]}, + ["data_tool"], + None, + ), + ( + "rag_agent", + {"tools": [_DummyTool("custom")]}, + [ + "load_document", + "query_knowledge_base", + "file_to_atomsdata", + "smiles_to_coordinate_file", + "run_ase", + "molecule_name_to_smiles", + "save_atomsdata_to_file", + "calculator", + ], + llm_agent.rag_agent_prompt, + ), + ( + "single_agent_xanes", + {"tools": [_DummyTool("custom")]}, + [ + "molecule_name_to_smiles", + "smiles_to_coordinate_file", + "run_ase", + "run_xanes", + "fetch_xanes_data", + "plot_xanes_data", + ], + llm_agent.default_xanes_single_agent_prompt, + ), + ], +) +@pytest.mark.asyncio +async def test_run_turn_workflow_tool_and_prompt_wiring( + monkeypatch, + tmp_path, + workflow_type, + kwargs, + expected_extra_tools, + expected_prompt, +): + captured = {} + + async def fake_run_turn(**run_kwargs): + captured.update(run_kwargs) + return TurnResult( + final_text="done", + state={"messages": [AIMessage(content="done")]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=run_kwargs["thread_id"], + duration_s=0.0, + ) + + monkeypatch.setattr("chemgraph.agent.llm_agent.run_turn", fake_run_turn) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: "FAKE_LLM", + ) + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type=workflow_type, + enable_memory=False, + log_dir=str(tmp_path / "logs"), + **kwargs, + ) + response = await cg.run("hello", config={"thread_id": "test-thread"}) + + assert response.content == "done" + tool_names = _tool_names(captured["tools"]) + assert tool_names[0] == list(kwargs["tools"])[0].name + for name in expected_extra_tools: + assert name in tool_names + if expected_prompt is not None: + assert captured["system_prompt"] == expected_prompt + + +def test_single_agent_initialization_injects_calculator_availability(monkeypatch, tmp_path): + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: "FAKE_LLM", + ) + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + ) + + assert "Calculator availability detected during ChemGraph initialization" in cg.system_prompt + assert cg.default_calculator in cg.system_prompt + assert cg.default_calculator in cg.available_calculators diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index 8d46339d..70aed013 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -10,13 +10,17 @@ def mock_llm(): return Mock() -def test_chemgraph_initialization(): +def test_chemgraph_initialization(tmp_path): with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load: mock_load.return_value = Mock() - agent = ChemGraph(model_name="gpt-4o-mini") + agent = ChemGraph( + model_name="gpt-4o-mini", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + ) assert hasattr(agent, "workflow") -def test_agent_query(mock_llm): +def test_agent_query(mock_llm, tmp_path): with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load: # Set up the mock chain mock_chain = Mock() @@ -24,7 +28,11 @@ def test_agent_query(mock_llm): mock_llm.bind_tools.return_value = mock_chain mock_load.return_value = mock_llm - agent = ChemGraph(model_name="gpt-4o-mini") + agent = ChemGraph( + model_name="gpt-4o-mini", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + ) response = asyncio.run(agent.run("What is the SMILES string for water?")) assert isinstance(response, AIMessage) assert response.content == "Test response" From 5e64525a9967a15e87cd3539d1d7f4660162d7dd Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 12:36:54 -0500 Subject: [PATCH 063/119] fix(academy): avoid stdlib shadowing in profiles --- src/chemgraph/academy/runtime/profiles/aurora.template.json | 3 +-- src/chemgraph/academy/runtime/profiles/polaris.template.json | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/academy/runtime/profiles/aurora.template.json b/src/chemgraph/academy/runtime/profiles/aurora.template.json index db59c939..3d6404de 100644 --- a/src/chemgraph/academy/runtime/profiles/aurora.template.json +++ b/src/chemgraph/academy/runtime/profiles/aurora.template.json @@ -14,8 +14,7 @@ "redis_protected_mode": "no", "mpiexec": "mpiexec", "pythonpath_entries": [ - "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src", - "/flare/${ALCF_PROJECT}/${ALCF_USER}/academy" + "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src" ], "path_entries": [ "/flare/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", diff --git a/src/chemgraph/academy/runtime/profiles/polaris.template.json b/src/chemgraph/academy/runtime/profiles/polaris.template.json index 0737fefc..c1cf3dc1 100644 --- a/src/chemgraph/academy/runtime/profiles/polaris.template.json +++ b/src/chemgraph/academy/runtime/profiles/polaris.template.json @@ -14,8 +14,7 @@ "redis_protected_mode": "no", "mpiexec": "mpiexec", "pythonpath_entries": [ - "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src", - "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy" + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src" ], "path_entries": [ "/eagle/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", From 23dab005f592798c58233a43074927cb38e8d05c Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 13:15:15 -0500 Subject: [PATCH 064/119] fix(dashboard): show daemon workflow turns --- src/chemgraph/academy/dashboard/static/app.js | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/academy/dashboard/static/app.js b/src/chemgraph/academy/dashboard/static/app.js index 0237d2e8..1796c0e8 100644 --- a/src/chemgraph/academy/dashboard/static/app.js +++ b/src/chemgraph/academy/dashboard/static/app.js @@ -2602,7 +2602,8 @@ } function isWorkflowEvent(event) { - return [ + const p = event.payload || {}; + const workflowNames = [ 'run_started', 'run_finished', 'workflow_started', @@ -2614,7 +2615,18 @@ 'tool_call_started', 'tool_call_finished', 'tool_call_failed', - ].includes(event.event) && Boolean(event.payload?.nested || event.payload?.runtime); + ]; + if (!workflowNames.includes(event.event)) return false; + return Boolean( + p.nested + || p.runtime + || p.workflow_type + || p.workflow_span_id + || p.span_id + || p.parent_span_id + || p.thread_id + || (event.agent_id && p.round !== undefined && p.round !== null) + ); } function workflowSpanId(event) { @@ -2635,10 +2647,18 @@ function workflowRootSpanId(event) { const p = event.payload || {}; if (p.workflow_span_id) return p.workflow_span_id; + if (p.thread_id) return p.thread_id; if ((event.event === 'workflow_started' || event.event === 'workflow_finished') && p.span_id) { return p.span_id; } - return p.parent_span_id || p.span_id || event.correlation_id || null; + if (p.parent_span_id || p.span_id || event.correlation_id) { + return p.parent_span_id || p.span_id || event.correlation_id; + } + const agentId = workflowAgentId(event); + if (agentId && p.round !== undefined && p.round !== null) { + return `${agentId}-round-${p.round}`; + } + return null; } function workflowEventsForSpan(spanId) { From c529682e11116a04eb6d219d11b2972263495eec Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 13:51:41 -0500 Subject: [PATCH 065/119] refactor(models): share LLM endpoint settings --- src/chemgraph/academy/core/__init__.py | 4 - src/chemgraph/academy/core/agent.py | 2 +- src/chemgraph/academy/core/lm.py | 71 -------- src/chemgraph/academy/core/turn.py | 2 +- .../academy/observability/run_artifacts.py | 2 +- src/chemgraph/academy/runtime/daemon.py | 4 +- src/chemgraph/agent/llm_agent.py | 47 ++---- src/chemgraph/models/loader.py | 22 ++- src/chemgraph/models/settings.py | 157 ++++++++++++++++++ tests/test_academy_reasoning_phase2.py | 2 +- tests/test_llm_agent.py | 7 +- 11 files changed, 200 insertions(+), 120 deletions(-) delete mode 100644 src/chemgraph/academy/core/lm.py create mode 100644 src/chemgraph/models/settings.py diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py index 5a6cc6b0..bf12083b 100644 --- a/src/chemgraph/academy/core/__init__.py +++ b/src/chemgraph/academy/core/__init__.py @@ -8,8 +8,6 @@ from chemgraph.academy.core.campaign import ToolSpec from chemgraph.academy.core.campaign import load_campaign from chemgraph.academy.core.campaign import resolve_campaign_resources -from chemgraph.academy.core.lm import LLMSettings -from chemgraph.academy.core.lm import load_lm_config from chemgraph.academy.core.prompt import PromptProfile from chemgraph.academy.core.prompt import load_prompt_profile from chemgraph.academy.core.turn import ReasoningTurnResult @@ -20,13 +18,11 @@ "ChemGraphCampaign", "ChemGraphDaemonConfig", "ChemGraphLogicalAgent", - "LLMSettings", "PromptProfile", "ReasoningTurnResult", "ResourceSpec", "ToolSpec", "load_campaign", - "load_lm_config", "load_prompt_profile", "resolve_campaign_resources", "run_academy_turn", diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py index 94e03afd..f28ae9be 100644 --- a/src/chemgraph/academy/core/agent.py +++ b/src/chemgraph/academy/core/agent.py @@ -23,8 +23,8 @@ from chemgraph.academy.core.turn import run_academy_turn from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign -from chemgraph.academy.core.lm import LLMSettings from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.models.settings import LLMSettings class ChemGraphLogicalAgent(Agent): diff --git a/src/chemgraph/academy/core/lm.py b/src/chemgraph/academy/core/lm.py deleted file mode 100644 index 52f886c0..00000000 --- a/src/chemgraph/academy/core/lm.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -import dataclasses -import json -from pathlib import Path -from typing import Any - - -@dataclasses.dataclass(frozen=True) -class LLMSettings: - """Configuration for an OpenAI-compatible chat-completions endpoint.""" - - base_url: str - model: str - provider: str - timeout_s: float - temperature: float - max_tokens: int - max_retries: int - retry_delay_s: float - api_key: str | None = None - user: str | None = None - - -def load_lm_config(path: str | Path) -> LLMSettings: - """Load LM settings from a JSON config file.""" - config_path = Path(path) - data = json.loads(config_path.read_text(encoding="utf-8")) - if not isinstance(data, dict): - raise ValueError(f"LM config must be a JSON object: {config_path}") - return _settings_from_mapping(data, source=str(config_path)) - - -def _settings_from_mapping(data: dict[str, Any], *, source: str) -> LLMSettings: - required = ( - "base_url", - "model", - "provider", - "timeout_s", - "temperature", - "max_tokens", - "max_retries", - "retry_delay_s", - ) - missing = [name for name in required if data.get(name) is None] - if missing: - raise ValueError(f"LM config {source} is missing required keys: {missing}") - - provider = str(data["provider"]) - if provider != "openai_compatible_tools": - raise ValueError( - f"LM config {source} provider must be 'openai_compatible_tools'", - ) - if not data.get("api_key"): - raise ValueError( - f"LM config {source} requires api_key; use 'dummy' for Argo shim " - "routes that do not require auth", - ) - - return LLMSettings( - base_url=str(data["base_url"]), - model=str(data["model"]), - provider=provider, - api_key=str(data["api_key"]), - user=str(data["user"]) if data.get("user") else None, - timeout_s=float(data["timeout_s"]), - temperature=float(data["temperature"]), - max_tokens=int(data["max_tokens"]), - max_retries=int(data["max_retries"]), - retry_delay_s=float(data["retry_delay_s"]), - ) diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index 24d4baa2..8a61bdaa 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -10,10 +10,10 @@ from langchain_core.tools import BaseTool from chemgraph.academy.core.campaign import ChemGraphAgentSpec, ChemGraphCampaign from chemgraph.academy.core.campaign import visible_resources_payload -from chemgraph.academy.core.lm import LLMSettings from chemgraph.academy.core.prompt import PromptProfile from chemgraph.academy.observability.run_files import read_json_file from chemgraph.agent.llm_agent import run_turn +from chemgraph.models.settings import LLMSettings TraceFn = Callable[[str, dict[str, Any]], None] ACTION_TOOL_NAMES = frozenset({"send_message", "ask_peer", "submit_result", "finish_turn"}) diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py index e884de3b..b2e4fdb5 100644 --- a/src/chemgraph/academy/observability/run_artifacts.py +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -17,7 +17,7 @@ from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig from chemgraph.academy.runtime.mpi import append_system_trace -from chemgraph.academy.core.lm import LLMSettings +from chemgraph.models.settings import LLMSettings def write_run_artifacts(run_dir: str | pathlib.Path) -> dict[str, Any]: diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index cf7735db..751d5925 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -31,8 +31,8 @@ from chemgraph.academy.runtime.mpi import placement_payload from chemgraph.academy.runtime.mpi import rank_from_env from chemgraph.academy.core.agent import ChemGraphLogicalAgent -from chemgraph.academy.core.lm import load_lm_config from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.models.settings import load_lm_settings from chemgraph.mcp.fastmcp_client import ( FastMCPExecutionConfig, build_fastmcp_tool_invoker, @@ -41,7 +41,7 @@ async def run_daemon(config: ChemGraphDaemonConfig) -> int: config.run_dir.mkdir(parents=True, exist_ok=True) - llm_settings = load_lm_config(config.lm_config) + llm_settings = load_lm_settings(config.lm_config) campaign = resolve_campaign_resources( load_campaign(config.campaign_config), config.run_dir, diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 10a18902..7791ed0e 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -14,6 +14,7 @@ from chemgraph.models.anthropic import load_anthropic_model from chemgraph.models.gemini import load_gemini_model from chemgraph.models.groq import load_groq_model +from chemgraph.models.loader import load_chat_model from chemgraph.models.supported_models import ( supported_openai_models, supported_ollama_models, @@ -23,6 +24,7 @@ supported_gemini_models, ) +from chemgraph.models.settings import LLMSettings from chemgraph.schemas.ase_input import ( get_available_calculator_names, get_calculator_selection_context, @@ -485,41 +487,18 @@ def _load_turn_llm( argo_user: str | None, ) -> Any: temperature = 0.0 - if model_name in supported_openai_models or model_name in supported_argo_models: - kwargs = { - "model_name": model_name, - "temperature": temperature, - "base_url": base_url, - } - if argo_user is not None: - kwargs["argo_user"] = argo_user - return load_openai_model(**kwargs) - if model_name in supported_ollama_models: - return load_ollama_model(model_name=model_name, temperature=temperature) - if model_name in supported_alcf_models: - return load_alcf_model( - model_name=model_name, - base_url=base_url, - api_key=api_key, - ) - if model_name in supported_anthropic_models: - return load_anthropic_model( - model_name=model_name, - api_key=api_key, - temperature=temperature, - ) - if model_name in supported_gemini_models: - return load_gemini_model( - model_name=model_name, - api_key=api_key, - temperature=temperature, - ) - if model_name.startswith("groq:"): - return load_groq_model( - model_name=model_name, - api_key=api_key, - temperature=temperature, + try: + return load_chat_model( + settings=LLMSettings( + model=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + temperature=temperature, + ), ) + except ValueError: + pass endpoint = os.getenv("VLLM_BASE_URL", base_url or "") key = os.getenv("OPENAI_API_KEY", api_key or "dummy_vllm_key") diff --git a/src/chemgraph/models/loader.py b/src/chemgraph/models/loader.py index 2e2968d2..64f0f105 100644 --- a/src/chemgraph/models/loader.py +++ b/src/chemgraph/models/loader.py @@ -14,29 +14,31 @@ from chemgraph.models.groq import load_groq_model from chemgraph.models.local_model import load_ollama_model from chemgraph.models.openai import load_openai_model +from chemgraph.models.settings import LLMSettings from chemgraph.models.supported_models import ( supported_alcf_models, supported_anthropic_models, supported_argo_models, supported_gemini_models, - supported_ollama_models, supported_openai_models, ) def load_chat_model( - model_name: str, + model_name: str | None = None, temperature: float = 0.0, base_url: Optional[str] = None, api_key: Optional[str] = None, argo_user: Optional[str] = None, + *, + settings: LLMSettings | None = None, ): """Load a LangChain chat model by provider auto-detection. Parameters ---------- - model_name : str + model_name : str, optional Model name from any supported provider list. temperature : float Sampling temperature (default 0.0 for deterministic output). @@ -46,6 +48,9 @@ def load_chat_model( API key override (falls back to environment variables). argo_user : str, optional Argo user identifier. + settings : LLMSettings, optional + Canonical endpoint settings. When provided, this overrides + model_name/base_url/api_key/argo_user. Returns ------- @@ -57,6 +62,17 @@ def load_chat_model( ValueError If the model name is not found in any supported provider list. """ + if settings is not None: + model_name = settings.model + base_url = settings.base_url + api_key = settings.api_key + argo_user = settings.argo_user + if settings.temperature is not None: + temperature = settings.temperature + + if model_name is None: + raise ValueError("load_chat_model requires model_name or settings") + if model_name in supported_openai_models or model_name in supported_argo_models: kwargs = { "model_name": model_name, diff --git a/src/chemgraph/models/settings.py b/src/chemgraph/models/settings.py new file mode 100644 index 00000000..e24951bf --- /dev/null +++ b/src/chemgraph/models/settings.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import dataclasses +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore[no-redef] + + +@dataclasses.dataclass(frozen=True, init=False) +class LLMSettings: + """Fully resolved description of one LLM endpoint.""" + + model: str + base_url: str | None = None + api_key: str | None = None + argo_user: str | None = None + provider: str | None = None + timeout_s: float | None = None + temperature: float | None = None + max_tokens: int | None = None + max_retries: int | None = None + retry_delay_s: float | None = None + + def __init__( + self, + model: str, + base_url: str | None = None, + api_key: str | None = None, + argo_user: str | None = None, + provider: str | None = None, + timeout_s: float | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + max_retries: int | None = None, + retry_delay_s: float | None = None, + user: str | None = None, + ) -> None: + object.__setattr__(self, "model", model) + object.__setattr__(self, "base_url", base_url) + object.__setattr__(self, "api_key", api_key) + object.__setattr__(self, "argo_user", argo_user or user) + object.__setattr__(self, "provider", provider) + object.__setattr__(self, "timeout_s", timeout_s) + object.__setattr__(self, "temperature", temperature) + object.__setattr__(self, "max_tokens", max_tokens) + object.__setattr__(self, "max_retries", max_retries) + object.__setattr__(self, "retry_delay_s", retry_delay_s) + + @property + def user(self) -> str | None: + """Backward-compatible academy name for Argo user metadata.""" + return self.argo_user + + +def load_lm_settings(source: str | Path | Mapping[str, Any]) -> LLMSettings: + """Build LLMSettings from a JSON file, TOML file, or already-parsed dict.""" + if isinstance(source, Mapping): + return _from_mapping(source) + + path = Path(source) + text = path.read_text(encoding="utf-8") + if path.suffix.lower() == ".toml": + raw = tomllib.loads(text) + return _from_mapping(_extract_endpoint_from_cli_toml(raw)) + return _from_mapping(json.loads(text)) + + +def _from_mapping(data: Mapping[str, Any]) -> LLMSettings: + if not isinstance(data, Mapping): + raise ValueError("LM config must be a mapping/object") + + model = data.get("model") or data.get("model_name") + if not isinstance(model, str) or not model: + raise ValueError("LM config requires a non-empty 'model' field") + + provider = data.get("provider") + if provider is not None and provider != "openai_compatible_tools": + raise ValueError( + "LM config 'provider' must be 'openai_compatible_tools' or absent", + ) + + api_key = data.get("api_key") + if provider == "openai_compatible_tools" and not api_key: + raise ValueError( + "openai_compatible_tools provider requires api_key " + "(use 'dummy' for Argo shim routes that ignore auth)", + ) + + return LLMSettings( + model=str(model), + base_url=_str_or_none(data.get("base_url")), + api_key=_str_or_none(api_key), + argo_user=_str_or_none(data.get("user") or data.get("argo_user")), + provider=_str_or_none(provider), + timeout_s=_float_or_none(data.get("timeout_s")), + temperature=_float_or_none(data.get("temperature")), + max_tokens=_int_or_none(data.get("max_tokens")), + max_retries=_int_or_none(data.get("max_retries")), + retry_delay_s=_float_or_none(data.get("retry_delay_s")), + ) + + +def _extract_endpoint_from_cli_toml(raw: Mapping[str, Any]) -> dict[str, Any]: + """Pull LLM endpoint fields out of the CLI's nested TOML structure.""" + general = raw.get("general") or {} + api = raw.get("api") or {} + model = general.get("model") + argo_user = general.get("argo_user") or (api.get("argo") or {}).get("user") + + base_url = None + if isinstance(model, str): + if model.startswith("argo:"): + base_url = (api.get("argo") or {}).get("base_url") + else: + for section_name in ("openai", "anthropic", "gemini", "alcf", "ollama"): + section = api.get(section_name) or {} + if section.get("base_url"): + base_url = section["base_url"] + break + + return { + "model": model, + "base_url": base_url, + "argo_user": argo_user, + "api_key": (api.get(_provider_section_for(model)) or {}).get("api_key"), + } + + +def _provider_section_for(model: Any) -> str: + if isinstance(model, str): + if model.startswith("argo:"): + return "argo" + if model.startswith("groq:"): + return "groq" + return "openai" + + +def _str_or_none(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value or None + return str(value) or None + + +def _float_or_none(value: Any) -> float | None: + return None if value is None else float(value) + + +def _int_or_none(value: Any) -> int | None: + return None if value is None else int(value) diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index fa140c4a..e4180474 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -13,11 +13,11 @@ from chemgraph.academy.core.agent import ChemGraphLogicalAgent from chemgraph.academy.core.campaign import ChemGraphAgentSpec, ChemGraphCampaign from chemgraph.academy.core.campaign import ResourceSpec, resolve_campaign_resources -from chemgraph.academy.core.lm import LLMSettings from chemgraph.academy.core.prompt import PromptProfile, PromptStateLimits from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools from chemgraph.academy.core.turn import ReasoningTurnResult, build_peer_status from chemgraph.agent.llm_agent import TurnResult +from chemgraph.models.settings import LLMSettings def _agent_spec() -> ChemGraphAgentSpec: diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index 70aed013..32b72b87 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -21,12 +21,15 @@ def test_chemgraph_initialization(tmp_path): assert hasattr(agent, "workflow") def test_agent_query(mock_llm, tmp_path): - with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load: + with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_init_load, patch( + "chemgraph.models.loader.load_openai_model" + ) as mock_turn_load: # Set up the mock chain mock_chain = Mock() mock_chain.invoke.return_value = AIMessage(content="Test response") mock_llm.bind_tools.return_value = mock_chain - mock_load.return_value = mock_llm + mock_init_load.return_value = mock_llm + mock_turn_load.return_value = mock_llm agent = ChemGraph( model_name="gpt-4o-mini", From f67b3401c18ba1aaba430e1e7f777fe6368fb4a0 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 14:07:45 -0500 Subject: [PATCH 066/119] chore(academy): remove stale built-in listings --- src/chemgraph/academy/examples/__init__.py | 5 ----- src/chemgraph/cli/main.py | 13 +------------ 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/src/chemgraph/academy/examples/__init__.py b/src/chemgraph/academy/examples/__init__.py index 3f376834..52a4ba45 100644 --- a/src/chemgraph/academy/examples/__init__.py +++ b/src/chemgraph/academy/examples/__init__.py @@ -12,7 +12,6 @@ } BUILTIN_LM_CONFIG_TEMPLATES = { - 'argo-gpt54-template': f'{EXAMPLE_002}/lm_config.template.json', 'argo-gpt54-mace-template': f'{EXAMPLE_002}/lm_config.template.json', } @@ -63,10 +62,6 @@ def list_builtin_campaigns() -> list[str]: return sorted(BUILTIN_CAMPAIGNS) -def list_builtin_lm_config_templates() -> list[str]: - return sorted(BUILTIN_LM_CONFIG_TEMPLATES) - - def campaign_launch_defaults(campaign: str) -> CampaignLaunchDefaults: try: return BUILTIN_CAMPAIGN_LAUNCH_DEFAULTS[campaign] diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index 7e23b346..94af80fc 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -289,10 +289,6 @@ def create_argument_parser() -> argparse.ArgumentParser: "campaigns", help="List built-in ChemGraph Academy campaign specs.", ) - academy_sub.add_parser( - "logical-agent-configs", - help="List built-in ChemGraph Academy logical-agent prompt configs.", - ) # ---- Legacy fallback args ------------------------------------------- # Also add run args to the top-level parser so that @@ -589,16 +585,9 @@ def _handle_academy(args: argparse.Namespace) -> None: for name in list_builtin_campaigns(): console.print(name) return - if command == "logical-agent-configs": - from chemgraph.academy.examples import list_builtin_logical_agent_configs - - for name in list_builtin_logical_agent_configs(): - console.print(name) - return console.print( "Usage: chemgraph academy " - "{mpi-daemon,run-compute,dashboard,campaigns," - "logical-agent-configs}.", + "{mpi-daemon,run-compute,dashboard,campaigns}.", ) From 2cee93a9d082cffc73f6b8858e625217f8b08478 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 14:12:24 -0500 Subject: [PATCH 067/119] chore(academy): remove console script entrypoints --- pyproject.toml | 3 --- src/chemgraph/academy/runtime/dashboard_launcher.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e667eab2..842d495c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,9 +96,6 @@ rag = [ [project.scripts] chemgraph = "chemgraph.cli:main" chemgraph-eval = "chemgraph.eval.cli:main" -chemgraph-academy-run = "chemgraph.academy.runtime.compute_launcher:main" -chemgraph-academy-dashboard = "chemgraph.academy.runtime.dashboard_launcher:main" -chemgraph-dashboard = "chemgraph.academy.dashboard:main" [tool.setuptools.packages.find] where = ["src/"] diff --git a/src/chemgraph/academy/runtime/dashboard_launcher.py b/src/chemgraph/academy/runtime/dashboard_launcher.py index f9735871..b38ac733 100644 --- a/src/chemgraph/academy/runtime/dashboard_launcher.py +++ b/src/chemgraph/academy/runtime/dashboard_launcher.py @@ -107,7 +107,7 @@ def loop() -> None: def compute_lines(profile: SystemProfile, wrapper_path: str, run_id: str, campaign: str) -> list[str]: lines = [" module use /soft/modulefiles", " module load conda", " conda activate base"] if profile.name == "polaris" else [" module load frameworks"] - return lines + [f" source {profile.remote_root}/venvs/academy-swarm/bin/activate", f" export PATH={profile.remote_root}/bin:$PATH", " chemgraph-academy-run \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}", "", "If PATH is not configured, use:", f" {wrapper_path} \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}"] + return lines + [f" source {profile.remote_root}/venvs/academy-swarm/bin/activate", f" export PATH={profile.remote_root}/bin:$PATH", " chemgraph academy run-compute \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}", "", "If PATH is not configured, use:", f" {wrapper_path} \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}"] def main() -> int: args = parse_args() @@ -162,7 +162,7 @@ def main() -> int: relay_host = wait_relay(profile, remote_host, control_path, relay_port, relay_process, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log")) lm_base_url = f"http://{relay_host}:{relay_port}/argoapi/v1" if relay_host else str(args.lm_base_url) print(f"Compute-node LM URL: {lm_base_url}", flush=True) - metadata = {"created_at": time.time(), "created_by": "chemgraph-academy-dashboard", "run_id": args.run_id, "system": profile.name, "campaign": args.campaign, "remote_run_dir": remote_run_dir, "remote_host": remote_host, "lm_connect": args.lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, "academy_repo_root": profile.academy_repo_root, "chemgraph_repo_root": profile.repo_root} + metadata = {"created_at": time.time(), "created_by": "chemgraph academy dashboard", "run_id": args.run_id, "system": profile.name, "campaign": args.campaign, "remote_run_dir": remote_run_dir, "remote_host": remote_host, "lm_connect": args.lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, "academy_repo_root": profile.academy_repo_root, "chemgraph_repo_root": profile.repo_root} if relay_host: metadata.update({"relay_host": relay_host, "relay_port": relay_port}) print(f"Writing run metadata: {remote_host}:{remote_run_dir}/dashboard_metadata.json", flush=True) From 7e5f91bc0fda04ceb51e43a6a90e3d7672312d43 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 15:53:51 -0500 Subject: [PATCH 068/119] refactor(academy): launch campaign MCP tools as servers --- pyproject.toml | 1 + src/chemgraph/academy/__init__.py | 4 +- src/chemgraph/academy/core/__init__.py | 4 +- src/chemgraph/academy/core/agent.py | 14 +- src/chemgraph/academy/core/campaign.py | 91 ++--- src/chemgraph/academy/core/tools.py | 103 +---- src/chemgraph/academy/core/turn.py | 9 +- .../campaign.jsonc | 49 +-- .../academy/observability/run_artifacts.py | 4 +- src/chemgraph/academy/runtime/daemon.py | 272 +++++++------ .../academy/runtime/mcp_supervisor.py | 274 +++++++++++++ src/chemgraph/mcp/fastmcp_client.py | 360 ------------------ tests/test_academy_campaign.py | 94 ++++- tests/test_academy_mcp_supervisor.py | 125 ++++++ tests/test_academy_reasoning_phase2.py | 5 +- tests/test_tool_adapter_validation.py | 3 +- 16 files changed, 710 insertions(+), 702 deletions(-) create mode 100644 src/chemgraph/academy/runtime/mcp_supervisor.py delete mode 100644 src/chemgraph/mcp/fastmcp_client.py create mode 100644 tests/test_academy_mcp_supervisor.py diff --git a/pyproject.toml b/pyproject.toml index 842d495c..7ea2ed94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ globus_compute = [ ] academy = [ "academy-py", + "httpx", "redis", ] xanes = [ diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py index a5ad2313..fd0a6c4c 100644 --- a/src/chemgraph/academy/__init__.py +++ b/src/chemgraph/academy/__init__.py @@ -12,8 +12,8 @@ from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import MCPServerSpec from chemgraph.academy.core.campaign import ResourceSpec -from chemgraph.academy.core.campaign import ToolSpec from chemgraph.academy.core.campaign import load_campaign from chemgraph.academy.core.campaign import resolve_campaign_resources from chemgraph.academy.observability.event_log import CampaignEvent @@ -28,11 +28,11 @@ "ChemGraphCampaign", "ChemGraphDaemonConfig", "EventLog", + "MCPServerSpec", "PromptProfile", "ResourceSpec", "ChemGraphLogicalAgent", "load_campaign", "load_prompt_profile", "resolve_campaign_resources", - "ToolSpec", ] diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py index bf12083b..b47c1c18 100644 --- a/src/chemgraph/academy/core/__init__.py +++ b/src/chemgraph/academy/core/__init__.py @@ -4,8 +4,8 @@ from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import MCPServerSpec from chemgraph.academy.core.campaign import ResourceSpec -from chemgraph.academy.core.campaign import ToolSpec from chemgraph.academy.core.campaign import load_campaign from chemgraph.academy.core.campaign import resolve_campaign_resources from chemgraph.academy.core.prompt import PromptProfile @@ -18,10 +18,10 @@ "ChemGraphCampaign", "ChemGraphDaemonConfig", "ChemGraphLogicalAgent", + "MCPServerSpec", "PromptProfile", "ReasoningTurnResult", "ResourceSpec", - "ToolSpec", "load_campaign", "load_prompt_profile", "resolve_campaign_resources", diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py index f28ae9be..6f2c81ca 100644 --- a/src/chemgraph/academy/core/agent.py +++ b/src/chemgraph/academy/core/agent.py @@ -4,7 +4,7 @@ import asyncio import time -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from pathlib import Path from typing import Any @@ -12,10 +12,8 @@ from academy.agent import loop from academy.handle import Handle from academy.identifier import AgentId +from langchain_core.tools import BaseTool -from chemgraph.mcp.fastmcp_client import ( - FastMCPToolInvoker, -) from chemgraph.academy.core.peer_protocol import validate_message from chemgraph.academy.observability.event_log import EventLog from chemgraph.academy.observability.run_artifacts import write_status_snapshot @@ -39,7 +37,7 @@ def __init__( prompt_profile: PromptProfile, run_dir: Path, max_decisions: int, - tool_invoker: FastMCPToolInvoker, + external_tools: Sequence[BaseTool] = (), peer_agent_ids: Mapping[str, AgentId[Any]] | None = None, placement: dict[str, Any] | None = None, poll_timeout_s: float = 2.0, @@ -53,7 +51,7 @@ def __init__( self.prompt_profile = prompt_profile self.run_dir = run_dir self.max_decisions = max_decisions - self.tool_invoker = tool_invoker + self.external_tools = list(external_tools) self.peer_agent_ids = dict(peer_agent_ids or {}) self.placement = placement or {} self.poll_timeout_s = poll_timeout_s @@ -82,7 +80,7 @@ async def agent_on_startup(self) -> None: 'agent_started', { 'role': self.spec.role, - 'tool_names': list(self.spec.tool_names), + 'tool_names': [tool.name for tool in self.external_tools], 'allowed_peers': list(self.spec.allowed_peers), 'placement': self.placement, **self.placement, @@ -195,7 +193,7 @@ async def _reasoning_round(self) -> bool: tools = await build_chemgraph_reasoning_tools( spec=self.spec, run_dir=self.run_dir, - tool_invoker=self.tool_invoker, + external_tools=self.external_tools, peer_names=self.peer_names, peer_handles=self.peer_handles, outbox=self.outbox, diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py index baaa570b..4e514f1c 100644 --- a/src/chemgraph/academy/core/campaign.py +++ b/src/chemgraph/academy/core/campaign.py @@ -7,7 +7,7 @@ from typing import Any from chemgraph.academy.examples import resolve_builtin_campaign -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator _REMOVED_CAMPAIGN_FIELDS = frozenset( @@ -29,22 +29,28 @@ ) -class ToolSpec(BaseModel): - """Campaign-declared external tool available to agents.""" +class MCPServerSpec(BaseModel): + """Campaign-declared MCP server subprocess available to agents.""" model_config = ConfigDict(extra='forbid') - name: str - module: str - tool: str - description: str = '' + name: str = Field(min_length=1) + command: str = Field( + min_length=1, + description=( + "Shell command to launch the MCP server. Tokens after the first " + "are arguments. Do not include --transport/--host/--port; the " + "supervisor adds them." + ), + ) + env: dict[str, str] = Field(default_factory=dict) - @field_validator('name', 'module', 'tool') + @field_validator('name', 'command') @classmethod def _non_empty(cls, value: str) -> str: value = value.strip() if not value: - raise ValueError('tool spec fields must be non-empty strings') + raise ValueError('field must be non-empty') return value @@ -99,13 +105,9 @@ class ChemGraphAgentSpec: role: str mission: str allowed_peers: tuple[str, ...] - tools: tuple[ToolSpec, ...] + mcp_servers: tuple[str, ...] = () resources: tuple[str, ...] = () - @property - def tool_names(self) -> tuple[str, ...]: - return tuple(tool.name for tool in self.tools) - @dataclasses.dataclass(frozen=True) class ChemGraphCampaign: @@ -114,7 +116,7 @@ class ChemGraphCampaign: initial_agent: str prompt_profile: pathlib.Path agents: tuple[ChemGraphAgentSpec, ...] - tool_catalog: tuple[ToolSpec, ...] = () + mcp_servers: tuple[MCPServerSpec, ...] = () resources: Mapping[str, ResourceSpec] = dataclasses.field(default_factory=dict) @@ -223,7 +225,10 @@ def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: field_name='prompt_profile', ) - tool_catalog = _load_tool_catalog(data) + mcp_servers = tuple( + MCPServerSpec.model_validate(raw) + for raw in data.get('mcp_servers', ()) + ) resources = { name: _resolve_resource_spec(raw, campaign_path=path) for name, raw in dict(data.get('resources', {})).items() @@ -236,7 +241,7 @@ def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: role=item['role'], mission=item['mission'], allowed_peers=tuple(item.get('allowed_peers', ())), - tools=_load_declared_tools(item, tool_catalog), + mcp_servers=tuple(item.get('mcp_servers', ())), resources=tuple(item.get('resources', ())), ), ) @@ -246,7 +251,7 @@ def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: initial_agent=data.get('initial_agent', agents[0].name), prompt_profile=prompt_profile, agents=tuple(agents), - tool_catalog=tuple(tool_catalog.values()), + mcp_servers=mcp_servers, resources=resources, ) @@ -342,44 +347,6 @@ def _resolve_campaign_relative_path( return path.resolve() -def _load_tool_catalog(data: Mapping[str, Any]) -> dict[str, ToolSpec]: - catalog: dict[str, ToolSpec] = {} - for raw in data.get('tools', ()): - if not isinstance(raw, dict): - raise RuntimeError('campaign top-level tools[] entries must be objects') - spec = ToolSpec.model_validate(raw) - if spec.name in catalog: - raise RuntimeError(f'duplicate campaign tool name: {spec.name}') - catalog[spec.name] = spec - return catalog - - -def _load_declared_tools( - item: Mapping[str, Any], - catalog: Mapping[str, ToolSpec], -) -> tuple[ToolSpec, ...]: - raw_tools = item.get('tools') - if raw_tools is None: - raw_tools = item.get('tool_names', ()) - tools: list[ToolSpec] = [] - for raw in raw_tools: - if isinstance(raw, str): - try: - tools.append(catalog[raw]) - except KeyError as exc: - raise RuntimeError( - f'agent {item.get("name")!r} references unknown campaign tool {raw!r}; ' - 'declare it in top-level tools[] or inline as a FastMCP ToolSpec object', - ) from exc - elif isinstance(raw, dict): - tools.append(ToolSpec.model_validate(raw)) - else: - raise RuntimeError( - f'agent {item.get("name")!r} tools[] entries must be strings or objects', - ) - return tuple(tools) - - def validate_campaign(campaign: ChemGraphCampaign, agent_count: int) -> None: if len(campaign.agents) != agent_count: raise RuntimeError( @@ -393,6 +360,10 @@ def validate_campaign(campaign: ChemGraphCampaign, agent_count: int) -> None: raise RuntimeError( f'initial_agent {campaign.initial_agent!r} is not an agent', ) + server_names = [server.name for server in campaign.mcp_servers] + if len(set(server_names)) != len(server_names): + raise RuntimeError('campaign MCP server names must be unique') + declared_servers = set(server_names) for agent in campaign.agents: unknown = sorted(set(agent.allowed_peers).difference(names)) if unknown: @@ -401,9 +372,11 @@ def validate_campaign(campaign: ChemGraphCampaign, agent_count: int) -> None: ) if agent.name in agent.allowed_peers: raise RuntimeError(f'{agent.name} must not list itself as a peer') - tool_names = list(agent.tool_names) - if len(set(tool_names)) != len(tool_names): - raise RuntimeError(f'{agent.name} has duplicate tool declarations') + unknown_servers = sorted(set(agent.mcp_servers).difference(declared_servers)) + if unknown_servers: + raise RuntimeError( + f'{agent.name} references unknown MCP servers: {unknown_servers}', + ) unknown_resources = sorted(set(agent.resources).difference(campaign.resources)) if unknown_resources: raise RuntimeError( diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py index a03f5ec5..a636b4a5 100644 --- a/src/chemgraph/academy/core/tools.py +++ b/src/chemgraph/academy/core/tools.py @@ -2,12 +2,10 @@ from __future__ import annotations -import json import pathlib import time -import uuid import asyncio -from collections.abc import Callable, Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any from academy.handle import Handle @@ -15,9 +13,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError from chemgraph.academy.core.campaign import ChemGraphAgentSpec -from chemgraph.mcp.fastmcp_client import ToolInvocation -from chemgraph.mcp.fastmcp_client import fastmcp_tool_schemas -from chemgraph.mcp.fastmcp_client import FastMCPToolInvoker from chemgraph.academy.core.peer_protocol import build_message from chemgraph.academy.observability.run_files import append_jsonl @@ -104,29 +99,11 @@ def _disallowed_recipient_response( return {**payload, "status": "error"} -def _compact_for_lm(value: Any, *, max_chars: int = 4000) -> Any: - """Return a JSON-safe, size-bounded value for tool feedback.""" - try: - text = json.dumps(value, sort_keys=True) - except TypeError: - text = repr(value) - if len(text) <= max_chars: - try: - return json.loads(text) - except json.JSONDecodeError: - return text - return { - "truncated": True, - "preview": text[:max_chars], - "full_result_location": "tool_results.jsonl", - } - - async def build_chemgraph_reasoning_tools( *, spec: ChemGraphAgentSpec, run_dir: pathlib.Path, - tool_invoker: FastMCPToolInvoker, + external_tools: Sequence[BaseTool] = (), peer_names: tuple[str, ...], peer_handles: Mapping[str, Handle[Any]], outbox: list[dict[str, Any]], @@ -314,80 +291,6 @@ async def finish_turn(**kwargs: Any) -> dict[str, Any]: metadata={"chemgraph_academy_tool_kind": "action_tool"}, ), ] - - fastmcp_schemas = await fastmcp_tool_schemas(list(spec.tools)) - schema_by_name = { - schema["function"]["name"]: schema["function"] - for schema in fastmcp_schemas - if schema.get("type") == "function" - } - - for tool_spec in spec.tools: - function_schema = schema_by_name[tool_spec.name] - - async def run_fastmcp_tool( - __tool_name: str = tool_spec.name, - **kwargs: Any, - ) -> dict[str, Any]: - if __tool_name not in spec.tool_names: - raise RuntimeError( - f"{spec.name} cannot call unavailable tool {__tool_name}", - ) - tool_result_id = f"tool-{uuid.uuid4()}" - started = { - "tool_result_id": tool_result_id, - "tool_name": __tool_name, - "arguments": kwargs, - } - trace("tool_call_started", started) - result_record = await tool_invoker.invoke( - ToolInvocation( - tool_name=__tool_name, - arguments=kwargs, - agent_id=spec.name, - role=spec.role, - correlation_id=tool_result_id, - ), - ) - if result_record.status != "success": - failure = { - **started, - "status": "failed", - "error": result_record.error - or "tool returned non-success status", - } - append_jsonl(run_dir / "tool_results.jsonl", failure) - trace("tool_call_failed", failure) - raise RuntimeError(f"{__tool_name} failed: {failure['error']}") - - record = { - **started, - "timestamp": time.time(), - "agent_name": spec.name, - "status": "ok", - "result": result_record.result, - } - tool_results.append(record) - append_jsonl(run_dir / "tool_results.jsonl", record) - trace("tool_call_finished", record) - return { - "status": "ok", - "tool_result_id": tool_result_id, - "tool_name": __tool_name, - "result": _compact_for_lm(result_record.result), - } - - tools.append( - StructuredTool.from_function( - coroutine=run_fastmcp_tool, - name=tool_spec.name, - description=function_schema.get("description") - or tool_spec.description - or f"Run ChemGraph FastMCP tool {tool_spec.name}.", - args_schema=function_schema.get("parameters") - or {"type": "object", "properties": {}}, - metadata={"chemgraph_academy_tool_kind": "science_tool"}, - ), - ) + tools.extend(external_tools) return tools diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index 8a61bdaa..967efc2b 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -53,8 +53,11 @@ async def run_academy_turn( def on_event(event: str, payload: dict) -> None: trace(event, {"round": round_index, **payload}) + available_tool_names = tuple( + tool.name for tool in tools if tool.name not in ACTION_TOOL_NAMES + ) result = await run_turn( - query=json.dumps(_state(campaign, spec, prompt_profile, run_dir, max_decisions, round_index, received_message_history, outbox, tool_results, get_final_result, peer_names), sort_keys=True), + query=json.dumps(_state(campaign, spec, prompt_profile, run_dir, max_decisions, round_index, received_message_history, outbox, tool_results, get_final_result, peer_names, available_tool_names), sort_keys=True), tools=tools, model_name=llm_settings.model, base_url=llm_settings.base_url, @@ -82,7 +85,7 @@ def on_event(event: str, payload: dict) -> None: trace("chemgraph_reasoning_turn_finished", {"round": round_index, "thread_id": out.thread_id, "action_tools_called": list(action_tools), "science_tools_called": list(science_tools), "requested_finish": out.requested_finish, "requested_self_wake": out.requested_self_wake}) return out -def _state(campaign, spec, profile, run_dir, max_decisions, round_index, messages, outbox, results, get_final_result, peer_names) -> dict[str, Any]: +def _state(campaign, spec, profile, run_dir, max_decisions, round_index, messages, outbox, results, get_final_result, peer_names, available_tool_names) -> dict[str, Any]: limits = profile.state_limits return { "campaign": campaign.run_id, @@ -95,7 +98,7 @@ def _state(campaign, spec, profile, run_dir, max_decisions, round_index, message "resources": visible_resources_payload(campaign, spec), "allowed_peers": list(spec.allowed_peers), "peer_status": build_peer_status(run_dir=run_dir, peer_names=peer_names), - "available_chemgraph_tools": list(spec.tool_names), + "available_chemgraph_tools": list(available_tool_names), "received_messages": _tail(messages, limits.received_messages_last_n), "local_chemgraph_tool_results": _tail(results, limits.tool_results_last_n), "recent_actions": build_recent_actions(outbox=outbox, tool_results=results, limit=limits.actions_last_n), diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc index 6cac1d79..23fb9362 100644 --- a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc +++ b/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc @@ -39,6 +39,22 @@ "description": "Local MACE model file shipped with this campaign example." } }, + "mcp_servers": [ + // MCP server fields: + // command: launch command; runtime appends --transport/--host/--port. + { + "name": "general", + "command": "python -m chemgraph.mcp.mcp_tools" + }, + { + "name": "mace", + "command": "python -m chemgraph.mcp.mace_mcp_hpc" + }, + { + "name": "hpc_misc", + "command": "python -m chemgraph.mcp.hpc_misc_mcp" + } + ], "agents": [ { "name": "coordinator-agent", @@ -50,7 +66,7 @@ "mace-agent", "assessment-agent" ], - "tools": [], + "mcp_servers": [], "resources": [ "candidate_dataset", "structure_output_directory", @@ -63,7 +79,7 @@ "role": "MolecularStructureWorkerAgent", "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", "allowed_peers": ["coordinator-agent"], - "tools": ["smiles_to_coordinate_file"], + "mcp_servers": ["general"], "resources": [] }, { @@ -71,7 +87,7 @@ "role": "MolecularStructureWorkerAgent", "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", "allowed_peers": ["coordinator-agent"], - "tools": ["smiles_to_coordinate_file"], + "mcp_servers": ["general"], "resources": [] }, { @@ -79,7 +95,7 @@ "role": "MACEEnsembleAgent", "mission": "Run MACE only after a concrete request from coordinator-agent. Report started, completed, partial, or failed evidence back to coordinator-agent, including output paths and tool_result_ids; pending work is not a failure.", "allowed_peers": ["coordinator-agent"], - "tools": ["run_mace_ensemble", "inspect_json"], + "mcp_servers": ["mace", "hpc_misc"], "resources": ["mace_model_file"] }, { @@ -87,31 +103,8 @@ "role": "ScreeningAssessmentAgent", "mission": "Assess evidence received from coordinator-agent. Summarize structure coverage, MACE coverage, failures, ranking readiness, and pending work without treating pending MACE work as failure.", "allowed_peers": ["coordinator-agent"], - "tools": ["inspect_json"], + "mcp_servers": ["hpc_misc"], "resources": [] } - ], - "tools": [ - // Tool fields: - // module: Python module containing the ChemGraph FastMCP object. - // tool: concrete tool name exposed by that module. - { - "name": "smiles_to_coordinate_file", - "module": "chemgraph.mcp.mcp_tools", - "tool": "smiles_to_coordinate_file", - "description": "Convert one SMILES string to a generated XYZ coordinate file." - }, - { - "name": "run_mace_ensemble", - "module": "chemgraph.mcp.mace_mcp_hpc", - "tool": "run_mace_ensemble", - "description": "Run MACE calculations over every generated structure in a directory." - }, - { - "name": "inspect_json", - "module": "chemgraph.mcp.hpc_misc_mcp", - "tool": "inspect_json", - "description": "Inspect a JSON file, output directory, or missing expected JSON path and return compact summaries." - } ] } diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py index b2e4fdb5..a2d11980 100644 --- a/src/chemgraph/academy/observability/run_artifacts.py +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -332,8 +332,8 @@ def initialize_run_files( { 'agents': [spec.name for spec in campaign.agents], 'roles': {spec.name: spec.role for spec in campaign.agents}, - 'tool_names': { - spec.name: list(spec.tool_names) + 'mcp_servers': { + spec.name: list(spec.mcp_servers) for spec in campaign.agents }, }, diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index 751d5925..6b384a92 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -3,6 +3,7 @@ import argparse import asyncio import pathlib +import signal from academy.exchange.redis import RedisExchangeFactory from academy.handle import Handle @@ -32,11 +33,8 @@ from chemgraph.academy.runtime.mpi import rank_from_env from chemgraph.academy.core.agent import ChemGraphLogicalAgent from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.runtime.mcp_supervisor import MCPServerSupervisor from chemgraph.models.settings import load_lm_settings -from chemgraph.mcp.fastmcp_client import ( - FastMCPExecutionConfig, - build_fastmcp_tool_invoker, -) async def run_daemon(config: ChemGraphDaemonConfig) -> int: @@ -50,142 +48,150 @@ async def run_daemon(config: ChemGraphDaemonConfig) -> int: validate_campaign(campaign, config.agent_count) agent_spec = selected_agent(campaign, config.rank) placement = placement_payload(config, agent_spec.name) - - academy_factory = RedisExchangeFactory( - hostname=config.redis_host, - port=config.redis_port, + supervisor = MCPServerSupervisor( + specs=[ + spec + for spec in campaign.mcp_servers + if spec.name in agent_spec.mcp_servers + ], + run_dir=config.run_dir / f'rank{config.rank}', ) - if config.rank == 0: - initialize_run_files( - run_dir=config.run_dir, - campaign=campaign, - config=config, - llm_settings=llm_settings, - ) - registrar = await academy_factory.create_user_client( - name=f'{config.run_dir.name}-registrar', - start_listener=False, - ) - try: - registered = await registrar.register_agents( - [ - (ChemGraphLogicalAgent, spec.name) - for spec in campaign.agents - ], - ) - finally: - await registrar.close() - registrations = dict( - zip( - (spec.name for spec in campaign.agents), - registered, - strict=True, - ), - ) - write_academy_registrations( - run_dir=config.run_dir, - run_token=config.run_token, - registrations=registrations, - ) - else: - registrations = await wait_academy_registrations( - config.run_dir, - run_token=config.run_token, - timeout_s=config.startup_timeout_s, - ) - if config.rank == 0: - registrations = load_academy_registrations( - config.run_dir, - run_token=config.run_token, - ) - registration = registrations[agent_spec.name] - peer_agent_ids = { - peer: registrations[peer].agent_id - for peer in agent_spec.allowed_peers - if peer in registrations - } - - tool_invoker = await build_fastmcp_tool_invoker( - specs=list(agent_spec.tools), - execution=FastMCPExecutionConfig(backend='local', system='local'), - run_dir=config.run_dir, - agent_name=agent_spec.name, - ) - agent = ChemGraphLogicalAgent( - agent_spec, - campaign=campaign, - llm_settings=llm_settings, - prompt_profile=prompt_profile, - run_dir=config.run_dir, - max_decisions=config.max_decisions, - tool_invoker=tool_invoker, - peer_agent_ids=peer_agent_ids, - placement=placement, - poll_timeout_s=config.poll_timeout_s, - idle_timeout_s=config.idle_timeout_s, - status_interval_s=config.status_interval_s, - ) - runtime_config = RuntimeConfig( - terminate_on_success=False, - terminate_on_error=False, - ) - runtime = Runtime( - agent, - exchange_factory=academy_factory, - registration=registration, - config=runtime_config, - ) - async with runtime: - await agent.write_runtime_status() + try: + await supervisor.start_all() + external_tools = await supervisor.get_tools(agent_spec.mcp_servers) + academy_factory = RedisExchangeFactory( + hostname=config.redis_host, + port=config.redis_port, + ) if config.rank == 0: - bootstrap = build_message( - sender='campaign', - recipient=campaign.initial_agent, - content=campaign_bootstrap_text(campaign), - kind='message', - tldr='Campaign bootstrap', - reason='Initial campaign task dispatch.', - confidence=1.0, + initialize_run_files( + run_dir=config.run_dir, + campaign=campaign, + config=config, + llm_settings=llm_settings, ) - initial_handle: Handle[Any] = Handle( - registrations[campaign.initial_agent].agent_id, + registrar = await academy_factory.create_user_client( + name=f'{config.run_dir.name}-registrar', + start_listener=False, ) - await initial_handle.action( - 'receive_message', - bootstrap, + try: + registered = await registrar.register_agents( + [ + (ChemGraphLogicalAgent, spec.name) + for spec in campaign.agents + ], + ) + finally: + await registrar.close() + registrations = dict( + zip( + (spec.name for spec in campaign.agents), + registered, + strict=True, + ), ) - append_system_trace( + write_academy_registrations( + run_dir=config.run_dir, + run_token=config.run_token, + registrations=registrations, + ) + else: + registrations = await wait_academy_registrations( config.run_dir, - 'bootstrap_message_dispatched', - { - 'agent': campaign.initial_agent, - 'message_id': bootstrap['message_id'], - 'via': 'academy_action', - }, + run_token=config.run_token, + timeout_s=config.startup_timeout_s, ) - await runtime.wait_shutdown() + if config.rank == 0: + registrations = load_academy_registrations( + config.run_dir, + run_token=config.run_token, + ) + registration = registrations[agent_spec.name] + peer_agent_ids = { + peer: registrations[peer].agent_id + for peer in agent_spec.allowed_peers + if peer in registrations + } - if config.rank == 0: - all_done = await wait_for_agent_statuses_finished( - run_dir=config.run_dir, + agent = ChemGraphLogicalAgent( + agent_spec, campaign=campaign, - timeout_s=config.completion_timeout_s, - ) - append_system_trace( - config.run_dir, - 'campaign_finished', - {'all_agents_done': all_done}, - ) - write_status_snapshot( + llm_settings=llm_settings, + prompt_profile=prompt_profile, run_dir=config.run_dir, - campaign=campaign, - agent_state=await agent.report_state(), + max_decisions=config.max_decisions, + external_tools=external_tools, + peer_agent_ids=peer_agent_ids, placement=placement, + poll_timeout_s=config.poll_timeout_s, + idle_timeout_s=config.idle_timeout_s, + status_interval_s=config.status_interval_s, + ) + runtime_config = RuntimeConfig( + terminate_on_success=False, + terminate_on_error=False, + ) + runtime = Runtime( + agent, + exchange_factory=academy_factory, + registration=registration, + config=runtime_config, ) - return 0 + async with runtime: + await agent.write_runtime_status() + + if config.rank == 0: + bootstrap = build_message( + sender='campaign', + recipient=campaign.initial_agent, + content=campaign_bootstrap_text(campaign), + kind='message', + tldr='Campaign bootstrap', + reason='Initial campaign task dispatch.', + confidence=1.0, + ) + initial_handle: Handle[Any] = Handle( + registrations[campaign.initial_agent].agent_id, + ) + await initial_handle.action( + 'receive_message', + bootstrap, + ) + append_system_trace( + config.run_dir, + 'bootstrap_message_dispatched', + { + 'agent': campaign.initial_agent, + 'message_id': bootstrap['message_id'], + 'via': 'academy_action', + }, + ) + + await runtime.wait_shutdown() + + if config.rank == 0: + all_done = await wait_for_agent_statuses_finished( + run_dir=config.run_dir, + campaign=campaign, + timeout_s=config.completion_timeout_s, + ) + append_system_trace( + config.run_dir, + 'campaign_finished', + {'all_agents_done': all_done}, + ) + write_status_snapshot( + run_dir=config.run_dir, + campaign=campaign, + agent_state=await agent.report_state(), + placement=placement, + ) + return 0 + finally: + await supervisor.shutdown() def parse_args() -> argparse.Namespace: @@ -243,8 +249,22 @@ def config_from_args(args: argparse.Namespace) -> ChemGraphDaemonConfig: ) +async def _main_async() -> int: + task = asyncio.create_task(run_daemon(config_from_args(parse_args()))) + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, task.cancel) + except (NotImplementedError, RuntimeError): + pass + try: + return await task + except asyncio.CancelledError: + return 130 + + def main() -> int: - return asyncio.run(run_daemon(config_from_args(parse_args()))) + return asyncio.run(_main_async()) if __name__ == '__main__': diff --git a/src/chemgraph/academy/runtime/mcp_supervisor.py b/src/chemgraph/academy/runtime/mcp_supervisor.py new file mode 100644 index 00000000..ea83846a --- /dev/null +++ b/src/chemgraph/academy/runtime/mcp_supervisor.py @@ -0,0 +1,274 @@ +"""Spawn per-rank MCP server subprocesses, wait for readiness, connect.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import os +import shlex +import socket +import subprocess +import time +from pathlib import Path +from typing import Any + +import httpx +from langchain_core.tools import BaseTool +from langchain_core.tools import StructuredTool +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import CallToolResult + +from chemgraph.academy.core.campaign import MCPServerSpec + +logger = logging.getLogger(__name__) + +_READINESS_TIMEOUT_S = 30.0 +_READINESS_POLL_INTERVAL_S = 0.25 +_SHUTDOWN_TIMEOUT_S = 5.0 + + +def _pick_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +class MCPServerSupervisor: + """Per-rank MCP subprocess lifecycle and client wiring.""" + + def __init__(self, specs: list[MCPServerSpec], run_dir: Path) -> None: + self._specs = list(specs) + self._run_dir = Path(run_dir) + self._log_dir = self._run_dir / "mcp_logs" + self._processes: dict[str, subprocess.Popen[bytes]] = {} + self._log_handles: dict[str, object] = {} + self._urls: dict[str, str] = {} + + @property + def urls(self) -> dict[str, str]: + return dict(self._urls) + + async def start_all(self) -> dict[str, str]: + if not self._specs: + return {} + self._log_dir.mkdir(parents=True, exist_ok=True) + for spec in self._specs: + port = _pick_free_port() + url = f"http://127.0.0.1:{port}/mcp/" + cmd = shlex.split(spec.command) + [ + "--transport", + "streamable_http", + "--host", + "127.0.0.1", + "--port", + str(port), + ] + env = {**os.environ, **spec.env} + log_path = self._log_dir / f"{spec.name}.log" + log_handle = log_path.open("ab") + logger.info( + "spawning MCP server %s on port %d: %s", + spec.name, + port, + " ".join(cmd), + ) + proc = subprocess.Popen( + cmd, + stdout=log_handle, + stderr=subprocess.STDOUT, + env=env, + start_new_session=True, + ) + self._processes[spec.name] = proc + self._log_handles[spec.name] = log_handle + self._urls[spec.name] = url + await self._await_all_ready() + return dict(self._urls) + + async def get_tools( + self, + server_names: tuple[str, ...] | None = None, + ) -> list[BaseTool]: + if not self._urls: + return [] + wanted = tuple(server_names) if server_names else tuple(self._urls) + unknown = sorted(set(wanted) - set(self._urls)) + if unknown: + raise RuntimeError( + f"agent requested unknown MCP servers: {unknown}; " + f"available: {sorted(self._urls)}", + ) + connections = { + name: self._urls[name] + for name in wanted + } + tools: list[BaseTool] = [] + tool_names: set[str] = set() + for server_name, url in connections.items(): + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + listed = await session.list_tools() + for mcp_tool in listed.tools: + if mcp_tool.name in tool_names: + raise RuntimeError( + f"duplicate MCP tool name {mcp_tool.name!r} " + f"from server {server_name!r}", + ) + tool_names.add(mcp_tool.name) + tools.append( + _langchain_tool( + server_name=server_name, + server_url=url, + tool_name=mcp_tool.name, + description=mcp_tool.description + or f"MCP tool {mcp_tool.name}.", + input_schema=mcp_tool.inputSchema, + ), + ) + return tools + + async def shutdown(self) -> None: + for name, proc in list(self._processes.items()): + if proc.poll() is not None: + continue + with contextlib.suppress(ProcessLookupError): + proc.terminate() + + deadline = time.monotonic() + _SHUTDOWN_TIMEOUT_S + for name, proc in list(self._processes.items()): + remaining = max(0.0, deadline - time.monotonic()) + try: + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + logger.warning("MCP server %s did not exit; killing", name) + with contextlib.suppress(ProcessLookupError): + proc.kill() + with contextlib.suppress(subprocess.TimeoutExpired): + proc.wait(timeout=2) + + for handle in self._log_handles.values(): + with contextlib.suppress(Exception): + handle.close() + self._processes.clear() + self._log_handles.clear() + self._urls.clear() + + async def _await_all_ready(self) -> None: + deadline = time.monotonic() + _READINESS_TIMEOUT_S + pending = dict(self._urls) + async with httpx.AsyncClient(timeout=2.0) as client: + while pending and time.monotonic() < deadline: + ready_now: list[str] = [] + for name, url in pending.items(): + proc = self._processes[name] + if proc.poll() is not None: + log_tail = self._tail_log(name) + raise RuntimeError( + f"MCP server {name!r} exited before readiness " + f"(returncode={proc.returncode}). Last log lines:\n" + f"{log_tail}", + ) + try: + response = await client.get(url) + if response.status_code < 500: + ready_now.append(name) + except httpx.RequestError: + pass + for name in ready_now: + logger.info("MCP server %s ready at %s", name, pending[name]) + pending.pop(name) + if pending: + await asyncio.sleep(_READINESS_POLL_INTERVAL_S) + if pending: + stuck = sorted(pending) + tails = "\n".join( + f"=== {name} ===\n{self._tail_log(name)}" + for name in stuck + ) + raise RuntimeError( + f"MCP servers did not become ready within " + f"{_READINESS_TIMEOUT_S:.0f}s: {stuck}\n{tails}", + ) + + def _tail_log(self, name: str, n: int = 40) -> str: + path = self._log_dir / f"{name}.log" + if not path.exists(): + return "(no log file)" + try: + text = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return "(log unreadable)" + return "\n".join(text.splitlines()[-n:]) + + +def _langchain_tool( + *, + server_name: str, + server_url: str, + tool_name: str, + description: str, + input_schema: dict[str, Any], +) -> BaseTool: + async def call_mcp_tool(**kwargs: Any) -> Any: + return await _call_mcp_tool( + server_url=server_url, + tool_name=tool_name, + arguments=kwargs, + ) + + call_mcp_tool.__name__ = f"{server_name}_{tool_name}" + return StructuredTool.from_function( + coroutine=call_mcp_tool, + name=tool_name, + description=description, + args_schema=input_schema, + metadata={ + "chemgraph_academy_tool_kind": "science_tool", + "mcp_server": server_name, + }, + ) + + +async def _call_mcp_tool( + *, + server_url: str, + tool_name: str, + arguments: dict[str, Any], +) -> Any: + async with streamablehttp_client(server_url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool(tool_name, arguments) + return _serialize_call_tool_result(result) + + +def _serialize_call_tool_result(result: CallToolResult) -> dict[str, Any]: + payload: dict[str, Any] = { + "is_error": bool(result.isError), + "content": [ + _json_safe(block) + for block in result.content + ], + } + if result.structuredContent is not None: + payload["structured_content"] = _json_safe(result.structuredContent) + if result.isError: + payload["status"] = "error" + else: + payload["status"] = "ok" + return payload + + +def _json_safe(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe(item) for item in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return repr(value) diff --git a/src/chemgraph/mcp/fastmcp_client.py b/src/chemgraph/mcp/fastmcp_client.py deleted file mode 100644 index 6519b313..00000000 --- a/src/chemgraph/mcp/fastmcp_client.py +++ /dev/null @@ -1,360 +0,0 @@ -"""In-process client adapter for FastMCP tool modules.""" - -from __future__ import annotations - -import importlib -import json -import uuid -from collections.abc import Mapping -from collections.abc import Sequence -from pathlib import Path -from typing import Any -from typing import Protocol - -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import Field - - -class ToolInvocation(BaseModel): - """A normalized record of one agent-requested FastMCP tool call.""" - - model_config = ConfigDict(extra="forbid") - - tool_name: str - arguments: dict[str, Any] = Field(default_factory=dict) - agent_id: str | None = None - role: str | None = None - correlation_id: str = Field(default_factory=lambda: f"call-{uuid.uuid4()}") - - -class ToolResult(BaseModel): - """Normalized result from a FastMCP tool call.""" - - model_config = ConfigDict(extra="allow") - - tool_name: str - status: str - result: Any = None - error: str | None = None - correlation_id: str - - -class FastMCPToolSpec(Protocol): - """Structural interface for config-declared FastMCP tools.""" - - name: str - module: str - tool: str - description: str | None - - -class FastMCPExecutionSpec(Protocol): - """Structural interface for backend configuration used by CGFastMCP.""" - - backend: str | None - system: str | None - config_path: str | None - options: Mapping[str, Any] - - -class FastMCPExecutionConfig(BaseModel): - """Concrete backend configuration for in-process FastMCP clients.""" - - model_config = ConfigDict(extra="forbid") - - backend: str | None = "local" - system: str | None = "local" - config_path: str | None = None - options: dict[str, Any] = Field(default_factory=dict) - - -def load_fastmcp_tool_module( - module_name: str, - *, - cache: dict[str, Any] | None = None, -) -> Any: - """Return a module's top-level FastMCP object declared by a campaign tool.""" - if cache is not None and module_name in cache: - return cache[module_name] - - module = importlib.import_module(module_name) - try: - server = module.mcp - except AttributeError as exc: - raise RuntimeError( - f"FastMCP tool module {module_name!r} does not expose " - "a top-level 'mcp' object", - ) from exc - - if cache is not None: - cache[module_name] = server - return server - - -async def fastmcp_tool_schemas( - specs: Sequence[FastMCPToolSpec], -) -> list[dict[str, Any]]: - """Build OpenAI tool schemas from declared FastMCP ToolSpecs.""" - schemas: list[dict[str, Any]] = [] - module_cache: dict[str, Any] = {} - tools_cache: dict[str, dict[str, Any]] = {} - for spec in specs: - if spec.module not in tools_cache: - tools = await load_fastmcp_tool_module( - spec.module, - cache=module_cache, - ).list_tools() - tools_cache[spec.module] = { - _fastmcp_tool_name(tool): _fastmcp_tool_payload(tool) - for tool in tools - } - try: - tool_payload = tools_cache[spec.module][spec.tool] - except KeyError as exc: - raise RuntimeError( - f"FastMCP module {spec.module!r} does not expose tool " - f"{spec.tool!r}", - ) from exc - schemas.append(_openai_tool_schema(spec, tool_payload)) - return schemas - - -def _fastmcp_tool_name(tool: Any) -> str: - if isinstance(tool, dict): - return str(tool.get("name", "")) - return str(getattr(tool, "name", "")) - - -def _fastmcp_tool_payload(tool: Any) -> dict[str, Any]: - if isinstance(tool, dict): - return dict(tool) - if hasattr(tool, "model_dump"): - return tool.model_dump(mode="json") - return { - "name": getattr(tool, "name", ""), - "description": getattr(tool, "description", ""), - "inputSchema": getattr(tool, "inputSchema", None), - } - - -def _openai_tool_schema( - spec: FastMCPToolSpec, - tool_payload: dict[str, Any], -) -> dict[str, Any]: - parameters = _sanitize_input_schema( - tool_payload.get("inputSchema") or {"type": "object", "properties": {}}, - ) - return { - "type": "function", - "function": { - "name": spec.name, - "description": spec.description - or str(tool_payload.get("description") or ""), - "parameters": parameters, - }, - } - - -def _sanitize_input_schema(schema: Any) -> dict[str, Any]: - if hasattr(schema, "model_dump"): - schema = schema.model_dump(mode="json") - if not isinstance(schema, dict): - return {"type": "object", "properties": {}, "additionalProperties": False} - sanitized = json.loads(json.dumps(schema)) - sanitized.setdefault("type", "object") - sanitized.setdefault("properties", {}) - sanitized.setdefault("additionalProperties", False) - return sanitized - - -def serialize_fastmcp_result(result: Any) -> Any: - """Convert FastMCP content blocks to JSON-friendly values.""" - if isinstance(result, dict): - return result - if isinstance(result, (str, int, float, bool)) or result is None: - return result - if hasattr(result, "model_dump"): - return result.model_dump(mode="json") - if isinstance(result, Sequence) and not isinstance(result, (str, bytes)): - values = [serialize_fastmcp_result(item) for item in result] - structured = _first_structured_result(values) - if structured is not None: - return structured - json_text = _first_json_text_result(values) - if json_text is not None: - return json_text - return values - return str(result) - - -def _first_structured_result(values: list[Any]) -> dict[str, Any] | None: - for value in values: - if isinstance(value, dict) and ( - "results" in value - or "batch_id" in value - or "progress_pct" in value - or value.get("status") in {"completed", "submitted"} - ): - return value - if isinstance(value, list): - nested = _first_structured_result(value) - if nested is not None: - return nested - if isinstance(value, dict) and isinstance(value.get("text"), str): - try: - parsed = json.loads(value["text"]) - except json.JSONDecodeError: - continue - nested = _first_structured_result([parsed]) - if nested is not None: - return nested - return None - - -def _first_json_text_result(values: list[Any]) -> Any | None: - for value in values: - if isinstance(value, dict) and isinstance(value.get("text"), str): - try: - return json.loads(value["text"]) - except json.JSONDecodeError: - continue - if isinstance(value, list): - nested = _first_json_text_result(value) - if nested is not None: - return nested - return None - - -class FastMCPToolInvoker: - """Invoke allowed tools through in-process FastMCP modules.""" - - def __init__( - self, - *, - specs: Sequence[FastMCPToolSpec], - execution: FastMCPExecutionSpec, - run_dir: str | Path, - ) -> None: - self.specs = {spec.name: spec for spec in specs} - self.execution = execution - self.run_dir = Path(run_dir) - self._module_cache: dict[str, Any] = {} - self._available_cache: dict[str, set[str]] = {} - - def names(self) -> list[str]: - return sorted(self.specs) - - async def verify_allowed_tools(self) -> list[str]: - """Return tools missing from their declared FastMCP module.""" - missing: list[str] = [] - for spec in self.specs.values(): - try: - available = await self._available_tool_names(spec.module) - except Exception: # noqa: BLE001 - caller needs aggregate missing names - missing.append(spec.name) - continue - if spec.tool not in available: - missing.append(spec.name) - return missing - - async def invoke(self, invocation: ToolInvocation) -> ToolResult: - spec = self.specs.get(invocation.tool_name) - if spec is None: - raise KeyError( - f"unknown FastMCP tool: {invocation.tool_name}", - ) - - try: - available = await self._available_tool_names(spec.module) - if spec.tool not in available: - raise KeyError( - f"FastMCP module {spec.module!r} does not expose " - f"tool {spec.tool!r}", - ) - mcp = self._fastmcp_module(spec.module) - _configure_fastmcp_backend( - mcp, - module_name=spec.module, - execution=self.execution, - run_dir=self.run_dir, - ) - result = await mcp.call_tool(spec.tool, invocation.arguments) - except Exception as exc: # noqa: BLE001 - preserve tool failure as data - return ToolResult( - tool_name=invocation.tool_name, - status="error", - error=repr(exc), - correlation_id=invocation.correlation_id, - ) - - return ToolResult( - tool_name=invocation.tool_name, - status="success", - result=serialize_fastmcp_result(result), - correlation_id=invocation.correlation_id, - ) - - async def _available_tool_names(self, module_name: str) -> set[str]: - if module_name not in self._available_cache: - tools = await self._fastmcp_module(module_name).list_tools() - self._available_cache[module_name] = { - str(getattr(tool, "name", "")) - if not isinstance(tool, dict) - else str(tool.get("name", "")) - for tool in tools - } - return self._available_cache[module_name] - - def _fastmcp_module(self, module_name: str) -> Any: - return load_fastmcp_tool_module(module_name, cache=self._module_cache) - - -def _configure_fastmcp_backend( - mcp: Any, - *, - module_name: str, - execution: FastMCPExecutionSpec, - run_dir: str | Path, -) -> None: - """Configure a CGFastMCP backend without initialising compute resources.""" - if not hasattr(mcp, "init_backend"): - return - if getattr(mcp, "_backend_kwargs", None) is not None: - return - - kwargs: dict[str, Any] = dict(execution.options) - if execution.config_path: - kwargs["config_path"] = execution.config_path - if execution.backend: - kwargs["backend_name"] = execution.backend - if execution.system: - kwargs["system"] = execution.system - - tracker_name = module_name.replace(".", "_") - tracker_path = Path(run_dir) / f"{tracker_name}_jobs.json" - mcp.init_backend( - tracker_kwargs={"persist_file": str(tracker_path)}, - **kwargs, - ) - - -async def build_fastmcp_tool_invoker( - *, - specs: Sequence[FastMCPToolSpec], - execution: FastMCPExecutionSpec, - run_dir: str | Path, - agent_name: str, -) -> FastMCPToolInvoker: - """Build and verify one in-process FastMCP tool invoker.""" - invoker = FastMCPToolInvoker( - specs=list(specs), - execution=execution, - run_dir=run_dir, - ) - missing = await invoker.verify_allowed_tools() - if missing: - raise RuntimeError( - f"Could not resolve requested FastMCP tools for {agent_name}: {missing}", - ) - return invoker diff --git a/tests/test_academy_campaign.py b/tests/test_academy_campaign.py index f129320f..51c69bed 100644 --- a/tests/test_academy_campaign.py +++ b/tests/test_academy_campaign.py @@ -6,6 +6,7 @@ from chemgraph.academy.core.campaign import campaign_bootstrap_text from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import MCPServerSpec from chemgraph.academy.core.campaign import validate_campaign @@ -55,10 +56,10 @@ def test_removed_structured_orchestration_fields_are_rejected(tmp_path) -> None: "role": "Role", "mission": "Do the task.", "allowed_peers": [], - "tools": [], + "mcp_servers": [], }, ], - "tools": [], + "mcp_servers": [], }, ), encoding="utf-8", @@ -92,11 +93,16 @@ def test_campaign_loader_accepts_jsonc_comments(tmp_path) -> None: "role": "Role", "mission": "Do the task.", "allowed_peers": [], - "tools": [], + "mcp_servers": ["general"], "resources": ["input"] } ], - "tools": [] + "mcp_servers": [ + { + "name": "general", + "command": "python -m chemgraph.mcp.mcp_tools" + } + ] } """, encoding="utf-8", @@ -106,6 +112,23 @@ def test_campaign_loader_accepts_jsonc_comments(tmp_path) -> None: assert campaign.run_id == "commented" assert campaign.resources["input"].kind == "json" + assert campaign.mcp_servers[0].name == "general" + assert campaign.agents[0].mcp_servers == ("general",) + + +def test_mcp_server_spec_validation() -> None: + spec = MCPServerSpec.model_validate( + {"name": "general", "command": "python -m server"}, + ) + assert spec.env == {} + + with pytest.raises(ValueError, match="field required|Field required"): + MCPServerSpec.model_validate({"name": "general"}) + + with pytest.raises(ValueError): + MCPServerSpec.model_validate( + {"name": "general", "command": "python -m server", "extra": "bad"}, + ) def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: @@ -129,10 +152,10 @@ def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: "role": "Role", "mission": "Do the task.", "allowed_peers": [], - "tools": [], + "mcp_servers": [], }, ], - "tools": [], + "mcp_servers": [], }, ), encoding="utf-8", @@ -140,3 +163,62 @@ def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: with pytest.raises(ValueError, match="resource kind must be one of"): load_campaign(campaign_path) + + +def test_validate_campaign_rejects_unknown_mcp_server(tmp_path) -> None: + campaign_path = tmp_path / "campaign.json" + campaign_path.write_text( + json.dumps( + { + "run_id": "bad-server", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["missing"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises(RuntimeError, match="unknown MCP servers"): + validate_campaign(campaign, 1) + + +def test_validate_campaign_rejects_duplicate_mcp_server_names(tmp_path) -> None: + campaign_path = tmp_path / "campaign.json" + campaign_path.write_text( + json.dumps( + { + "run_id": "duplicate-server", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m one"}, + {"name": "general", "command": "python -m two"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises(RuntimeError, match="MCP server names must be unique"): + validate_campaign(campaign, 1) diff --git a/tests/test_academy_mcp_supervisor.py b/tests/test_academy_mcp_supervisor.py new file mode 100644 index 00000000..901cc78d --- /dev/null +++ b/tests/test_academy_mcp_supervisor.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +from chemgraph.academy.core.campaign import MCPServerSpec +from chemgraph.academy.runtime.mcp_supervisor import MCPServerSupervisor + + +def _pythonpath(tmp_path: Path) -> str: + current = os.environ.get("PYTHONPATH", "") + parts = [str(tmp_path)] + if current: + parts.append(current) + return os.pathsep.join(parts) + + +def _write_tiny_server(tmp_path: Path) -> None: + (tmp_path / "tiny_mcp.py").write_text( + """ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("tiny") + +@mcp.tool(name="echo", description="Echo one string.") +def echo(text: str) -> dict: + return {"text": text} + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + run_mcp_server(mcp, default_port=0) +""", + encoding="utf-8", + ) + + +@pytest.mark.asyncio +async def test_mcp_supervisor_starts_server_and_gets_tools(tmp_path) -> None: + _write_tiny_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="tiny", + command=f"{sys.executable} -m tiny_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + urls = await supervisor.start_all() + tools = await supervisor.get_tools(("tiny",)) + echo = next(tool for tool in tools if tool.name == "echo") + result = await echo.ainvoke({"text": "hello"}) + finally: + await supervisor.shutdown() + + assert sorted(urls) == ["tiny"] + assert "echo" in {tool.name for tool in tools} + assert result["status"] == "ok" + assert "hello" in repr(result) + + +@pytest.mark.asyncio +async def test_mcp_supervisor_shutdown_terminates_process(tmp_path) -> None: + _write_tiny_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="tiny", + command=f"{sys.executable} -m tiny_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + await supervisor.start_all() + proc = supervisor._processes["tiny"] + + await supervisor.shutdown() + + assert proc.poll() is not None + + +@pytest.mark.asyncio +async def test_mcp_supervisor_reports_server_exit_log_tail(tmp_path) -> None: + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="bad", + command=f"{sys.executable} -c \"print('boom'); raise SystemExit(1)\"", + ), + ], + run_dir=tmp_path / "run", + ) + + with pytest.raises(RuntimeError, match="boom"): + await supervisor.start_all() + + await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_mcp_supervisor_rejects_unknown_server_request(tmp_path) -> None: + _write_tiny_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="tiny", + command=f"{sys.executable} -m tiny_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + with pytest.raises(RuntimeError, match="available"): + await supervisor.get_tools(("missing",)) + finally: + await supervisor.shutdown() diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index e4180474..35e85fd4 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -26,7 +26,7 @@ def _agent_spec() -> ChemGraphAgentSpec: role="Worker", mission="Use explicit tools only.", allowed_peers=(), - tools=(), + mcp_servers=(), ) @@ -91,7 +91,6 @@ async def test_reasoning_adapter_finish_turn_traces(tmp_path) -> None: tools = await build_chemgraph_reasoning_tools( spec=_agent_spec(), run_dir=tmp_path, - tool_invoker=object(), peer_names=(), peer_handles={}, outbox=[], @@ -122,7 +121,6 @@ async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: tools = await build_chemgraph_reasoning_tools( spec=_agent_spec_with_peer(), run_dir=tmp_path, - tool_invoker=object(), peer_names=("agent-b",), peer_handles={"agent-b": peer}, outbox=outbox, @@ -210,7 +208,6 @@ async def test_logical_agent_reasoning_round_calls_turn_runner(monkeypatch, tmp_ prompt_profile=_prompt_profile(), run_dir=tmp_path, max_decisions=5, - tool_invoker=object(), ) agent.round_index = 1 diff --git a/tests/test_tool_adapter_validation.py b/tests/test_tool_adapter_validation.py index d255ca90..47dfed6a 100644 --- a/tests/test_tool_adapter_validation.py +++ b/tests/test_tool_adapter_validation.py @@ -23,7 +23,7 @@ def _agent_spec() -> ChemGraphAgentSpec: role="Worker", mission="Use explicit tools only.", allowed_peers=("agent-b",), - tools=(), + mcp_servers=(), ) @@ -34,7 +34,6 @@ async def _build_tools(tmp_path): tools = await build_chemgraph_reasoning_tools( spec=_agent_spec(), run_dir=tmp_path, - tool_invoker=object(), # unused when spec.tools is empty peer_names=("agent-b",), peer_handles={"agent-b": peer_handle}, outbox=outbox, From de0a79dc6ef67c343ef934f4a4c6bce710890cc1 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 16:12:27 -0500 Subject: [PATCH 069/119] refactor(academy): generalize exchange registration --- src/chemgraph/academy/core/campaign.py | 1 + .../academy/observability/run_artifacts.py | 3 +- .../academy/runtime/compute_launcher.py | 14 ++- src/chemgraph/academy/runtime/daemon.py | 13 ++- src/chemgraph/academy/runtime/exchange.py | 39 +++++++ src/chemgraph/academy/runtime/mpi.py | 1 + src/chemgraph/academy/runtime/registration.py | 51 +++++++-- tests/test_academy_compute_launcher.py | 1 + tests/test_academy_exchange_registration.py | 106 ++++++++++++++++++ 9 files changed, 213 insertions(+), 16 deletions(-) create mode 100644 src/chemgraph/academy/runtime/exchange.py create mode 100644 tests/test_academy_exchange_registration.py diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py index 4e514f1c..a08e3ad4 100644 --- a/src/chemgraph/academy/core/campaign.py +++ b/src/chemgraph/academy/core/campaign.py @@ -139,6 +139,7 @@ class ChemGraphDaemonConfig: rank: int local_rank: int | None chemgraph_repo_root: pathlib.Path + exchange_type: str = 'redis' def namespace_for_run(run_dir: pathlib.Path) -> str: diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py index a2d11980..11fa8b4a 100644 --- a/src/chemgraph/academy/observability/run_artifacts.py +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -307,7 +307,8 @@ def initialize_run_files( ), 'prompt_profile': str(campaign.prompt_profile), 'chemgraph_repo_root': str(config.chemgraph_repo_root), - 'communication_transport': 'academy_redis_actions', + 'communication_transport': f'academy_{config.exchange_type}_actions', + 'exchange_type': config.exchange_type, 'redis_host': config.redis_host, 'redis_port': config.redis_port, 'redis_namespace': config.redis_namespace, diff --git a/src/chemgraph/academy/runtime/compute_launcher.py b/src/chemgraph/academy/runtime/compute_launcher.py index 655a9a6a..d96ee263 100644 --- a/src/chemgraph/academy/runtime/compute_launcher.py +++ b/src/chemgraph/academy/runtime/compute_launcher.py @@ -47,6 +47,7 @@ class AllocationPlan: start_redis: bool mpiexec: str chemgraph_repo_root: Path + exchange_type: str = "redis" def parse_args(argv: list[str] | None = None) -> argparse.Namespace: @@ -76,6 +77,11 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser.add_argument("--agents-per-node", type=int) parser.add_argument("--max-decisions", type=int) parser.add_argument("--redis-port", type=int) + parser.add_argument( + "--exchange-type", + choices=("redis", "local", "hybrid"), + default="redis", + ) parser.add_argument("--no-start-redis", action="store_true") return parser.parse_args(argv) @@ -243,6 +249,7 @@ def prepare_compute_launch(args: argparse.Namespace) -> AllocationPlan: start_redis=not args.no_start_redis, mpiexec=profile.mpiexec, chemgraph_repo_root=Path(profile.repo_root).resolve(), + exchange_type=args.exchange_type, ) @@ -267,7 +274,8 @@ def run_allocation(plan: AllocationPlan) -> int: """Start Redis if requested and run per-rank daemons under mpiexec.""" plan.run_dir.mkdir(parents=True, exist_ok=True) redis_proc: subprocess.Popen[bytes] | None = None - if plan.start_redis: + uses_redis = plan.exchange_type in {"redis", "hybrid"} + if plan.start_redis and uses_redis: redis_server = shutil.which("redis-server") if redis_server is None: raise RuntimeError("redis-server is required unless --no-start-redis is set") @@ -296,7 +304,8 @@ def run_allocation(plan: AllocationPlan) -> int: encoding="utf-8", ) try: - wait_redis(plan.redis_host, plan.redis_port, plan.run_dir) + if uses_redis: + wait_redis(plan.redis_host, plan.redis_port, plan.run_dir) daemon_args = [ "--run-dir", str(plan.run_dir), "--run-token", plan.run_token, @@ -312,6 +321,7 @@ def run_allocation(plan: AllocationPlan) -> int: "--redis-host", plan.redis_host, "--redis-port", str(plan.redis_port), "--redis-namespace", plan.redis_namespace, + "--exchange-type", plan.exchange_type, "--chemgraph-repo-root", str(plan.chemgraph_repo_root), ] cmd = [ diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index 6b384a92..15094b3e 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -5,12 +5,12 @@ import pathlib import signal -from academy.exchange.redis import RedisExchangeFactory from academy.handle import Handle from academy.runtime import Runtime from academy.runtime import RuntimeConfig from chemgraph.academy.core.peer_protocol import build_message +from chemgraph.academy.runtime.exchange import build_exchange_factory from chemgraph.academy.runtime.registration import load_academy_registrations from chemgraph.academy.runtime.registration import wait_academy_registrations from chemgraph.academy.runtime.registration import write_academy_registrations @@ -61,10 +61,7 @@ async def run_daemon(config: ChemGraphDaemonConfig) -> int: await supervisor.start_all() external_tools = await supervisor.get_tools(agent_spec.mcp_servers) - academy_factory = RedisExchangeFactory( - hostname=config.redis_host, - port=config.redis_port, - ) + academy_factory = build_exchange_factory(config) if config.rank == 0: initialize_run_files( run_dir=config.run_dir, @@ -212,6 +209,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--redis-host', default='127.0.0.1') parser.add_argument('--redis-port', type=int, required=True) parser.add_argument('--redis-namespace') + parser.add_argument( + '--exchange-type', + choices=('redis', 'local', 'hybrid'), + default='redis', + ) parser.add_argument('--chemgraph-repo-root') return parser.parse_args() @@ -239,6 +241,7 @@ def config_from_args(args: argparse.Namespace) -> ChemGraphDaemonConfig: redis_host=args.redis_host, redis_port=args.redis_port, redis_namespace=args.redis_namespace or namespace_for_run(run_dir), + exchange_type=args.exchange_type, rank=rank_from_env(), local_rank=local_rank_from_env(), chemgraph_repo_root=( diff --git a/src/chemgraph/academy/runtime/exchange.py b/src/chemgraph/academy/runtime/exchange.py new file mode 100644 index 00000000..6a8b2b2d --- /dev/null +++ b/src/chemgraph/academy/runtime/exchange.py @@ -0,0 +1,39 @@ +"""Build the Academy exchange factory matching a daemon config.""" + +from __future__ import annotations + +from typing import Any + +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig + + +def build_exchange_factory(config: ChemGraphDaemonConfig) -> Any: + """Return the Academy exchange factory matching ``config.exchange_type``.""" + exchange_type = config.exchange_type + + if exchange_type == 'redis': + from academy.exchange.redis import RedisExchangeFactory + + return RedisExchangeFactory( + hostname=config.redis_host, + port=config.redis_port, + ) + + if exchange_type == 'local': + from academy.exchange.local import LocalExchangeFactory + + return LocalExchangeFactory() + + if exchange_type == 'hybrid': + from academy.exchange.hybrid import HybridExchangeFactory + + return HybridExchangeFactory( + redis_host=config.redis_host, + redis_port=config.redis_port, + namespace=config.redis_namespace, + ) + + raise ValueError( + f"Unsupported exchange type {exchange_type!r}; expected one of " + "'redis', 'local', 'hybrid'.", + ) diff --git a/src/chemgraph/academy/runtime/mpi.py b/src/chemgraph/academy/runtime/mpi.py index dfe88cd8..7439f587 100644 --- a/src/chemgraph/academy/runtime/mpi.py +++ b/src/chemgraph/academy/runtime/mpi.py @@ -95,6 +95,7 @@ def placement_payload(config: Any, agent_name: str) -> dict[str, Any]: 'python_executable': sys.executable, 'rank': config.rank, 'local_rank': config.local_rank, + 'exchange_type': config.exchange_type, 'redis_host': config.redis_host, 'redis_port': config.redis_port, 'redis_namespace': config.redis_namespace, diff --git a/src/chemgraph/academy/runtime/registration.py b/src/chemgraph/academy/runtime/registration.py index c56db752..ef8823da 100644 --- a/src/chemgraph/academy/runtime/registration.py +++ b/src/chemgraph/academy/runtime/registration.py @@ -7,24 +7,53 @@ from collections.abc import Mapping from typing import Any +from academy.exchange.hybrid import HybridAgentRegistration +from academy.exchange.local import LocalAgentRegistration from academy.exchange.redis import RedisAgentRegistration +from academy.exchange.transport import AgentRegistration from academy.identifier import AgentId +from pydantic import BaseModel from chemgraph.academy.observability.run_files import write_json_atomic +_REGISTRATION_TYPES: dict[str, type[BaseModel]] = { + 'local': LocalAgentRegistration, + 'hybrid': HybridAgentRegistration, + 'redis': RedisAgentRegistration, +} + + def academy_registration_path(run_dir: pathlib.Path) -> pathlib.Path: return run_dir / 'academy_registrations.json' +def _exchange_type_of(registration: AgentRegistration[Any]) -> str: + value = getattr(registration, 'exchange_type', None) + if not isinstance(value, str): + raise TypeError( + f'Registration {type(registration).__name__} has no string ' + '`exchange_type` field; cannot persist.', + ) + return value + + def registration_payload( *, run_token: str, - registrations: Mapping[str, RedisAgentRegistration[Any]], + registrations: Mapping[str, AgentRegistration[Any]], ) -> dict[str, Any]: + if not registrations: + raise ValueError('at least one registration is required') + exchange_types = {_exchange_type_of(r) for r in registrations.values()} + if len(exchange_types) > 1: + raise ValueError( + f'mixed exchange types in one campaign: {sorted(exchange_types)}', + ) + (exchange_type,) = exchange_types return { 'run_token': run_token, - 'exchange_type': 'redis', + 'exchange_type': exchange_type, 'agents': { name: registration.agent_id.model_dump(mode='json') for name, registration in registrations.items() @@ -36,7 +65,7 @@ def write_academy_registrations( *, run_dir: pathlib.Path, run_token: str, - registrations: Mapping[str, RedisAgentRegistration[Any]], + registrations: Mapping[str, AgentRegistration[Any]], ) -> None: write_json_atomic( academy_registration_path(run_dir), @@ -48,20 +77,26 @@ def load_academy_registrations( run_dir: pathlib.Path, *, run_token: str, -) -> dict[str, RedisAgentRegistration[Any]]: +) -> dict[str, AgentRegistration[Any]]: path = academy_registration_path(run_dir) data = json.loads(path.read_text(encoding='utf-8')) if data.get('run_token') != run_token: raise RuntimeError( f'Academy registration file {path} belongs to a different run', ) + exchange_type = data.get('exchange_type') + if exchange_type not in _REGISTRATION_TYPES: + raise RuntimeError( + f'Academy registration file has unsupported exchange_type ' + f'{exchange_type!r}; expected one of ' + f'{sorted(_REGISTRATION_TYPES)}', + ) + cls = _REGISTRATION_TYPES[exchange_type] agents = data.get('agents') if not isinstance(agents, dict): raise RuntimeError(f'Academy registration file is malformed: {path}') return { - name: RedisAgentRegistration( - agent_id=AgentId[Any].model_validate(agent_id), - ) + name: cls(agent_id=AgentId[Any].model_validate(agent_id)) for name, agent_id in agents.items() } @@ -71,7 +106,7 @@ async def wait_academy_registrations( *, run_token: str, timeout_s: float, -) -> dict[str, RedisAgentRegistration[Any]]: +) -> dict[str, AgentRegistration[Any]]: path = academy_registration_path(run_dir) deadline = time.monotonic() + timeout_s while True: diff --git a/tests/test_academy_compute_launcher.py b/tests/test_academy_compute_launcher.py index 10652443..abce39cd 100644 --- a/tests/test_academy_compute_launcher.py +++ b/tests/test_academy_compute_launcher.py @@ -53,5 +53,6 @@ def test_run_allocation_builds_single_mpiexec_command(tmp_path, monkeypatch) -> assert "mpi-daemon" in cmd assert "--campaign-config" in cmd assert "--lm-config" in cmd + assert "--exchange-type" in cmd assert "--chemgraph-repo-root" in cmd assert (tmp_path / "launch_command.txt").exists() diff --git a/tests/test_academy_exchange_registration.py b/tests/test_academy_exchange_registration.py new file mode 100644 index 00000000..4780cc21 --- /dev/null +++ b/tests/test_academy_exchange_registration.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from academy.exchange.hybrid import HybridAgentRegistration +from academy.exchange.local import LocalAgentRegistration +from academy.exchange.redis import RedisAgentRegistration +from academy.identifier import AgentId + +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.runtime.exchange import build_exchange_factory +from chemgraph.academy.runtime.registration import load_academy_registrations +from chemgraph.academy.runtime.registration import registration_payload +from chemgraph.academy.runtime.registration import write_academy_registrations + + +def _config(tmp_path: Path, exchange_type: str) -> ChemGraphDaemonConfig: + return ChemGraphDaemonConfig( + run_dir=tmp_path, + run_token='token-1', + agent_count=1, + campaign_config=tmp_path / 'campaign.json', + lm_config=tmp_path / 'lm.json', + max_decisions=1, + poll_timeout_s=1.0, + idle_timeout_s=1.0, + startup_timeout_s=1.0, + completion_timeout_s=1.0, + status_interval_s=1.0, + redis_host='localhost', + redis_port=6392, + redis_namespace='ns', + rank=0, + local_rank=0, + chemgraph_repo_root=tmp_path, + exchange_type=exchange_type, + ) + + +@pytest.mark.parametrize( + ('exchange_type', 'expected_class'), + [ + ('redis', 'RedisExchangeFactory'), + ('local', 'LocalExchangeFactory'), + ('hybrid', 'HybridExchangeFactory'), + ], +) +def test_build_exchange_factory_dispatches_by_config( + tmp_path, + exchange_type, + expected_class, +) -> None: + factory = build_exchange_factory(_config(tmp_path, exchange_type)) + + assert type(factory).__name__ == expected_class + + +def test_build_exchange_factory_rejects_unknown_exchange(tmp_path) -> None: + with pytest.raises(ValueError, match='Unsupported exchange type'): + build_exchange_factory(_config(tmp_path, 'bad')) + + +@pytest.mark.parametrize( + 'registration_cls', + [ + RedisAgentRegistration, + LocalAgentRegistration, + HybridAgentRegistration, + ], +) +def test_academy_registration_round_trips_by_exchange_type( + tmp_path, + registration_cls, +) -> None: + registration = registration_cls(agent_id=AgentId.new('agent-a')) + write_academy_registrations( + run_dir=tmp_path, + run_token='token-1', + registrations={'agent-a': registration}, + ) + + loaded = load_academy_registrations(tmp_path, run_token='token-1') + + assert isinstance(loaded['agent-a'], registration_cls) + assert loaded['agent-a'].agent_id == registration.agent_id + + +def test_registration_payload_rejects_mixed_exchange_types() -> None: + with pytest.raises(ValueError, match='mixed exchange types'): + registration_payload( + run_token='token-1', + registrations={ + 'redis-agent': RedisAgentRegistration( + agent_id=AgentId.new('redis-agent'), + ), + 'local-agent': LocalAgentRegistration( + agent_id=AgentId.new('local-agent'), + ), + }, + ) + + +def test_registration_payload_rejects_empty_registrations() -> None: + with pytest.raises(ValueError, match='at least one registration'): + registration_payload(run_token='token-1', registrations={}) From 99f073d9b2d69130f4e81c30463f7471088c86bf Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 16:29:32 -0500 Subject: [PATCH 070/119] fix(agent): emit llm decision events for tool calls --- src/chemgraph/agent/llm_agent.py | 33 +++++++++++++ tests/test_llm_agent.py | 84 +++++++++++++++++++++++++++++++- 2 files changed, 115 insertions(+), 2 deletions(-) diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 7791ed0e..1b08541c 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -353,6 +353,8 @@ def on_llm_end(self, response, **kwargs) -> None: if isinstance(usage, dict): payload["llm_output"] = usage self._emit("llm_call_finished", payload) + if tool_calls := _response_tool_calls(response): + self._emit("llm_decision", {"tool_calls": tool_calls}) def on_llm_error(self, error, **kwargs) -> None: self._emit("llm_call_failed", {"error": repr(error)}) @@ -395,6 +397,29 @@ def _message_tool_calls(message: Any) -> list[Any]: return calls if isinstance(calls, list) else [] +def _response_tool_calls(response: Any) -> list[dict[str, str | None]]: + try: + generations = getattr(response, "generations", None) or [] + tool_calls: list[dict[str, str | None]] = [] + for generation_group in generations: + for generation in generation_group or []: + message = getattr(generation, "message", None) + for call in _message_tool_calls(message): + name = _call_name(call) + if not name: + continue + tool_calls.append( + { + "name": name, + "id": _call_id(call), + }, + ) + return tool_calls + except Exception: # noqa: BLE001 - event extraction must not break runs. + logger.debug("failed to extract llm_decision tool calls", exc_info=True) + return [] + + def _tool_message_name(message: Any) -> str | None: if isinstance(message, dict): name = message.get("name") @@ -420,6 +445,14 @@ def _call_name(call: Any) -> str | None: return str(name) if name else None +def _call_id(call: Any) -> str | None: + if isinstance(call, dict): + value = call.get("id") or call.get("tool_call_id") + else: + value = getattr(call, "id", None) or getattr(call, "tool_call_id", None) + return str(value) if value else None + + def _state_messages(state: Any) -> list[Any]: if isinstance(state, dict): messages = state.get("messages", []) diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index 32b72b87..18794a62 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -1,8 +1,10 @@ -import pytest import asyncio -from chemgraph.agent.llm_agent import ChemGraph +from types import SimpleNamespace from unittest.mock import Mock, patch + +import pytest from langchain_core.messages import AIMessage +from chemgraph.agent.llm_agent import ChemGraph, _TurnEventCallback @pytest.fixture @@ -41,3 +43,81 @@ def test_agent_query(mock_llm, tmp_path): assert response.content == "Test response" mock_llm.bind_tools.assert_called_once() mock_chain.invoke.assert_called_once() + + +def test_turn_event_callback_emits_llm_decision_for_tool_calls(): + events = [] + callback = _TurnEventCallback( + lambda event, payload: events.append((event, payload)), + "thread-1", + ) + response = SimpleNamespace( + llm_output={"token_usage": {"total_tokens": 12}}, + generations=[ + [ + SimpleNamespace( + message=SimpleNamespace( + tool_calls=[ + {"name": "molecule_name_to_smiles", "id": "call-1"}, + { + "function": {"name": "smiles_to_coordinate_file"}, + "tool_call_id": "call-2", + }, + ], + ), + ), + ], + ], + ) + + callback.on_llm_end(response) + + assert events == [ + ( + "llm_call_finished", + { + "thread_id": "thread-1", + "llm_output": {"token_usage": {"total_tokens": 12}}, + }, + ), + ( + "llm_decision", + { + "thread_id": "thread-1", + "tool_calls": [ + {"name": "molecule_name_to_smiles", "id": "call-1"}, + {"name": "smiles_to_coordinate_file", "id": "call-2"}, + ], + }, + ), + ] + + +def test_turn_event_callback_skips_llm_decision_without_tool_calls(): + events = [] + callback = _TurnEventCallback( + lambda event, payload: events.append((event, payload)), + "thread-1", + ) + + callback.on_llm_end( + SimpleNamespace(generations=[[SimpleNamespace(message=AIMessage(content="done"))]]), + ) + + assert [event for event, _payload in events] == ["llm_call_finished"] + + +def test_turn_event_callback_ignores_llm_decision_extraction_errors(): + class BrokenGenerationGroup: + def __iter__(self): + raise RuntimeError("broken response") + + events = [] + callback = _TurnEventCallback( + lambda event, payload: events.append((event, payload)), + "thread-1", + ) + + callback.on_llm_end(SimpleNamespace(generations=[BrokenGenerationGroup()])) + + assert [event for event, _payload in events] == ["llm_call_finished"] From 83bcd7a60c66c03e7116eea60935b62231c00b3f Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 17:24:50 -0500 Subject: [PATCH 071/119] fix(mcp): isolate hpc backend workers PicklingError: Can't pickle run_mace_singleArguments was triggered by dill recursing into the worker's module globals and hitting a FastMCP-generated Pydantic class. Fix by isolating backend worker functions in modules with no FastMCP state so Parsl can serialize import-clean workers for MACE, XANES, and gRASPA. hpc_misc_mcp.py was inspected and left unchanged because it has no backend worker callables. --- src/chemgraph/mcp/graspa_mcp_hpc.py | 52 +---------------- src/chemgraph/mcp/graspa_worker.py | 54 ++++++++++++++++++ src/chemgraph/mcp/mace_mcp_hpc.py | 74 +----------------------- src/chemgraph/mcp/mace_worker.py | 72 +++++++++++++++++++++++ src/chemgraph/mcp/xanes_mcp_hpc.py | 88 ++--------------------------- src/chemgraph/mcp/xanes_worker.py | 80 ++++++++++++++++++++++++++ tests/test_mcp.py | 29 +++++++++- 7 files changed, 239 insertions(+), 210 deletions(-) create mode 100644 src/chemgraph/mcp/graspa_worker.py create mode 100644 src/chemgraph/mcp/mace_worker.py create mode 100644 src/chemgraph/mcp/xanes_worker.py diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index be7737a6..d1af7343 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -14,7 +14,6 @@ """ import logging -import os from pathlib import Path from chemgraph.execution.base import TaskSpec @@ -24,6 +23,7 @@ resolve_structure_files, ) from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.graspa_worker import _graspa_worker, _ls_remote_files from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble @@ -64,59 +64,9 @@ ) -# ── Worker (runs on the backend) ─────────────────────────────────────── - - -def _graspa_worker(job: dict) -> dict: - """Execute a single gRASPA simulation on a backend worker.""" - from chemgraph.schemas.graspa_schema import graspa_input_schema - from chemgraph.tools.graspa_tools import run_graspa_core - - job = dict(job) - structure = job.pop("_structure_name", None) - temperature = job.get("temperature") - pressure = job.get("pressure") - - remote_file = job.pop("remote_structure_file", None) - if remote_file is not None: - job["input_structure_file"] = remote_file - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - os.path.dirname(remote_file), - job.get("output_result_file", "raspa.log"), - ) - - params = graspa_input_schema(**job) - result = run_graspa_core(params) - - if isinstance(result, dict): - merged = { - "structure": structure, - "temperature": temperature, - "pressure": pressure, - **result, - } - merged.setdefault("status", "success") - return merged - return { - "structure": structure, - "temperature": temperature, - "pressure": pressure, - "result": result, - "status": "success", - } - - # ── Ensemble fanout ──────────────────────────────────────────────────── -def _ls_remote_files(path: str) -> list[str]: - """Backend-side helper: list non-directory entries in *path*.""" - return sorted( - f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) - ) - - def _expand_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: """Server-side expansion of an ensemble request into per-job dicts. diff --git a/src/chemgraph/mcp/graspa_worker.py b/src/chemgraph/mcp/graspa_worker.py new file mode 100644 index 00000000..2e26cd5e --- /dev/null +++ b/src/chemgraph/mcp/graspa_worker.py @@ -0,0 +1,54 @@ +"""Backend worker functions for gRASPA MCP tools. + +This module intentionally contains no FastMCP/CGFastMCP objects or tool +decorators, keeping worker functions safe for Parsl/dill serialization. +""" + +import os + + +def _graspa_worker(job: dict) -> dict: + """Execute a single gRASPA simulation on a backend worker.""" + from chemgraph.schemas.graspa_schema import graspa_input_schema + from chemgraph.tools.graspa_tools import run_graspa_core + + job = dict(job) + structure = job.pop("_structure_name", None) + temperature = job.get("temperature") + pressure = job.get("pressure") + + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "raspa.log"), + ) + + params = graspa_input_schema(**job) + result = run_graspa_core(params) + + if isinstance(result, dict): + merged = { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + **result, + } + merged.setdefault("status", "success") + return merged + return { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + "result": result, + "status": "success", + } + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index b93dcd4c..8f1c715b 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -27,12 +27,13 @@ resolve_structure_files, ) from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.mace_worker import _ls_remote_files, _mace_worker from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.mace_parsl_schema import ( mace_input_schema, mace_input_schema_ensemble, ) -from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core +from chemgraph.tools.parsl_tools import extract_output_json logger = logging.getLogger(__name__) @@ -73,70 +74,6 @@ ) -# ── Worker (runs on the backend) ─────────────────────────────────────── - - -def _mace_worker(job: dict) -> dict: - """Execute a single MACE simulation on a backend worker. - - Accepts a *job dict* (not the schema) so the pre-submit hook can - attach transport keys ``inline_structure`` / ``remote_structure_file`` - before submission. - """ - import json - import tempfile - - job = dict(job) - - # Pre-staged remote file: use the path directly on the worker FS. - remote_file = job.pop("remote_structure_file", None) - if remote_file is not None: - job["input_structure_file"] = remote_file - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - os.path.dirname(remote_file), - job.get("output_result_file", "output.json"), - ) - - # Inline structure: materialise on the worker's filesystem. - inline = job.pop("inline_structure", None) - if inline is not None: - from ase import Atoms - from ase.io import write as ase_write - - atoms = Atoms( - numbers=inline["numbers"], - positions=inline["positions"], - cell=inline.get("cell"), - pbc=inline.get("pbc"), - ) - tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") - xyz_path = os.path.join(tmpdir, "structure.xyz") - ase_write(xyz_path, atoms) - job["input_structure_file"] = xyz_path - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - tmpdir, job.get("output_result_file", "output.json") - ) - - output_file = job.get("output_result_file") - if output_file: - os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) - - params = mace_input_schema(**job) - result = run_mace_core(params) - - # When inline, embed full output so the caller doesn't need to read - # a file on the remote filesystem to recover the results. - if inline is not None and isinstance(result, dict): - out_file = job.get("output_result_file", "") - if os.path.isfile(out_file): - with open(out_file) as fh: - result["full_output"] = json.load(fh) - - return result - - # ── Pre-submit transport hook ────────────────────────────────────────── @@ -214,13 +151,6 @@ def run_mace_single(params: mace_input_schema) -> dict: # ── Ensemble fanout ──────────────────────────────────────────────────── -def _ls_remote_files(path: str) -> list[str]: - """Backend-side helper: list non-directory entries in *path*.""" - return sorted( - f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) - ) - - def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: """Server-side expansion of an ensemble request into per-file jobs. diff --git a/src/chemgraph/mcp/mace_worker.py b/src/chemgraph/mcp/mace_worker.py new file mode 100644 index 00000000..b4b00b23 --- /dev/null +++ b/src/chemgraph/mcp/mace_worker.py @@ -0,0 +1,72 @@ +"""Backend worker functions for MACE MCP tools. + +This module intentionally contains no FastMCP/CGFastMCP objects or tool +decorators. Parsl/dill serializes worker functions by walking their module +globals, so backend workers must live outside modules that contain FastMCP's +runtime-generated argument classes. +""" + +import os + + +def _mace_worker(job: dict) -> dict: + """Execute a single MACE simulation on a backend worker.""" + import json + import tempfile + + from chemgraph.schemas.mace_parsl_schema import mace_input_schema + from chemgraph.tools.parsl_tools import run_mace_core + + job = dict(job) + + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "output.json"), + ) + + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, + job.get("output_result_file", "output.json"), + ) + + output_file = job.get("output_result_file") + if output_file: + os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) + + params = mace_input_schema(**job) + result = run_mace_core(params) + + if inline is not None and isinstance(result, dict): + out_file = job.get("output_result_file", "") + if os.path.isfile(out_file): + with open(out_file, encoding="utf-8") as fh: + result["full_output"] = json.load(fh) + + return result + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 8583ae65..c9394c31 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -17,16 +17,15 @@ """ import logging -import subprocess from pathlib import Path from chemgraph.execution.config import get_transfer_manager from chemgraph.execution.utils import resolve_structure_files from chemgraph.mcp.cg_fastmcp import CGFastMCP from chemgraph.mcp.transfer_tools import register_transfer_tools +from chemgraph.mcp.xanes_worker import _xanes_ensemble_worker, run_xanes_single from chemgraph.schemas.xanes_schema import ( mp_query_schema, - xanes_input_schema, xanes_input_schema_ensemble, ) @@ -66,96 +65,17 @@ ) -# ── Single-structure tool ────────────────────────────────────────────── - - -def _xanes_single_worker(params: xanes_input_schema) -> dict: - """Run a single FDMNES calculation on a backend worker.""" - from chemgraph.tools.xanes_tools import run_xanes_core - - result = run_xanes_core(params) - if isinstance(result, dict): - result.setdefault("status", "success") - return result - return {"status": "success", "result": result} - - -@mcp.tool( +mcp.tool( name="run_xanes_single", description="Run a single XANES/FDMNES calculation for one input structure.", +)( + run_xanes_single ) -def run_xanes_single(params: xanes_input_schema): - """Run a single FDMNES calculation using the core engine. - - The CGFastMCP wrapper submits this call to the configured backend; - the body is the direct-call fallback when no backend is active. - """ - return _xanes_single_worker(params) # ── Ensemble fanout ──────────────────────────────────────────────────── -def _xanes_ensemble_worker(item: dict) -> dict: - """Execute one prepared FDMNES run on the backend. - - The expander has already written ``input_fdmnes.txt`` (or the - equivalent) into ``item['run_dir']``; this worker runs the binary - via subprocess and then extracts convergence data. - """ - from chemgraph.tools.xanes_tools import extract_conv - - run_dir = item["run_dir"] - fdmnes_exe = item["fdmnes_exe"] - meta = { - "structure": item.get("structure"), - "run_dir": run_dir, - "z_absorber": item.get("z_absorber"), - } - - stdout_path = Path(run_dir) / "fdmnes_stdout.txt" - stderr_path = Path(run_dir) / "fdmnes_stderr.txt" - try: - with open(stdout_path, "w") as out, open(stderr_path, "w") as err: - proc = subprocess.run( - [fdmnes_exe], - cwd=run_dir, - stdout=out, - stderr=err, - check=False, - ) - if proc.returncode != 0: - return { - **meta, - "status": "failure", - "error_type": "FDMNESExitCode", - "message": f"FDMNES exited with code {proc.returncode}", - "returncode": proc.returncode, - } - except Exception as e: - return { - **meta, - "status": "failure", - "error_type": type(e).__name__, - "message": f"FDMNES launch failed: {e}", - } - - try: - conv_data = extract_conv(run_dir) - return { - **meta, - "status": "success", - "n_conv_files": len(conv_data), - } - except Exception as e: - return { - **meta, - "status": "failure", - "error_type": type(e).__name__, - "message": f"Post-processing failed: {e}", - } - - def _expand_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: """Server-side expansion: prepare per-structure run dirs and return one item per structure for the worker to execute.""" diff --git a/src/chemgraph/mcp/xanes_worker.py b/src/chemgraph/mcp/xanes_worker.py new file mode 100644 index 00000000..ce15cd6a --- /dev/null +++ b/src/chemgraph/mcp/xanes_worker.py @@ -0,0 +1,80 @@ +"""Backend worker functions for XANES MCP tools. + +This module intentionally contains no FastMCP/CGFastMCP objects or tool +decorators, keeping worker functions safe for Parsl/dill serialization. +""" + +import subprocess +from pathlib import Path + +from chemgraph.schemas.xanes_schema import xanes_input_schema + + +def run_xanes_single(params: xanes_input_schema) -> dict: + """Run a single FDMNES calculation on a backend worker.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + result = run_xanes_core(params) + if isinstance(result, dict): + result.setdefault("status", "success") + return result + return {"status": "success", "result": result} + + +def _xanes_ensemble_worker(item: dict) -> dict: + """Execute one prepared FDMNES run on the backend.""" + from chemgraph.tools.xanes_tools import extract_conv + + run_dir = item["run_dir"] + fdmnes_exe = item["fdmnes_exe"] + meta = { + "structure": item.get("structure"), + "run_dir": run_dir, + "z_absorber": item.get("z_absorber"), + } + + stdout_path = Path(run_dir) / "fdmnes_stdout.txt" + stderr_path = Path(run_dir) / "fdmnes_stderr.txt" + try: + with open(stdout_path, "w", encoding="utf-8") as out, open( + stderr_path, + "w", + encoding="utf-8", + ) as err: + proc = subprocess.run( + [fdmnes_exe], + cwd=run_dir, + stdout=out, + stderr=err, + check=False, + ) + if proc.returncode != 0: + return { + **meta, + "status": "failure", + "error_type": "FDMNESExitCode", + "message": f"FDMNES exited with code {proc.returncode}", + "returncode": proc.returncode, + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"FDMNES launch failed: {e}", + } + + try: + conv_data = extract_conv(run_dir) + return { + **meta, + "status": "success", + "n_conv_files": len(conv_data), + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"Post-processing failed: {e}", + } diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b4615871..51a42f4d 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -48,7 +48,8 @@ def fanout(params: dict) -> list[dict]: def test_mace_worker_creates_inline_output_parent(monkeypatch): from ase import Atoms - from chemgraph.mcp import mace_mcp_hpc + from chemgraph.mcp import mace_worker + from chemgraph.tools import parsl_tools from chemgraph.tools.ase_core import atoms_to_atomsdata atoms = Atoms(numbers=[1, 1], positions=[[0, 0, 0], [0, 0, 0.74]]) @@ -60,9 +61,9 @@ def fake_run_mace_core(params): output_path.write_text('{"ok": true}', encoding="utf-8") return {"status": "success"} - monkeypatch.setattr(mace_mcp_hpc, "run_mace_core", fake_run_mace_core) + monkeypatch.setattr(parsl_tools, "run_mace_core", fake_run_mace_core) - result = mace_mcp_hpc._mace_worker( + result = mace_worker._mace_worker( { "inline_structure": atoms_to_atomsdata(atoms).model_dump(), "output_result_file": output_file, @@ -76,6 +77,28 @@ def fake_run_mace_core(params): assert result["full_output"] == {"ok": True} +def test_hpc_worker_functions_are_dill_picklable(): + dill = pytest.importorskip("dill") + + from chemgraph.mcp.graspa_worker import ( + _graspa_worker, + _ls_remote_files as _graspa_ls_remote_files, + ) + from chemgraph.mcp.mace_worker import _ls_remote_files as _mace_ls_remote_files + from chemgraph.mcp.mace_worker import _mace_worker + from chemgraph.mcp.xanes_worker import _xanes_ensemble_worker, run_xanes_single + + for worker in ( + _mace_worker, + _mace_ls_remote_files, + run_xanes_single, + _xanes_ensemble_worker, + _graspa_worker, + _graspa_ls_remote_files, + ): + dill.dumps(worker) + + @pytest.mark.asyncio async def test_split_cif_dataset(tmp_path): """Test splitting a dataset of CIF files.""" From 4441d98533a5b595c43955c1333209763df8c0e4 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 18:39:56 -0500 Subject: [PATCH 072/119] refactor(academy): move packaged campaigns out of examples --- .../README.md | 35 +++++++++++++++++++ .../notes.md | 22 ++++++++++++ pyproject.toml | 3 +- .../{examples => campaigns}/__init__.py | 34 +++++++++--------- .../campaign.json} | 0 .../data/mace_screening_20_smiles.json | 0 .../lm_config.json} | 0 .../prompt_profiles/default.json | 0 src/chemgraph/academy/core/campaign.py | 4 +-- .../academy/runtime/compute_launcher.py | 12 +++---- src/chemgraph/academy/runtime/daemon.py | 4 +-- .../academy/runtime/dashboard_launcher.py | 2 +- src/chemgraph/cli/main.py | 6 ++-- 13 files changed, 90 insertions(+), 32 deletions(-) create mode 100644 examples/academy/example-002-mace-ensemble-screening/README.md create mode 100644 examples/academy/example-002-mace-ensemble-screening/notes.md rename src/chemgraph/academy/{examples => campaigns}/__init__.py (52%) rename src/chemgraph/academy/{examples/example-002-mace-ensemble-screening/campaign.jsonc => campaigns/example-002-mace-ensemble-screening/campaign.json} (100%) rename src/chemgraph/academy/{examples => campaigns}/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json (100%) rename src/chemgraph/academy/{examples/example-002-mace-ensemble-screening/lm_config.template.json => campaigns/example-002-mace-ensemble-screening/lm_config.json} (100%) rename src/chemgraph/academy/{examples => campaigns}/example-002-mace-ensemble-screening/prompt_profiles/default.json (100%) diff --git a/examples/academy/example-002-mace-ensemble-screening/README.md b/examples/academy/example-002-mace-ensemble-screening/README.md new file mode 100644 index 00000000..bd6f7104 --- /dev/null +++ b/examples/academy/example-002-mace-ensemble-screening/README.md @@ -0,0 +1,35 @@ +# Example 002: MACE Ensemble Screening + +This example demonstrates five persistent ChemGraph Academy logical agents +running under MPI: + +```text +coordinator-agent +structure-agent-a +structure-agent-b +mace-agent +assessment-agent +``` + +The coordinator delegates 20 SMILES candidates, structure agents generate XYZ +files, the MACE agent runs an ensemble energy screen, and the assessment agent +summarizes readiness/ranking evidence. + +The campaign assets are packaged under: + +```text +src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/ +``` + +Run it by campaign name: + +```bash +chemgraph academy run-compute \ + --system aurora \ + --run-id aurora-mace-ensemble-screening-001 \ + --campaign mace-ensemble-screening-20 \ + --lm-user +``` + +See `notes.md` for the high-level architecture notes. The internal E2E user +guide is intentionally not stored in this public example directory. diff --git a/examples/academy/example-002-mace-ensemble-screening/notes.md b/examples/academy/example-002-mace-ensemble-screening/notes.md new file mode 100644 index 00000000..bff829a5 --- /dev/null +++ b/examples/academy/example-002-mace-ensemble-screening/notes.md @@ -0,0 +1,22 @@ +# Notes + +This root example directory is for user-facing explanation only. The CLI loads +the actual campaign from package data so installed ChemGraph environments can +run the same campaign without relying on a source checkout's root `examples/` +directory. + +Packaged assets: + +```text +src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/ + campaign.json + lm_config.json + prompt_profiles/ + data/ + models/ +``` + +The campaign declares MCP server subprocesses for general ChemGraph tools, MACE +screening, and HPC utility inspection. The Academy runtime places one logical +agent per MPI rank, launches the declared MCP servers for each agent, and uses +Academy exchange handles for peer communication. diff --git a/pyproject.toml b/pyproject.toml index 7ea2ed94..f6db21fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,10 +103,11 @@ where = ["src/"] [tool.setuptools.package-data] "chemgraph.eval" = ["data/*.json"] -"chemgraph.academy.examples" = [ +"chemgraph.academy.campaigns" = [ "example-*/*.json", "example-*/*.jsonc", "example-*/data/*.json", + "example-*/models/*", "example-*/prompt_profiles/*.json", ] "chemgraph.academy.runtime.profiles" = ["*.json"] diff --git a/src/chemgraph/academy/examples/__init__.py b/src/chemgraph/academy/campaigns/__init__.py similarity index 52% rename from src/chemgraph/academy/examples/__init__.py rename to src/chemgraph/academy/campaigns/__init__.py index 52a4ba45..2b01a733 100644 --- a/src/chemgraph/academy/examples/__init__.py +++ b/src/chemgraph/academy/campaigns/__init__.py @@ -7,18 +7,18 @@ EXAMPLE_002 = 'example-002-mace-ensemble-screening' -BUILTIN_CAMPAIGNS = { - 'mace-ensemble-screening-20': f'{EXAMPLE_002}/campaign.jsonc', +CAMPAIGNS = { + 'mace-ensemble-screening-20': f'{EXAMPLE_002}/campaign.json', } -BUILTIN_LM_CONFIG_TEMPLATES = { - 'argo-gpt54-mace-template': f'{EXAMPLE_002}/lm_config.template.json', +LM_CONFIG_TEMPLATES = { + 'argo-gpt54-mace-template': f'{EXAMPLE_002}/lm_config.json', } @dataclasses.dataclass(frozen=True) class CampaignLaunchDefaults: - """Runtime defaults for a built-in ChemGraph Academy campaign.""" + """Runtime defaults for a packaged ChemGraph Academy campaign.""" lm_config_template: str agent_count: int @@ -26,7 +26,7 @@ class CampaignLaunchDefaults: max_decisions: int -BUILTIN_CAMPAIGN_LAUNCH_DEFAULTS = { +CAMPAIGN_LAUNCH_DEFAULTS = { 'mace-ensemble-screening-20': CampaignLaunchDefaults( lm_config_template='argo-gpt54-mace-template', agent_count=5, @@ -36,36 +36,36 @@ class CampaignLaunchDefaults: } -def _resolve_builtin( +def _resolve_campaign_asset( path_or_name: str | Path, - builtins: dict[str, str], + known_assets: dict[str, str], ) -> Path: value = str(path_or_name) path = Path(value) if path.exists(): return path.resolve() - relative = builtins.get(value) + relative = known_assets.get(value) if relative is None: return path return Path(str(resources.files(__package__).joinpath(relative))) -def resolve_builtin_campaign(path_or_name: str | Path) -> Path: - return _resolve_builtin(path_or_name, BUILTIN_CAMPAIGNS) +def resolve_campaign(path_or_name: str | Path) -> Path: + return _resolve_campaign_asset(path_or_name, CAMPAIGNS) -def resolve_builtin_lm_config_template(path_or_name: str | Path) -> Path: - return _resolve_builtin(path_or_name, BUILTIN_LM_CONFIG_TEMPLATES) +def resolve_lm_config_template(path_or_name: str | Path) -> Path: + return _resolve_campaign_asset(path_or_name, LM_CONFIG_TEMPLATES) -def list_builtin_campaigns() -> list[str]: - return sorted(BUILTIN_CAMPAIGNS) +def list_campaigns() -> list[str]: + return sorted(CAMPAIGNS) def campaign_launch_defaults(campaign: str) -> CampaignLaunchDefaults: try: - return BUILTIN_CAMPAIGN_LAUNCH_DEFAULTS[campaign] + return CAMPAIGN_LAUNCH_DEFAULTS[campaign] except KeyError as exc: raise KeyError( - f'No built-in launch defaults for campaign {campaign!r}', + f'No launch defaults for campaign {campaign!r}', ) from exc diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json similarity index 100% rename from src/chemgraph/academy/examples/example-002-mace-ensemble-screening/campaign.jsonc rename to src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json similarity index 100% rename from src/chemgraph/academy/examples/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json rename to src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/lm_config.template.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/lm_config.json similarity index 100% rename from src/chemgraph/academy/examples/example-002-mace-ensemble-screening/lm_config.template.json rename to src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/lm_config.json diff --git a/src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/prompt_profiles/default.json similarity index 100% rename from src/chemgraph/academy/examples/example-002-mace-ensemble-screening/prompt_profiles/default.json rename to src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/prompt_profiles/default.json diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py index a08e3ad4..d2fcd14c 100644 --- a/src/chemgraph/academy/core/campaign.py +++ b/src/chemgraph/academy/core/campaign.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import Any -from chemgraph.academy.examples import resolve_builtin_campaign +from chemgraph.academy.campaigns import resolve_campaign from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -217,7 +217,7 @@ def _resolve_resource_spec( def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: - path = resolve_builtin_campaign(path) + path = resolve_campaign(path) data = _load_jsonc(path) _reject_removed_campaign_fields(data, campaign_path=path) prompt_profile = _resolve_campaign_relative_path( diff --git a/src/chemgraph/academy/runtime/compute_launcher.py b/src/chemgraph/academy/runtime/compute_launcher.py index d96ee263..3ba9ad41 100644 --- a/src/chemgraph/academy/runtime/compute_launcher.py +++ b/src/chemgraph/academy/runtime/compute_launcher.py @@ -12,9 +12,9 @@ from pathlib import Path from typing import Any -from chemgraph.academy.examples import campaign_launch_defaults -from chemgraph.academy.examples import resolve_builtin_campaign -from chemgraph.academy.examples import resolve_builtin_lm_config_template +from chemgraph.academy.campaigns import campaign_launch_defaults +from chemgraph.academy.campaigns import resolve_campaign +from chemgraph.academy.campaigns import resolve_lm_config_template from chemgraph.academy.runtime.profiles import list_builtin_system_profiles from chemgraph.academy.runtime.profiles import load_system_profile from chemgraph.academy.runtime.profiles.system import SystemProfile @@ -53,7 +53,7 @@ class AllocationPlan: def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description=( - "Run a built-in ChemGraph Academy campaign inside the current " + "Run a ChemGraph Academy campaign inside the current " "HPC compute allocation." ), ) @@ -154,7 +154,7 @@ def _write_lm_config( lm_user: str | None, max_tokens: int | None, ) -> Path: - template_path = resolve_builtin_lm_config_template(template_name) + template_path = resolve_lm_config_template(template_name) data = json.loads(template_path.read_text(encoding="utf-8")) if not isinstance(data, dict): raise RuntimeError(f"LM template must contain a JSON object: {template_path}") @@ -224,7 +224,7 @@ def prepare_compute_launch(args: argparse.Namespace) -> AllocationPlan: max_decisions = args.max_decisions or defaults.max_decisions redis_port = args.redis_port or profile.redis_port - campaign_config = resolve_builtin_campaign(args.campaign) + campaign_config = resolve_campaign(args.campaign) if not campaign_config.exists(): campaign_config = Path(args.campaign).resolve() diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index 15094b3e..cff65bd8 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -26,7 +26,7 @@ from chemgraph.academy.core.campaign import resolve_campaign_resources from chemgraph.academy.core.campaign import selected_agent from chemgraph.academy.core.campaign import validate_campaign -from chemgraph.academy.examples import resolve_builtin_campaign +from chemgraph.academy.campaigns import resolve_campaign from chemgraph.academy.runtime.mpi import append_system_trace from chemgraph.academy.runtime.mpi import local_rank_from_env from chemgraph.academy.runtime.mpi import placement_payload @@ -220,7 +220,7 @@ def parse_args() -> argparse.Namespace: def config_from_args(args: argparse.Namespace) -> ChemGraphDaemonConfig: run_dir = pathlib.Path(args.run_dir).resolve() - resolved_campaign = resolve_builtin_campaign(args.campaign_config) + resolved_campaign = resolve_campaign(args.campaign_config) campaign_config = ( resolved_campaign.resolve() if resolved_campaign.exists() diff --git a/src/chemgraph/academy/runtime/dashboard_launcher.py b/src/chemgraph/academy/runtime/dashboard_launcher.py index b38ac733..1869a1e1 100644 --- a/src/chemgraph/academy/runtime/dashboard_launcher.py +++ b/src/chemgraph/academy/runtime/dashboard_launcher.py @@ -10,7 +10,7 @@ from pathlib import Path from chemgraph.academy.dashboard import serve_dashboard -from chemgraph.academy.examples import campaign_launch_defaults +from chemgraph.academy.campaigns import campaign_launch_defaults from chemgraph.academy.runtime.profiles import list_builtin_system_profiles from chemgraph.academy.runtime.profiles import load_system_profile from chemgraph.academy.runtime.profiles.system import SystemProfile diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index 94af80fc..3fb6c0b4 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -287,7 +287,7 @@ def create_argument_parser() -> argparse.ArgumentParser: academy_sub.add_parser( "campaigns", - help="List built-in ChemGraph Academy campaign specs.", + help="List ChemGraph Academy campaign specs.", ) # ---- Legacy fallback args ------------------------------------------- @@ -580,9 +580,9 @@ def _handle_academy(args: argparse.Namespace) -> None: sys.exit(code) return if command == "campaigns": - from chemgraph.academy.examples import list_builtin_campaigns + from chemgraph.academy.campaigns import list_campaigns - for name in list_builtin_campaigns(): + for name in list_campaigns(): console.print(name) return console.print( From f5d2f4d9fabe4f1ab6d5172ba788592730c2ef6b Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Wed, 10 Jun 2026 18:41:06 -0500 Subject: [PATCH 073/119] feat(cli): --trace-dir option for traditional dashboard view --- src/chemgraph/agent/llm_agent.py | 3 + src/chemgraph/cli/commands.py | 2 + src/chemgraph/cli/main.py | 53 +++++++++++++- src/chemgraph/cli/trace.py | 114 +++++++++++++++++++++++++++++++ 4 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 src/chemgraph/cli/trace.py diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 1b08541c..6de5f7fc 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -747,6 +747,7 @@ def __init__( human_input_handler: Optional[Callable[[str], str]] = None, human_supervised: bool = False, terminal_tool_names: Collection[str] = (), + on_event: Optional[EventCallback] = None, ): """Initialize a ChemGraph workflow instance. @@ -950,6 +951,7 @@ def __init__( self.human_input_handler = human_input_handler self.human_supervised = human_supervised self.terminal_tool_names = tuple(terminal_tool_names) + self.on_event = on_event self._last_run_state: dict[str, Any] | None = None # When human supervision is disabled and the caller is using the @@ -1468,6 +1470,7 @@ async def run( thread_id=thread_id, terminal_tool_names=self.terminal_tool_names, human_supervised=self.human_supervised, + on_event=self.on_event, ) self._last_run_state = result.state self._save_messages_to_store(result.state, query) diff --git a/src/chemgraph/cli/commands.py b/src/chemgraph/cli/commands.py index abbd0fff..70934046 100644 --- a/src/chemgraph/cli/commands.py +++ b/src/chemgraph/cli/commands.py @@ -177,6 +177,7 @@ def initialize_agent( verbose: bool = False, human_supervised: bool = False, tools: Optional[list] = None, + on_event: Optional[Any] = None, ) -> Any: """Initialize a ChemGraph agent with progress indication. @@ -280,6 +281,7 @@ def _create_agent() -> Any: structured_output=structured_output, human_supervised=human_supervised, tools=tools, + on_event=on_event, ) try: diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index 3fb6c0b4..5d1222d9 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -173,6 +173,16 @@ def _add_run_args(parser: argparse.ArgumentParser) -> None: default="ChemGraph General Tools", help="Display name for the MCP server connection (default: 'ChemGraph General Tools')", ) + parser.add_argument( + "--trace-dir", + type=str, + default=None, + help=( + "Write per-run events to this directory so the run is viewable " + "via 'chemgraph dashboard -- --run-dir '. " + "Currently only effective for single-agent workflows." + ), + ) def create_argument_parser() -> argparse.ArgumentParser: @@ -493,6 +503,33 @@ def _handle_run(args: argparse.Namespace) -> None: # Show banner console.print(create_banner()) + # ---- Optional run trace for the local dashboard -------------------- + trace = None + trace_dir = getattr(args, "trace_dir", None) or config.get("trace_dir") + if trace_dir: + from pathlib import Path + + from chemgraph.cli.trace import CLIRunTrace + + if args.workflow != "single_agent": + console.print( + "[yellow]--trace-dir is currently only effective for the " + "single_agent workflow; events will not be written for " + f"{args.workflow!r}.[/yellow]" + ) + else: + trace = CLIRunTrace( + Path(trace_dir), + model_name=args.model, + workflow_type=args.workflow, + query=args.query, + ) + trace.start() + console.print( + f"[dim]Tracing run to {trace.trace_dir}. " + f"View with: chemgraph dashboard -- --run-dir {trace.trace_dir}[/dim]" + ) + # Initialize agent agent = initialize_agent( args.model, @@ -506,18 +543,28 @@ def _handle_run(args: argparse.Namespace) -> None: verbose=(args.verbose > 0), human_supervised=args.human_supervised, tools=mcp_tools, + on_event=trace.on_event if trace else None, ) if not agent: + if trace is not None: + trace.finish(status="failed", error="agent_initialization_failed") sys.exit(1) # Execute query console.print(f"[bold blue]Query:[/bold blue] {args.query}") if args.resume: console.print(f"[bold blue]Resuming from:[/bold blue] {args.resume}") - result = run_query( - agent, args.query, verbose=(args.verbose > 0), resume_from=args.resume - ) + try: + result = run_query( + agent, args.query, verbose=(args.verbose > 0), resume_from=args.resume + ) + except Exception: + if trace is not None: + trace.finish(status="failed") + raise + if trace is not None: + trace.finish(status="completed") if result: format_response(result, verbose=(args.verbose > 0)) diff --git a/src/chemgraph/cli/trace.py b/src/chemgraph/cli/trace.py new file mode 100644 index 00000000..47d5bf5e --- /dev/null +++ b/src/chemgraph/cli/trace.py @@ -0,0 +1,114 @@ +"""Trace writer for traditional ChemGraph CLI runs. + +Bridges the `run_turn` event callback into the dashboard's on-disk +schema (`events.jsonl` + `status.json` + `manifest.json`), so the +existing ``chemgraph dashboard`` browser UI can render a single-agent +ChemGraph run without going through the Academy daemon path. +""" + +from __future__ import annotations + +import time +from pathlib import Path + +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.run_files import write_json_atomic + + +_AGENT_ID = "chemgraph" +_AGENT_ROLE = "single_agent" + + +class CLIRunTrace: + """Writer for a single traditional ChemGraph run. + + Produces the on-disk layout the dashboard expects: + + :: + + /events.jsonl + /status.json + /manifest.json + + The ``status.json.mode`` field is ``"chemgraph_workflow"`` so the + dashboard renders the per-agent workflow inspector (the "inner tab" + you'd see if you clicked a logical-agent node in an Academy run). + """ + + def __init__( + self, + trace_dir: Path, + *, + run_id: str | None = None, + model_name: str | None = None, + workflow_type: str | None = None, + query: str | None = None, + ) -> None: + self.trace_dir = Path(trace_dir) + self.run_id = run_id or self.trace_dir.name + self.model_name = model_name + self.workflow_type = workflow_type + self.query = query + self._log = EventLog(self.trace_dir / "events.jsonl") + + def start(self) -> None: + """Initialise the run directory and write the static metadata.""" + self.trace_dir.mkdir(parents=True, exist_ok=True) + write_json_atomic( + self.trace_dir / "manifest.json", + { + "mode": "chemgraph_workflow", + "run_id": self.run_id, + "model": self.model_name, + "workflow_type": self.workflow_type, + }, + ) + self._write_status() + self._log.emit( + "run_started", + run_id=self.run_id, + agent_id=_AGENT_ID, + role=_AGENT_ROLE, + payload={ + "model": self.model_name, + "workflow_type": self.workflow_type, + "query": self.query, + }, + ) + + def finish(self, *, status: str, error: str | None = None) -> None: + """Mark the run as completed and refresh ``status.json``.""" + self._log.emit( + "run_finished", + run_id=self.run_id, + agent_id=_AGENT_ID, + role=_AGENT_ROLE, + payload={"status": status, "error": error} if error else {"status": status}, + ) + self._write_status() + + def on_event(self, event: str, payload: dict) -> None: + """Callback handed to :func:`chemgraph.agent.llm_agent.run_turn`.""" + self._log.emit( + event, # type: ignore[arg-type] + run_id=self.run_id, + agent_id=_AGENT_ID, + role=_AGENT_ROLE, + payload=payload, + ) + + def _write_status(self) -> None: + write_json_atomic( + self.trace_dir / "status.json", + { + "mode": "chemgraph_workflow", + "updated": time.time(), + "agents": [ + { + "agent_id": _AGENT_ID, + "agent_name": _AGENT_ID, + "role": _AGENT_ROLE, + }, + ], + }, + ) From 41a3f2242bf2663cb5fdf9d9345974935636cef9 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 11:06:08 -0500 Subject: [PATCH 074/119] docs(example-002): add sanitized e2e guide --- .../e2e_guide.md | 294 ++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 examples/academy/example-002-mace-ensemble-screening/e2e_guide.md diff --git a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md new file mode 100644 index 00000000..8b9e9a1c --- /dev/null +++ b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md @@ -0,0 +1,294 @@ +# Example 002 E2E Guide + +This guide runs the `mace-ensemble-screening-20` ChemGraph Academy campaign on +Aurora or Polaris. The campaign starts five persistent logical agents under MPI: + +```text +coordinator-agent +structure-agent-a +structure-agent-b +mace-agent +assessment-agent +``` + +The coordinator delegates 20 SMILES candidates, structure agents generate XYZ +files, the MACE agent runs an ensemble energy screen, and the assessment agent +summarizes readiness/ranking evidence. + +## Configure Paths + +Set these values in each terminal before copying the commands below: + +```bash +export ALCF_PROJECT= +export ALCF_USER= +export ALCF_LOGIN= +export ARGO_USER= + +export LOCAL_CHEMGRAPH= +``` + +For Aurora: + +```bash +export ALCF_SYSTEM=aurora +export ALCF_HOST=aurora.alcf.anl.gov +export REMOTE_ROOT=/flare/$ALCF_PROJECT/$ALCF_USER +``` + +For Polaris: + +```bash +export ALCF_SYSTEM=polaris +export ALCF_HOST=polaris.alcf.anl.gov +export REMOTE_ROOT=/eagle/$ALCF_PROJECT/$ALCF_USER +``` + +`ALCF_USER` is the shared-filesystem path component. It may differ from the SSH +login and from the Argo user. + +## One-Time Setup + +Sync ChemGraph: + +```bash +cd "$LOCAL_CHEMGRAPH" + +rsync -az --delete --delete-excluded \ + --exclude '.git/' \ + --exclude '__pycache__/' \ + --exclude '.pytest_cache/' \ + --exclude 'runs/' \ + --exclude 'venvs/' \ + --exclude '*.pyc' \ + ./ \ + "$ALCF_LOGIN@$ALCF_HOST:$REMOTE_ROOT/ChemGraph/" +``` + +Install ChemGraph dependencies on the remote system: + +```bash +ssh "$ALCF_LOGIN@$ALCF_HOST" +cd "$REMOTE_ROOT/ChemGraph" + +# Aurora: +module load frameworks + +# Polaris: +# module use /soft/modulefiles +# module load conda +# conda activate base + +source "$REMOTE_ROOT/venvs/academy-swarm/bin/activate" +python -m pip install -e ".[academy,parsl]" +``` + +Verify the campaign is visible: + +```bash +PYTHONDONTWRITEBYTECODE=1 PYTHONPATH=src \ +python -m chemgraph.cli.main academy campaigns +``` + +Expected: + +```text +mace-ensemble-screening-20 +``` + +Verify Redis: + +```bash +export PATH="$REMOTE_ROOT/tools/redis/bin:$PATH" +command -v redis-server +redis-server --version +``` + +If Redis is missing, build it once on a login/UAN node: + +```bash +cd "$REMOTE_ROOT" +mkdir -p src tools +cd src +test -d redis || git clone --depth 1 https://github.com/redis/redis.git +cd redis +make -j4 +make PREFIX="$REMOTE_ROOT/tools/redis" install +``` + +Stage the MACE model: + +```bash +cd "$LOCAL_CHEMGRAPH" + +MODEL=src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/models/mace-mpa-0-medium.model +URL=https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model + +mkdir -p "$(dirname "$MODEL")" +test -f "$MODEL" || curl -L --fail -o "$MODEL" "$URL" +ls -lh "$MODEL" +``` + +Then sync ChemGraph again. + +## Start argo-shim + +On the local machine: + +```bash +CELS_USERNAME="$ARGO_USER" \ +PYTHONPATH= \ +python -m argo_shim --no-auth --no-update-settings --port 18085 +``` + +## Start Dashboard + +Use a fresh run id: + +```bash +cd "$LOCAL_CHEMGRAPH" + +export RUN_ID="${ALCF_SYSTEM}-mace-ensemble-screening-001" + +PYTHONPATH=src python -m chemgraph.cli.main academy dashboard -- \ + --system "$ALCF_SYSTEM" \ + --remote-host "$ALCF_LOGIN@$ALCF_HOST" \ + --campaign mace-ensemble-screening-20 \ + --lm-connect mac-argo-relay \ + "$RUN_ID" +``` + +The dashboard command starts the local dashboard, an rsync mirror, an SSH +control connection, and a relay from compute nodes to local `argo-shim`. + +## Start The Campaign On Compute + +Run inside an interactive allocation: + +```bash +cd "$REMOTE_ROOT/ChemGraph" + +# Aurora: +module load frameworks + +# Polaris: +# module use /soft/modulefiles +# module load conda +# conda activate base + +source "$REMOTE_ROOT/venvs/academy-swarm/bin/activate" + +export RUN_ID="${ALCF_SYSTEM}-mace-ensemble-screening-001" + +export NUMEXPR_MAX_THREADS=256 +export NUMEXPR_NUM_THREADS=64 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +export CHEMGRAPH_EXECUTION_BACKEND=parsl +export COMPUTE_SYSTEM="$ALCF_SYSTEM" + +export PATH="$REMOTE_ROOT/bin:$REMOTE_ROOT/tools/redis/bin:$PATH" + +: "${CHEMGRAPH_EXECUTION_BACKEND:?must be set to 'parsl' before launch}" +: "${COMPUTE_SYSTEM:?must be set to aurora or polaris before launch}" +echo "execution backend = $CHEMGRAPH_EXECUTION_BACKEND" +echo "compute system = $COMPUTE_SYSTEM" + +chemgraph academy run-compute \ + --system "$ALCF_SYSTEM" \ + --run-id "$RUN_ID" \ + --campaign mace-ensemble-screening-20 \ + --lm-user "$ARGO_USER" +``` + +If you reconnect to the login/compute node and re-run only the final +`chemgraph academy run-compute` invocation, the env exports above will not be +in your shell. Re-run the full block, or re-export both variables, before +relaunching. If `CHEMGRAPH_EXECUTION_BACKEND` is unset, the MCP server can fall +back to LocalBackend and produce `BrokenProcessPool` failures under per-rank +memory pressure. + +If the wrapper is installed but `chemgraph` is not on `PATH`, use: + +```bash +chemgraph-academy-run \ + --system "$ALCF_SYSTEM" \ + --run-id "$RUN_ID" \ + --campaign mace-ensemble-screening-20 \ + --lm-user "$ARGO_USER" +``` + +## Reopen A Local Dashboard + +Once the run has been synced locally: + +```bash +cd "$LOCAL_CHEMGRAPH" + +PYTHONPATH=src python -m chemgraph.cli.main academy dashboard -- \ + --system "$ALCF_SYSTEM" \ + --remote-host "$ALCF_LOGIN@$ALCF_HOST" \ + --campaign mace-ensemble-screening-20 \ + "$RUN_ID" \ + --local +``` + +## Troubleshooting + +Check the relay from compute: + +```bash +UAN_RELAY_HOST="$(tr -d '[:space:]' < "$REMOTE_ROOT/uan-relay-18186.host")" +curl --noproxy '*' -I "http://${UAN_RELAY_HOST}:18186/v1/models" +``` + +Expected: + +```text +HTTP/1.1 200 OK +``` + +If the first model response is an Argo access-denied notice for ``, +the compute command was launched without `--lm-user "$ARGO_USER"`. Use a fresh +run id, or restart the dashboard with `--overwrite-run`, then rerun compute +with `--lm-user`. + +If imports are slow or NumExpr complains, set: + +```bash +export NUMEXPR_MAX_THREADS=256 +export NUMEXPR_NUM_THREADS=64 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +``` + +If MACE results come back as `PicklingError: Can't pickle +run_mace_singleArguments`, the remote ChemGraph checkout does not have the +worker-module fix synced. Sync the latest ChemGraph checkout to the ALCF +filesystem, restart the dashboard with a fresh run id, and rerun from a fresh +compute allocation. + +If MACE results come back as `BrokenProcessPool` failures, confirm the MACE MCP +server initialized Parsl: + +```bash +grep "backend initialised" \ + "$REMOTE_ROOT/runs/$RUN_ID/rank3/mcp_logs/mace.log" +``` + +Expected: + +```text +CGFastMCP backend initialised: ParslBackend +``` + +If the log shows `LocalBackend initialized with 4 workers`, re-run the full +compute block with `CHEMGRAPH_EXECUTION_BACKEND=parsl`. + +If the log shows `Parsl is required for the ParslBackend`, the Parsl package is +missing from the venv: + +```bash +python -m pip install -e ".[academy,parsl]" +``` From 64c467cea45edcd1e0cd6b76bea80bc94e86df15 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 12:56:57 -0500 Subject: [PATCH 075/119] refactor(agent): restore per-workflow graphs; move run_turn to agent/turn.py Restore per-workflow graph files as the canonical CLI path while keeping Academy on the run_turn primitive now housed in chemgraph.agent.turn. ChemGraph.run is routed through graph astream again with matching trace events, and Academy imports run_turn directly from the new module. No intended behavior change for either caller. --- src/chemgraph/academy/core/turn.py | 2 +- src/chemgraph/agent/llm_agent.py | 815 ++++-------------- src/chemgraph/agent/turn.py | 526 +++++++++++ src/chemgraph/cli/commands.py | 1 + src/chemgraph/cli/main.py | 3 +- src/chemgraph/cli/trace.py | 10 +- src/chemgraph/graphs/graspa_agent.py | 182 ++++ src/chemgraph/graphs/mock_agent.py | 102 +++ src/chemgraph/graphs/python_relp_agent.py | 224 +++++ src/chemgraph/graphs/rag_agent.py | 245 ++++++ .../graphs/single_agent_architector.py | 162 ++++ src/chemgraph/graphs/single_agent_mcp.py | 116 +++ src/chemgraph/graphs/single_agent_xanes.py | 272 ++++++ tests/test_agent_session.py | 76 +- tests/test_graph_constructors.py | 99 ++- tests/test_graphs.py | 192 +++-- tests/test_llm_agent.py | 59 ++ 17 files changed, 2287 insertions(+), 799 deletions(-) create mode 100644 src/chemgraph/agent/turn.py create mode 100644 src/chemgraph/graphs/graspa_agent.py create mode 100644 src/chemgraph/graphs/mock_agent.py create mode 100644 src/chemgraph/graphs/python_relp_agent.py create mode 100644 src/chemgraph/graphs/rag_agent.py create mode 100644 src/chemgraph/graphs/single_agent_architector.py create mode 100644 src/chemgraph/graphs/single_agent_mcp.py create mode 100644 src/chemgraph/graphs/single_agent_xanes.py diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py index 967efc2b..3b833849 100644 --- a/src/chemgraph/academy/core/turn.py +++ b/src/chemgraph/academy/core/turn.py @@ -12,7 +12,7 @@ from chemgraph.academy.core.campaign import visible_resources_payload from chemgraph.academy.core.prompt import PromptProfile from chemgraph.academy.observability.run_files import read_json_file -from chemgraph.agent.llm_agent import run_turn +from chemgraph.agent.turn import run_turn from chemgraph.models.settings import LLMSettings TraceFn = Callable[[str, dict[str, Any]], None] diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 6de5f7fc..22c40729 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -1,6 +1,5 @@ import asyncio import datetime -import dataclasses import os import time from typing import Any, Callable, Collection, List, Optional @@ -14,7 +13,6 @@ from chemgraph.models.anthropic import load_anthropic_model from chemgraph.models.gemini import load_gemini_model from chemgraph.models.groq import load_groq_model -from chemgraph.models.loader import load_chat_model from chemgraph.models.supported_models import ( supported_openai_models, supported_ollama_models, @@ -22,9 +20,7 @@ supported_alcf_models, supported_argo_models, supported_gemini_models, - ) -from chemgraph.models.settings import LLMSettings from chemgraph.schemas.ase_input import ( get_available_calculator_names, get_calculator_selection_context, @@ -44,280 +40,44 @@ planner_prompt as default_planner_prompt, ) from langgraph.errors import GraphInterrupt -from langchain_core.messages import AIMessage from langchain_core.callbacks import BaseCallbackHandler +from chemgraph.agent.turn import ( + EventCallback, + TurnResult, + _TurnEventCallback, + _custom_openai_compatible_kwargs, + _executed_tool_names, + _response_tool_calls, + _serialized_name, + _state_messages, + _terminal_tool_name, + run_turn, + serialize_state, +) from chemgraph.graphs.single_agent import construct_single_agent_graph from chemgraph.graphs.multi_agent import construct_multi_agent_graph +from chemgraph.graphs.python_relp_agent import construct_relp_graph +from chemgraph.graphs.graspa_agent import construct_graspa_graph +from chemgraph.graphs.mock_agent import construct_mock_agent_graph +from chemgraph.graphs.single_agent_mcp import construct_single_agent_mcp_graph from chemgraph.graphs.graspa_mcp import construct_graspa_mcp_graph +from chemgraph.graphs.rag_agent import construct_rag_agent_graph +from chemgraph.graphs.single_agent_xanes import construct_single_agent_xanes_graph +from chemgraph.graphs.single_agent_architector import construct_single_agent_architector_graph from chemgraph.prompt.rag_prompt import rag_agent_prompt from chemgraph.prompt.xanes_prompt import ( xanes_single_agent_prompt as default_xanes_single_agent_prompt, xanes_formatter_prompt as default_xanes_formatter_prompt, ) -from chemgraph.tools.ase_tools import ( - file_to_atomsdata, - run_ase, - save_atomsdata_to_file, -) -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_atomsdata, - smiles_to_coordinate_file, -) -from chemgraph.tools.generic_tools import calculator, repl_tool -from chemgraph.tools.graspa_tools import run_graspa -from chemgraph.tools.rag_tools import load_document, query_knowledge_base -from chemgraph.tools.xanes_tools import ( - fetch_xanes_data, - plot_xanes_data, - run_xanes, -) import logging logger = logging.getLogger(__name__) -SINGLE_AGENT_TURN_WORKFLOWS = { - "single_agent", - "python_relp", - "graspa", - "mock_agent", - "single_agent_mcp", - "rag_agent", - "single_agent_xanes", -} - -LEGACY_GRAPH_WORKFLOWS = {"multi_agent", "graspa_mcp"} - - -def _tool_name(tool: Any) -> str: - return str(getattr(tool, "name", getattr(tool, "__name__", repr(tool)))) - - -def _merge_tools(*groups: Collection[Any] | None) -> list[Any]: - """Merge tool groups by visible tool name while preserving order.""" - merged: list[Any] = [] - seen: set[str] = set() - for group in groups: - for tool in group or (): - name = _tool_name(tool) - if name not in seen: - merged.append(tool) - seen.add(name) - return merged - - -def _xanes_tools() -> list[Any]: - return [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - run_ase, - run_xanes, - fetch_xanes_data, - plot_xanes_data, - ] - - -def _rag_tools() -> list[Any]: - return [ - load_document, - query_knowledge_base, - file_to_atomsdata, - smiles_to_coordinate_file, - run_ase, - molecule_name_to_smiles, - save_atomsdata_to_file, - calculator, - ] - - -def _mock_tools() -> list[Any]: - return [ - file_to_atomsdata, - smiles_to_atomsdata, - run_ase, - molecule_name_to_smiles, - save_atomsdata_to_file, - calculator, - ] - - -def _last_ai_message(state: dict[str, Any], fallback_text: str) -> AIMessage: - """Return the last AI message from a turn state, preserving objects when present.""" - messages = state.get("messages", []) if isinstance(state, dict) else [] - for message in reversed(messages): - if isinstance(message, AIMessage): - return message - if isinstance(message, dict): - message_type = message.get("type") or message.get("role") - if message_type in {"ai", "assistant"}: - return AIMessage(content=_message_text(message)) - return AIMessage(content=fallback_text) - - -def _is_mock_object(value) -> bool: - """Return True for unittest.mock objects without importing test-only APIs. - - Parameters - ---------- - value : Any - Object to inspect. - - Returns - ------- - bool - ``True`` when the object comes from ``unittest.mock``. - """ - return value.__class__.__module__.startswith("unittest.mock") - - -def serialize_state(state, *, max_depth: int = 50, _seen: set[int] | None = None): - """Convert non-serializable objects in state to a JSON-friendly format. - - Parameters - ---------- - state : Any - The state object to be serialized. Can be a list, dict, or object with __dict__ - max_depth : int, optional - Maximum object nesting depth to serialize before falling back to a - placeholder. This prevents runaway recursion for complex graph objects. - - Returns - ------- - Any - A JSON-serializable version of the input state - """ - if _seen is None: - _seen = set() - - if max_depth < 0: - return f"" - - if isinstance(state, (str, int, float, bool)) or state is None: - return state - - if isinstance(state, (datetime.datetime, datetime.date)): - return state.isoformat() - - if _is_mock_object(state): - return str(state) - - state_id = id(state) - if state_id in _seen: - return f"" - - if isinstance(state, dict): - _seen.add(state_id) - try: - return { - str(key): serialize_state( - value, max_depth=max_depth - 1, _seen=_seen - ) - for key, value in state.items() - } - finally: - _seen.remove(state_id) - - if isinstance(state, (list, tuple, set, frozenset)): - _seen.add(state_id) - try: - return [ - serialize_state(item, max_depth=max_depth - 1, _seen=_seen) - for item in state - ] - finally: - _seen.remove(state_id) - - model_dump = getattr(state, "model_dump", None) - if callable(model_dump): - _seen.add(state_id) - try: - try: - dumped = model_dump(mode="json") - except TypeError: - dumped = model_dump() - return serialize_state(dumped, max_depth=max_depth - 1, _seen=_seen) - except Exception: - return str(state) - finally: - _seen.remove(state_id) - - if dataclasses.is_dataclass(state) and not isinstance(state, type): - _seen.add(state_id) - try: - return { - field.name: serialize_state( - getattr(state, field.name), - max_depth=max_depth - 1, - _seen=_seen, - ) - for field in dataclasses.fields(state) - } - finally: - _seen.remove(state_id) - - if hasattr(state, "__dict__"): - _seen.add(state_id) - try: - return { - str(key): serialize_state( - value, max_depth=max_depth - 1, _seen=_seen - ) - for key, value in vars(state).items() - } - finally: - _seen.remove(state_id) - - return str(state) - - -def _custom_openai_compatible_kwargs( - *, - model_name: str, - temperature: float, - base_url: str, - api_key: str, - max_tokens: int, - top_p: float, - frequency_penalty: float, - presence_penalty: float, - argo_user: str | None, -) -> dict: - kwargs = { - "model": model_name, - "temperature": temperature, - "base_url": base_url, - "api_key": api_key, - "max_tokens": max_tokens, - "top_p": top_p, - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - } - user = argo_user or os.getenv("ARGO_USER") - if base_url and "argoapi" in base_url and user: - kwargs["model_kwargs"] = {"user": user} - return kwargs - - -EventCallback = Callable[[str, dict], None] - - -@dataclasses.dataclass(frozen=True) -class TurnResult: - """Result of one bounded ChemGraph single-agent turn.""" - - final_text: str - state: dict[str, Any] - executed_tool_names: tuple[str, ...] - terminal_tool: str | None - thread_id: str - duration_s: float - - -class _TurnEventCallback(BaseCallbackHandler): - """Forward LangChain callback events to a small stable callback surface.""" +class _AstreamEventCallback(BaseCallbackHandler): + """Forward LangChain callback events from graph-backed CLI runs.""" def __init__(self, on_event: EventCallback, thread_id: str) -> None: self._on_event = on_event @@ -327,7 +87,7 @@ def _emit(self, event: str, payload: dict[str, Any]) -> None: try: self._on_event(event, {"thread_id": self._thread_id, **payload}) except Exception: # noqa: BLE001 - callbacks must not break the run. - logger.debug("turn event callback failed", exc_info=True) + logger.debug("astream event callback failed", exc_info=True) def on_chat_model_start(self, serialized, messages, **kwargs) -> None: self._emit( @@ -383,283 +143,6 @@ def on_tool_error(self, error, **kwargs) -> None: self._emit("tool_call_failed", payload) -def _serialized_name(serialized: Any) -> str | None: - if isinstance(serialized, dict): - return serialized.get("name") or serialized.get("id") - return None - - -def _message_tool_calls(message: Any) -> list[Any]: - if isinstance(message, dict): - calls = message.get("tool_calls") - else: - calls = getattr(message, "tool_calls", None) - return calls if isinstance(calls, list) else [] - - -def _response_tool_calls(response: Any) -> list[dict[str, str | None]]: - try: - generations = getattr(response, "generations", None) or [] - tool_calls: list[dict[str, str | None]] = [] - for generation_group in generations: - for generation in generation_group or []: - message = getattr(generation, "message", None) - for call in _message_tool_calls(message): - name = _call_name(call) - if not name: - continue - tool_calls.append( - { - "name": name, - "id": _call_id(call), - }, - ) - return tool_calls - except Exception: # noqa: BLE001 - event extraction must not break runs. - logger.debug("failed to extract llm_decision tool calls", exc_info=True) - return [] - - -def _tool_message_name(message: Any) -> str | None: - if isinstance(message, dict): - name = message.get("name") - role = message.get("role") or message.get("type") - if name and role in {"tool", "tool_message", "ToolMessage"}: - return str(name) - return str(name) if name and not _message_tool_calls(message) else None - name = getattr(message, "name", None) - message_type = getattr(message, "type", None) - if name and message_type == "tool": - return str(name) - return str(name) if name and not _message_tool_calls(message) else None - - -def _call_name(call: Any) -> str | None: - if isinstance(call, dict): - if call.get("name"): - return str(call["name"]) - function = call.get("function") - if isinstance(function, dict) and function.get("name"): - return str(function["name"]) - name = getattr(call, "name", None) - return str(name) if name else None - - -def _call_id(call: Any) -> str | None: - if isinstance(call, dict): - value = call.get("id") or call.get("tool_call_id") - else: - value = getattr(call, "id", None) or getattr(call, "tool_call_id", None) - return str(value) if value else None - - -def _state_messages(state: Any) -> list[Any]: - if isinstance(state, dict): - messages = state.get("messages", []) - else: - messages = getattr(state, "messages", []) - return list(messages or []) - - -def _executed_tool_names(messages: list[Any]) -> tuple[str, ...]: - names: list[str] = [] - for message in messages: - name = _tool_message_name(message) - if name: - names.append(name) - if names: - return tuple(names) - for message in messages: - for call in _message_tool_calls(message): - if name := _call_name(call): - names.append(name) - return tuple(names) - - -def _terminal_tool_name( - executed_tool_names: tuple[str, ...], - terminal_tool_names: Collection[str], -) -> str | None: - terminal = set(terminal_tool_names) - for name in reversed(executed_tool_names): - if name in terminal: - return name - return None - - -def _message_text(message: Any) -> str: - content = message.get("content") if isinstance(message, dict) else getattr(message, "content", "") - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, dict): - parts.append(str(item.get("text") or item.get("content") or item)) - else: - parts.append(str(item)) - return "\n".join(parts) - return "" if content is None else str(content) - - -def _final_text(messages: list[Any]) -> str: - for message in reversed(messages): - message_type = ( - message.get("role") or message.get("type") - if isinstance(message, dict) - else getattr(message, "type", None) - ) - if message_type in {"ai", "assistant"}: - return _message_text(message) - return _message_text(messages[-1]) if messages else "" - - -def _load_turn_llm( - *, - model_name: str, - base_url: str | None, - api_key: str | None, - argo_user: str | None, -) -> Any: - temperature = 0.0 - try: - return load_chat_model( - settings=LLMSettings( - model=model_name, - base_url=base_url, - api_key=api_key, - argo_user=argo_user, - temperature=temperature, - ), - ) - except ValueError: - pass - - endpoint = os.getenv("VLLM_BASE_URL", base_url or "") - key = os.getenv("OPENAI_API_KEY", api_key or "dummy_vllm_key") - if not endpoint: - raise ValueError(f"Unsupported model or missing base URL for: {model_name}") - from langchain_openai import ChatOpenAI - - return ChatOpenAI( - **_custom_openai_compatible_kwargs( - model_name=model_name, - temperature=temperature, - base_url=endpoint, - api_key=key, - max_tokens=4000, - top_p=1.0, - frequency_penalty=0.0, - presence_penalty=0.0, - argo_user=argo_user, - ), - ) - - -async def run_turn( - *, - query: str, - tools: list[Any] | None = None, - model_name: str = "gpt-4o-mini", - base_url: str | None = None, - api_key: str | None = None, - argo_user: str | None = None, - system_prompt: str = single_agent_prompt, - formatter_prompt: str = default_formatter_prompt, - structured_output: bool = False, - generate_report: bool = False, - report_prompt: str = default_report_prompt, - recursion_limit: int = 50, - thread_id: str | None = None, - terminal_tool_names: Collection[str] = (), - human_supervised: bool = False, - on_event: EventCallback | None = None, -) -> TurnResult: - """Run one bounded single-agent ChemGraph LangGraph turn.""" - - started = time.time() - thread_id = thread_id or str(uuid.uuid4()) - callbacks = [_TurnEventCallback(on_event, thread_id)] if on_event else [] - event = on_event or (lambda _event, _payload: None) - event( - "workflow_started", - { - "workflow_type": "single_agent", - "thread_id": thread_id, - "tool_names": [getattr(tool, "name", str(tool)) for tool in tools or []], - }, - ) - llm = _load_turn_llm( - model_name=model_name, - base_url=base_url, - api_key=api_key, - argo_user=argo_user, - ) - workflow = construct_single_agent_graph( - llm, - system_prompt, - structured_output, - formatter_prompt, - generate_report, - report_prompt, - tools, - human_supervised=human_supervised, - terminal_tool_names=terminal_tool_names, - ) - config: dict[str, Any] = { - "configurable": {"thread_id": thread_id}, - "recursion_limit": recursion_limit, - } - if callbacks: - config["callbacks"] = callbacks - - last_state: Any = None - try: - async for state in workflow.astream( - {"messages": query}, - stream_mode="values", - config=config, - ): - last_state = state - except Exception as exc: - event( - "workflow_finished", - { - "workflow_type": "single_agent", - "thread_id": thread_id, - "status": "failed", - "error": repr(exc), - "duration_s": round(time.time() - started, 3), - }, - ) - raise - - if last_state is None: - raise RuntimeError("ChemGraph turn produced no states.") - - messages = _state_messages(last_state) - executed_tools = _executed_tool_names(messages) - terminal_tool = _terminal_tool_name(executed_tools, terminal_tool_names) - result = TurnResult( - final_text=_final_text(messages), - state=serialize_state(last_state), - executed_tool_names=executed_tools, - terminal_tool=terminal_tool, - thread_id=thread_id, - duration_s=round(time.time() - started, 3), - ) - event( - "workflow_finished", - { - "workflow_type": "single_agent", - "thread_id": thread_id, - "status": "completed", - "executed_tool_names": list(result.executed_tool_names), - "terminal_tool": terminal_tool, - "duration_s": result.duration_s, - }, - ) - return result - - class ChemGraph: """A graph-based workflow for LLM-powered computational chemistry tasks. @@ -919,11 +402,26 @@ def __init__( logger.error(f"Exception thrown when loading {model_name}: {str(e)}") raise e - supported_workflows = SINGLE_AGENT_TURN_WORKFLOWS | LEGACY_GRAPH_WORKFLOWS - if workflow_type not in supported_workflows: + self.workflow_map = { + "single_agent": {"constructor": construct_single_agent_graph}, + "multi_agent": {"constructor": construct_multi_agent_graph}, + "python_relp": {"constructor": construct_relp_graph}, + "graspa": {"constructor": construct_graspa_graph}, + "graspa_agent": {"constructor": construct_graspa_graph}, + "mock_agent": {"constructor": construct_mock_agent_graph}, + "single_agent_mcp": {"constructor": construct_single_agent_mcp_graph}, + "graspa_mcp": {"constructor": construct_graspa_mcp_graph}, + "rag_agent": {"constructor": construct_rag_agent_graph}, + "single_agent_xanes": {"constructor": construct_single_agent_xanes_graph}, + "single_agent_architector": { + "constructor": construct_single_agent_architector_graph, + }, + } + + if workflow_type not in self.workflow_map: raise ValueError( f"Unsupported workflow type: {workflow_type}. " - f"Available types: {sorted(supported_workflows)}" + f"Available types: {list(self.workflow_map.keys())}" ) self._using_default_system_prompt = system_prompt == single_agent_prompt @@ -965,18 +463,7 @@ def __init__( self.calculator_selection_context = get_calculator_selection_context() def append_calculator_context(prompt: str) -> str: - """Append calculator availability guidance to a prompt once. - - Parameters - ---------- - prompt : str - Prompt text to augment. - - Returns - ------- - str - Prompt with calculator-selection context appended. - """ + """Append calculator availability guidance to a prompt once.""" if self.calculator_selection_context in prompt: return prompt return f"{prompt}{self.calculator_selection_context}" @@ -992,15 +479,20 @@ def append_calculator_context(prompt: str) -> str: else: self.support_structured_output = support_structured_output - self.workflow_map = { - "multi_agent": {"constructor": construct_multi_agent_graph}, - "graspa_mcp": {"constructor": construct_graspa_mcp_graph}, - } - - self.tools = self._resolve_turn_tools(tools, data_tools) - self._resolve_turn_prompts() - - if self.workflow_type == "multi_agent": + if self.workflow_type == "single_agent": + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm, + self.system_prompt, + self.structured_output, + self.formatter_prompt, + self.generate_report, + self.report_prompt, + self.tools, + max_retries=self.max_retries, + human_supervised=self.human_supervised, + terminal_tool_names=self.terminal_tool_names, + ) + elif self.workflow_type == "multi_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( llm, planner_prompt=self.planner_prompt, @@ -1010,51 +502,61 @@ def append_calculator_context(prompt: str) -> str: formatter_prompt=self.formatter_multi_prompt, max_retries=self.max_retries, ) + elif self.workflow_type == "python_relp": + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm, + self.system_prompt, + ) + elif self.workflow_type in {"graspa", "graspa_agent"}: + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm, + self.system_prompt, + self.structured_output, + self.formatter_prompt, + ) + elif self.workflow_type == "mock_agent": + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm=llm, + system_prompt=self.system_prompt, + ) + elif self.workflow_type == "single_agent_mcp": + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm=llm, + system_prompt=self.system_prompt, + tools=self.tools, + ) elif self.workflow_type == "graspa_mcp": self.workflow = self.workflow_map[workflow_type]["constructor"]( llm=llm, executor_tools=self.tools, analysis_tools=self.data_tools, ) - else: - self.workflow = None - - def _resolve_turn_tools( - self, - tools: Collection[Any] | None, - data_tools: Collection[Any] | None, - ) -> list[Any] | None: - """Resolve the LangGraph tools for run_turn-backed workflows.""" - if self.workflow_type == "single_agent": - return list(tools) if tools is not None else None - if self.workflow_type == "python_relp": - return _merge_tools(tools, [repl_tool, calculator]) - if self.workflow_type == "graspa": - return _merge_tools(tools, [run_graspa]) - if self.workflow_type == "mock_agent": - return _merge_tools(tools, _mock_tools()) - if self.workflow_type == "single_agent_mcp": - resolved = _merge_tools(tools, data_tools) - if not resolved: - raise ValueError( - "No MCP tools loaded. Ensure MCP servers are configured and reachable." - ) - return resolved - if self.workflow_type == "rag_agent": - return _merge_tools(tools, _rag_tools()) - if self.workflow_type == "single_agent_xanes": - return _merge_tools(tools, _xanes_tools()) - return list(tools) if tools is not None else None - - def _resolve_turn_prompts(self) -> None: - """Apply workflow-specific prompt defaults before run_turn.""" - if self.workflow_type == "rag_agent" and self._using_default_system_prompt: - self.system_prompt = rag_agent_prompt + elif self.workflow_type == "rag_agent": + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm=llm, + system_prompt=self.system_prompt + if not self._using_default_system_prompt + else rag_agent_prompt, + tools=self.tools, + ) elif self.workflow_type == "single_agent_xanes": - if self._using_default_system_prompt: - self.system_prompt = default_xanes_single_agent_prompt - if self._using_default_formatter_prompt: - self.formatter_prompt = default_xanes_formatter_prompt + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm, + system_prompt=self.system_prompt + if not self._using_default_system_prompt + else default_xanes_single_agent_prompt, + structured_output=self.structured_output, + formatter_prompt=self.formatter_prompt + if not self._using_default_formatter_prompt + else default_xanes_formatter_prompt, + tools=self.tools, + ) + elif self.workflow_type == "single_agent_architector": + self.workflow = self.workflow_map[workflow_type]["constructor"]( + llm=llm, + system_prompt=self.system_prompt, + tools=self.tools, + ) def visualize(self, method: str = "ascii"): """Visualize the LangGraph graph structure. @@ -1079,11 +581,6 @@ def visualize(self, method: str = "ascii"): Requires IPython and nest_asyncio to be installed. The visualization uses Mermaid diagrams with custom styling. """ - if self.workflow is None: - raise RuntimeError( - f"Workflow {self.workflow_type!r} is run-turn-backed and is built " - "inside ChemGraph.run(); it is not available for pre-run visualization." - ) import nest_asyncio from IPython.display import Image, display from langchain_core.runnables.graph import ( @@ -1127,12 +624,6 @@ def get_state(self, config={"configurable": {"thread_id": "1"}}): list List of messages in the current state """ - if self.workflow is None: - if self._last_run_state is None: - raise RuntimeError( - f"Workflow {self.workflow_type!r} has not produced state yet." - ) - return self._last_run_state return self.workflow.get_state(config).values def write_state( @@ -1405,29 +896,11 @@ async def run( config=None, resume_from: Optional[str] = None, ): - """ - Async runner for run-turn-backed and legacy graph-backed workflows. + """Run a graph-backed ChemGraph workflow. - Run-turn-backed workflows delegate to :func:`run_turn`, while legacy - multi-node graph workflows stream through ``self.workflow.astream``. - The return value follows ``self.return_option`` ("last_message" or - "state"). - - When the graph pauses for human input (via ``interrupt()``), the - ``human_input_handler`` callback is invoked to obtain the user's - response, and the graph is automatically resumed. If no handler - is configured, the ``GraphInterrupt`` exception propagates to the - caller. - - Parameters - ---------- - query : str - The user query to execute. - config : dict, optional - LangGraph config with thread_id, etc. - resume_from : str, optional - Session ID to load context from. The previous conversation - summary is prepended to the query. + All CLI workflows execute through their restored LangGraph + constructors. Academy uses :func:`run_turn` directly instead of this + method. """ if config is None: config = {} @@ -1452,37 +925,21 @@ async def run( ) logger.info(f"Injected context from session {resume_from}") + started = time.time() thread_id = str(config["configurable"]["thread_id"]) - if self.workflow_type in SINGLE_AGENT_TURN_WORKFLOWS: - result = await run_turn( - query=query, - tools=self.tools, - model_name=self.model_name, - base_url=self.base_url, - api_key=self.api_key, - argo_user=self.argo_user, - system_prompt=self.system_prompt, - formatter_prompt=self.formatter_prompt, - structured_output=self.structured_output, - generate_report=self.generate_report, - report_prompt=self.report_prompt, - recursion_limit=self.recursion_limit, - thread_id=thread_id, - terminal_tool_names=self.terminal_tool_names, - human_supervised=self.human_supervised, - on_event=self.on_event, - ) - self._last_run_state = result.state - self._save_messages_to_store(result.state, query) - self.write_state(config=config, file_path=None) - if self.return_option == "state": - return result.state - if self.return_option == "last_message": - return _last_ai_message(result.state, result.final_text) - raise ValueError( - f"Unsupported return_option: {self.return_option}. " - "Use 'last_message' or 'state'." - ) + event = self.on_event or (lambda _event, _payload: None) + if self.on_event: + callbacks = list(config.get("callbacks") or []) + callbacks.append(_AstreamEventCallback(self.on_event, thread_id)) + config["callbacks"] = callbacks + event( + "workflow_started", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "tool_names": [getattr(tool, "name", str(tool)) for tool in self.tools or []], + }, + ) try: last_state = None @@ -1501,6 +958,24 @@ async def run( last_state = state if last_state is None: raise RuntimeError("Workflow produced no states") + + messages = _state_messages(last_state) + executed_tools = _executed_tool_names(messages) + terminal_tool = _terminal_tool_name( + executed_tools, + self.terminal_tool_names, + ) + event( + "workflow_finished", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "status": "completed", + "executed_tool_names": list(executed_tools), + "terminal_tool": terminal_tool, + "duration_s": round(time.time() - started, 3), + }, + ) self._last_run_state = serialize_state(last_state) self._save_messages_to_store(last_state, query) self.write_state(config=config, file_path=None) @@ -1515,6 +990,16 @@ async def run( except GraphInterrupt: raise except Exception as e: + event( + "workflow_finished", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "status": "failed", + "error": repr(e), + "duration_s": round(time.time() - started, 3), + }, + ) logger.error(f"Error running workflow {self.workflow_type}: {e}") raise diff --git a/src/chemgraph/agent/turn.py b/src/chemgraph/agent/turn.py new file mode 100644 index 00000000..a6cf9720 --- /dev/null +++ b/src/chemgraph/agent/turn.py @@ -0,0 +1,526 @@ +from __future__ import annotations + +import dataclasses +import datetime +import logging +import os +import time +import uuid +from typing import Any, Callable, Collection + +from langchain_core.callbacks import BaseCallbackHandler + +from chemgraph.graphs.single_agent import construct_single_agent_graph +from chemgraph.models.loader import load_chat_model +from chemgraph.models.settings import LLMSettings +from chemgraph.prompt.single_agent_prompt import ( + formatter_prompt as default_formatter_prompt, +) +from chemgraph.prompt.single_agent_prompt import report_prompt as default_report_prompt +from chemgraph.prompt.single_agent_prompt import single_agent_prompt + +logger = logging.getLogger(__name__) + + +def _is_mock_object(value) -> bool: + """Return True for unittest.mock objects without importing test-only APIs. + + Parameters + ---------- + value : Any + Object to inspect. + + Returns + ------- + bool + ``True`` when the object comes from ``unittest.mock``. + """ + return value.__class__.__module__.startswith("unittest.mock") + + +def serialize_state(state, *, max_depth: int = 50, _seen: set[int] | None = None): + """Convert non-serializable objects in state to a JSON-friendly format. + + Parameters + ---------- + state : Any + The state object to be serialized. Can be a list, dict, or object with __dict__ + max_depth : int, optional + Maximum object nesting depth to serialize before falling back to a + placeholder. This prevents runaway recursion for complex graph objects. + + Returns + ------- + Any + A JSON-serializable version of the input state + """ + if _seen is None: + _seen = set() + + if max_depth < 0: + return f"" + + if isinstance(state, (str, int, float, bool)) or state is None: + return state + + if isinstance(state, (datetime.datetime, datetime.date)): + return state.isoformat() + + if _is_mock_object(state): + return str(state) + + state_id = id(state) + if state_id in _seen: + return f"" + + if isinstance(state, dict): + _seen.add(state_id) + try: + return { + str(key): serialize_state( + value, max_depth=max_depth - 1, _seen=_seen + ) + for key, value in state.items() + } + finally: + _seen.remove(state_id) + + if isinstance(state, (list, tuple, set, frozenset)): + _seen.add(state_id) + try: + return [ + serialize_state(item, max_depth=max_depth - 1, _seen=_seen) + for item in state + ] + finally: + _seen.remove(state_id) + + model_dump = getattr(state, "model_dump", None) + if callable(model_dump): + _seen.add(state_id) + try: + try: + dumped = model_dump(mode="json") + except TypeError: + dumped = model_dump() + return serialize_state(dumped, max_depth=max_depth - 1, _seen=_seen) + except Exception: + return str(state) + finally: + _seen.remove(state_id) + + if dataclasses.is_dataclass(state) and not isinstance(state, type): + _seen.add(state_id) + try: + return { + field.name: serialize_state( + getattr(state, field.name), + max_depth=max_depth - 1, + _seen=_seen, + ) + for field in dataclasses.fields(state) + } + finally: + _seen.remove(state_id) + + if hasattr(state, "__dict__"): + _seen.add(state_id) + try: + return { + str(key): serialize_state( + value, max_depth=max_depth - 1, _seen=_seen + ) + for key, value in vars(state).items() + } + finally: + _seen.remove(state_id) + + return str(state) + + +def _custom_openai_compatible_kwargs( + *, + model_name: str, + temperature: float, + base_url: str, + api_key: str, + max_tokens: int, + top_p: float, + frequency_penalty: float, + presence_penalty: float, + argo_user: str | None, +) -> dict: + kwargs = { + "model": model_name, + "temperature": temperature, + "base_url": base_url, + "api_key": api_key, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + } + user = argo_user or os.getenv("ARGO_USER") + if base_url and "argoapi" in base_url and user: + kwargs["model_kwargs"] = {"user": user} + return kwargs + + +EventCallback = Callable[[str, dict], None] + + +@dataclasses.dataclass(frozen=True) +class TurnResult: + """Result of one bounded ChemGraph single-agent turn.""" + + final_text: str + state: dict[str, Any] + executed_tool_names: tuple[str, ...] + terminal_tool: str | None + thread_id: str + duration_s: float + + +class _TurnEventCallback(BaseCallbackHandler): + """Forward LangChain callback events to a small stable callback surface.""" + + def __init__(self, on_event: EventCallback, thread_id: str) -> None: + self._on_event = on_event + self._thread_id = thread_id + + def _emit(self, event: str, payload: dict[str, Any]) -> None: + try: + self._on_event(event, {"thread_id": self._thread_id, **payload}) + except Exception: # noqa: BLE001 - callbacks must not break the run. + logger.debug("turn event callback failed", exc_info=True) + + def on_chat_model_start(self, serialized, messages, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(messages[0]) if messages else 0, + }, + ) + + def on_llm_start(self, serialized, prompts, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(prompts or []), + }, + ) + + def on_llm_end(self, response, **kwargs) -> None: + payload: dict[str, Any] = {} + usage = getattr(response, "llm_output", None) + if isinstance(usage, dict): + payload["llm_output"] = usage + self._emit("llm_call_finished", payload) + if tool_calls := _response_tool_calls(response): + self._emit("llm_decision", {"tool_calls": tool_calls}) + + def on_llm_error(self, error, **kwargs) -> None: + self._emit("llm_call_failed", {"error": repr(error)}) + + def on_tool_start(self, serialized, input_str, **kwargs) -> None: + self._emit( + "tool_call_started", + { + "tool_name": _serialized_name(serialized), + "arguments": serialize_state(input_str), + }, + ) + + def on_tool_end(self, output, **kwargs) -> None: + payload: dict[str, Any] = {"result": serialize_state(output)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_finished", payload) + + def on_tool_error(self, error, **kwargs) -> None: + payload = {"error": repr(error)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_failed", payload) + + +def _serialized_name(serialized: Any) -> str | None: + if isinstance(serialized, dict): + return serialized.get("name") or serialized.get("id") + return None + + +def _message_tool_calls(message: Any) -> list[Any]: + if isinstance(message, dict): + calls = message.get("tool_calls") + else: + calls = getattr(message, "tool_calls", None) + return calls if isinstance(calls, list) else [] + + +def _response_tool_calls(response: Any) -> list[dict[str, str | None]]: + try: + generations = getattr(response, "generations", None) or [] + tool_calls: list[dict[str, str | None]] = [] + for generation_group in generations: + for generation in generation_group or []: + message = getattr(generation, "message", None) + for call in _message_tool_calls(message): + name = _call_name(call) + if not name: + continue + tool_calls.append( + { + "name": name, + "id": _call_id(call), + }, + ) + return tool_calls + except Exception: # noqa: BLE001 - event extraction must not break runs. + logger.debug("failed to extract llm_decision tool calls", exc_info=True) + return [] + + +def _tool_message_name(message: Any) -> str | None: + if isinstance(message, dict): + name = message.get("name") + role = message.get("role") or message.get("type") + if name and role in {"tool", "tool_message", "ToolMessage"}: + return str(name) + return str(name) if name and not _message_tool_calls(message) else None + name = getattr(message, "name", None) + message_type = getattr(message, "type", None) + if name and message_type == "tool": + return str(name) + return str(name) if name and not _message_tool_calls(message) else None + + +def _call_name(call: Any) -> str | None: + if isinstance(call, dict): + if call.get("name"): + return str(call["name"]) + function = call.get("function") + if isinstance(function, dict) and function.get("name"): + return str(function["name"]) + name = getattr(call, "name", None) + return str(name) if name else None + + +def _call_id(call: Any) -> str | None: + if isinstance(call, dict): + value = call.get("id") or call.get("tool_call_id") + else: + value = getattr(call, "id", None) or getattr(call, "tool_call_id", None) + return str(value) if value else None + + +def _state_messages(state: Any) -> list[Any]: + if isinstance(state, dict): + messages = state.get("messages", []) + else: + messages = getattr(state, "messages", []) + return list(messages or []) + + +def _executed_tool_names(messages: list[Any]) -> tuple[str, ...]: + names: list[str] = [] + for message in messages: + name = _tool_message_name(message) + if name: + names.append(name) + if names: + return tuple(names) + for message in messages: + for call in _message_tool_calls(message): + if name := _call_name(call): + names.append(name) + return tuple(names) + + +def _terminal_tool_name( + executed_tool_names: tuple[str, ...], + terminal_tool_names: Collection[str], +) -> str | None: + terminal = set(terminal_tool_names) + for name in reversed(executed_tool_names): + if name in terminal: + return name + return None + + +def _message_text(message: Any) -> str: + content = message.get("content") if isinstance(message, dict) else getattr(message, "content", "") + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text") or item.get("content") or item)) + else: + parts.append(str(item)) + return "\n".join(parts) + return "" if content is None else str(content) + + +def _final_text(messages: list[Any]) -> str: + for message in reversed(messages): + message_type = ( + message.get("role") or message.get("type") + if isinstance(message, dict) + else getattr(message, "type", None) + ) + if message_type in {"ai", "assistant"}: + return _message_text(message) + return _message_text(messages[-1]) if messages else "" + + +def _load_turn_llm( + *, + model_name: str, + base_url: str | None, + api_key: str | None, + argo_user: str | None, +) -> Any: + temperature = 0.0 + try: + return load_chat_model( + settings=LLMSettings( + model=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + temperature=temperature, + ), + ) + except ValueError: + pass + + endpoint = os.getenv("VLLM_BASE_URL", base_url or "") + key = os.getenv("OPENAI_API_KEY", api_key or "dummy_vllm_key") + if not endpoint: + raise ValueError(f"Unsupported model or missing base URL for: {model_name}") + from langchain_openai import ChatOpenAI + + return ChatOpenAI( + **_custom_openai_compatible_kwargs( + model_name=model_name, + temperature=temperature, + base_url=endpoint, + api_key=key, + max_tokens=4000, + top_p=1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + argo_user=argo_user, + ), + ) + + +async def run_turn( + *, + query: str, + tools: list[Any] | None = None, + model_name: str = "gpt-4o-mini", + base_url: str | None = None, + api_key: str | None = None, + argo_user: str | None = None, + system_prompt: str = single_agent_prompt, + formatter_prompt: str = default_formatter_prompt, + structured_output: bool = False, + generate_report: bool = False, + report_prompt: str = default_report_prompt, + recursion_limit: int = 50, + thread_id: str | None = None, + terminal_tool_names: Collection[str] = (), + human_supervised: bool = False, + on_event: EventCallback | None = None, +) -> TurnResult: + """Run one bounded single-agent ChemGraph LangGraph turn.""" + + started = time.time() + thread_id = thread_id or str(uuid.uuid4()) + callbacks = [_TurnEventCallback(on_event, thread_id)] if on_event else [] + event = on_event or (lambda _event, _payload: None) + event( + "workflow_started", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "tool_names": [getattr(tool, "name", str(tool)) for tool in tools or []], + }, + ) + llm = _load_turn_llm( + model_name=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + ) + workflow = construct_single_agent_graph( + llm, + system_prompt, + structured_output, + formatter_prompt, + generate_report, + report_prompt, + tools, + human_supervised=human_supervised, + terminal_tool_names=terminal_tool_names, + ) + config: dict[str, Any] = { + "configurable": {"thread_id": thread_id}, + "recursion_limit": recursion_limit, + } + if callbacks: + config["callbacks"] = callbacks + + last_state: Any = None + try: + async for state in workflow.astream( + {"messages": query}, + stream_mode="values", + config=config, + ): + last_state = state + except Exception as exc: + event( + "workflow_finished", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "status": "failed", + "error": repr(exc), + "duration_s": round(time.time() - started, 3), + }, + ) + raise + + if last_state is None: + raise RuntimeError("ChemGraph turn produced no states.") + + messages = _state_messages(last_state) + executed_tools = _executed_tool_names(messages) + terminal_tool = _terminal_tool_name(executed_tools, terminal_tool_names) + result = TurnResult( + final_text=_final_text(messages), + state=serialize_state(last_state), + executed_tool_names=executed_tools, + terminal_tool=terminal_tool, + thread_id=thread_id, + duration_s=round(time.time() - started, 3), + ) + event( + "workflow_finished", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "status": "completed", + "executed_tool_names": list(result.executed_tool_names), + "terminal_tool": terminal_tool, + "duration_s": result.duration_s, + }, + ) + return result + diff --git a/src/chemgraph/cli/commands.py b/src/chemgraph/cli/commands.py index 70934046..ad351e61 100644 --- a/src/chemgraph/cli/commands.py +++ b/src/chemgraph/cli/commands.py @@ -48,6 +48,7 @@ "graspa_mcp", "rag_agent", "single_agent_xanes", + "single_agent_architector", ] # Common aliases so users can type the "obvious" name. diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index 5d1222d9..788d5aa6 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -179,8 +179,7 @@ def _add_run_args(parser: argparse.ArgumentParser) -> None: default=None, help=( "Write per-run events to this directory so the run is viewable " - "via 'chemgraph dashboard -- --run-dir '. " - "Currently only effective for single-agent workflows." + "via 'chemgraph dashboard -- --run-dir '." ), ) diff --git a/src/chemgraph/cli/trace.py b/src/chemgraph/cli/trace.py index 47d5bf5e..9faae2cc 100644 --- a/src/chemgraph/cli/trace.py +++ b/src/chemgraph/cli/trace.py @@ -1,9 +1,9 @@ """Trace writer for traditional ChemGraph CLI runs. -Bridges the `run_turn` event callback into the dashboard's on-disk -schema (`events.jsonl` + `status.json` + `manifest.json`), so the -existing ``chemgraph dashboard`` browser UI can render a single-agent -ChemGraph run without going through the Academy daemon path. +Bridges ChemGraph run events into the dashboard's on-disk schema +(`events.jsonl` + `status.json` + `manifest.json`), so the existing +``chemgraph dashboard`` browser UI can render a traditional ChemGraph run +without going through the Academy daemon path. """ from __future__ import annotations @@ -88,7 +88,7 @@ def finish(self, *, status: str, error: str | None = None) -> None: self._write_status() def on_event(self, event: str, payload: dict) -> None: - """Callback handed to :func:`chemgraph.agent.llm_agent.run_turn`.""" + """Callback handed to :class:`chemgraph.agent.llm_agent.ChemGraph`.""" self._log.emit( event, # type: ignore[arg-type] run_id=self.run_id, diff --git a/src/chemgraph/graphs/graspa_agent.py b/src/chemgraph/graphs/graspa_agent.py new file mode 100644 index 00000000..578db3d1 --- /dev/null +++ b/src/chemgraph/graphs/graspa_agent.py @@ -0,0 +1,182 @@ + + +from langgraph.graph import StateGraph, START, END +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import ToolNode + +from chemgraph.tools.graspa_tools import run_graspa +from chemgraph.schemas.agent_response import ResponseFormatter +from chemgraph.prompt.single_agent_prompt import ( + single_agent_prompt, + formatter_prompt, +) +from chemgraph.utils.logging_config import setup_logger +from chemgraph.state.state import State + +logger = setup_logger(__name__) + + +def route_tools(state: State): + """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + + Returns + ------- + str + Either 'tools' or 'done' based on the state conditions + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return "done" + + +def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): + """LLM node that processes messages and decides next actions. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for processing + system_prompt : str + The system prompt to guide the LLM's behavior + tools : list, optional + List of tools available to the agent, by default None + + Returns + ------- + dict + Updated state containing the LLM's response + """ + + # Load default tools if no tool is specified. + if tools is None: + tools = [run_graspa] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + + +def ResponseAgent(state: State, llm: ChatOpenAI, formatter_prompt: str): + """An LLM agent responsible for formatting final messag + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for formatting + formatter_prompt : str + The prompt to guide the LLM's formatting behavior + + Returns + ------- + dict + Updated state containing the formatted response + """ + messages = [ + {"role": "system", "content": formatter_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_structured_output = llm.with_structured_output(ResponseFormatter) + response = llm_structured_output.invoke(messages).model_dump_json() + return {"messages": [response]} + + +def construct_graspa_graph( + llm: ChatOpenAI, + system_prompt: str = single_agent_prompt, + structured_output: bool = False, + formatter_prompt: str = formatter_prompt, + tools: list = None, +): + """Construct a geometry optimization graph. + + Parameters + ---------- + llm : ChatOpenAI + The language model to use for the graph + system_prompt : str, optional + The system prompt to guide the LLM's behavior, by default single_agent_prompt + structured_output : bool, optional + Whether to use structured output, by default False + formatter_prompt : str, optional + The prompt to guide the LLM's formatting behavior, by default formatter_prompt + tool: list, optional + The list of tools for the agent, by default None + Returns + ------- + StateGraph + The constructed geometry optimization graph + """ + try: + logger.info("Constructing gRASPA graph") + checkpointer = MemorySaver() + if tools is None: + tools = [run_graspa] + tool_node = ToolNode(tools=tools) + graph_builder = StateGraph(State) + + if not structured_output: + graph_builder.add_node( + "ChemGraphAgent", + lambda state: ChemGraphAgent( + state, llm, system_prompt=system_prompt, tools=tools + ), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_conditional_edges( + "ChemGraphAgent", + route_tools, + {"tools": "tools", "done": END}, + ) + graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_edge(START, "ChemGraphAgent") + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("gRASPA graph construction completed") + return graph + else: + graph_builder.add_node( + "ChemGraphAgent", + lambda state: ChemGraphAgent( + state, llm, system_prompt=system_prompt, tools=tools + ), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_node( + "ResponseAgent", + lambda state: ResponseAgent( + state, llm, formatter_prompt=formatter_prompt + ), + ) + graph_builder.add_conditional_edges( + "ChemGraphAgent", + route_tools, + {"tools": "tools", "done": "ResponseAgent"}, + ) + graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_edge(START, "ChemGraphAgent") + graph_builder.add_edge("ResponseAgent", END) + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("gRASPA graph construction completed") + return graph + + except Exception as e: + logger.error(f"Error constructing graph: {str(e)}") + raise diff --git a/src/chemgraph/graphs/mock_agent.py b/src/chemgraph/graphs/mock_agent.py new file mode 100644 index 00000000..d10441e7 --- /dev/null +++ b/src/chemgraph/graphs/mock_agent.py @@ -0,0 +1,102 @@ +from langgraph.graph import StateGraph, START, END +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from chemgraph.tools.ase_tools import ( + run_ase, + save_atomsdata_to_file, + file_to_atomsdata, +) +from chemgraph.tools.cheminformatics_tools import ( + molecule_name_to_smiles, + smiles_to_atomsdata, +) +from chemgraph.tools.generic_tools import calculator +from chemgraph.prompt.single_agent_prompt import ( + single_agent_prompt, +) +from chemgraph.utils.logging_config import setup_logger +from chemgraph.state.state import State + +logger = setup_logger(__name__) + + +def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): + """LLM node that processes messages and decides next actions. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for processing + system_prompt : str + The system prompt to guide the LLM's behavior + tools : list, optional + List of tools available to the agent, by default None + + Returns + ------- + dict + Updated state containing the LLM's response + """ + + # Load default tools if no tool is specified. + if tools is None: + tools = [ + file_to_atomsdata, + smiles_to_atomsdata, + run_ase, + molecule_name_to_smiles, + save_atomsdata_to_file, + calculator, + ] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + +def construct_mock_agent_graph( + llm: ChatOpenAI, + system_prompt: str = single_agent_prompt, + tools: list = None, +): + """Construct a geometry optimization graph. + + Parameters + ---------- + llm : ChatOpenAI + The language model to use for the graph + system_prompt : str, optional + The system prompt to guide the LLM's behavior, by default single_agent_prompt + tools: list, optional + The list of tools for the main agent, by default None + Returns + ------- + StateGraph + The constructed single agent graph + """ + logger.info("Constructing mock agent graph") + checkpointer = MemorySaver() + if tools is None: + tools = [ + file_to_atomsdata, + smiles_to_atomsdata, + run_ase, + molecule_name_to_smiles, + save_atomsdata_to_file, + calculator, + ] + graph_builder = StateGraph(State) + + graph_builder.add_node( + "ChemGraphAgent", + lambda state: ChemGraphAgent(state, llm, system_prompt=system_prompt, tools=tools), + ) + graph_builder.add_edge(START, "ChemGraphAgent") + graph_builder.add_edge("ChemGraphAgent", END) + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("Mock agent graph construction completed") + return graph diff --git a/src/chemgraph/graphs/python_relp_agent.py b/src/chemgraph/graphs/python_relp_agent.py new file mode 100644 index 00000000..dd8edf98 --- /dev/null +++ b/src/chemgraph/graphs/python_relp_agent.py @@ -0,0 +1,224 @@ +from typing import Annotated +from typing_extensions import TypedDict + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages +from langchain_core.messages import ToolMessage +import json +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from chemgraph.tools.generic_tools import repl_tool +from chemgraph.tools.generic_tools import calculator +from chemgraph.prompt.single_agent_prompt import single_agent_prompt +from chemgraph.utils.logging_config import setup_logger + +logger = setup_logger(__name__) + + +class State(TypedDict): + """Type definition for the state dictionary used in the graph. + + Attributes + ---------- + messages : list + List of messages in the conversation, annotated with add_messages + """ + + messages: Annotated[list, add_messages] + + +class BasicToolNode: + """A node that executes tools requested in the last AIMessage. + + This class processes tool calls from AI messages and executes the corresponding + tools, handling their results and any potential errors. + + Parameters + ---------- + tools : list + List of tool objects that can be called by the node + + Attributes + ---------- + tools_by_name : dict + Dictionary mapping tool names to their corresponding tool objects + """ + + def __init__(self, tools: list) -> None: + """Initialize the tool node. + + Parameters + ---------- + tools : list + Tool objects keyed by their ``name`` attribute. + """ + self.tools_by_name = {tool.name: tool for tool in tools} + + def __call__(self, inputs: State) -> State: + """Execute tools requested in the last message. + + Parameters + ---------- + inputs : State + The current state containing messages + + Returns + ------- + State + Updated state containing tool execution results + + Raises + ------ + ValueError + If no message is found in the input state + """ + if messages := inputs.get("messages", []): + message = messages[-1] + else: + raise ValueError("No message found in input") + + outputs = [] + for tool_call in message.tool_calls: + try: + tool_name = tool_call.get("name") + if not tool_name or tool_name not in self.tools_by_name: + raise ValueError(f"Invalid tool name: {tool_name}") + + tool_result = self.tools_by_name[tool_name].invoke(tool_call.get("args", {})) + + # Handle different types of tool results + result_content = ( + tool_result.dict() + if hasattr(tool_result, "dict") + else (tool_result if isinstance(tool_result, dict) else str(tool_result)) + ) + + outputs.append( + ToolMessage( + content=json.dumps(result_content), + name=tool_name, + tool_call_id=tool_call.get("id", ""), + ) + ) + + except Exception as e: + outputs.append( + ToolMessage( + content=json.dumps({"error": str(e)}), + name=tool_name if tool_name else "unknown_tool", + tool_call_id=tool_call.get("id", ""), + ) + ) + return {"messages": outputs} + + +def route_tools(state: State): + """Route to the 'tools' node if the last message has tool calls; otherwise, route to END. + + Parameters + ---------- + state : State + The current state containing messages + + Returns + ------- + str + Either 'tools' or END based on the presence of tool calls + + Raises + ------ + ValueError + If no messages are found in the input state + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return END + + +def CompChemAgent(state: State, llm: ChatOpenAI, system_prompt=single_agent_prompt, tools=None): + """LLM node that processes messages and decides next actions. + + Parameters + ---------- + state : State + The current state containing messages + llm : ChatOpenAI + The language model to use for processing + system_prompt : str, optional + The system prompt to guide the LLM's behavior, + by default single_agent_prompt + tools : list, optional + List of tools available to the agent, by default None + + Returns + ------- + dict + Updated state containing the LLM's response + """ + if tools is None: + tools = [] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + + +def construct_relp_graph(llm: ChatOpenAI, system_prompt=single_agent_prompt): + """Construct a graph for REPL-based Python execution workflow. + + This function creates a state graph that implements a workflow for executing + Python code through a REPL interface, using LLM agents and tools. + + Parameters + ---------- + llm : ChatOpenAI + The language model to use in the workflow + system_prompt : str, optional + The system prompt to guide the LLM's behavior, + by default single_agent_prompt + + Returns + ------- + StateGraph + A compiled state graph implementing the REPL workflow + + Raises + ------ + Exception + If there is an error during graph construction + """ + try: + logger.info("Constructing geometry optimization graph") + checkpointer = MemorySaver() + tools = [ + repl_tool, + calculator, + ] + tool_node = BasicToolNode(tools=tools) + graph_builder = StateGraph(State) + graph_builder.add_node( + "CompChemAgent", + lambda state: CompChemAgent(state, llm, system_prompt=system_prompt, tools=tools), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_conditional_edges( + "CompChemAgent", + route_tools, + {"tools": "tools", END: END}, + ) + graph_builder.add_edge("tools", "CompChemAgent") + graph_builder.add_edge(START, "CompChemAgent") + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("Graph construction completed") + return graph + except Exception as e: + logger.error(f"Error constructing graph: {str(e)}") + raise diff --git a/src/chemgraph/graphs/rag_agent.py b/src/chemgraph/graphs/rag_agent.py new file mode 100644 index 00000000..91611166 --- /dev/null +++ b/src/chemgraph/graphs/rag_agent.py @@ -0,0 +1,245 @@ +"""LangGraph workflow for the RAG (Retrieval-Augmented Generation) agent. + +This graph combines document retrieval tools (load_document, +query_knowledge_base) with the standard chemistry tools so the agent +can answer questions grounded in user-provided text documents *and* +run molecular simulations when needed. + +Graph structure +--------------- + + START + | + v + RAGAgent <-------+ + | | + (route) | + / \\ | + v v | + tools done-->END | + | | + +----------------+ + +The agent loops through a ReAct cycle: it can call any combination of +RAG tools and chemistry tools, inspect the results, and decide whether +to call more tools or produce a final answer. +""" + +from langgraph.graph import StateGraph, START, END +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import ToolNode + +from chemgraph.tools.rag_tools import load_document, query_knowledge_base +from chemgraph.tools.ase_tools import ( + run_ase, + save_atomsdata_to_file, + file_to_atomsdata, +) +from chemgraph.tools.cheminformatics_tools import ( + molecule_name_to_smiles, + smiles_to_coordinate_file, +) +from chemgraph.tools.generic_tools import calculator +from chemgraph.prompt.rag_prompt import rag_agent_prompt +from chemgraph.state.state import State +from chemgraph.utils.logging_config import setup_logger + +logger = setup_logger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers (reuse the repeated-tool-call detection from single_agent) +# --------------------------------------------------------------------------- +def _tool_call_signature(tool_calls) -> tuple: + """Create a comparable signature for a list of tool calls. + + Parameters + ---------- + tool_calls : list + Tool-call dictionaries from an AI message. + + Returns + ------- + tuple + Deterministic signature of tool names and arguments. + """ + signature = [] + for call in tool_calls or []: + name = call.get("name") if isinstance(call, dict) else None + args = call.get("args", {}) if isinstance(call, dict) else {} + if isinstance(args, dict): + args_sig = tuple(sorted(args.items())) + else: + args_sig = str(args) + signature.append((name, args_sig)) + return tuple(signature) + + +def _is_repeated_tool_cycle(messages) -> bool: + """Detect if the most recent AI tool-call set repeats the previous one. + + Parameters + ---------- + messages : list + Message history to inspect. + + Returns + ------- + bool + ``True`` when the last two AI tool-call sets are identical. + """ + ai_with_calls = [ + m + for m in messages + if hasattr(m, "tool_calls") and getattr(m, "tool_calls", None) + ] + if len(ai_with_calls) < 2: + return False + last = _tool_call_signature(ai_with_calls[-1].tool_calls) + prev = _tool_call_signature(ai_with_calls[-2].tool_calls) + return bool(last) and last == prev + + +# --------------------------------------------------------------------------- +# Routing +# --------------------------------------------------------------------------- +def route_tools(state: State): + """Route to 'tools' if the last message has tool calls, else 'done'. + + Parameters + ---------- + state : State + Current graph state. + + Returns + ------- + str + ``"tools"`` or ``"done"``. + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state: {state}") + + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + if not isinstance(state, list) and _is_repeated_tool_cycle(messages): + return "done" + return "tools" + return "done" + + +# --------------------------------------------------------------------------- +# Agent node +# --------------------------------------------------------------------------- +def RAGAgent(state: State, llm, system_prompt: str, tools=None): + """LLM node that can retrieve from documents and run chemistry tools. + + Parameters + ---------- + state : State + Current graph state with messages. + llm : BaseChatModel + The bound language model. + system_prompt : str + System prompt guiding the agent's behaviour. + tools : list, optional + Tools available to the agent. Uses the default RAG + chemistry + tool set when ``None``. + + Returns + ------- + dict + Updated state with the LLM's response appended to messages. + """ + if tools is None: + tools = _default_tools() + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + + +# --------------------------------------------------------------------------- +# Default tool set +# --------------------------------------------------------------------------- +def _default_tools(): + """Return the combined RAG + chemistry tool list.""" + return [ + # RAG tools + load_document, + query_knowledge_base, + # Chemistry tools + file_to_atomsdata, + smiles_to_coordinate_file, + run_ase, + molecule_name_to_smiles, + save_atomsdata_to_file, + calculator, + ] + + +# --------------------------------------------------------------------------- +# Graph constructor +# --------------------------------------------------------------------------- +def construct_rag_agent_graph( + llm, + system_prompt: str = rag_agent_prompt, + tools: list = None, +): + """Construct a RAG agent graph with document retrieval and chemistry tools. + + Parameters + ---------- + llm : BaseChatModel + The language model to power the agent. + system_prompt : str, optional + System prompt for the RAG agent, by default ``rag_agent_prompt``. + tools : list, optional + Custom tool list. When ``None`` the default RAG + chemistry + tools are used. + + Returns + ------- + CompiledStateGraph + The compiled LangGraph workflow ready for execution. + """ + try: + logger.info("Constructing RAG agent graph") + checkpointer = MemorySaver() + + if tools is None: + tools = _default_tools() + + tool_node = ToolNode(tools=tools) + graph_builder = StateGraph(State) + + # Nodes + graph_builder.add_node( + "RAGAgent", + lambda state: RAGAgent( + state, llm, system_prompt=system_prompt, tools=tools + ), + ) + graph_builder.add_node("tools", tool_node) + + # Edges + graph_builder.add_edge(START, "RAGAgent") + graph_builder.add_conditional_edges( + "RAGAgent", + route_tools, + {"tools": "tools", "done": END}, + ) + graph_builder.add_edge("tools", "RAGAgent") + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("RAG agent graph construction completed") + return graph + + except Exception as e: + logger.error(f"Error constructing RAG agent graph: {e}") + raise diff --git a/src/chemgraph/graphs/single_agent_architector.py b/src/chemgraph/graphs/single_agent_architector.py new file mode 100644 index 00000000..758ed5e6 --- /dev/null +++ b/src/chemgraph/graphs/single_agent_architector.py @@ -0,0 +1,162 @@ +from langgraph.graph import StateGraph, START, END +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import ToolNode +from chemgraph.tools.cheminformatics_tools import ( + molecule_name_to_smiles, + smiles_to_coordinate_file, +) + +try: + from chemgraph.tools.architector_tools import ( + visualize_molecule, + image_to_connection_points, + build_metal_complex, + ) +except ModuleNotFoundError: + def _missing_architector_tool(*_args, **_kwargs): + raise ImportError( + "single_agent_architector requires chemgraph.tools.architector_tools, " + "which is not available in this installation." + ) + + def visualize_molecule(smiles: str) -> str: + """Visualize a molecule for Architector workflows.""" + return _missing_architector_tool(smiles) + + def image_to_connection_points(image_path: str) -> str: + """Extract connection points from an image for Architector workflows.""" + return _missing_architector_tool(image_path) + + def build_metal_complex(specification: str) -> str: + """Build a metal complex for Architector workflows.""" + return _missing_architector_tool(specification) +from chemgraph.utils.logging_config import setup_logger +from chemgraph.state.state import State + +logger = setup_logger(__name__) + +single_agent_prompt = "" + +def route_tools(state: State): + """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + + Returns + ------- + str + Either 'tools' or 'done' based on the state conditions + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return "done" + + +def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): + """LLM node that processes messages and decides next actions. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for processing + system_prompt : str + The system prompt to guide the LLM's behavior + tools : list, optional + List of tools available to the agent, by default None + + Returns + ------- + dict + Updated state containing the LLM's response + """ + + # Load default tools if no tool is specified. + if tools is None: + tools = [ + molecule_name_to_smiles, + smiles_to_coordinate_file, + visualize_molecule, + image_to_connection_points, + build_metal_complex + ] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + +def construct_single_agent_architector_graph( + llm: ChatOpenAI, + system_prompt: str = "", + tools: list = None, +): + """Construct a geometry optimization graph. + + Parameters + ---------- + llm : ChatOpenAI + The language model to use for the graph + system_prompt : str, optional + The system prompt to guide the LLM's behavior, by default single_agent_prompt + structured_output : bool, optional + Whether to use structured output, by default False + formatter_prompt : str, optional + The prompt to guide the LLM's formatting behavior, by default formatter_prompt + generate_report: bool, optional + Whether to generate a report, by default False + report_prompt: str, optional + The prompt to guide the LLM's report generation behavior, by default report_prompt + tool: list, optional + The list of tools for the main agent, by default None + Returns + ------- + StateGraph + The constructed single agent graph + """ + try: + logger.info("Constructing single agent graph") + checkpointer = MemorySaver() + if tools is None: + tools = [ + molecule_name_to_smiles, + smiles_to_coordinate_file, + visualize_molecule, + image_to_connection_points, + build_metal_complex + ] + tool_node = ToolNode(tools=tools) + graph_builder = StateGraph(State) + + graph_builder.add_node( + "ChemGraphAgent", + lambda state: ChemGraphAgent(state, llm, system_prompt=system_prompt, tools=tools), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_edge(START, "ChemGraphAgent") + graph_builder.add_conditional_edges( + "ChemGraphAgent", + route_tools, + {"tools": "tools", "done": END}, + ) + graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_edge("ChemGraphAgent", END) + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("Graph construction completed") + return graph + except Exception as e: + logger.error(f"Error constructing graph: {str(e)}") + raise diff --git a/src/chemgraph/graphs/single_agent_mcp.py b/src/chemgraph/graphs/single_agent_mcp.py new file mode 100644 index 00000000..f858a9c0 --- /dev/null +++ b/src/chemgraph/graphs/single_agent_mcp.py @@ -0,0 +1,116 @@ +from typing import List, Any + +from langgraph.graph import StateGraph, START, END +from langchain_openai import ChatOpenAI +from langgraph.prebuilt import ToolNode +from langgraph.checkpoint.memory import MemorySaver + +from chemgraph.prompt.single_agent_prompt import ( + single_agent_prompt, +) +from chemgraph.utils.logging_config import setup_logger +from chemgraph.state.state import State + +logger = setup_logger(__name__) + + +def route_tools(state: State) -> str: + """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + + Returns + ------- + str + Either 'tools' or 'done' based on the state conditions + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + return "tools" + return "done" + + +def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): + """LLM node that processes messages and decides next actions. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for processing + system_prompt : str + The system prompt to guide the LLM's behavior + tools : list, optional + List of tools available to the agent, by default None + + Returns + ------- + dict + Updated state containing the LLM's response + """ + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + + +def construct_single_agent_mcp_graph( + llm: ChatOpenAI, + system_prompt: str = single_agent_prompt, + tools: List[Any] = None, +): + """Construct a geometry optimization graph. + + Parameters + ---------- + llm : ChatOpenAI + The language model to use for the graph + system_prompt : str, optional + The system prompt to guide the LLM's behavior, by default single_agent_prompt + Returns + ------- + StateGraph + The constructed single agent graph + """ + if not tools: + raise ValueError( + "No MCP tools loaded. Ensure MCP servers are configured and reachable." + ) + logger.info("Constructing single agent MCP graph (sync)") + + checkpointer = MemorySaver() + tool_node = ToolNode(tools=tools) + graph_builder = StateGraph(State) + + graph_builder.add_node( + "ChemGraphAgent", + lambda state: ChemGraphAgent( + state, llm, system_prompt=system_prompt, tools=tools + ), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_edge(START, "ChemGraphAgent") + + graph_builder.add_conditional_edges( + "ChemGraphAgent", + route_tools, + {"tools": "tools", "done": END}, + ) + graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_edge("ChemGraphAgent", END) + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("Graph construction completed") + return graph diff --git a/src/chemgraph/graphs/single_agent_xanes.py b/src/chemgraph/graphs/single_agent_xanes.py new file mode 100644 index 00000000..1c3935d8 --- /dev/null +++ b/src/chemgraph/graphs/single_agent_xanes.py @@ -0,0 +1,272 @@ +import os + +from langgraph.graph import StateGraph, START, END +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import ToolNode +from chemgraph.tools.cheminformatics_tools import ( + molecule_name_to_smiles, + smiles_to_coordinate_file, +) +from chemgraph.tools.ase_tools import run_ase +from chemgraph.tools.xanes_tools import ( + run_xanes, + fetch_xanes_data, + plot_xanes_data,) +from chemgraph.schemas.agent_response import ResponseFormatter +from chemgraph.prompt.xanes_prompt import ( + xanes_single_agent_prompt, + xanes_formatter_prompt, +) +from chemgraph.utils.logging_config import setup_logger +from chemgraph.state.state import State + +logger = setup_logger(__name__) + + +def _tool_call_signature(tool_calls) -> tuple: + """Create a comparable signature for a list of tool calls. + + Parameters + ---------- + tool_calls : list + Tool-call dictionaries from an AI message. + + Returns + ------- + tuple + Deterministic signature of tool names and arguments. + """ + signature = [] + for call in tool_calls or []: + name = call.get("name") if isinstance(call, dict) else None + args = call.get("args", {}) if isinstance(call, dict) else {} + if isinstance(args, dict): + args_sig = tuple(sorted(args.items())) + else: + args_sig = str(args) + signature.append((name, args_sig)) + return tuple(signature) + + +def _is_repeated_tool_cycle(messages) -> bool: + """Detect if the most recent AI tool-call set repeats the previous one. + + Parameters + ---------- + messages : list + Message history to inspect. + + Returns + ------- + bool + ``True`` when the last two AI tool-call sets are identical. + """ + ai_with_calls = [] + for message in messages: + if hasattr(message, "tool_calls") and getattr(message, "tool_calls", None): + ai_with_calls.append(message) + + if len(ai_with_calls) < 2: + return False + + last_calls = _tool_call_signature(ai_with_calls[-1].tool_calls) + prev_calls = _tool_call_signature(ai_with_calls[-2].tool_calls) + return bool(last_calls) and last_calls == prev_calls + + +def route_tools(state: State): + """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + + Returns + ------- + str + Either 'tools' or 'done' based on the state conditions + """ + if isinstance(state, list): + ai_message = state[-1] + elif messages := state.get("messages", []): + ai_message = messages[-1] + else: + raise ValueError(f"No messages found in input state to tool_edge: {state}") + if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + if not isinstance(state, list) and _is_repeated_tool_cycle(messages): + return "done" + return "tools" + return "done" + + +def XANESAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): + """LLM node for XANES workflows that processes messages and decides next actions. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for processing + system_prompt : str + The system prompt to guide the LLM's behavior + tools : list, optional + List of tools available to the agent, by default None + + Returns + ------- + dict + Updated state containing the LLM's response + """ + if tools is None: + tools = [ + molecule_name_to_smiles, + smiles_to_coordinate_file, + run_ase, + run_xanes, + fetch_xanes_data, + plot_xanes_data, + ] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_with_tools = llm.bind_tools(tools=tools) + return {"messages": [llm_with_tools.invoke(messages)]} + + +def ResponseAgent(state: State, llm: ChatOpenAI, formatter_prompt: str): + """An LLM agent responsible for formatting final message. + + Parameters + ---------- + state : State + The current state containing messages and remaining steps + llm : ChatOpenAI + The language model to use for formatting + formatter_prompt : str + The prompt to guide the LLM's formatting behavior + + Returns + ------- + dict + Updated state containing the formatted response + """ + messages = [ + {"role": "system", "content": formatter_prompt}, + {"role": "user", "content": f"{state['messages']}"}, + ] + llm_structured_output = llm.with_structured_output(ResponseFormatter) + response = llm_structured_output.invoke(messages).model_dump_json() + return {"messages": [response]} + + +def construct_single_agent_xanes_graph( + llm: ChatOpenAI, + system_prompt: str = xanes_single_agent_prompt, + structured_output: bool = False, + formatter_prompt: str = xanes_formatter_prompt, + tools: list = None, +): + """Construct a single-agent graph for XANES/FDMNES workflows. + + Parameters + ---------- + llm : ChatOpenAI + The language model to use for the graph + system_prompt : str, optional + The system prompt to guide the LLM's behavior, + by default xanes_single_agent_prompt + structured_output : bool, optional + Whether to use structured output, by default False + formatter_prompt : str, optional + The prompt to guide the LLM's formatting behavior, + by default xanes_formatter_prompt + tools : list, optional + The list of tools for the main agent, by default None + + Returns + ------- + StateGraph + The constructed single agent XANES graph + """ + try: + logger.info("Constructing single agent XANES graph") + + if not os.environ.get("MP_API_KEY"): + logger.warning( + "MP_API_KEY environment variable is not set. " + "The fetch_xanes_data tool will require an API key " + "to be passed explicitly." + ) + if not os.environ.get("FDMNES_EXE"): + logger.warning( + "FDMNES_EXE environment variable is not set. " + "The run_xanes tool will not work without the FDMNES executable." + ) + + checkpointer = MemorySaver() + if tools is None: + tools = [ + molecule_name_to_smiles, + smiles_to_coordinate_file, + run_ase, + run_xanes, + fetch_xanes_data, + plot_xanes_data, + ] + tool_node = ToolNode(tools=tools) + graph_builder = StateGraph(State) + + if not structured_output: + graph_builder.add_node( + "XANESAgent", + lambda state: XANESAgent( + state, llm, system_prompt=system_prompt, tools=tools + ), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_edge(START, "XANESAgent") + graph_builder.add_conditional_edges( + "XANESAgent", + route_tools, + {"tools": "tools", "done": END}, + ) + graph_builder.add_edge("tools", "XANESAgent") + graph_builder.add_edge("XANESAgent", END) + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("XANES graph construction completed") + return graph + else: + graph_builder.add_node( + "XANESAgent", + lambda state: XANESAgent( + state, llm, system_prompt=system_prompt, tools=tools + ), + ) + graph_builder.add_node("tools", tool_node) + graph_builder.add_node( + "ResponseAgent", + lambda state: ResponseAgent( + state, llm, formatter_prompt=formatter_prompt + ), + ) + graph_builder.add_conditional_edges( + "XANESAgent", + route_tools, + {"tools": "tools", "done": "ResponseAgent"}, + ) + graph_builder.add_edge("tools", "XANESAgent") + graph_builder.add_edge(START, "XANESAgent") + graph_builder.add_edge("ResponseAgent", END) + + graph = graph_builder.compile(checkpointer=checkpointer) + logger.info("XANES graph construction completed") + return graph + + except Exception as e: + logger.error(f"Error constructing XANES graph: {str(e)}") + raise diff --git a/tests/test_agent_session.py b/tests/test_agent_session.py index a9799c97..40d987c9 100644 --- a/tests/test_agent_session.py +++ b/tests/test_agent_session.py @@ -14,6 +14,7 @@ import os import pytest +from types import SimpleNamespace from unittest.mock import Mock, patch from chemgraph.agent.llm_agent import ChemGraph, TurnResult, serialize_state @@ -44,30 +45,47 @@ def tmp_db(tmp_path): return str(tmp_path / "test_sessions.db") +class _GraphStreamCompatibleWorkflow: + def __init__(self): + self.side_effect = self.default_graph_stream + self.last_state = {"messages": []} + + async def default_graph_stream(self, **kwargs): + ai_msg = Mock() + ai_msg.type = "ai" + ai_msg.content = "Test response" + return TurnResult( + final_text="Test response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) + + async def astream(self, inputs, *, stream_mode, config): + result = await self.side_effect( + query=inputs.get("messages"), + thread_id=str(config["configurable"]["thread_id"]), + ) + self.last_state = result.state + yield self.last_state + + def get_state(self, config): + return SimpleNamespace(values=self.last_state) + + @pytest.fixture def mock_agent_patches(): - """Patch LLM loading and run_turn for fast agent creation.""" + """Patch LLM loading and graph streaming for fast agent creation.""" with ( patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load, - patch("chemgraph.agent.llm_agent.run_turn") as mock_run_turn, + patch("chemgraph.agent.llm_agent.construct_single_agent_graph") as mock_constructor, ): mock_load.return_value = Mock() - - async def default_run_turn(**kwargs): - ai_msg = Mock() - ai_msg.type = "ai" - ai_msg.content = "Test response" - return TurnResult( - final_text="Test response", - state={"messages": [ai_msg]}, - executed_tool_names=(), - terminal_tool=None, - thread_id=kwargs["thread_id"], - duration_s=0.0, - ) - - mock_run_turn.side_effect = default_run_turn - yield mock_load, mock_run_turn + workflow = _GraphStreamCompatibleWorkflow() + mock_constructor.return_value = workflow + yield mock_load, workflow def _make_agent(clean_env, mock_agent_patches, tmp_db, **kwargs): @@ -416,7 +434,7 @@ def test_no_overwrite_same_second( class TestResumeFrom: def _make_streamable_agent(self, clean_env, mock_agent_patches, tmp_db): - """Create an agent whose run path is mocked through run_turn.""" + """Create an agent whose run path is mocked through graph stream.""" return _make_agent(clean_env, mock_agent_patches, tmp_db) @pytest.mark.asyncio @@ -434,10 +452,10 @@ async def test_resume_prepends_context(self, clean_env, mock_agent_patches, tmp_ # Create second agent sharing the same DB agent2 = self._make_streamable_agent(clean_env, mock_agent_patches, tmp_db) - # Track what query is passed to run_turn. + # Track what query is passed to graph stream. captured_inputs = [] - async def tracking_run_turn(**kwargs): + async def tracking_graph_stream(**kwargs): captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" @@ -451,7 +469,7 @@ async def tracking_run_turn(**kwargs): duration_s=0.0, ) - mock_agent_patches[1].side_effect = tracking_run_turn + mock_agent_patches[1].side_effect = tracking_graph_stream await agent2.run("Continue the analysis", resume_from=session_id) @@ -469,7 +487,7 @@ async def test_resume_from_nonexistent_session( captured_inputs = [] - async def tracking_run_turn(**kwargs): + async def tracking_graph_stream(**kwargs): captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" @@ -483,7 +501,7 @@ async def tracking_run_turn(**kwargs): duration_s=0.0, ) - mock_agent_patches[1].side_effect = tracking_run_turn + mock_agent_patches[1].side_effect = tracking_graph_stream await agent.run("Hello", resume_from="nonexistent_id") @@ -500,7 +518,7 @@ async def test_resume_from_ignored_when_memory_disabled( captured_inputs = [] - async def tracking_run_turn(**kwargs): + async def tracking_graph_stream(**kwargs): captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" @@ -514,7 +532,7 @@ async def tracking_run_turn(**kwargs): duration_s=0.0, ) - mock_agent_patches[1].side_effect = tracking_run_turn + mock_agent_patches[1].side_effect = tracking_graph_stream await agent.run("Hello", resume_from="some_id") @@ -544,7 +562,7 @@ async def test_full_lifecycle(self, clean_env, mock_agent_patches, tmp_db): final_state = {"messages": [human_msg, ai_msg]} - async def mock_run_turn(**kwargs): + async def mock_graph_stream(**kwargs): return TurnResult( final_text=ai_msg.content, state=final_state, @@ -554,7 +572,7 @@ async def mock_run_turn(**kwargs): duration_s=0.0, ) - mock_agent_patches[1].side_effect = mock_run_turn + mock_agent_patches[1].side_effect = mock_graph_stream # Step 1: Run await agent.run("Calculate energy of H2") @@ -578,7 +596,7 @@ async def mock_run_turn(**kwargs): del os.environ["CHEMGRAPH_LOG_DIR"] agent2 = _make_agent(clean_env, mock_agent_patches, tmp_db) - mock_agent_patches[1].side_effect = mock_run_turn + mock_agent_patches[1].side_effect = mock_graph_stream await agent2.run("Now optimize H2", resume_from=agent.uuid) diff --git a/tests/test_graph_constructors.py b/tests/test_graph_constructors.py index 6efcaef3..72e58b55 100644 --- a/tests/test_graph_constructors.py +++ b/tests/test_graph_constructors.py @@ -1,5 +1,94 @@ -from tests.test_graphs import ( - test_legacy_graph_constructor_is_called, - test_run_turn_workflow_tool_and_prompt_wiring, - test_single_agent_initialization_injects_calculator_availability, -) +import pytest +from chemgraph.agent.llm_agent import ChemGraph + + +WORKFLOWS = [ + "single_agent", + "multi_agent", + "python_relp", + "graspa", + "mock_agent", + "single_agent_mcp", + "graspa_mcp", + "single_agent_xanes", +] + + +@pytest.mark.parametrize("workflow_type", WORKFLOWS) +def test_constructor_is_called(monkeypatch, workflow_type): + called = {} + + def fake_constructor(*args, **kwargs): + called["args"] = (args, kwargs) + return f"WORKFLOW-SENTINEL-{workflow_type}" + + # Patch the constructor name used by chemgraph.agent.llm_agent + constructor_attr = { + "single_agent": "construct_single_agent_graph", + "multi_agent": "construct_multi_agent_graph", + "python_relp": "construct_relp_graph", + "graspa": "construct_graspa_graph", + "mock_agent": "construct_mock_agent_graph", + "single_agent_mcp": "construct_single_agent_mcp_graph", + "graspa_mcp": "construct_graspa_mcp_graph", + "single_agent_xanes": "construct_single_agent_xanes_graph", + }[workflow_type] + + monkeypatch.setattr( + f"chemgraph.agent.llm_agent.{constructor_attr}", + fake_constructor, + ) + + # Ensure model loading is deterministic and doesn't call external APIs + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda model_name, temperature, base_url=None: "FAKE_LLM", + ) + + # For MCP workflows some constructors expect tools; pass a non-empty list + kwargs = {} + if workflow_type in {"single_agent_mcp", "graspa_mcp"}: + kwargs["tools"] = ["DUMMY_TOOL"] + kwargs["data_tools"] = ["DUMMY_TOOL"] + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type=workflow_type, + enable_memory=False, + **kwargs, + ) + assert cg.workflow == f"WORKFLOW-SENTINEL-{workflow_type}" + args_tuple, kwargs_called = called["args"] + if args_tuple: + assert args_tuple[0] == "FAKE_LLM" + else: + assert kwargs_called.get("llm") == "FAKE_LLM" + + +def test_single_agent_initialization_injects_calculator_availability(monkeypatch): + called = {} + + def fake_constructor(*args, **kwargs): + called["args"] = (args, kwargs) + return "WORKFLOW-SENTINEL-single_agent" + + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + fake_constructor, + ) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda model_name, temperature, base_url=None: "FAKE_LLM", + ) + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent", + enable_memory=False, + ) + + args_tuple, _ = called["args"] + system_prompt = args_tuple[1] + assert "Calculator availability detected during ChemGraph initialization" in system_prompt + assert cg.default_calculator in system_prompt + assert cg.default_calculator in cg.available_calculators diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 1bf66180..d0ef2d53 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -1,8 +1,10 @@ +from types import SimpleNamespace + import pytest from langchain_core.messages import AIMessage from chemgraph.agent import llm_agent -from chemgraph.agent.llm_agent import ChemGraph, TurnResult +from chemgraph.agent.llm_agent import ChemGraph class _DummyTool: @@ -10,22 +12,49 @@ def __init__(self, name): self.name = name -def _tool_names(tools): - return [getattr(tool, "name", str(tool)) for tool in tools or []] +class _FakeWorkflow: + def __init__(self): + self.astream_calls = [] + self.last_state = {"messages": [AIMessage(content="done")]} + + async def astream(self, inputs, *, stream_mode, config): + self.astream_calls.append( + {"inputs": inputs, "stream_mode": stream_mode, "config": config}, + ) + for callback in config.get("callbacks", []): + callback.on_chat_model_start({"name": "FakeChatModel"}, [["hello"]]) + callback.on_llm_end(SimpleNamespace(generations=[])) + yield self.last_state + + def get_state(self, config): + return SimpleNamespace(values=self.last_state) @pytest.mark.parametrize( ("workflow_type", "constructor_attr", "kwargs"), [ + ("single_agent", "construct_single_agent_graph", {}), ("multi_agent", "construct_multi_agent_graph", {}), + ("python_relp", "construct_relp_graph", {}), + ("graspa", "construct_graspa_graph", {}), + ("graspa_agent", "construct_graspa_graph", {}), + ("mock_agent", "construct_mock_agent_graph", {}), + ( + "single_agent_mcp", + "construct_single_agent_mcp_graph", + {"tools": [_DummyTool("mcp_tool")]}, + ), ( "graspa_mcp", "construct_graspa_mcp_graph", {"tools": [_DummyTool("executor")], "data_tools": [_DummyTool("analysis")]}, ), + ("rag_agent", "construct_rag_agent_graph", {}), + ("single_agent_xanes", "construct_single_agent_xanes_graph", {}), + ("single_agent_architector", "construct_single_agent_architector_graph", {}), ], ) -def test_legacy_graph_constructor_is_called( +def test_graph_constructor_is_called( monkeypatch, tmp_path, workflow_type, @@ -33,11 +62,12 @@ def test_legacy_graph_constructor_is_called( kwargs, ): called = {} + workflow = _FakeWorkflow() def fake_constructor(*args, **constructor_kwargs): called["args"] = args called["kwargs"] = constructor_kwargs - return f"WORKFLOW-SENTINEL-{workflow_type}" + return workflow monkeypatch.setattr(f"chemgraph.agent.llm_agent.{constructor_attr}", fake_constructor) monkeypatch.setattr( @@ -53,90 +83,21 @@ def fake_constructor(*args, **constructor_kwargs): **kwargs, ) - assert cg.workflow == f"WORKFLOW-SENTINEL-{workflow_type}" + assert cg.workflow is workflow args = called.get("args", ()) constructor_kwargs = called.get("kwargs", {}) assert (args and args[0] == "FAKE_LLM") or constructor_kwargs.get("llm") == "FAKE_LLM" -@pytest.mark.parametrize( - ("workflow_type", "kwargs", "expected_extra_tools", "expected_prompt"), - [ - ("single_agent", {"tools": [_DummyTool("custom")]}, [], None), - ("python_relp", {"tools": [_DummyTool("custom")]}, ["python_repl", "calculator"], None), - ("graspa", {"tools": [_DummyTool("custom")]}, ["run_graspa"], None), - ( - "mock_agent", - {"tools": [_DummyTool("custom")]}, - [ - "file_to_atomsdata", - "smiles_to_atomsdata", - "run_ase", - "molecule_name_to_smiles", - "save_atomsdata_to_file", - "calculator", - ], - None, - ), - ( - "single_agent_mcp", - {"tools": [_DummyTool("mcp_tool")], "data_tools": [_DummyTool("data_tool")]}, - ["data_tool"], - None, - ), - ( - "rag_agent", - {"tools": [_DummyTool("custom")]}, - [ - "load_document", - "query_knowledge_base", - "file_to_atomsdata", - "smiles_to_coordinate_file", - "run_ase", - "molecule_name_to_smiles", - "save_atomsdata_to_file", - "calculator", - ], - llm_agent.rag_agent_prompt, - ), - ( - "single_agent_xanes", - {"tools": [_DummyTool("custom")]}, - [ - "molecule_name_to_smiles", - "smiles_to_coordinate_file", - "run_ase", - "run_xanes", - "fetch_xanes_data", - "plot_xanes_data", - ], - llm_agent.default_xanes_single_agent_prompt, - ), - ], -) @pytest.mark.asyncio -async def test_run_turn_workflow_tool_and_prompt_wiring( - monkeypatch, - tmp_path, - workflow_type, - kwargs, - expected_extra_tools, - expected_prompt, -): - captured = {} - - async def fake_run_turn(**run_kwargs): - captured.update(run_kwargs) - return TurnResult( - final_text="done", - state={"messages": [AIMessage(content="done")]}, - executed_tool_names=(), - terminal_tool=None, - thread_id=run_kwargs["thread_id"], - duration_s=0.0, - ) +async def test_graph_backed_run_uses_astream_and_emits_events(monkeypatch, tmp_path): + workflow = _FakeWorkflow() + events = [] - monkeypatch.setattr("chemgraph.agent.llm_agent.run_turn", fake_run_turn) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + lambda *_args, **_kwargs: workflow, + ) monkeypatch.setattr( "chemgraph.agent.llm_agent.load_openai_model", lambda **_kwargs: "FAKE_LLM", @@ -144,23 +105,37 @@ async def fake_run_turn(**run_kwargs): cg = ChemGraph( model_name="gpt-4o-mini", - workflow_type=workflow_type, + workflow_type="single_agent", enable_memory=False, log_dir=str(tmp_path / "logs"), - **kwargs, + return_option="last_message", + on_event=lambda event, payload: events.append((event, payload)), ) response = await cg.run("hello", config={"thread_id": "test-thread"}) assert response.content == "done" - tool_names = _tool_names(captured["tools"]) - assert tool_names[0] == list(kwargs["tools"])[0].name - for name in expected_extra_tools: - assert name in tool_names - if expected_prompt is not None: - assert captured["system_prompt"] == expected_prompt + assert workflow.astream_calls[0]["inputs"] == {"messages": "hello"} + assert workflow.astream_calls[0]["stream_mode"] == "values" + assert workflow.astream_calls[0]["config"]["configurable"]["thread_id"] == "test-thread" + assert [event for event, _payload in events] == [ + "workflow_started", + "llm_call_started", + "llm_call_finished", + "workflow_finished", + ] def test_single_agent_initialization_injects_calculator_availability(monkeypatch, tmp_path): + called = {} + + def fake_constructor(*args, **kwargs): + called["args"] = (args, kwargs) + return _FakeWorkflow() + + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + fake_constructor, + ) monkeypatch.setattr( "chemgraph.agent.llm_agent.load_openai_model", lambda **_kwargs: "FAKE_LLM", @@ -173,6 +148,39 @@ def test_single_agent_initialization_injects_calculator_availability(monkeypatch log_dir=str(tmp_path / "logs"), ) - assert "Calculator availability detected during ChemGraph initialization" in cg.system_prompt - assert cg.default_calculator in cg.system_prompt + args_tuple, _ = called["args"] + system_prompt = args_tuple[1] + assert "Calculator availability detected during ChemGraph initialization" in system_prompt + assert cg.default_calculator in system_prompt assert cg.default_calculator in cg.available_calculators + + +def test_rag_and_xanes_default_prompts_are_preserved(monkeypatch, tmp_path): + captured = {} + + def fake_constructor(*args, **kwargs): + captured[kwargs.get("system_prompt")] = True + return _FakeWorkflow() + + monkeypatch.setattr("chemgraph.agent.llm_agent.construct_rag_agent_graph", fake_constructor) + monkeypatch.setattr("chemgraph.agent.llm_agent.construct_single_agent_xanes_graph", fake_constructor) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: "FAKE_LLM", + ) + + ChemGraph( + model_name="gpt-4o-mini", + workflow_type="rag_agent", + enable_memory=False, + log_dir=str(tmp_path / "rag-logs"), + ) + ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent_xanes", + enable_memory=False, + log_dir=str(tmp_path / "xanes-logs"), + ) + + assert llm_agent.rag_agent_prompt in captured + assert llm_agent.default_xanes_single_agent_prompt in captured diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index 18794a62..116b96e5 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -1,4 +1,5 @@ import asyncio +import json from types import SimpleNamespace from unittest.mock import Mock, patch @@ -121,3 +122,61 @@ def __iter__(self): callback.on_llm_end(SimpleNamespace(generations=[BrokenGenerationGroup()])) assert [event for event, _payload in events] == ["llm_call_finished"] + + +@pytest.mark.asyncio +async def test_cli_trace_events_are_emitted_from_astream_path(monkeypatch, tmp_path): + from chemgraph.cli.trace import CLIRunTrace + + class FakeWorkflow: + def __init__(self): + self.state = {"messages": [AIMessage(content="done")]} + + async def astream(self, inputs, *, stream_mode, config): + for callback in config.get("callbacks", []): + callback.on_chat_model_start({"name": "FakeChatModel"}, [["hello"]]) + callback.on_llm_end(SimpleNamespace(generations=[])) + yield self.state + + def get_state(self, config): + return SimpleNamespace(values=self.state) + + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + lambda *_args, **_kwargs: FakeWorkflow(), + ) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: Mock(), + ) + + trace = CLIRunTrace( + tmp_path / "trace", + run_id="trace-test", + model_name="gpt-4o-mini", + workflow_type="single_agent", + query="x", + ) + trace.start() + agent = ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + on_event=trace.on_event, + ) + await agent.run("x") + trace.finish(status="completed") + + events = [ + json.loads(line)["event"] + for line in (tmp_path / "trace" / "events.jsonl").read_text().splitlines() + ] + assert events == [ + "run_started", + "workflow_started", + "llm_call_started", + "llm_call_finished", + "workflow_finished", + "run_finished", + ] From b5183c4f88978969e9ecaa3c1dfc34aebb29079d Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 13:41:17 -0500 Subject: [PATCH 076/119] chore(agent): drop unused turn re-exports from llm_agent TurnResult, _TurnEventCallback, and run_turn were imported in llm_agent.py purely to satisfy test imports. Tests now import those symbols from chemgraph.agent.turn directly; test_agent_session.py already exercises ChemGraph.run through the graph astream path rather than patching run_turn on llm_agent.\n\nVerification: focused affected tests pass. Full pytest tests/ -x still stops at the existing optional dependency failure for ensemble-launcher in tests/test_execution.py::TestELBackend::test_python_task. --- src/chemgraph/agent/llm_agent.py | 3 --- tests/test_academy_reasoning_phase2.py | 2 +- tests/test_agent_session.py | 3 ++- tests/test_llm_agent.py | 3 ++- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 22c40729..59abb36e 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -44,15 +44,12 @@ from chemgraph.agent.turn import ( EventCallback, - TurnResult, - _TurnEventCallback, _custom_openai_compatible_kwargs, _executed_tool_names, _response_tool_calls, _serialized_name, _state_messages, _terminal_tool_name, - run_turn, serialize_state, ) from chemgraph.graphs.single_agent import construct_single_agent_graph diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index 35e85fd4..e4ce9198 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -16,7 +16,7 @@ from chemgraph.academy.core.prompt import PromptProfile, PromptStateLimits from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools from chemgraph.academy.core.turn import ReasoningTurnResult, build_peer_status -from chemgraph.agent.llm_agent import TurnResult +from chemgraph.agent.turn import TurnResult from chemgraph.models.settings import LLMSettings diff --git a/tests/test_agent_session.py b/tests/test_agent_session.py index 40d987c9..5db9887e 100644 --- a/tests/test_agent_session.py +++ b/tests/test_agent_session.py @@ -17,7 +17,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from chemgraph.agent.llm_agent import ChemGraph, TurnResult, serialize_state +from chemgraph.agent.llm_agent import ChemGraph +from chemgraph.agent.turn import TurnResult, serialize_state from chemgraph.memory.store import SessionStore diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index 116b96e5..a27eca6c 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -5,7 +5,8 @@ import pytest from langchain_core.messages import AIMessage -from chemgraph.agent.llm_agent import ChemGraph, _TurnEventCallback +from chemgraph.agent.llm_agent import ChemGraph +from chemgraph.agent.turn import _TurnEventCallback @pytest.fixture From f1593ab1b7d0a96cae42d1151cfcbdd3ce080a75 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 14:05:28 -0500 Subject: [PATCH 077/119] revert(agent): restore llm_agent.py to pre-academy shape Restore llm_agent.py to its earlier user-facing ChemGraph class shape, removing the temporary event callback wiring and turn primitive imports. Follow-up commits add back only the small hooks needed by the runtime and dashboard paths. --- src/chemgraph/agent/llm_agent.py | 584 +++++++++++++------------------ 1 file changed, 248 insertions(+), 336 deletions(-) diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 59abb36e..3f20a44b 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -1,8 +1,7 @@ import asyncio import datetime import os -import time -from typing import Any, Callable, Collection, List, Optional +from typing import Callable, List, Optional import uuid from chemgraph.memory.store import SessionStore @@ -20,11 +19,7 @@ supported_alcf_models, supported_argo_models, supported_gemini_models, -) -from chemgraph.schemas.ase_input import ( - get_available_calculator_names, - get_calculator_selection_context, - get_default_calculator_name, + ) from chemgraph.prompt.single_agent_prompt import ( @@ -39,29 +34,20 @@ aggregator_prompt as default_aggregator_prompt, planner_prompt as default_planner_prompt, ) +from langgraph.types import Command from langgraph.errors import GraphInterrupt -from langchain_core.callbacks import BaseCallbackHandler - -from chemgraph.agent.turn import ( - EventCallback, - _custom_openai_compatible_kwargs, - _executed_tool_names, - _response_tool_calls, - _serialized_name, - _state_messages, - _terminal_tool_name, - serialize_state, -) + from chemgraph.graphs.single_agent import construct_single_agent_graph -from chemgraph.graphs.multi_agent import construct_multi_agent_graph + + from chemgraph.graphs.python_relp_agent import construct_relp_graph +from chemgraph.graphs.multi_agent import construct_multi_agent_graph from chemgraph.graphs.graspa_agent import construct_graspa_graph from chemgraph.graphs.mock_agent import construct_mock_agent_graph from chemgraph.graphs.single_agent_mcp import construct_single_agent_mcp_graph from chemgraph.graphs.graspa_mcp import construct_graspa_mcp_graph from chemgraph.graphs.rag_agent import construct_rag_agent_graph from chemgraph.graphs.single_agent_xanes import construct_single_agent_xanes_graph -from chemgraph.graphs.single_agent_architector import construct_single_agent_architector_graph from chemgraph.prompt.rag_prompt import rag_agent_prompt from chemgraph.prompt.xanes_prompt import ( xanes_single_agent_prompt as default_xanes_single_agent_prompt, @@ -73,71 +59,29 @@ logger = logging.getLogger(__name__) -class _AstreamEventCallback(BaseCallbackHandler): - """Forward LangChain callback events from graph-backed CLI runs.""" - - def __init__(self, on_event: EventCallback, thread_id: str) -> None: - self._on_event = on_event - self._thread_id = thread_id +def serialize_state(state): + """Convert non-serializable objects in state to a JSON-friendly format. - def _emit(self, event: str, payload: dict[str, Any]) -> None: - try: - self._on_event(event, {"thread_id": self._thread_id, **payload}) - except Exception: # noqa: BLE001 - callbacks must not break the run. - logger.debug("astream event callback failed", exc_info=True) - - def on_chat_model_start(self, serialized, messages, **kwargs) -> None: - self._emit( - "llm_call_started", - { - "model": _serialized_name(serialized), - "message_count": len(messages[0]) if messages else 0, - }, - ) - - def on_llm_start(self, serialized, prompts, **kwargs) -> None: - self._emit( - "llm_call_started", - { - "model": _serialized_name(serialized), - "message_count": len(prompts or []), - }, - ) - - def on_llm_end(self, response, **kwargs) -> None: - payload: dict[str, Any] = {} - usage = getattr(response, "llm_output", None) - if isinstance(usage, dict): - payload["llm_output"] = usage - self._emit("llm_call_finished", payload) - if tool_calls := _response_tool_calls(response): - self._emit("llm_decision", {"tool_calls": tool_calls}) - - def on_llm_error(self, error, **kwargs) -> None: - self._emit("llm_call_failed", {"error": repr(error)}) - - def on_tool_start(self, serialized, input_str, **kwargs) -> None: - self._emit( - "tool_call_started", - { - "tool_name": _serialized_name(serialized), - "arguments": serialize_state(input_str), - }, - ) - - def on_tool_end(self, output, **kwargs) -> None: - payload: dict[str, Any] = {"result": serialize_state(output)} - name = kwargs.get("name") - if name: - payload["tool_name"] = name - self._emit("tool_call_finished", payload) + Parameters + ---------- + state : Any + The state object to be serialized. Can be a list, dict, or object with __dict__ - def on_tool_error(self, error, **kwargs) -> None: - payload = {"error": repr(error)} - name = kwargs.get("name") - if name: - payload["tool_name"] = name - self._emit("tool_call_failed", payload) + Returns + ------- + Any + A JSON-serializable version of the input state + """ + if isinstance(state, (int, float, bool)) or state is None: + return state + elif isinstance(state, list): + return [serialize_state(item) for item in state] + elif isinstance(state, dict): + return {key: serialize_state(value) for key, value in state.items()} + elif hasattr(state, "__dict__"): + return {key: serialize_state(value) for key, value in state.__dict__.items()} + else: + return str(state) class ChemGraph: @@ -226,66 +170,7 @@ def __init__( max_retries: int = 1, human_input_handler: Optional[Callable[[str], str]] = None, human_supervised: bool = False, - terminal_tool_names: Collection[str] = (), - on_event: Optional[EventCallback] = None, ): - """Initialize a ChemGraph workflow instance. - - Parameters - ---------- - model_name : str, optional - LLM model identifier. - workflow_type : str, optional - Workflow constructor key. - base_url : str, optional - Custom provider endpoint URL. - api_key : str, optional - API key passed to compatible model loaders. - argo_user : str, optional - Argo username for Argo-hosted models. - system_prompt : str, optional - System prompt for single-agent-style workflows. - formatter_prompt : str, optional - Prompt used to format single-agent final output. - structured_output : bool, optional - Whether structured final output is requested. - return_option : str, optional - Return mode, such as ``"last_message"`` or ``"state"``. - recursion_limit : int, optional - LangGraph recursion limit. - planner_prompt : str, optional - Planner prompt for multi-agent workflows. - executor_prompt : str, optional - Executor prompt for multi-agent workflows. - aggregator_prompt : str, optional - Aggregator prompt retained for compatibility. - formatter_multi_prompt : str, optional - Formatter prompt for multi-agent workflows. - generate_report : bool, optional - Whether report generation is enabled. - report_prompt : str, optional - Prompt used by the report-generation workflow. - support_structured_output : bool, optional - Whether the selected model supports structured output. - tools : list, optional - Custom tool list for applicable workflows. - data_tools : list, optional - Additional data-analysis tools for MCP workflows. - session_store : SessionStore, optional - Existing session store instance. - enable_memory : bool, optional - Whether persistent session memory is enabled. - memory_db_path : str, optional - SQLite path for the session store. - log_dir : str, optional - Directory for run logs and artifacts. - max_retries : int, optional - LLM parse-retry limit for formatter/planner nodes. - human_input_handler : Callable[[str], str], optional - Callback used to answer graph human-interrupt prompts. - human_supervised : bool, optional - Whether to expose human-supervision tools to the agent. - """ # Always generate a unique identifier for this instance self.uuid = str(uuid.uuid4())[:8] @@ -372,8 +257,8 @@ def __init__( ) from langchain_openai import ChatOpenAI - llm_kwargs = _custom_openai_compatible_kwargs( - model_name=model_name, + llm = ChatOpenAI( + model=model_name, temperature=temperature, base_url=vllm_base_url, api_key=vllm_api_key, @@ -381,9 +266,7 @@ def __init__( top_p=top_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, - argo_user=argo_user, ) - llm = ChatOpenAI(**llm_kwargs) logger.info( f"Successfully initialized ChatOpenAI for model '{model_name}' at {vllm_base_url}" ) @@ -399,36 +282,8 @@ def __init__( logger.error(f"Exception thrown when loading {model_name}: {str(e)}") raise e - self.workflow_map = { - "single_agent": {"constructor": construct_single_agent_graph}, - "multi_agent": {"constructor": construct_multi_agent_graph}, - "python_relp": {"constructor": construct_relp_graph}, - "graspa": {"constructor": construct_graspa_graph}, - "graspa_agent": {"constructor": construct_graspa_graph}, - "mock_agent": {"constructor": construct_mock_agent_graph}, - "single_agent_mcp": {"constructor": construct_single_agent_mcp_graph}, - "graspa_mcp": {"constructor": construct_graspa_mcp_graph}, - "rag_agent": {"constructor": construct_rag_agent_graph}, - "single_agent_xanes": {"constructor": construct_single_agent_xanes_graph}, - "single_agent_architector": { - "constructor": construct_single_agent_architector_graph, - }, - } - - if workflow_type not in self.workflow_map: - raise ValueError( - f"Unsupported workflow type: {workflow_type}. " - f"Available types: {list(self.workflow_map.keys())}" - ) - - self._using_default_system_prompt = system_prompt == single_agent_prompt - self._using_default_formatter_prompt = formatter_prompt == default_formatter_prompt - self.workflow_type = workflow_type self.model_name = model_name - self.base_url = base_url - self.api_key = api_key - self.argo_user = argo_user self.system_prompt = system_prompt self.formatter_prompt = formatter_prompt self.structured_output = structured_output @@ -445,9 +300,6 @@ def __init__( self.max_retries = max_retries self.human_input_handler = human_input_handler self.human_supervised = human_supervised - self.terminal_tool_names = tuple(terminal_tool_names) - self.on_event = on_event - self._last_run_state: dict[str, Any] | None = None # When human supervision is disabled and the caller is using the # default system prompt, strip the ask_human instructions so the @@ -455,27 +307,28 @@ def __init__( if not self.human_supervised and self.system_prompt == single_agent_prompt: self.system_prompt = get_single_agent_prompt(human_supervised=False) - self.available_calculators = get_available_calculator_names() - self.default_calculator = get_default_calculator_name() - self.calculator_selection_context = get_calculator_selection_context() - - def append_calculator_context(prompt: str) -> str: - """Append calculator availability guidance to a prompt once.""" - if self.calculator_selection_context in prompt: - return prompt - return f"{prompt}{self.calculator_selection_context}" - - if self.workflow_type in {"single_agent", "mock_agent", "single_agent_mcp"}: - self.system_prompt = append_calculator_context(self.system_prompt) - elif self.workflow_type == "multi_agent": - self.planner_prompt = append_calculator_context(self.planner_prompt) - self.executor_prompt = append_calculator_context(self.executor_prompt) - if model_name in supported_argo_models: self.support_structured_output = False else: self.support_structured_output = support_structured_output + self.workflow_map = { + "single_agent": {"constructor": construct_single_agent_graph}, + "multi_agent": {"constructor": construct_multi_agent_graph}, + "python_relp": {"constructor": construct_relp_graph}, + "graspa": {"constructor": construct_graspa_graph}, + "mock_agent": {"constructor": construct_mock_agent_graph}, + "single_agent_mcp": {"constructor": construct_single_agent_mcp_graph}, + "graspa_mcp": {"constructor": construct_graspa_mcp_graph}, + "rag_agent": {"constructor": construct_rag_agent_graph}, + "single_agent_xanes": {"constructor": construct_single_agent_xanes_graph}, + } + + if workflow_type not in self.workflow_map: + raise ValueError( + f"Unsupported workflow type: {workflow_type}. Available types: {list(self.workflow_map.keys())}" + ) + if self.workflow_type == "single_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( llm, @@ -487,7 +340,6 @@ def append_calculator_context(prompt: str) -> str: self.tools, max_retries=self.max_retries, human_supervised=self.human_supervised, - terminal_tool_names=self.terminal_tool_names, ) elif self.workflow_type == "multi_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( @@ -504,7 +356,7 @@ def append_calculator_context(prompt: str) -> str: llm, self.system_prompt, ) - elif self.workflow_type in {"graspa", "graspa_agent"}: + elif self.workflow_type == "graspa": self.workflow = self.workflow_map[workflow_type]["constructor"]( llm, self.system_prompt, @@ -532,7 +384,7 @@ def append_calculator_context(prompt: str) -> str: self.workflow = self.workflow_map[workflow_type]["constructor"]( llm=llm, system_prompt=self.system_prompt - if not self._using_default_system_prompt + if self.system_prompt != single_agent_prompt else rag_agent_prompt, tools=self.tools, ) @@ -540,20 +392,14 @@ def append_calculator_context(prompt: str) -> str: self.workflow = self.workflow_map[workflow_type]["constructor"]( llm, system_prompt=self.system_prompt - if not self._using_default_system_prompt + if self.system_prompt != single_agent_prompt else default_xanes_single_agent_prompt, structured_output=self.structured_output, formatter_prompt=self.formatter_prompt - if not self._using_default_formatter_prompt + if self.formatter_prompt != default_formatter_prompt else default_xanes_formatter_prompt, tools=self.tools, ) - elif self.workflow_type == "single_agent_architector": - self.workflow = self.workflow_map[workflow_type]["constructor"]( - llm=llm, - system_prompt=self.system_prompt, - tools=self.tools, - ) def visualize(self, method: str = "ascii"): """Visualize the LangGraph graph structure. @@ -561,18 +407,6 @@ def visualize(self, method: str = "ascii"): This method creates and displays a visual representation of the workflow graph using Mermaid diagrams. The visualization is shown in Jupyter notebooks. - Parameters - ---------- - method : str, optional - Visualization backend. ``"ascii"`` returns an ASCII graph; - any other value renders a Mermaid PNG in the active notebook. - - Returns - ------- - str or None - ASCII graph text when ``method`` is ``"ascii"``; otherwise - displays an image and returns ``None``. - Notes ----- Requires IPython and nest_asyncio to be installed. @@ -743,13 +577,7 @@ def session_id(self) -> str: return self.uuid def _ensure_session(self, query: str) -> None: - """Create a session record on first run if memory is enabled. - - Parameters - ---------- - query : str - User query used to generate the session title. - """ + """Create a session record on first run if memory is enabled.""" if self.session_store is None: return if self._session_created: @@ -767,15 +595,7 @@ def _ensure_session(self, query: str) -> None: logger.info(f"Created session {self.uuid}: {self._session_title}") def _save_messages_to_store(self, last_state: dict, query: str) -> None: - """Extract messages from workflow state and persist to session store. - - Parameters - ---------- - last_state : dict - Latest LangGraph state containing a ``messages`` sequence. - query : str - Original user query associated with the saved messages. - """ + """Extract messages from workflow state and persist to session store.""" if self.session_store is None or not self._session_created: return @@ -869,16 +689,6 @@ async def _call_human_input_handler(self, question: str) -> str: Raises :class:`HumanInputRequired` when no handler is configured, allowing external callers (CLI, UI) to catch it, prompt the user, and resume the graph. - - Parameters - ---------- - question : str - Prompt emitted by the graph for a human response. - - Returns - ------- - str - Human response returned by the configured handler. """ handler = self.human_input_handler if handler is None: @@ -887,31 +697,158 @@ async def _call_human_input_handler(self, question: str) -> str: return await handler(question) return handler(question) - async def run( - self, - query: str, - config=None, - resume_from: Optional[str] = None, - ): - """Run a graph-backed ChemGraph workflow. + async def run(self, query: str, config=None, resume_from: Optional[str] = None): + """ + Async-only runner. Requires `self.workflow.astream(...)`. + Streams values, logs new messages, writes state, and returns according to + `self.return_option` ("last_message" or "state"). + + When the graph pauses for human input (via ``interrupt()``), the + ``human_input_handler`` callback is invoked to obtain the user's + response, and the graph is automatically resumed. If no handler + is configured, the ``GraphInterrupt`` exception propagates to the + caller. - All CLI workflows execute through their restored LangGraph - constructors. Academy uses :func:`run_turn` directly instead of this - method. + Parameters + ---------- + query : str + The user query to execute. + config : dict, optional + LangGraph config with thread_id, etc. + resume_from : str, optional + Session ID to load context from. The previous conversation + summary is prepended to the query. """ - if config is None: - config = {} - if not isinstance(config, dict): - raise TypeError(f"`config` must be a dictionary, got {type(config).__name__}") - if "thread_id" in config: - config.setdefault("configurable", {})["thread_id"] = str(config["thread_id"]) - config.setdefault("configurable", {}).setdefault("thread_id", "1") - config["recursion_limit"] = self.recursion_limit + def _validate_config(cfg): + if cfg is None: + cfg = {} + if not isinstance(cfg, dict): + raise TypeError( + f"`config` must be a dictionary, got {type(cfg).__name__}" + ) + + # Support top-level thread_id for convenience + if "thread_id" in cfg: + if "configurable" not in cfg: + cfg["configurable"] = {} + cfg["configurable"]["thread_id"] = str(cfg["thread_id"]) + + cfg.setdefault("configurable", {}).setdefault("thread_id", "1") + cfg["recursion_limit"] = self.recursion_limit + return cfg + + def _save_state_and_select_return(last_state, cfg): + log_dir = self.log_dir + if not log_dir: + log_dir = "cg_logs" + + os.makedirs(log_dir, exist_ok=True) + log_path = None + self.write_state(config=cfg, file_path=log_path) + + if self.return_option == "last_message": + return last_state["messages"][-1] + elif self.return_option == "state": + return serialize_state(self.get_state(config=cfg)) + else: + raise ValueError( + f"Unsupported return_option: {self.return_option}. Use 'last_message' or 'state'." + ) + + async def _stream_until_interrupt(stream_input, cfg): + """Stream the workflow until completion or an interrupt. + + Returns ``(last_state, interrupt_value)`` where + ``interrupt_value`` is ``None`` when the graph completed + normally. + + LangGraph's ``astream(stream_mode="values")`` does **not** + raise ``GraphInterrupt``. Instead the stream emits a state + containing an ``__interrupt__`` key and then ends. We + detect this in two ways: + + 1. Check for the ``__interrupt__`` key in streamed states. + 2. After the stream ends, inspect the checkpoint snapshot + for pending interrupt tasks. + """ + prev_msgs: list = [] + last_st = None + interrupt_val = None + try: + async for s in self.workflow.astream( + stream_input, stream_mode="values", config=cfg + ): + # Detect inline interrupt marker emitted by astream. + if "__interrupt__" in s: + int_data = s["__interrupt__"] + if isinstance(int_data, (list, tuple)) and int_data: + interrupt_val = int_data[0].value + elif hasattr(int_data, "value"): + interrupt_val = int_data.value + else: + interrupt_val = { + "question": "The workflow needs your input." + } + + if "messages" in s and s["messages"] != prev_msgs: + new_message = s["messages"][-1] + try: + new_message.pretty_print() + except Exception: + pass + logger.info(new_message) + prev_msgs = s["messages"] + last_st = s + except GraphInterrupt as gi: + # Fallback: some LangGraph versions may still raise. + interrupts = gi.args[0] if gi.args else [] + if interrupts: + interrupt_val = interrupts[0].value + else: + interrupt_val = { + "question": "The workflow needs your input." + } + + # Double-check the checkpoint for pending interrupts that + # the stream may not have surfaced explicitly. + if interrupt_val is None: + try: + snapshot = self.workflow.get_state(cfg) + if snapshot and snapshot.tasks: + for t in snapshot.tasks: + t_interrupts = getattr(t, "interrupts", None) + if t_interrupts: + interrupt_val = t_interrupts[0].value + break + except Exception: + pass + + if interrupt_val is not None: + logger.info("Graph interrupted: %s", interrupt_val) + # Refresh state from checkpoint for consistency. + try: + snapshot = self.workflow.get_state(cfg) + if snapshot: + last_st = snapshot.values + except Exception: + pass + + return last_st, interrupt_val + + logger.debug("run called with config=%s", config) + config = _validate_config(config) + logger.debug("validated config=%s", config) + + # Initialize logging directory before determining inputs or running workflow + # Check if CHEMGRAPH_LOG_DIR is already set if not os.environ.get("CHEMGRAPH_LOG_DIR"): os.environ["CHEMGRAPH_LOG_DIR"] = self.log_dir + # Ensure session exists in memory store self._ensure_session(query) + + # If resuming from a previous session, prepend context if resume_from and self.session_store: context = self.session_store.build_context_summary(resume_from) if context: @@ -922,81 +859,63 @@ async def run( ) logger.info(f"Injected context from session {resume_from}") - started = time.time() - thread_id = str(config["configurable"]["thread_id"]) - event = self.on_event or (lambda _event, _payload: None) - if self.on_event: - callbacks = list(config.get("callbacks") or []) - callbacks.append(_AstreamEventCallback(self.on_event, thread_id)) - config["callbacks"] = callbacks - event( - "workflow_started", - { - "workflow_type": self.workflow_type, - "thread_id": thread_id, - "tool_names": [getattr(tool, "name", str(tool)) for tool in self.tools or []], - }, - ) + inputs = {"messages": query} try: - last_state = None - async for state in self.workflow.astream( - {"messages": query}, - stream_mode="values", - config=config, - ): - if "messages" in state: - for message in state["messages"][-1:]: - try: - message.pretty_print() - except Exception: - pass - logger.info(message) - last_state = state + last_state, interrupt_value = await _stream_until_interrupt(inputs, config) + + # --- Human-in-the-loop resume loop --- + # When the graph pauses with an interrupt, ask the human and + # resume. This loop handles chains of multiple interrupts + # (e.g., the agent asks a follow-up question after receiving + # the first answer). + max_interrupts = 10 # safety guard against infinite interrupt loops + interrupt_count = 0 + while interrupt_value is not None: + interrupt_count += 1 + if interrupt_count > max_interrupts: + logger.error( + "Exceeded maximum number of human interrupts (%d); " + "aborting workflow.", + max_interrupts, + ) + raise RuntimeError( + f"Workflow exceeded maximum of {max_interrupts} " + f"human interrupts." + ) + + # Extract the question text from the interrupt value. + if isinstance(interrupt_value, dict): + question = interrupt_value.get( + "question", + interrupt_value.get("message", str(interrupt_value)), + ) + else: + question = str(interrupt_value) + + logger.info("Requesting human input: %s", question) + human_answer = await self._call_human_input_handler(question) + logger.info("Human responded: %s", human_answer) + + # Resume the graph from the checkpoint with the human's answer. + resume_cmd = Command(resume=human_answer) + last_state, interrupt_value = await _stream_until_interrupt( + resume_cmd, config + ) + if last_state is None: - raise RuntimeError("Workflow produced no states") + raise RuntimeError("Workflow produced no states.") - messages = _state_messages(last_state) - executed_tools = _executed_tool_names(messages) - terminal_tool = _terminal_tool_name( - executed_tools, - self.terminal_tool_names, - ) - event( - "workflow_finished", - { - "workflow_type": self.workflow_type, - "thread_id": thread_id, - "status": "completed", - "executed_tool_names": list(executed_tools), - "terminal_tool": terminal_tool, - "duration_s": round(time.time() - started, 3), - }, - ) - self._last_run_state = serialize_state(last_state) + # Save messages to persistent session store self._save_messages_to_store(last_state, query) - self.write_state(config=config, file_path=None) - if self.return_option == "state": - return serialize_state(self.get_state(config=config)) - if self.return_option == "last_message": - return last_state["messages"][-1] - raise ValueError( - f"Unsupported return_option: {self.return_option}. " - "Use 'last_message' or 'state'." - ) - except GraphInterrupt: + + return _save_state_and_select_return(last_state, config) + + except HumanInputRequired: + # No human_input_handler configured — propagate so the + # caller (CLI / UI) can prompt the user and resume. raise except Exception as e: - event( - "workflow_finished", - { - "workflow_type": self.workflow_type, - "thread_id": thread_id, - "status": "failed", - "error": repr(e), - "duration_s": round(time.time() - started, 3), - }, - ) logger.error(f"Error running workflow {self.workflow_type}: {e}") raise @@ -1009,12 +928,5 @@ class HumanInputRequired(Exception): """ def __init__(self, question: str): - """Initialize the exception with the pending human question. - - Parameters - ---------- - question : str - Question that should be presented to the user. - """ self.question = question super().__init__(question) From bcf072d26beb64d905c047a67ce7ec5c96401774 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 14:07:26 -0500 Subject: [PATCH 078/119] refactor(agent): extract dashboard event callbacks into agent/events.py Separate LangChain callback event translation from the run_turn primitive so the turn runner and graph-backed CLI path can share the same dashboard event shapes without coupling either driver to the other. --- src/chemgraph/agent/events.py | 109 ++++++++++++++++++++++++++++++++++ src/chemgraph/agent/turn.py | 77 ++---------------------- 2 files changed, 113 insertions(+), 73 deletions(-) create mode 100644 src/chemgraph/agent/events.py diff --git a/src/chemgraph/agent/events.py b/src/chemgraph/agent/events.py new file mode 100644 index 00000000..b714cd53 --- /dev/null +++ b/src/chemgraph/agent/events.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +from typing import Any, Callable + +from langchain_core.callbacks import BaseCallbackHandler + +logger = logging.getLogger(__name__) + +EventCallback = Callable[[str, dict], None] + + +def _serialized_name(serialized: Any) -> str | None: + from chemgraph.agent.turn import _serialized_name as turn_serialized_name + + return turn_serialized_name(serialized) + + +def _response_tool_calls(response: Any) -> list[dict[str, str | None]]: + from chemgraph.agent.turn import _response_tool_calls as turn_response_tool_calls + + return turn_response_tool_calls(response) + + +def _serialize_state(value: Any) -> Any: + from chemgraph.agent.turn import serialize_state + + return serialize_state(value) + + +class _BaseDashboardEventCallback(BaseCallbackHandler): + """Forward LangChain callback events to the dashboard event surface.""" + + _failure_log_message = "dashboard event callback failed" + + def __init__(self, on_event: EventCallback, thread_id: str) -> None: + self._on_event = on_event + self._thread_id = thread_id + + def _emit(self, event: str, payload: dict[str, Any]) -> None: + try: + self._on_event(event, {"thread_id": self._thread_id, **payload}) + except Exception: # noqa: BLE001 - callbacks must not break the run. + logger.debug(self._failure_log_message, exc_info=True) + + def on_chat_model_start(self, serialized, messages, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(messages[0]) if messages else 0, + }, + ) + + def on_llm_start(self, serialized, prompts, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(prompts or []), + }, + ) + + def on_llm_end(self, response, **kwargs) -> None: + payload: dict[str, Any] = {} + usage = getattr(response, "llm_output", None) + if isinstance(usage, dict): + payload["llm_output"] = usage + self._emit("llm_call_finished", payload) + if tool_calls := _response_tool_calls(response): + self._emit("llm_decision", {"tool_calls": tool_calls}) + + def on_llm_error(self, error, **kwargs) -> None: + self._emit("llm_call_failed", {"error": repr(error)}) + + def on_tool_start(self, serialized, input_str, **kwargs) -> None: + self._emit( + "tool_call_started", + { + "tool_name": _serialized_name(serialized), + "arguments": _serialize_state(input_str), + }, + ) + + def on_tool_end(self, output, **kwargs) -> None: + payload: dict[str, Any] = {"result": _serialize_state(output)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_finished", payload) + + def on_tool_error(self, error, **kwargs) -> None: + payload = {"error": repr(error)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_failed", payload) + + +class _TurnEventCallback(_BaseDashboardEventCallback): + """Forward run_turn callback events to the dashboard event surface.""" + + _failure_log_message = "turn event callback failed" + + +class _AstreamEventCallback(_BaseDashboardEventCallback): + """Forward graph stream callback events to the dashboard event surface.""" + + _failure_log_message = "astream event callback failed" diff --git a/src/chemgraph/agent/turn.py b/src/chemgraph/agent/turn.py index a6cf9720..e5652134 100644 --- a/src/chemgraph/agent/turn.py +++ b/src/chemgraph/agent/turn.py @@ -6,9 +6,7 @@ import os import time import uuid -from typing import Any, Callable, Collection - -from langchain_core.callbacks import BaseCallbackHandler +from typing import Any, Collection from chemgraph.graphs.single_agent import construct_single_agent_graph from chemgraph.models.loader import load_chat_model @@ -166,9 +164,6 @@ def _custom_openai_compatible_kwargs( return kwargs -EventCallback = Callable[[str, dict], None] - - @dataclasses.dataclass(frozen=True) class TurnResult: """Result of one bounded ChemGraph single-agent turn.""" @@ -181,73 +176,6 @@ class TurnResult: duration_s: float -class _TurnEventCallback(BaseCallbackHandler): - """Forward LangChain callback events to a small stable callback surface.""" - - def __init__(self, on_event: EventCallback, thread_id: str) -> None: - self._on_event = on_event - self._thread_id = thread_id - - def _emit(self, event: str, payload: dict[str, Any]) -> None: - try: - self._on_event(event, {"thread_id": self._thread_id, **payload}) - except Exception: # noqa: BLE001 - callbacks must not break the run. - logger.debug("turn event callback failed", exc_info=True) - - def on_chat_model_start(self, serialized, messages, **kwargs) -> None: - self._emit( - "llm_call_started", - { - "model": _serialized_name(serialized), - "message_count": len(messages[0]) if messages else 0, - }, - ) - - def on_llm_start(self, serialized, prompts, **kwargs) -> None: - self._emit( - "llm_call_started", - { - "model": _serialized_name(serialized), - "message_count": len(prompts or []), - }, - ) - - def on_llm_end(self, response, **kwargs) -> None: - payload: dict[str, Any] = {} - usage = getattr(response, "llm_output", None) - if isinstance(usage, dict): - payload["llm_output"] = usage - self._emit("llm_call_finished", payload) - if tool_calls := _response_tool_calls(response): - self._emit("llm_decision", {"tool_calls": tool_calls}) - - def on_llm_error(self, error, **kwargs) -> None: - self._emit("llm_call_failed", {"error": repr(error)}) - - def on_tool_start(self, serialized, input_str, **kwargs) -> None: - self._emit( - "tool_call_started", - { - "tool_name": _serialized_name(serialized), - "arguments": serialize_state(input_str), - }, - ) - - def on_tool_end(self, output, **kwargs) -> None: - payload: dict[str, Any] = {"result": serialize_state(output)} - name = kwargs.get("name") - if name: - payload["tool_name"] = name - self._emit("tool_call_finished", payload) - - def on_tool_error(self, error, **kwargs) -> None: - payload = {"error": repr(error)} - name = kwargs.get("name") - if name: - payload["tool_name"] = name - self._emit("tool_call_failed", payload) - - def _serialized_name(serialized: Any) -> str | None: if isinstance(serialized, dict): return serialized.get("name") or serialized.get("id") @@ -419,6 +347,9 @@ def _load_turn_llm( ) +from chemgraph.agent.events import EventCallback, _TurnEventCallback + + async def run_turn( *, query: str, From 789920e9e94a67f3a0c274216bdbecda57dc8848 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 14:12:53 -0500 Subject: [PATCH 079/119] feat(agent): add minimum on_event and terminal_tool_names hooks for academy and dashboard Re-add the two constructor parameters and graph-stream event framing needed for trace-dir dashboard runs and terminal-tool stopping. Event callback translation stays in chemgraph.agent.events and turn-specific helpers stay in chemgraph.agent.turn. --- src/chemgraph/agent/llm_agent.py | 64 +++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index 3f20a44b..f736889d 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -1,9 +1,11 @@ import asyncio import datetime import os -from typing import Callable, List, Optional +import time +from typing import Callable, Collection, List, Optional import uuid +from chemgraph.agent.events import EventCallback, _AstreamEventCallback from chemgraph.memory.store import SessionStore from chemgraph.memory.schemas import SessionMessage from chemgraph.models.openai import load_openai_model @@ -133,6 +135,11 @@ class ChemGraph: pause and request human input. When ``False`` the tool is excluded from the tool list and the corresponding instruction is removed from the default system prompt, by default False. + terminal_tool_names : Collection[str], optional + Tool names that should terminate supported workflows after + successful execution, by default empty. + on_event : callable, optional + Callback invoked with dashboard workflow events, by default None. Raises ------ @@ -170,6 +177,8 @@ def __init__( max_retries: int = 1, human_input_handler: Optional[Callable[[str], str]] = None, human_supervised: bool = False, + terminal_tool_names: Collection[str] = (), + on_event: Optional[EventCallback] = None, ): # Always generate a unique identifier for this instance self.uuid = str(uuid.uuid4())[:8] @@ -300,6 +309,8 @@ def __init__( self.max_retries = max_retries self.human_input_handler = human_input_handler self.human_supervised = human_supervised + self.terminal_tool_names = tuple(terminal_tool_names) + self.on_event = on_event # When human supervision is disabled and the caller is using the # default system prompt, strip the ask_human instructions so the @@ -340,6 +351,7 @@ def __init__( self.tools, max_retries=self.max_retries, human_supervised=self.human_supervised, + terminal_tool_names=self.terminal_tool_names, ) elif self.workflow_type == "multi_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( @@ -719,6 +731,11 @@ async def run(self, query: str, config=None, resume_from: Optional[str] = None): Session ID to load context from. The previous conversation summary is prepended to the query. """ + from chemgraph.agent.turn import ( + _executed_tool_names, + _state_messages, + _terminal_tool_name, + ) def _validate_config(cfg): if cfg is None: @@ -838,6 +855,13 @@ async def _stream_until_interrupt(stream_input, cfg): logger.debug("run called with config=%s", config) config = _validate_config(config) + thread_id = str(config["configurable"]["thread_id"]) + started = time.time() + event = self.on_event or (lambda _event, _payload: None) + if self.on_event: + callbacks = list(config.get("callbacks") or []) + callbacks.append(_AstreamEventCallback(self.on_event, thread_id)) + config["callbacks"] = callbacks logger.debug("validated config=%s", config) # Initialize logging directory before determining inputs or running workflow @@ -860,6 +884,16 @@ async def _stream_until_interrupt(stream_input, cfg): logger.info(f"Injected context from session {resume_from}") inputs = {"messages": query} + event( + "workflow_started", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "tool_names": [ + getattr(tool, "name", str(tool)) for tool in self.tools or [] + ], + }, + ) try: last_state, interrupt_value = await _stream_until_interrupt(inputs, config) @@ -909,6 +943,24 @@ async def _stream_until_interrupt(stream_input, cfg): # Save messages to persistent session store self._save_messages_to_store(last_state, query) + messages = _state_messages(last_state) + executed_tools = _executed_tool_names(messages) + terminal_tool = _terminal_tool_name( + executed_tools, + self.terminal_tool_names, + ) + event( + "workflow_finished", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "status": "completed", + "executed_tool_names": list(executed_tools), + "terminal_tool": terminal_tool, + "duration_s": round(time.time() - started, 3), + }, + ) + return _save_state_and_select_return(last_state, config) except HumanInputRequired: @@ -916,6 +968,16 @@ async def _stream_until_interrupt(stream_input, cfg): # caller (CLI / UI) can prompt the user and resume. raise except Exception as e: + event( + "workflow_finished", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "status": "failed", + "error": repr(e), + "duration_s": round(time.time() - started, 3), + }, + ) logger.error(f"Error running workflow {self.workflow_type}: {e}") raise From 60817d1cd84209044f1ff66d5a23eca9ae8086b0 Mon Sep 17 00:00:00 2001 From: tdpham2 Date: Thu, 11 Jun 2026 19:25:00 +0000 Subject: [PATCH 080/119] Fix Parsl pickling of MCP server callables launched via "python -m" Closes the submitter PicklingError / worker AttributeError on run_mace_singleArguments by making FastMCP's dynamic Arguments and Output classes picklable by qualname, and by making top-level MCP-server callables pickle by reference even under the runpy double-module case (sys.modules["__main__"] vs sys.modules["pkg.mod"] are distinct objects when launched with python -m, and the leaf module isn't attached to its parent package). - cg_fastmcp.py: _register_fastmcp_dynamic_models() injects dynamic arg/output models into the func_metadata module namespace and rebinds the captured local in tools.base / prompts.base / resources.templates. _fix_module_for_pickle now also sets the function attr on the resolved target module and attaches the leaf module to its parent package, so dill's by-qualname lookup succeeds. Backend wrappers route args/kwargs through to_picklable. - mace_mcp_hpc.py: apply _fix_module_for_pickle to _mace_worker and _ls_remote_files; debug-log the transport hook. - execution/utils.py: add to_picklable() helper that recursively serializes Pydantic BaseModel instances via model_dump(). - execution/parsl_backend.py: wrap task.args / task.kwargs with to_picklable() before dispatching to the python app. Co-Authored-By: Claude Opus 4.7 --- src/chemgraph/execution/parsl_backend.py | 6 +- src/chemgraph/execution/utils.py | 26 ++++++ src/chemgraph/mcp/cg_fastmcp.py | 109 +++++++++++++++++++++-- src/chemgraph/mcp/mace_mcp_hpc.py | 19 ++++ 4 files changed, 152 insertions(+), 8 deletions(-) diff --git a/src/chemgraph/execution/parsl_backend.py b/src/chemgraph/execution/parsl_backend.py index f2e4fe37..9e963b50 100644 --- a/src/chemgraph/execution/parsl_backend.py +++ b/src/chemgraph/execution/parsl_backend.py @@ -91,7 +91,11 @@ def submit(self, task: TaskSpec) -> Future: raise ValueError( f"Task '{task.task_id}': task_type='python' requires a callable." ) - return self._python_app(task.callable, task.args, task.kwargs) + from chemgraph.execution.utils import to_picklable + + return self._python_app( + task.callable, to_picklable(task.args), to_picklable(task.kwargs) + ) elif task.task_type == "shell": if task.command is None: diff --git a/src/chemgraph/execution/utils.py b/src/chemgraph/execution/utils.py index ba941fd6..c7a0ed0b 100644 --- a/src/chemgraph/execution/utils.py +++ b/src/chemgraph/execution/utils.py @@ -18,6 +18,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional +from pydantic import BaseModel + if TYPE_CHECKING: from chemgraph.execution.base import ExecutionBackend from chemgraph.execution.job_tracker import JobTracker @@ -25,6 +27,30 @@ logger = logging.getLogger(__name__) +def to_picklable(value: Any) -> Any: + """Recursively convert Pydantic instances to plain dicts. + + FastMCP's ``func_metadata`` builds tool-argument models with + ``pydantic.create_model`` and a ``__module__`` that does not actually + contain the class, so cloudpickle cannot serialize instances of those + classes to a Parsl/Globus-Compute worker. Converting every Pydantic + instance to a dict at the framework boundary side-steps the problem + without patching the third-party library. + + Containers (``dict``, ``list``, ``tuple``) are walked recursively and + rebuilt with the same shape; everything else passes through unchanged. + """ + if isinstance(value, BaseModel): + return value.model_dump() + if isinstance(value, dict): + return {k: to_picklable(v) for k, v in value.items()} + if isinstance(value, list): + return [to_picklable(v) for v in value] + if isinstance(value, tuple): + return tuple(to_picklable(v) for v in value) + return value + + def resolve_structure_files( input_source: str | list[str], extensions: set[str] | None = None, diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index 155dd76d..db2c3a23 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -23,6 +23,66 @@ logger = logging.getLogger(__name__) +def _register_fastmcp_dynamic_models() -> None: + """Make pydantic models built by ``fastmcp.func_metadata`` pickle-by-qualname. + + FastMCP builds per-tool ``Arguments`` / ``Output`` classes via + ``pydantic.create_model(__module__="mcp.server.fastmcp.utilities.func_metadata")`` + but never inserts them into that module's namespace. Dill's by-qualname + lookup then fails and either raises ``PicklingError`` or falls back to + pickle-by-value, which walks ``__globals__`` and can hit other surprises. + Wrapping ``func_metadata`` so the resulting models are inserted into the + module's ``__dict__`` makes the lookup succeed regardless of how the + pickle graph reaches the class. + """ + import sys + + from mcp.server.fastmcp.utilities import func_metadata as _fm + + if getattr(_fm, "_chemgraph_models_registered", False): + return + + _orig = _fm.func_metadata + _mod_ns = sys.modules[_fm.__name__].__dict__ + + def _register(model): + if model is None: + return + name = getattr(model, "__name__", None) + if name and name not in _mod_ns: + _mod_ns[name] = model + try: + model.__module__ = _fm.__name__ + except (AttributeError, TypeError): + pass + + def _patched(*args, **kwargs): + meta = _orig(*args, **kwargs) + _register(getattr(meta, "arg_model", None)) + _register(getattr(meta, "output_model", None)) + return meta + + _fm.func_metadata = _patched + # Several fastmcp modules captured the original via + # ``from mcp.server.fastmcp.utilities.func_metadata import func_metadata`` + # before this patch ran, so they hold their own bound name. Rebind the + # name in each known call site so every tool registration goes through + # the wrapper. + for _modname in ( + "mcp.server.fastmcp.tools.base", + "mcp.server.fastmcp.prompts.base", + "mcp.server.fastmcp.resources.templates", + ): + _m = sys.modules.get(_modname) + if _m is not None and getattr(_m, "func_metadata", None) is _orig: + _m.func_metadata = _patched + + _fm._chemgraph_models_registered = True + + +_register_fastmcp_dynamic_models() + + class CGFastMCP(FastMCP): """FastMCP with an integrated execution backend. @@ -187,15 +247,46 @@ def check_endpoint_status() -> dict: @staticmethod def _fix_module_for_pickle(fn: Callable) -> None: - """Ensure *fn* is picklable when the MCP server runs as ``__main__``.""" + """Ensure *fn* is picklable when the MCP server runs as ``__main__``. + + Under ``python -m pkg.mod`` runpy sets ``__name__ == "__main__"`` + and populates both ``sys.modules["__main__"]`` and + ``sys.modules["pkg.mod"]`` -- but it does **not** attach + ``mod`` as an attribute of the parent package ``pkg``. Dill's + by-qualname pickling resolves ``pkg.mod.fn`` via + ``__import__("pkg", fromlist=["mod"])`` followed by + ``getattr(pkg, "mod")``, which fails for that reason and silently + falls back to pickle-by-value -- dragging the entire module's + globals (including the FastMCP dynamic ``arg_model`` classes) + into the byte stream. + + Three things must be true for dill to pickle ``fn`` by reference: + + 1. ``fn.__module__`` points at the real dotted name (not ``__main__``). + 2. ``sys.modules[fn.__module__]`` exists and contains ``fn`` as + an attribute. + 3. The parent package has the leaf module attached as an attribute + (so ``getattr(pkg, leaf)`` resolves to the same module object). + """ if fn.__module__ == "__main__": import sys spec = getattr(sys.modules.get("__main__"), "__spec__", None) if spec and spec.name: fn.__module__ = spec.name - if spec.name not in sys.modules: - sys.modules[spec.name] = sys.modules["__main__"] + target = sys.modules.get(spec.name) + if target is None: + target = sys.modules["__main__"] + sys.modules[spec.name] = target + elif getattr(target, fn.__name__, None) is not fn: + setattr(target, fn.__name__, fn) + # Attach the leaf module to its parent package so dill's + # ``__import__(parent, fromlist=[leaf])`` lookup succeeds. + if "." in spec.name: + parent_name, _, leaf = spec.name.rpartition(".") + parent = sys.modules.get(parent_name) + if parent is not None and getattr(parent, leaf, None) is not target: + setattr(parent, leaf, target) # ── Tool registration ─────────────────────────────────────────────── @@ -329,6 +420,8 @@ def decorator(fn: Callable) -> Callable: param_type = param.annotation async def wrapper(params): + from chemgraph.execution.utils import to_picklable + self._ensure_backend() pending = [] for i, p in enumerate(params): @@ -336,7 +429,7 @@ async def wrapper(params): task_id=f"{fn.__name__}_{i}", task_type="python", callable=fn, - kwargs={param.name: p}, + kwargs={param.name: to_picklable(p)}, **task_spec_kwargs, ) task = self._apply_pre_submit_hook(task) @@ -457,6 +550,8 @@ def decorator(expander: Callable) -> Callable: tool_name = name or expander.__name__ async def wrapper(**kwargs): + from chemgraph.execution.utils import to_picklable + self._ensure_backend() ensemble_params = kwargs[param.name] items = expander(ensemble_params) @@ -466,7 +561,7 @@ async def wrapper(**kwargs): task_id=f"{tool_name}_{i}", task_type="python", callable=worker, - kwargs={worker_param_name: item}, + kwargs={worker_param_name: to_picklable(item)}, **task_spec_kwargs, ) task = self._apply_pre_submit_hook(task) @@ -500,7 +595,7 @@ def _make_backend_wrapper( ) -> Callable: """Build an async wrapper that submits *fn* to the backend.""" from chemgraph.execution.base import TaskSpec - from chemgraph.execution.utils import submit_or_gather + from chemgraph.execution.utils import submit_or_gather, to_picklable self._fix_module_for_pickle(fn) @@ -511,7 +606,7 @@ async def wrapper(**kwargs: Any) -> Any: task_id=fn.__name__, task_type="python", callable=fn, - kwargs=kwargs, + kwargs=to_picklable(kwargs), **task_spec_kwargs, ) task = self._apply_pre_submit_hook(task) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index 58750b46..b18816a8 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -133,6 +133,17 @@ def _mace_worker(job: dict) -> dict: return result +# Force pickle-by-reference for callables that the transport hook installs +# as `task.callable`. Without this, dill sees `__module__ == "__main__"` +# (this file is run as ``python -m chemgraph.mcp.mace_mcp_hpc``) and falls +# back to pickle-by-value, which walks the module's globals and tries to +# serialize the dynamic ``run_mace_singleArguments`` class held by +# ``mcp._tool_manager._tools[...].fn_metadata.arg_model`` -- that class +# was created by ``pydantic.create_model`` with a ``__module__`` it was +# never registered into, so dill raises a PicklingError. +CGFastMCP._fix_module_for_pickle(_mace_worker) + + # ── Pre-submit transport hook ────────────────────────────────────────── @@ -163,6 +174,11 @@ def _normalize_model(job: dict) -> None: def _mace_transport_hook(task: TaskSpec) -> TaskSpec: """Route single-tool calls to the dict-based worker and embed local structures on whichever path is taken.""" + logger.debug( + "mace transport hook: task_id=%s callable=%s", + task.task_id, + getattr(task.callable, "__qualname__", task.callable), + ) if task.callable is run_mace_single: params = task.kwargs.get("params") if params is None: @@ -217,6 +233,9 @@ def _ls_remote_files(path: str) -> list[str]: ) +CGFastMCP._fix_module_for_pickle(_ls_remote_files) + + def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: """Server-side expansion of an ensemble request into per-file jobs. From 99ef6f394c1c0b9b162b3d46c2223485142464f3 Mon Sep 17 00:00:00 2001 From: tdpham2 Date: Thu, 11 Jun 2026 19:27:07 +0000 Subject: [PATCH 081/119] Cleanly shut down Parsl DFK in ParslBackend.shutdown parsl.clear() only removes the DFK from the global registry; it does not stop executors. Without parsl.dfk().cleanup(), Parsl logs "Python is exiting with a DFK still running" at interpreter exit and relies on atexit hooks for executor teardown. Call cleanup() before clear() and log (but do not raise) on cleanup failure. Co-Authored-By: Claude Opus 4.7 --- src/chemgraph/execution/parsl_backend.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/chemgraph/execution/parsl_backend.py b/src/chemgraph/execution/parsl_backend.py index 9e963b50..c1b1c286 100644 --- a/src/chemgraph/execution/parsl_backend.py +++ b/src/chemgraph/execution/parsl_backend.py @@ -119,6 +119,14 @@ def shutdown(self) -> None: try: import parsl + # cleanup() stops executors and releases resources; + # clear() only removes the DFK from the global registry. + # Without cleanup(), Parsl logs + # "Python is exiting with a DFK still running" at interpreter exit. + try: + parsl.dfk().cleanup() + except Exception: + logger.warning("Error during Parsl DFK cleanup.", exc_info=True) parsl.clear() logger.info("ParslBackend shut down.") except Exception: From 9be77409651fea6db8b2a0f7b04af7b0be386592 Mon Sep 17 00:00:00 2001 From: tdpham2 Date: Thu, 11 Jun 2026 19:29:39 +0000 Subject: [PATCH 082/119] Make Parsl worker_init configurable across HPC system configs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds resolve_worker_init(run_dir, fallback) in hpc_configs/loader.py with three-tier precedence: CHEMGRAPH_WORKER_INIT env override → submitter env auto-detect (VIRTUAL_ENV, then CONDA_PREFIX) → caller- provided per-system fallback. Every config now accepts an optional worker_init kwarg and routes through this helper so Parsl workers land in the same Python environment as the submitter without requiring code edits per HPC system. Per-system fallbacks: Crux "module load conda; conda activate base", Aurora "module load frameworks", Polaris/Local "true". Co-Authored-By: Claude Opus 4.7 --- src/chemgraph/hpc_configs/aurora_parsl.py | 12 +++++-- src/chemgraph/hpc_configs/crux_parsl.py | 15 ++++++--- src/chemgraph/hpc_configs/loader.py | 37 ++++++++++++++++++++++ src/chemgraph/hpc_configs/local_parsl.py | 11 +++++-- src/chemgraph/hpc_configs/polaris_parsl.py | 12 +++++-- 5 files changed, 77 insertions(+), 10 deletions(-) diff --git a/src/chemgraph/hpc_configs/aurora_parsl.py b/src/chemgraph/hpc_configs/aurora_parsl.py index 2f7ac354..26c80c63 100644 --- a/src/chemgraph/hpc_configs/aurora_parsl.py +++ b/src/chemgraph/hpc_configs/aurora_parsl.py @@ -5,9 +5,12 @@ from parsl.launchers import MpiExecLauncher from parsl.addresses import address_by_interface +from chemgraph.hpc_configs.loader import resolve_worker_init + def get_aurora_config( run_dir=None, + worker_init: str | None = None, ): """Create a Parsl configuration for Aurora PBS jobs. @@ -15,6 +18,11 @@ def get_aurora_config( ---------- run_dir : str, optional Directory used as Parsl's run directory and worker working directory. + worker_init : str, optional + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over the Aurora fallback + (``module load frameworks``). Returns ------- @@ -24,8 +32,8 @@ def get_aurora_config( if run_dir is None: run_dir = os.getcwd() - # Hard-wired worker_init for aurora - worker_init = f"export TMPDIR=/tmp; cd {run_dir}; module load frameworks" + if worker_init is None: + worker_init = resolve_worker_init(run_dir, fallback="module load frameworks") # Get the number of nodes: node_file = os.getenv("PBS_NODEFILE") diff --git a/src/chemgraph/hpc_configs/crux_parsl.py b/src/chemgraph/hpc_configs/crux_parsl.py index 07b3051b..e753ed3e 100644 --- a/src/chemgraph/hpc_configs/crux_parsl.py +++ b/src/chemgraph/hpc_configs/crux_parsl.py @@ -4,10 +4,13 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import MpiExecLauncher +from chemgraph.hpc_configs.loader import resolve_worker_init + def get_crux_config( run_dir=None, max_workers_per_node: int = 16, + worker_init: str | None = None, ): """Create a Parsl configuration for ALCF Crux PBS jobs. @@ -20,6 +23,10 @@ def get_crux_config( max_workers_per_node : int, optional Number of concurrent workers per node. Defaults to 16 (≈8 cores per worker on a 128-core node). + worker_init : str, optional + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over the Crux fallback. Returns ------- @@ -29,10 +36,10 @@ def get_crux_config( if run_dir is None: run_dir = os.getcwd() - worker_init = ( - f"export TMPDIR=/tmp; cd {run_dir}; " - "module load conda; conda activate base" - ) + if worker_init is None: + worker_init = resolve_worker_init( + run_dir, fallback="module load conda; conda activate base" + ) node_file = os.getenv("PBS_NODEFILE") if node_file and os.path.exists(node_file): diff --git a/src/chemgraph/hpc_configs/loader.py b/src/chemgraph/hpc_configs/loader.py index 7aed297f..29a71a10 100644 --- a/src/chemgraph/hpc_configs/loader.py +++ b/src/chemgraph/hpc_configs/loader.py @@ -13,6 +13,43 @@ logger = logging.getLogger(__name__) +def resolve_worker_init(run_dir: str, fallback: str) -> str: + """Build a Parsl ``worker_init`` shell snippet with layered precedence. + + Precedence (highest first): + + 1. Environment variable ``CHEMGRAPH_WORKER_INIT`` -- if set and non-empty, + used verbatim. Lets a user point Parsl workers at any env without + editing code. + 2. Auto-detect the submitting process's Python env and emit an activate + line for it (``VIRTUAL_ENV`` then ``CONDA_PREFIX``). The agent / MCP + subprocess runs from this env, so workers should too. + 3. The system-specific *fallback* string passed by the caller (e.g. + ``"module load conda; conda activate base"`` on Crux). + + The returned string is always prefixed with ``export TMPDIR=/tmp; + cd {run_dir};`` so Parsl workers land in the same directory the + submitter chose. + """ + override = os.environ.get("CHEMGRAPH_WORKER_INIT", "").strip() + if override: + activate = override + else: + venv = os.environ.get("VIRTUAL_ENV", "").strip() + conda_prefix = os.environ.get("CONDA_PREFIX", "").strip() + conda_env = os.environ.get("CONDA_DEFAULT_ENV", "").strip() + if venv: + activate = f"source {venv}/bin/activate" + elif conda_prefix and conda_env: + activate = ( + f"source {conda_prefix}/etc/profile.d/conda.sh && " + f"conda activate {conda_env}" + ) + else: + activate = fallback + return f"export TMPDIR=/tmp; cd {run_dir}; {activate}" + + def load_parsl_config(system_name: str, run_dir: str | None = None, **kwargs): """Dynamically import and return a Parsl ``Config`` for the given HPC system. diff --git a/src/chemgraph/hpc_configs/local_parsl.py b/src/chemgraph/hpc_configs/local_parsl.py index b4c05f01..ac4f61ff 100644 --- a/src/chemgraph/hpc_configs/local_parsl.py +++ b/src/chemgraph/hpc_configs/local_parsl.py @@ -15,6 +15,8 @@ from parsl.executors import HighThroughputExecutor from parsl.providers import LocalProvider +from chemgraph.hpc_configs.loader import resolve_worker_init + logger = logging.getLogger(__name__) _DEFAULT_MAX_WORKERS = 4 @@ -23,7 +25,7 @@ def get_local_config( run_dir: str | None = None, max_workers: int = _DEFAULT_MAX_WORKERS, - worker_init: str = "export TMPDIR=/tmp", + worker_init: str | None = None, ) -> Config: """Generate a Parsl configuration for local execution. @@ -34,11 +36,16 @@ def get_local_config( max_workers : int, optional Maximum number of concurrent workers. Default: 4. worker_init : str, optional - Shell commands executed on each worker before tasks. + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over a noop fallback. """ if run_dir is None: run_dir = os.getcwd() + if worker_init is None: + worker_init = resolve_worker_init(run_dir, fallback="true") + logger.info("Creating local Parsl config with %d workers", max_workers) config = Config( diff --git a/src/chemgraph/hpc_configs/polaris_parsl.py b/src/chemgraph/hpc_configs/polaris_parsl.py index ef60f207..bdaa9075 100644 --- a/src/chemgraph/hpc_configs/polaris_parsl.py +++ b/src/chemgraph/hpc_configs/polaris_parsl.py @@ -4,10 +4,12 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import MpiExecLauncher +from chemgraph.hpc_configs.loader import resolve_worker_init + def get_polaris_config( run_dir=None, - worker_init: str = "export TMPDIR=/tmp", + worker_init: str | None = None, ): """Generate the Parsl configuration for the Polaris supercomputer. @@ -16,7 +18,10 @@ def get_polaris_config( run_dir : str, optional Directory used as Parsl's run directory. worker_init : str, optional - Shell initialization snippet run by each Parsl worker. + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over a bare ``export TMPDIR=/tmp`` + fallback. Returns ------- @@ -26,6 +31,9 @@ def get_polaris_config( if run_dir is None: run_dir = os.getcwd() + if worker_init is None: + worker_init = resolve_worker_init(run_dir, fallback="true") + # Get the number of nodes from the PBS environment node_file = os.getenv("PBS_NODEFILE") if node_file and os.path.exists(node_file): From 396a528a7c8020be87b67579c2ebd40f8b54688a Mon Sep 17 00:00:00 2001 From: tdpham2 Date: Thu, 11 Jun 2026 19:30:39 +0000 Subject: [PATCH 083/119] Add Crux support and worker-env forwarding to Parsl agent demo - demo_parsl_in_job_agent.py: accept crux as a supported system (default device=cpu); forward VIRTUAL_ENV, CONDA_PREFIX, CONDA_DEFAULT_ENV, CHEMGRAPH_WORKER_INIT, PBS_NODEFILE, and PBS_O_WORKDIR to the MCP stdio subprocess so the Parsl workers re-activate the submitter's Python env. - _demo_chemistry.py: include wall-time column in the agent prompt. - README.md: document the Crux PBS workflow. - run_crux_demo.sh: PBS-side wrapper that activates the venv and invokes the Parsl + EnsembleLauncher demos with system=crux. Co-Authored-By: Claude Opus 4.7 --- scripts/demo/README.md | 16 +++++ scripts/demo/_demo_chemistry.py | 2 +- scripts/demo/demo_parsl_in_job_agent.py | 22 +++++-- scripts/demo/run_crux_demo.sh | 87 +++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 7 deletions(-) create mode 100755 scripts/demo/run_crux_demo.sh diff --git a/scripts/demo/README.md b/scripts/demo/README.md index 17c7167b..2b73b216 100644 --- a/scripts/demo/README.md +++ b/scripts/demo/README.md @@ -144,6 +144,22 @@ python scripts/demo/demo_parsl_in_job_direct.py --device xpu python scripts/demo/demo_ensemble_launcher_in_job_direct.py --device xpu ``` +### Inside a PBS allocation on Crux (CPU-only) + +```bash +qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle +cd /lus/eagle/projects/ChemGraph/thang/ChemGraph + +bash scripts/demo/run_crux_demo.sh # Parsl + EL, all 5 molecules +bash scripts/demo/run_crux_demo.sh --molecules water methane +bash scripts/demo/run_crux_demo.sh --parsl-only +bash scripts/demo/run_crux_demo.sh --el-only +``` + +The wrapper activates `.cg_crux_hpc/`, exports `COMPUTE_SYSTEM=crux`, and runs +`demo_parsl_in_job_direct.py` then `demo_ensemble_launcher_in_job_direct.py` +with `--device cpu`. CSVs land in `$PBS_O_WORKDIR/demo_{parsl,el}_out_crux/`. + Agent variants on either system require an LLM key and follow the same pattern as `demo_local_agent.py`. diff --git a/scripts/demo/_demo_chemistry.py b/scripts/demo/_demo_chemistry.py index 82f714ee..4e2d1547 100644 --- a/scripts/demo/_demo_chemistry.py +++ b/scripts/demo/_demo_chemistry.py @@ -269,7 +269,7 @@ def agent_prompt(device: str = "cpu") -> str: f"For each result, retrieve the optimized electronic energy, enthalpy, " f"entropy and Gibbs free energy by reading the output JSON via " f"extract_output_json. After all five complete, report a markdown table " - f"with columns: molecule, energy (eV), H (eV), G (eV), then a one-line " + f"with columns: molecule, energy (eV), H (eV), G (eV), and wall-time then a one-line " f"observation about which molecule has the lowest Gibbs free energy.\n\n" f"(Structure paths for reference: {files})" ) diff --git a/scripts/demo/demo_parsl_in_job_agent.py b/scripts/demo/demo_parsl_in_job_agent.py index 4aab2f46..3d05c650 100644 --- a/scripts/demo/demo_parsl_in_job_agent.py +++ b/scripts/demo/demo_parsl_in_job_agent.py @@ -2,11 +2,11 @@ """Agent + MCP + Parsl demo on an HPC compute node. LLM agent on the compute node drives a local ``mace_mcp_hpc`` -subprocess whose backend is ``parsl`` configured for Polaris or -Aurora. The agent uses ``run_mace_single`` to compute thermochemistry +subprocess whose backend is ``parsl`` configured for Polaris, Aurora, +or Crux. The agent uses ``run_mace_single`` to compute thermochemistry for each of the 5 molecules and reports a markdown table. -Must run inside ``qsub -I`` on Polaris/Aurora. LLM API key required. +Must run inside ``qsub -I`` on Polaris/Aurora/Crux. LLM API key required. Run:: @@ -53,6 +53,9 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int) "PATH": os.environ.get("PATH", ""), "HOME": os.environ.get("HOME", ""), "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + "CONDA_PREFIX": os.environ.get("CONDA_PREFIX", ""), + "CONDA_DEFAULT_ENV": os.environ.get("CONDA_DEFAULT_ENV", ""), + "CHEMGRAPH_WORKER_INIT": os.environ.get("CHEMGRAPH_WORKER_INIT", ""), "PBS_NODEFILE": os.environ.get("PBS_NODEFILE", ""), "PBS_O_WORKDIR": os.environ.get("PBS_O_WORKDIR", ""), } @@ -116,9 +119,16 @@ def main() -> None: if not args.system: _abort("COMPUTE_SYSTEM env var not set and --system not given.") system = args.system.lower().strip() - if system not in ("polaris", "aurora"): - _abort(f"Unsupported --system: {system!r}") - device = args.device or ("xpu" if system == "aurora" else "cuda") + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r} (expected polaris|aurora|crux)") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" query = args.query or agent_prompt(device=device) asyncio.run(amain(args.model, system, device, query, args.verbose)) diff --git a/scripts/demo/run_crux_demo.sh b/scripts/demo/run_crux_demo.sh new file mode 100755 index 00000000..53ba9acd --- /dev/null +++ b/scripts/demo/run_crux_demo.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash +# Run Parsl + EnsembleLauncher demo (5-molecule thermo screen, MACE on CPU) +# on a Crux compute node. +# +# Must be executed INSIDE an interactive PBS allocation on Crux: +# qsub -I -A -l select=1 -l walltime=01:00:00 -q debug +# cd /lus/eagle/projects/ChemGraph/thang/ChemGraph +# bash scripts/demo/run_crux_demo.sh # both backends +# bash scripts/demo/run_crux_demo.sh --parsl-only +# bash scripts/demo/run_crux_demo.sh --el-only +# bash scripts/demo/run_crux_demo.sh --molecules water methane + +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +abort() { + echo "[ABORT] $*" >&2 + exit 2 +} + +RUN_PARSL=1 +RUN_EL=1 +PASSTHROUGH=() +while (( $# )); do + case "$1" in + --parsl-only) RUN_EL=0; shift ;; + --el-only) RUN_PARSL=0; shift ;; + --molecules) shift; while (( $# )) && [[ "$1" != --* ]]; do PASSTHROUGH+=("$1"); shift; done; PASSTHROUGH=(--molecules "${PASSTHROUGH[@]}") ;; + --timeout) PASSTHROUGH+=("$1" "$2"); shift 2 ;; + -h|--help) sed -n '2,12p' "${BASH_SOURCE[0]}"; exit 0 ;; + *) abort "Unknown argument: $1" ;; + esac +done + +[[ -n "${PBS_NODEFILE:-}" && -f "${PBS_NODEFILE}" ]] \ + || abort "PBS_NODEFILE not set or missing -- run inside 'qsub -I' on Crux." + +VENV_ACTIVATE="$REPO_ROOT/.cg_crux_hpc/bin/activate" +[[ -f "$VENV_ACTIVATE" ]] || abort "Missing venv activate script: $VENV_ACTIVATE" + +if [[ "${VIRTUAL_ENV:-}" != "$REPO_ROOT/.cg_crux_hpc" ]]; then + module load conda 2>/dev/null || true + # shellcheck disable=SC1090 + source "$VENV_ACTIVATE" +fi + +export COMPUTE_SYSTEM=crux +RUN_DIR="${PBS_O_WORKDIR:-$PWD}/parsl_demo_runs_crux" +PARSL_OUT="${PBS_O_WORKDIR:-$PWD}/demo_parsl_out_crux" +EL_OUT="${PBS_O_WORKDIR:-$PWD}/demo_el_out_crux" +mkdir -p "$RUN_DIR" "$PARSL_OUT" "$EL_OUT" + +echo "REPO_ROOT=$REPO_ROOT" +echo "VIRTUAL_ENV=${VIRTUAL_ENV:-}" +echo "PBS_NODEFILE=$PBS_NODEFILE ($(wc -l <"$PBS_NODEFILE") node(s))" +echo "RUN_DIR=$RUN_DIR" +echo "PARSL_OUT=$PARSL_OUT EL_OUT=$EL_OUT" +echo + +parsl_rc=0 +el_rc=0 + +if (( RUN_PARSL )); then + echo "=== Parsl demo (system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/demo/demo_parsl_in_job_direct.py" \ + --system crux --device cpu --run-dir "$RUN_DIR" \ + --output-dir "$PARSL_OUT" "${PASSTHROUGH[@]}" \ + || parsl_rc=$? + echo +fi + +if (( RUN_EL )); then + echo "=== EnsembleLauncher demo (managed, system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/demo/demo_ensemble_launcher_in_job_direct.py" \ + --system crux --device cpu \ + --output-dir "$EL_OUT" "${PASSTHROUGH[@]}" \ + || el_rc=$? + echo +fi + +verdict() { (( $1 == 0 )) && echo PASS || echo "FAIL(rc=$1)"; } +echo "=== Summary ===" +(( RUN_PARSL )) && echo "parsl = $(verdict $parsl_rc) (output: $PARSL_OUT)" +(( RUN_EL )) && echo "el = $(verdict $el_rc) (output: $EL_OUT)" + +(( parsl_rc > el_rc )) && exit "$parsl_rc" || exit "$el_rc" From ff4aa1436032270a72387ed549ea87b7bbe35c48 Mon Sep 17 00:00:00 2001 From: tdpham2 Date: Thu, 11 Jun 2026 19:33:48 +0000 Subject: [PATCH 084/119] Add Crux support to Parsl + EnsembleLauncher smoke harness - _smoke_utils.py: add ensure_on_worker_pythonpath() so Parsl workers can import _smoke_utils from the script directory. - smoke_parsl_in_job.py / smoke_ensemble_launcher_in_job.py: call ensure_on_worker_pythonpath() at import time. - README.md: document the Crux PBS workflow. - run_crux_smoke.sh: PBS-side wrapper that activates the venv and runs both smoke entrypoints with system=crux. Co-Authored-By: Claude Opus 4.7 --- scripts/smoke/README.md | 16 ++++ scripts/smoke/_smoke_utils.py | 15 ++++ scripts/smoke/run_crux_smoke.sh | 80 +++++++++++++++++++ .../smoke/smoke_ensemble_launcher_in_job.py | 3 + scripts/smoke/smoke_parsl_in_job.py | 3 + 5 files changed, 117 insertions(+) create mode 100755 scripts/smoke/run_crux_smoke.sh diff --git a/scripts/smoke/README.md b/scripts/smoke/README.md index e52f2ee3..07e83e12 100644 --- a/scripts/smoke/README.md +++ b/scripts/smoke/README.md @@ -94,6 +94,22 @@ python scripts/smoke/smoke_parsl_in_job.py --device xpu python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed --device xpu ``` +### Inside a PBS allocation on Crux (CPU-only) + +```bash +qsub -I -A -l select=1 -l walltime=00:30:00 -q debug -l filesystems=home:eagle +cd /lus/eagle/projects/ChemGraph/thang/ChemGraph + +bash scripts/smoke/run_crux_smoke.sh # both backends + MACE on CPU +bash scripts/smoke/run_crux_smoke.sh --quick # skip MACE +bash scripts/smoke/run_crux_smoke.sh --parsl-only +bash scripts/smoke/run_crux_smoke.sh --el-only +``` + +The wrapper activates `.cg_crux_hpc/`, exports `COMPUTE_SYSTEM=crux`, and runs +`smoke_parsl_in_job.py` then `smoke_ensemble_launcher_in_job.py` with +`--device cpu`. It exits non-zero if either backend fails. + ### EnsembleLauncher client-only mode Exercises `EnsembleLauncherBackend(client_only=True, ...)` introduced in diff --git a/scripts/smoke/_smoke_utils.py b/scripts/smoke/_smoke_utils.py index 492657f8..2036c689 100644 --- a/scripts/smoke/_smoke_utils.py +++ b/scripts/smoke/_smoke_utils.py @@ -108,3 +108,18 @@ def trivial_env_probe() -> dict: except Exception as exc: info["torch_error"] = str(exc) return info + + +def ensure_on_worker_pythonpath() -> None: + """Add this file's directory to ``PYTHONPATH`` so that worker processes + (Parsl HTEX, EnsembleLauncher, Globus Compute) can ``import _smoke_utils`` + when unpickling tasks. Safe to call from the main process before backend + creation; no-op if already present. + """ + import os + + here = str(Path(__file__).resolve().parent) + existing = os.environ.get("PYTHONPATH", "") + parts = existing.split(os.pathsep) if existing else [] + if here not in parts: + os.environ["PYTHONPATH"] = os.pathsep.join([here, *parts]) if parts else here diff --git a/scripts/smoke/run_crux_smoke.sh b/scripts/smoke/run_crux_smoke.sh new file mode 100755 index 00000000..ab9f8850 --- /dev/null +++ b/scripts/smoke/run_crux_smoke.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# Run Parsl + EnsembleLauncher smoke tests on a Crux compute node (MACE on CPU). +# +# Must be executed INSIDE an interactive PBS allocation on Crux: +# qsub -I -A -l select=1 -l walltime=00:30:00 -q debug +# cd /lus/eagle/projects/ChemGraph/thang/ChemGraph +# bash scripts/smoke/run_crux_smoke.sh # both backends + MACE +# bash scripts/smoke/run_crux_smoke.sh --quick # skip MACE +# bash scripts/smoke/run_crux_smoke.sh --parsl-only +# bash scripts/smoke/run_crux_smoke.sh --el-only + +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +abort() { + echo "[ABORT] $*" >&2 + exit 2 +} + +QUICK="" +RUN_PARSL=1 +RUN_EL=1 +for arg in "$@"; do + case "$arg" in + --quick) QUICK="--quick" ;; + --parsl-only) RUN_EL=0 ;; + --el-only) RUN_PARSL=0 ;; + -h|--help) sed -n '2,11p' "${BASH_SOURCE[0]}"; exit 0 ;; + *) abort "Unknown argument: $arg" ;; + esac +done + +[[ -n "${PBS_NODEFILE:-}" && -f "${PBS_NODEFILE}" ]] \ + || abort "PBS_NODEFILE not set or missing -- run inside 'qsub -I' on Crux." + +VENV_ACTIVATE="$REPO_ROOT/.cg_crux_hpc/bin/activate" +[[ -f "$VENV_ACTIVATE" ]] || abort "Missing venv activate script: $VENV_ACTIVATE" + +if [[ "${VIRTUAL_ENV:-}" != "$REPO_ROOT/.cg_crux_hpc" ]]; then + module load conda 2>/dev/null || true + # shellcheck disable=SC1090 + source "$VENV_ACTIVATE" +fi + +export COMPUTE_SYSTEM=crux +RUN_DIR="${PBS_O_WORKDIR:-$PWD}/parsl_runs_smoke_crux" +mkdir -p "$RUN_DIR" + +echo "REPO_ROOT=$REPO_ROOT" +echo "VIRTUAL_ENV=${VIRTUAL_ENV:-}" +echo "PBS_NODEFILE=$PBS_NODEFILE ($(wc -l <"$PBS_NODEFILE") node(s))" +echo "RUN_DIR=$RUN_DIR" +echo + +parsl_rc=0 +el_rc=0 + +if (( RUN_PARSL )); then + echo "=== Parsl smoke (system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/smoke/smoke_parsl_in_job.py" \ + --system crux --device cpu --run-dir "$RUN_DIR" $QUICK \ + || parsl_rc=$? + echo +fi + +if (( RUN_EL )); then + echo "=== EnsembleLauncher smoke (managed, system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/smoke/smoke_ensemble_launcher_in_job.py" \ + --mode managed --system crux --device cpu $QUICK \ + || el_rc=$? + echo +fi + +verdict() { (( $1 == 0 )) && echo PASS || echo "FAIL(rc=$1)"; } +echo "=== Summary ===" +(( RUN_PARSL )) && echo "parsl = $(verdict $parsl_rc)" +(( RUN_EL )) && echo "el = $(verdict $el_rc)" + +(( parsl_rc > el_rc )) && exit "$parsl_rc" || exit "$el_rc" diff --git a/scripts/smoke/smoke_ensemble_launcher_in_job.py b/scripts/smoke/smoke_ensemble_launcher_in_job.py index b178454d..250e6ba1 100644 --- a/scripts/smoke/smoke_ensemble_launcher_in_job.py +++ b/scripts/smoke/smoke_ensemble_launcher_in_job.py @@ -47,12 +47,15 @@ from _smoke_utils import ( SmokeReporter, + ensure_on_worker_pythonpath, trivial_add, trivial_hostname, trivial_square, water_xyz_path, ) +ensure_on_worker_pythonpath() + def _abort(msg: str) -> None: print(f"[ABORT] {msg}") diff --git a/scripts/smoke/smoke_parsl_in_job.py b/scripts/smoke/smoke_parsl_in_job.py index 75daead3..ee426162 100644 --- a/scripts/smoke/smoke_parsl_in_job.py +++ b/scripts/smoke/smoke_parsl_in_job.py @@ -29,6 +29,7 @@ from _smoke_utils import ( SmokeReporter, + ensure_on_worker_pythonpath, trivial_add, trivial_env_probe, trivial_hostname, @@ -36,6 +37,8 @@ water_xyz_path, ) +ensure_on_worker_pythonpath() + def _abort(msg: str) -> None: print(f"[ABORT] {msg}") From 4bca7d9b18f1d9774f1dded9eaef64b608933385 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 15:20:23 -0500 Subject: [PATCH 085/119] fix(events): always emit llm_decision so dashboard renders single-LLM-call runs The browser dashboard workflowFlowGraph() builds zero nodes unless at least one event among llm_decision, workflow_output, run_finished, or tool_call_* appears in the stream. run_finished is filtered out by isWorkflowEvent for lacking workflow_type/thread_id markers. A CLI run that answers from the LLM without invoking any tools therefore rendered as "Waiting for ChemGraph workflow execution events." Always emitting llm_decision (with an empty tool_calls list when the LLM made no tool call) gives the renderer the marker it needs to draw the LM node. Both _AstreamEventCallback (CLI) and _TurnEventCallback (academy) inherit the change; academy already handled empty tool_calls correctly in its turn-classification logic. Also documents the --trace-dir + dashboard flow for traditional ChemGraph runs in the example-002 e2e guide. --- .../e2e_guide.md | 45 +++++++++++++++++++ src/chemgraph/agent/events.py | 6 ++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md index 8b9e9a1c..2f4a4de5 100644 --- a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md +++ b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md @@ -234,6 +234,51 @@ PYTHONPATH=src python -m chemgraph.cli.main academy dashboard -- \ --local ``` +## Dashboard For Traditional ChemGraph Runs + +The dashboard also renders single-agent ChemGraph runs that were not launched +through Academy. Pass `--trace-dir ` to `chemgraph run` to write the +events the dashboard needs (`events.jsonl`, `status.json`, `manifest.json`), +then point the dashboard at that directory. + +On-site at ANL, the simplest path is the built-in Argo support — no shim or +relay needed (set `ARGO_USER` once per shell, or in your shell profile): + +```bash +export ARGO_USER="$ARGO_USER" + +chemgraph run \ + -q "What is the SMILES for water" \ + -m "argo:gpt-5.4" \ + --trace-dir ./run-001 +``` + +Then serve the trace directory: + +```bash +chemgraph dashboard -- --run-dir ./run-001 --port 8765 +# Open http://127.0.0.1:8765 +``` + +The browser shows the same per-agent workflow inspector that Academy displays +for a logical-agent node (query → LLM call → tool calls → output), but at the +top level since the run only has one agent. Use a fresh `--trace-dir` per run +so multiple runs don't pile into one `events.jsonl`. + +`--trace-dir` is currently only effective for the `single_agent` workflow. +Other workflows (`multi_agent`, `python_relp`, `graspa`, `rag_agent`, +`single_agent_xanes`, ...) run normally but don't yet emit dashboard events, +and the CLI prints a yellow warning for those. + +If the browser shows "Waiting for ChemGraph workflow execution events" after a +run completed successfully, the remote checkout is missing the +`llm_decision`-on-every-LLM-call fix. Sync the latest ChemGraph and clear +stale bytecode locally: + +```bash +find src/chemgraph -name __pycache__ -type d -exec rm -rf {} + +``` + ## Troubleshooting Check the relay from compute: diff --git a/src/chemgraph/agent/events.py b/src/chemgraph/agent/events.py index b714cd53..88090877 100644 --- a/src/chemgraph/agent/events.py +++ b/src/chemgraph/agent/events.py @@ -67,8 +67,10 @@ def on_llm_end(self, response, **kwargs) -> None: if isinstance(usage, dict): payload["llm_output"] = usage self._emit("llm_call_finished", payload) - if tool_calls := _response_tool_calls(response): - self._emit("llm_decision", {"tool_calls": tool_calls}) + self._emit( + "llm_decision", + {"tool_calls": _response_tool_calls(response) or []}, + ) def on_llm_error(self, error, **kwargs) -> None: self._emit("llm_call_failed", {"error": repr(error)}) From cbe6094fbe95b46a095b6f13883e271a98fde306 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 11 Jun 2026 21:30:43 -0500 Subject: [PATCH 086/119] fix(mcp): isolate MACE worker in a subprocess to dodge Parsl hang on Aurora MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Running MACE directly inside a Parsl process_worker_pool.py worker on Aurora hangs the worker indefinitely. No Python exception is raised, no OS kill signal lands in dmesg, RSS is 68 MB at death — the worker simply stops making progress somewhere inside the MACE call. The same call from a standalone Python interpreter completes in 11 seconds. The failure reproduces with: - 1 worker per node and 9 per node - SimpleLauncher and MpiExecLauncher - fork and spawn multiprocessing start methods - hb_threshold=120 and 900 - bond0 and hsn0 interfaces - 1, 5, and 20-task workloads It does NOT reproduce when MACE runs in a child subprocess of the Parsl worker. So the public _mace_worker now materializes its job dict to a temp JSON file and invokes a small subrunner module in a fresh Python interpreter, which runs the historical in-proc worker body (preserved as _mace_worker_inproc) and writes the result back. Verified on a 5-node debug-scaling allocation: 20 tasks distributed across 3 nodes complete successfully (11-24s each) where the non-subprocess version had 0/20 success. Cost: ~3-5s per MACE call for the extra Python startup + frameworks re-init. Acceptable for the 60+ second MACE calls real workloads run. tests/test_mcp.py now exercises _mace_worker_inproc directly so its run_mace_core monkeypatch is visible to the function under test. --- src/chemgraph/mcp/_mace_subrunner.py | 43 +++++++++++++++ src/chemgraph/mcp/mace_worker.py | 79 +++++++++++++++++++++++++++- tests/test_mcp.py | 5 +- 3 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 src/chemgraph/mcp/_mace_subrunner.py diff --git a/src/chemgraph/mcp/_mace_subrunner.py b/src/chemgraph/mcp/_mace_subrunner.py new file mode 100644 index 00000000..655d0535 --- /dev/null +++ b/src/chemgraph/mcp/_mace_subrunner.py @@ -0,0 +1,43 @@ +"""Subprocess entry point for one MACE calculation. + +Invoked by :func:`chemgraph.mcp.mace_worker._mace_worker` as a fresh +Python interpreter to dodge a silent worker-hang seen when MACE runs +directly inside a Parsl ``process_worker_pool.py`` worker on Aurora. + +CLI: + python -m chemgraph.mcp._mace_subrunner + +Reads *job.json*, runs the in-proc MACE worker, writes the result dict +to *result.json*. Exits non-zero on any uncaught exception. +""" + +from __future__ import annotations + +import json +import sys + + +def main() -> int: + if len(sys.argv) != 3: + print( + "usage: python -m chemgraph.mcp._mace_subrunner ", + file=sys.stderr, + ) + return 2 + + job_path, result_path = sys.argv[1], sys.argv[2] + with open(job_path, encoding="utf-8") as fh: + job = json.load(fh) + + from chemgraph.mcp.mace_worker import _mace_worker_inproc + + result = _mace_worker_inproc(job) + + with open(result_path, "w", encoding="utf-8") as fh: + json.dump(result, fh, default=str) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/mcp/mace_worker.py b/src/chemgraph/mcp/mace_worker.py index b4b00b23..6c07e7e0 100644 --- a/src/chemgraph/mcp/mace_worker.py +++ b/src/chemgraph/mcp/mace_worker.py @@ -4,13 +4,90 @@ decorators. Parsl/dill serializes worker functions by walking their module globals, so backend workers must live outside modules that contain FastMCP's runtime-generated argument classes. + +The public ``_mace_worker`` shells the actual MACE call into a fresh Python +subprocess. Running MACE directly inside a Parsl worker on Aurora hangs +the worker indefinitely (no Python exception, no OS-level kill signal) — +the failure is silent and only happens inside Parsl's +``process_worker_pool.py`` process model. A clean interpreter dodges it, +so we pay the per-call subprocess cost (~3-5s of Python startup + +``module load frameworks`` env) to keep MACE working. """ +import json import os +import subprocess +import sys +import tempfile + +_SUBPROCESS_TIMEOUT_S = 3600 +_SUBRUNNER_MODULE = "chemgraph.mcp._mace_subrunner" def _mace_worker(job: dict) -> dict: - """Execute a single MACE simulation on a backend worker.""" + """Run one MACE simulation in an isolated subprocess. + + Materializes *job* to a temp JSON file and invokes + :mod:`chemgraph.mcp._mace_subrunner` in a child interpreter. The child + writes the result back to a sibling JSON file which we read and return. + """ + with tempfile.NamedTemporaryFile( + mode="w", suffix=".job.json", delete=False, encoding="utf-8", + ) as job_fh: + json.dump(job, job_fh) + job_path = job_fh.name + result_path = f"{job_path}.result.json" + + try: + completed = subprocess.run( + [sys.executable, "-m", _SUBRUNNER_MODULE, job_path, result_path], + capture_output=True, + text=True, + timeout=_SUBPROCESS_TIMEOUT_S, + ) + if completed.returncode != 0: + return { + "status": "failure", + "error_type": "SubprocessError", + "message": ( + f"MACE subprocess exited with {completed.returncode}. " + f"stderr tail: {completed.stderr[-500:]}" + ), + } + if not os.path.isfile(result_path): + return { + "status": "failure", + "error_type": "SubprocessError", + "message": ( + "MACE subprocess exited 0 but wrote no result file. " + f"stdout tail: {completed.stdout[-500:]}" + ), + } + with open(result_path, encoding="utf-8") as fh: + return json.load(fh) + except subprocess.TimeoutExpired: + return { + "status": "failure", + "error_type": "TimeoutExpired", + "message": f"MACE subprocess exceeded {_SUBPROCESS_TIMEOUT_S}s", + } + finally: + for path in (job_path, result_path): + try: + os.unlink(path) + except OSError: + pass + + +def _mace_worker_inproc(job: dict) -> dict: + """Execute a single MACE simulation directly in the caller's interpreter. + + This is the historical worker body. It is invoked by + :mod:`chemgraph.mcp._mace_subrunner` inside the subprocess that + :func:`_mace_worker` spawns, and is exposed here so callers that want + direct in-process execution (e.g. local CLI runs without a Parsl + backend) can still reach the same code path. + """ import json import tempfile diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 51a42f4d..e8911374 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -63,7 +63,10 @@ def fake_run_mace_core(params): monkeypatch.setattr(parsl_tools, "run_mace_core", fake_run_mace_core) - result = mace_worker._mace_worker( + # Exercise the in-proc body so we can monkeypatch run_mace_core. The + # public _mace_worker wraps this in a subprocess to dodge a Parsl-on- + # Aurora hang and would not see test monkeypatches. + result = mace_worker._mace_worker_inproc( { "inline_structure": atoms_to_atomsdata(atoms).model_dump(), "output_result_file": output_file, From e9dd903537bc2f3da4ec4ac492b6c106994b7fd2 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Fri, 12 Jun 2026 20:38:26 +0000 Subject: [PATCH 087/119] Pick a multi-node-safe MPI flavour per HPC system for EnsembleLauncher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `get_launcher_config`'s default `mpi_flavour="test"` only works for single-host runs: its `write_file_to_nodes` does not actually distribute the per-child JSON spec to remote `/tmp`, so when `main.mN` tries to launch a worker on a different node it dies with `FileNotFoundError` on `/tmp/.mpiexec_tmp/child_.json` and the demo hangs. Flip the default to `"mpich"` (hydra `mpiexec`), widen the Literal to cover every flavour EL knows about, and pick the right one per system in `get_backend("ensemble_launcher", system=...)`: `aurora`/`polaris`/ `crux` → `"mpich"`, `local` → `"test"`. An explicit `[execution.ensemble_launcher] mpi_flavour` in `config.toml` still overrides. --- src/chemgraph/execution/config.py | 16 +++++++++++++++- .../execution/ensemble_launcher_backend.py | 11 ++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py index 6be99a58..37c858d7 100644 --- a/src/chemgraph/execution/config.py +++ b/src/chemgraph/execution/config.py @@ -160,9 +160,23 @@ def get_backend( f"Unknown system {resolved_system}: " f"only know {list(SYSTEM_CONFIG_REGISTRY.keys())}" ) + # System-appropriate MPI flavour: multi-node HPC systems need + # mpich/hydra so child-spec JSON actually lands on remote /tmp; + # "test" only works for single-host runs. + launcher_cfg_kwargs = dict(backend_cfg) + if "mpi_flavour" not in launcher_cfg_kwargs: + _system_mpi_flavour = { + "aurora": "mpich", + "polaris": "mpich", + "crux": "mpich", + "local": "test", + } + launcher_cfg_kwargs["mpi_flavour"] = _system_mpi_flavour.get( + resolved_system, "mpich" + ) merged_kwargs = { "system_config": SYSTEM_CONFIG_REGISTRY[resolved_system], - "launcher_config": get_launcher_config(**backend_cfg), + "launcher_config": get_launcher_config(**launcher_cfg_kwargs), } elif resolved_backend == "globus_compute": diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index b1631a9a..1ba3424b 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -102,8 +102,17 @@ def get_launcher_config( child_executor_policy: str = "fixed_leafs_children_policy", policy_config=None, checkpoint_dir=f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}", - mpi_flavour: Literal["test", "mpich"] = "test", + mpi_flavour: Literal[ + "test", "mpich", "intel", "cray-pals", "openmpi", "srun", "aprun", "jsrun" + ] = "mpich", ): + """Build a LauncherConfig. + + ``mpi_flavour`` defaults to ``"mpich"`` (hydra ``mpiexec``) which is the + multi-node-safe choice for Aurora/Polaris/Crux. Use ``"test"`` only for + single-node local runs — its ``write_file_to_nodes`` does not actually + distribute child-spec JSON to remote ``/tmp``. + """ _require_ensemble_launcher() if policy_config is None: policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) From 92108b16a8e25b58089d3d6bc6f09b3ff7863396 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Fri, 12 Jun 2026 20:38:43 +0000 Subject: [PATCH 088/119] Serialize MACE model loads across both threads and processes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The existing `threading.Lock` in `mace_calc._mace_lock` only protected threads inside one process. The EnsembleLauncher `async_processpool` spawns multiple Python workers in parallel on the same node, so sibling processes raced on the torch.load + symbolic_trace path (see issue #110) and tripped the same NameError / hang the lock was introduced to prevent. Add `mace_loading_lock()`, a context manager that holds the existing in-process lock and an `fcntl.flock` on a per-uid lockfile under `$CHEMGRAPH_MACE_LOCK_DIR` → `$TMPDIR` → `tempfile.gettempdir()` → `~/.cache/chemgraph`. Degrades gracefully to thread-only locking where `fcntl` or no writable directory is available. Move the lock acquisition into `load_calculator` so every entry path (Parsl, EnsembleLauncher, local, agent) is covered, not just `ase_tools.run_ase`. Drop the now-redundant `with _mace_lock:` in `run_ase`. --- .../schemas/calculators/mace_calc.py | 78 +++++++++++++++++++ src/chemgraph/tools/ase_core.py | 13 +++- src/chemgraph/tools/ase_tools.py | 7 +- 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/src/chemgraph/schemas/calculators/mace_calc.py b/src/chemgraph/schemas/calculators/mace_calc.py index 2ad50216..712df5c2 100644 --- a/src/chemgraph/schemas/calculators/mace_calc.py +++ b/src/chemgraph/schemas/calculators/mace_calc.py @@ -1,13 +1,19 @@ """MACE foundation models parameters for ChemGraph Reference: https://github.com/ACEsuit/mace/blob/main/mace/calculators/foundations_models.py""" +import functools +import logging import os +import tempfile import threading +from contextlib import contextmanager from pathlib import Path from typing import Optional, Union from pydantic import BaseModel, Field import torch +_logger = logging.getLogger(__name__) + # Process-wide lock for MACE operations. # MACE model deserialization (torch.load) triggers torch.fx.symbolic_trace # inside Contraction.__init__, which temporarily patches @@ -18,6 +24,78 @@ _mace_lock = threading.Lock() +@functools.lru_cache(maxsize=1) +def _mace_lockfile_path() -> Optional[str]: + """Return the path of the per-node MACE init lock file, or ``None`` if + no writable directory is available. Memoised so we only resolve once.""" + candidates = [ + os.environ.get("CHEMGRAPH_MACE_LOCK_DIR"), + os.environ.get("TMPDIR"), + tempfile.gettempdir(), + str(Path.home() / ".cache" / "chemgraph"), + ] + uid = os.getuid() if hasattr(os, "getuid") else "unknown" + for d in candidates: + if not d: + continue + try: + Path(d).mkdir(parents=True, exist_ok=True) + path = str(Path(d) / f"chemgraph_mace_init.{uid}.lock") + # Touch to confirm we can write. + with open(path, "a"): + pass + return path + except OSError: + continue + return None + + +@contextmanager +def mace_loading_lock(): + """Serialize MACE model loads across both threads and processes on one node. + + EnsembleLauncher's ``AsyncProcessPool`` spawns multiple Python workers in + parallel; a per-process :data:`_mace_lock` is not enough because torch's + ``symbolic_trace`` patches ``torch.nn.Module.__call__`` at the class level + during MACE deserialization, and concurrent loads in sibling processes + racing on the same node can deadlock or trip the same NameError that #110 + describes. We add an ``fcntl.flock``-based file lock on top so that + siblings on the same node take turns. + + Degrades to thread-only locking when ``fcntl`` is unavailable (e.g. + Windows) or no writable lock directory exists. + """ + try: + import fcntl + except ImportError: + fcntl = None # type: ignore[assignment] + + path = _mace_lockfile_path() if fcntl is not None else None + fh = None + try: + with _mace_lock: + if path is not None: + fh = open(path, "w") + try: + fcntl.flock(fh.fileno(), fcntl.LOCK_EX) + except OSError as exc: + _logger.warning( + "fcntl.flock on %s failed (%s); proceeding without " + "inter-process MACE serialization.", + path, + exc, + ) + fh.close() + fh = None + yield + finally: + if fh is not None: + try: + fcntl.flock(fh.fileno(), fcntl.LOCK_UN) + finally: + fh.close() + + class MaceCalc(BaseModel): """MACE (Message-passing Atomic and Continuous Environment) calculator configuration. diff --git a/src/chemgraph/tools/ase_core.py b/src/chemgraph/tools/ase_core.py index 4e3dc915..ad54d0c5 100644 --- a/src/chemgraph/tools/ase_core.py +++ b/src/chemgraph/tools/ase_core.py @@ -202,7 +202,18 @@ def load_calculator(calculator: dict) -> tuple[object, dict, object]: if hasattr(calc, "get_atoms_properties"): extra_info = calc.get_atoms_properties() - return calc.get_calculator(), extra_info, calc + if "mace" in calc_type: + # MACE's torch.load + symbolic_trace is unsafe under concurrent loads, + # whether the concurrency is threads in one process or sibling processes + # spawned by the EnsembleLauncher process pool. See mace_calc._mace_lock. + from chemgraph.schemas.calculators.mace_calc import mace_loading_lock + + with mace_loading_lock(): + ase_calculator = calc.get_calculator() + else: + ase_calculator = calc.get_calculator() + + return ase_calculator, extra_info, calc # --------------------------------------------------------------------------- diff --git a/src/chemgraph/tools/ase_tools.py b/src/chemgraph/tools/ase_tools.py index ff4650a3..369c2753 100644 --- a/src/chemgraph/tools/ase_tools.py +++ b/src/chemgraph/tools/ase_tools.py @@ -13,7 +13,6 @@ from chemgraph.schemas.atomsdata import AtomsData from chemgraph.schemas.ase_input import ASEInputSchema -from chemgraph.schemas.calculators.mace_calc import _mace_lock from chemgraph.tools.ase_core import ( _resolve_path, atoms_to_atomsdata, @@ -166,8 +165,6 @@ def run_ase(params: ASEInputSchema) -> dict: ValueError If the calculator is not supported or if the calculation fails """ - calc_type = params.calculator.calculator_type.lower() - if "mace" in calc_type: - with _mace_lock: - return run_ase_core(params) + # MACE thread/process serialization now lives in run_ase_core -> + # load_calculator, so this wrapper just delegates. return run_ase_core(params) From cb3e1970d7186cb1a7af72dab5814938021b03b8 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Fri, 12 Jun 2026 16:12:43 -0500 Subject: [PATCH 089/119] Revert "fix(mcp): isolate MACE worker in a subprocess to dodge Parsl hang on Aurora" This reverts commit cbe6094fbe95b46a095b6f13883e271a98fde306. --- src/chemgraph/mcp/_mace_subrunner.py | 43 --------------- src/chemgraph/mcp/mace_worker.py | 79 +--------------------------- tests/test_mcp.py | 5 +- 3 files changed, 2 insertions(+), 125 deletions(-) delete mode 100644 src/chemgraph/mcp/_mace_subrunner.py diff --git a/src/chemgraph/mcp/_mace_subrunner.py b/src/chemgraph/mcp/_mace_subrunner.py deleted file mode 100644 index 655d0535..00000000 --- a/src/chemgraph/mcp/_mace_subrunner.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Subprocess entry point for one MACE calculation. - -Invoked by :func:`chemgraph.mcp.mace_worker._mace_worker` as a fresh -Python interpreter to dodge a silent worker-hang seen when MACE runs -directly inside a Parsl ``process_worker_pool.py`` worker on Aurora. - -CLI: - python -m chemgraph.mcp._mace_subrunner - -Reads *job.json*, runs the in-proc MACE worker, writes the result dict -to *result.json*. Exits non-zero on any uncaught exception. -""" - -from __future__ import annotations - -import json -import sys - - -def main() -> int: - if len(sys.argv) != 3: - print( - "usage: python -m chemgraph.mcp._mace_subrunner ", - file=sys.stderr, - ) - return 2 - - job_path, result_path = sys.argv[1], sys.argv[2] - with open(job_path, encoding="utf-8") as fh: - job = json.load(fh) - - from chemgraph.mcp.mace_worker import _mace_worker_inproc - - result = _mace_worker_inproc(job) - - with open(result_path, "w", encoding="utf-8") as fh: - json.dump(result, fh, default=str) - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/chemgraph/mcp/mace_worker.py b/src/chemgraph/mcp/mace_worker.py index 6c07e7e0..b4b00b23 100644 --- a/src/chemgraph/mcp/mace_worker.py +++ b/src/chemgraph/mcp/mace_worker.py @@ -4,90 +4,13 @@ decorators. Parsl/dill serializes worker functions by walking their module globals, so backend workers must live outside modules that contain FastMCP's runtime-generated argument classes. - -The public ``_mace_worker`` shells the actual MACE call into a fresh Python -subprocess. Running MACE directly inside a Parsl worker on Aurora hangs -the worker indefinitely (no Python exception, no OS-level kill signal) — -the failure is silent and only happens inside Parsl's -``process_worker_pool.py`` process model. A clean interpreter dodges it, -so we pay the per-call subprocess cost (~3-5s of Python startup + -``module load frameworks`` env) to keep MACE working. """ -import json import os -import subprocess -import sys -import tempfile - -_SUBPROCESS_TIMEOUT_S = 3600 -_SUBRUNNER_MODULE = "chemgraph.mcp._mace_subrunner" def _mace_worker(job: dict) -> dict: - """Run one MACE simulation in an isolated subprocess. - - Materializes *job* to a temp JSON file and invokes - :mod:`chemgraph.mcp._mace_subrunner` in a child interpreter. The child - writes the result back to a sibling JSON file which we read and return. - """ - with tempfile.NamedTemporaryFile( - mode="w", suffix=".job.json", delete=False, encoding="utf-8", - ) as job_fh: - json.dump(job, job_fh) - job_path = job_fh.name - result_path = f"{job_path}.result.json" - - try: - completed = subprocess.run( - [sys.executable, "-m", _SUBRUNNER_MODULE, job_path, result_path], - capture_output=True, - text=True, - timeout=_SUBPROCESS_TIMEOUT_S, - ) - if completed.returncode != 0: - return { - "status": "failure", - "error_type": "SubprocessError", - "message": ( - f"MACE subprocess exited with {completed.returncode}. " - f"stderr tail: {completed.stderr[-500:]}" - ), - } - if not os.path.isfile(result_path): - return { - "status": "failure", - "error_type": "SubprocessError", - "message": ( - "MACE subprocess exited 0 but wrote no result file. " - f"stdout tail: {completed.stdout[-500:]}" - ), - } - with open(result_path, encoding="utf-8") as fh: - return json.load(fh) - except subprocess.TimeoutExpired: - return { - "status": "failure", - "error_type": "TimeoutExpired", - "message": f"MACE subprocess exceeded {_SUBPROCESS_TIMEOUT_S}s", - } - finally: - for path in (job_path, result_path): - try: - os.unlink(path) - except OSError: - pass - - -def _mace_worker_inproc(job: dict) -> dict: - """Execute a single MACE simulation directly in the caller's interpreter. - - This is the historical worker body. It is invoked by - :mod:`chemgraph.mcp._mace_subrunner` inside the subprocess that - :func:`_mace_worker` spawns, and is exposed here so callers that want - direct in-process execution (e.g. local CLI runs without a Parsl - backend) can still reach the same code path. - """ + """Execute a single MACE simulation on a backend worker.""" import json import tempfile diff --git a/tests/test_mcp.py b/tests/test_mcp.py index e8911374..51a42f4d 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -63,10 +63,7 @@ def fake_run_mace_core(params): monkeypatch.setattr(parsl_tools, "run_mace_core", fake_run_mace_core) - # Exercise the in-proc body so we can monkeypatch run_mace_core. The - # public _mace_worker wraps this in a subprocess to dodge a Parsl-on- - # Aurora hang and would not see test monkeypatches. - result = mace_worker._mace_worker_inproc( + result = mace_worker._mace_worker( { "inline_structure": atoms_to_atomsdata(atoms).model_dump(), "output_result_file": output_file, From bef9d870589b086961377d80cb6e9466a581447d Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Fri, 12 Jun 2026 16:12:43 -0500 Subject: [PATCH 090/119] Revert "fix(mcp): isolate hpc backend workers" This reverts commit 83bcd7a60c66c03e7116eea60935b62231c00b3f. --- src/chemgraph/mcp/graspa_mcp_hpc.py | 52 ++++++++++++++++- src/chemgraph/mcp/graspa_worker.py | 54 ------------------ src/chemgraph/mcp/mace_mcp_hpc.py | 74 +++++++++++++++++++++++- src/chemgraph/mcp/mace_worker.py | 72 ----------------------- src/chemgraph/mcp/xanes_mcp_hpc.py | 88 +++++++++++++++++++++++++++-- src/chemgraph/mcp/xanes_worker.py | 80 -------------------------- tests/test_mcp.py | 29 +--------- 7 files changed, 210 insertions(+), 239 deletions(-) delete mode 100644 src/chemgraph/mcp/graspa_worker.py delete mode 100644 src/chemgraph/mcp/mace_worker.py delete mode 100644 src/chemgraph/mcp/xanes_worker.py diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index d1af7343..be7737a6 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -14,6 +14,7 @@ """ import logging +import os from pathlib import Path from chemgraph.execution.base import TaskSpec @@ -23,7 +24,6 @@ resolve_structure_files, ) from chemgraph.mcp.cg_fastmcp import CGFastMCP -from chemgraph.mcp.graspa_worker import _graspa_worker, _ls_remote_files from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble @@ -64,9 +64,59 @@ ) +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _graspa_worker(job: dict) -> dict: + """Execute a single gRASPA simulation on a backend worker.""" + from chemgraph.schemas.graspa_schema import graspa_input_schema + from chemgraph.tools.graspa_tools import run_graspa_core + + job = dict(job) + structure = job.pop("_structure_name", None) + temperature = job.get("temperature") + pressure = job.get("pressure") + + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "raspa.log"), + ) + + params = graspa_input_schema(**job) + result = run_graspa_core(params) + + if isinstance(result, dict): + merged = { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + **result, + } + merged.setdefault("status", "success") + return merged + return { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + "result": result, + "status": "success", + } + + # ── Ensemble fanout ──────────────────────────────────────────────────── +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) + + def _expand_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: """Server-side expansion of an ensemble request into per-job dicts. diff --git a/src/chemgraph/mcp/graspa_worker.py b/src/chemgraph/mcp/graspa_worker.py deleted file mode 100644 index 2e26cd5e..00000000 --- a/src/chemgraph/mcp/graspa_worker.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Backend worker functions for gRASPA MCP tools. - -This module intentionally contains no FastMCP/CGFastMCP objects or tool -decorators, keeping worker functions safe for Parsl/dill serialization. -""" - -import os - - -def _graspa_worker(job: dict) -> dict: - """Execute a single gRASPA simulation on a backend worker.""" - from chemgraph.schemas.graspa_schema import graspa_input_schema - from chemgraph.tools.graspa_tools import run_graspa_core - - job = dict(job) - structure = job.pop("_structure_name", None) - temperature = job.get("temperature") - pressure = job.get("pressure") - - remote_file = job.pop("remote_structure_file", None) - if remote_file is not None: - job["input_structure_file"] = remote_file - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - os.path.dirname(remote_file), - job.get("output_result_file", "raspa.log"), - ) - - params = graspa_input_schema(**job) - result = run_graspa_core(params) - - if isinstance(result, dict): - merged = { - "structure": structure, - "temperature": temperature, - "pressure": pressure, - **result, - } - merged.setdefault("status", "success") - return merged - return { - "structure": structure, - "temperature": temperature, - "pressure": pressure, - "result": result, - "status": "success", - } - - -def _ls_remote_files(path: str) -> list[str]: - """Backend-side helper: list non-directory entries in *path*.""" - return sorted( - f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) - ) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index 8f1c715b..b93dcd4c 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -27,13 +27,12 @@ resolve_structure_files, ) from chemgraph.mcp.cg_fastmcp import CGFastMCP -from chemgraph.mcp.mace_worker import _ls_remote_files, _mace_worker from chemgraph.mcp.transfer_tools import register_transfer_tools from chemgraph.schemas.mace_parsl_schema import ( mace_input_schema, mace_input_schema_ensemble, ) -from chemgraph.tools.parsl_tools import extract_output_json +from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core logger = logging.getLogger(__name__) @@ -74,6 +73,70 @@ ) +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _mace_worker(job: dict) -> dict: + """Execute a single MACE simulation on a backend worker. + + Accepts a *job dict* (not the schema) so the pre-submit hook can + attach transport keys ``inline_structure`` / ``remote_structure_file`` + before submission. + """ + import json + import tempfile + + job = dict(job) + + # Pre-staged remote file: use the path directly on the worker FS. + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "output.json"), + ) + + # Inline structure: materialise on the worker's filesystem. + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, job.get("output_result_file", "output.json") + ) + + output_file = job.get("output_result_file") + if output_file: + os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) + + params = mace_input_schema(**job) + result = run_mace_core(params) + + # When inline, embed full output so the caller doesn't need to read + # a file on the remote filesystem to recover the results. + if inline is not None and isinstance(result, dict): + out_file = job.get("output_result_file", "") + if os.path.isfile(out_file): + with open(out_file) as fh: + result["full_output"] = json.load(fh) + + return result + + # ── Pre-submit transport hook ────────────────────────────────────────── @@ -151,6 +214,13 @@ def run_mace_single(params: mace_input_schema) -> dict: # ── Ensemble fanout ──────────────────────────────────────────────────── +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) + + def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: """Server-side expansion of an ensemble request into per-file jobs. diff --git a/src/chemgraph/mcp/mace_worker.py b/src/chemgraph/mcp/mace_worker.py deleted file mode 100644 index b4b00b23..00000000 --- a/src/chemgraph/mcp/mace_worker.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Backend worker functions for MACE MCP tools. - -This module intentionally contains no FastMCP/CGFastMCP objects or tool -decorators. Parsl/dill serializes worker functions by walking their module -globals, so backend workers must live outside modules that contain FastMCP's -runtime-generated argument classes. -""" - -import os - - -def _mace_worker(job: dict) -> dict: - """Execute a single MACE simulation on a backend worker.""" - import json - import tempfile - - from chemgraph.schemas.mace_parsl_schema import mace_input_schema - from chemgraph.tools.parsl_tools import run_mace_core - - job = dict(job) - - remote_file = job.pop("remote_structure_file", None) - if remote_file is not None: - job["input_structure_file"] = remote_file - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - os.path.dirname(remote_file), - job.get("output_result_file", "output.json"), - ) - - inline = job.pop("inline_structure", None) - if inline is not None: - from ase import Atoms - from ase.io import write as ase_write - - atoms = Atoms( - numbers=inline["numbers"], - positions=inline["positions"], - cell=inline.get("cell"), - pbc=inline.get("pbc"), - ) - tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") - xyz_path = os.path.join(tmpdir, "structure.xyz") - ase_write(xyz_path, atoms) - job["input_structure_file"] = xyz_path - if not os.path.isabs(job.get("output_result_file", "")): - job["output_result_file"] = os.path.join( - tmpdir, - job.get("output_result_file", "output.json"), - ) - - output_file = job.get("output_result_file") - if output_file: - os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) - - params = mace_input_schema(**job) - result = run_mace_core(params) - - if inline is not None and isinstance(result, dict): - out_file = job.get("output_result_file", "") - if os.path.isfile(out_file): - with open(out_file, encoding="utf-8") as fh: - result["full_output"] = json.load(fh) - - return result - - -def _ls_remote_files(path: str) -> list[str]: - """Backend-side helper: list non-directory entries in *path*.""" - return sorted( - f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) - ) diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index c9394c31..8583ae65 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -17,15 +17,16 @@ """ import logging +import subprocess from pathlib import Path from chemgraph.execution.config import get_transfer_manager from chemgraph.execution.utils import resolve_structure_files from chemgraph.mcp.cg_fastmcp import CGFastMCP from chemgraph.mcp.transfer_tools import register_transfer_tools -from chemgraph.mcp.xanes_worker import _xanes_ensemble_worker, run_xanes_single from chemgraph.schemas.xanes_schema import ( mp_query_schema, + xanes_input_schema, xanes_input_schema_ensemble, ) @@ -65,17 +66,96 @@ ) -mcp.tool( +# ── Single-structure tool ────────────────────────────────────────────── + + +def _xanes_single_worker(params: xanes_input_schema) -> dict: + """Run a single FDMNES calculation on a backend worker.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + result = run_xanes_core(params) + if isinstance(result, dict): + result.setdefault("status", "success") + return result + return {"status": "success", "result": result} + + +@mcp.tool( name="run_xanes_single", description="Run a single XANES/FDMNES calculation for one input structure.", -)( - run_xanes_single ) +def run_xanes_single(params: xanes_input_schema): + """Run a single FDMNES calculation using the core engine. + + The CGFastMCP wrapper submits this call to the configured backend; + the body is the direct-call fallback when no backend is active. + """ + return _xanes_single_worker(params) # ── Ensemble fanout ──────────────────────────────────────────────────── +def _xanes_ensemble_worker(item: dict) -> dict: + """Execute one prepared FDMNES run on the backend. + + The expander has already written ``input_fdmnes.txt`` (or the + equivalent) into ``item['run_dir']``; this worker runs the binary + via subprocess and then extracts convergence data. + """ + from chemgraph.tools.xanes_tools import extract_conv + + run_dir = item["run_dir"] + fdmnes_exe = item["fdmnes_exe"] + meta = { + "structure": item.get("structure"), + "run_dir": run_dir, + "z_absorber": item.get("z_absorber"), + } + + stdout_path = Path(run_dir) / "fdmnes_stdout.txt" + stderr_path = Path(run_dir) / "fdmnes_stderr.txt" + try: + with open(stdout_path, "w") as out, open(stderr_path, "w") as err: + proc = subprocess.run( + [fdmnes_exe], + cwd=run_dir, + stdout=out, + stderr=err, + check=False, + ) + if proc.returncode != 0: + return { + **meta, + "status": "failure", + "error_type": "FDMNESExitCode", + "message": f"FDMNES exited with code {proc.returncode}", + "returncode": proc.returncode, + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"FDMNES launch failed: {e}", + } + + try: + conv_data = extract_conv(run_dir) + return { + **meta, + "status": "success", + "n_conv_files": len(conv_data), + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"Post-processing failed: {e}", + } + + def _expand_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: """Server-side expansion: prepare per-structure run dirs and return one item per structure for the worker to execute.""" diff --git a/src/chemgraph/mcp/xanes_worker.py b/src/chemgraph/mcp/xanes_worker.py deleted file mode 100644 index ce15cd6a..00000000 --- a/src/chemgraph/mcp/xanes_worker.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Backend worker functions for XANES MCP tools. - -This module intentionally contains no FastMCP/CGFastMCP objects or tool -decorators, keeping worker functions safe for Parsl/dill serialization. -""" - -import subprocess -from pathlib import Path - -from chemgraph.schemas.xanes_schema import xanes_input_schema - - -def run_xanes_single(params: xanes_input_schema) -> dict: - """Run a single FDMNES calculation on a backend worker.""" - from chemgraph.tools.xanes_tools import run_xanes_core - - result = run_xanes_core(params) - if isinstance(result, dict): - result.setdefault("status", "success") - return result - return {"status": "success", "result": result} - - -def _xanes_ensemble_worker(item: dict) -> dict: - """Execute one prepared FDMNES run on the backend.""" - from chemgraph.tools.xanes_tools import extract_conv - - run_dir = item["run_dir"] - fdmnes_exe = item["fdmnes_exe"] - meta = { - "structure": item.get("structure"), - "run_dir": run_dir, - "z_absorber": item.get("z_absorber"), - } - - stdout_path = Path(run_dir) / "fdmnes_stdout.txt" - stderr_path = Path(run_dir) / "fdmnes_stderr.txt" - try: - with open(stdout_path, "w", encoding="utf-8") as out, open( - stderr_path, - "w", - encoding="utf-8", - ) as err: - proc = subprocess.run( - [fdmnes_exe], - cwd=run_dir, - stdout=out, - stderr=err, - check=False, - ) - if proc.returncode != 0: - return { - **meta, - "status": "failure", - "error_type": "FDMNESExitCode", - "message": f"FDMNES exited with code {proc.returncode}", - "returncode": proc.returncode, - } - except Exception as e: - return { - **meta, - "status": "failure", - "error_type": type(e).__name__, - "message": f"FDMNES launch failed: {e}", - } - - try: - conv_data = extract_conv(run_dir) - return { - **meta, - "status": "success", - "n_conv_files": len(conv_data), - } - except Exception as e: - return { - **meta, - "status": "failure", - "error_type": type(e).__name__, - "message": f"Post-processing failed: {e}", - } diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 51a42f4d..b4615871 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -48,8 +48,7 @@ def fanout(params: dict) -> list[dict]: def test_mace_worker_creates_inline_output_parent(monkeypatch): from ase import Atoms - from chemgraph.mcp import mace_worker - from chemgraph.tools import parsl_tools + from chemgraph.mcp import mace_mcp_hpc from chemgraph.tools.ase_core import atoms_to_atomsdata atoms = Atoms(numbers=[1, 1], positions=[[0, 0, 0], [0, 0, 0.74]]) @@ -61,9 +60,9 @@ def fake_run_mace_core(params): output_path.write_text('{"ok": true}', encoding="utf-8") return {"status": "success"} - monkeypatch.setattr(parsl_tools, "run_mace_core", fake_run_mace_core) + monkeypatch.setattr(mace_mcp_hpc, "run_mace_core", fake_run_mace_core) - result = mace_worker._mace_worker( + result = mace_mcp_hpc._mace_worker( { "inline_structure": atoms_to_atomsdata(atoms).model_dump(), "output_result_file": output_file, @@ -77,28 +76,6 @@ def fake_run_mace_core(params): assert result["full_output"] == {"ok": True} -def test_hpc_worker_functions_are_dill_picklable(): - dill = pytest.importorskip("dill") - - from chemgraph.mcp.graspa_worker import ( - _graspa_worker, - _ls_remote_files as _graspa_ls_remote_files, - ) - from chemgraph.mcp.mace_worker import _ls_remote_files as _mace_ls_remote_files - from chemgraph.mcp.mace_worker import _mace_worker - from chemgraph.mcp.xanes_worker import _xanes_ensemble_worker, run_xanes_single - - for worker in ( - _mace_worker, - _mace_ls_remote_files, - run_xanes_single, - _xanes_ensemble_worker, - _graspa_worker, - _graspa_ls_remote_files, - ): - dill.dumps(worker) - - @pytest.mark.asyncio async def test_split_cif_dataset(tmp_path): """Test splitting a dataset of CIF files.""" From ca1019cc30de27944f8a85d410e9eb23497ffecc Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Fri, 12 Jun 2026 16:17:58 -0500 Subject: [PATCH 091/119] docs(example-002): switch MACE path to in-process run_ase The HPC MACE path (chemgraph.mcp.mace_mcp_hpc + ParslBackend) is being reworked in a parallel PR. Until that lands and the WorkerLost subprocess isolation can be folded back in, the example uses the general run_ase tool from chemgraph.mcp.mcp_tools, which runs MACE in-process inside the mace-agent MCP server. Campaign config changes: - mcp_servers: drop mace_mcp_hpc and hpc_misc_mcp, keep general only - mace-agent: server list [mace, hpc_misc] -> [general]; mission text now instructs the agent to call run_ase per XYZ with mace_mp - assessment-agent: server list [hpc_misc] -> [] (no MCP tools needed) - resources: replace mace_output_result_file + mace_model_file with a single mace_output_directory (mace_mp auto-downloads the model) e2e guide changes: - drop the MACE-model-staging step - drop CHEMGRAPH_EXECUTION_BACKEND/COMPUTE_SYSTEM env exports - drop Parsl-specific troubleshooting blocks - install command no longer requires the parsl extra - add a top-of-file note explaining the in-process MACE path --- .../e2e_guide.md | 79 ++++++------------- .../campaign.json | 43 ++++------ 2 files changed, 39 insertions(+), 83 deletions(-) diff --git a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md index 2f4a4de5..44c351f6 100644 --- a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md +++ b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md @@ -15,6 +15,19 @@ The coordinator delegates 20 SMILES candidates, structure agents generate XYZ files, the MACE agent runs an ensemble energy screen, and the assessment agent summarizes readiness/ranking evidence. +## About The MACE Path + +This example deliberately runs MACE through the general `run_ase` tool +(`chemgraph.mcp.mcp_tools`), which executes MACE in-process inside the MCP +server. It does **not** exercise `chemgraph.mcp.mace_mcp_hpc` or the +Parsl/EnsembleLauncher/Globus Compute backends — those are being reworked in +a separate PR. Once that lands and the WorkerLost subprocess fix is folded +back in, this example can be switched back to the HPC MACE path. + +In-process MACE means each per-structure energy evaluation runs synchronously +in the mace-agent's MCP server process. A 20-structure screen completes in +a few minutes on CPU. + ## Configure Paths Set these values in each terminal before copying the commands below: @@ -80,7 +93,7 @@ module load frameworks # conda activate base source "$REMOTE_ROOT/venvs/academy-swarm/bin/activate" -python -m pip install -e ".[academy,parsl]" +python -m pip install -e ".[academy]" ``` Verify the campaign is visible: @@ -116,21 +129,15 @@ make -j4 make PREFIX="$REMOTE_ROOT/tools/redis" install ``` -Stage the MACE model: +The `mace_mp` calculator downloads its foundation model on first use into +`~/.cache/mace`, so no manual MACE-model staging is needed for this example. +First-call download can take a minute; pre-warm it once on the compute node +if you want to skip that wait at run time: ```bash -cd "$LOCAL_CHEMGRAPH" - -MODEL=src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/models/mace-mpa-0-medium.model -URL=https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model - -mkdir -p "$(dirname "$MODEL")" -test -f "$MODEL" || curl -L --fail -o "$MODEL" "$URL" -ls -lh "$MODEL" +python -c "from mace.calculators import mace_mp; mace_mp(model='medium-mpa-0', device='cpu')" ``` -Then sync ChemGraph again. - ## Start argo-shim On the local machine: @@ -185,16 +192,8 @@ export NUMEXPR_NUM_THREADS=64 export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 -export CHEMGRAPH_EXECUTION_BACKEND=parsl -export COMPUTE_SYSTEM="$ALCF_SYSTEM" - export PATH="$REMOTE_ROOT/bin:$REMOTE_ROOT/tools/redis/bin:$PATH" -: "${CHEMGRAPH_EXECUTION_BACKEND:?must be set to 'parsl' before launch}" -: "${COMPUTE_SYSTEM:?must be set to aurora or polaris before launch}" -echo "execution backend = $CHEMGRAPH_EXECUTION_BACKEND" -echo "compute system = $COMPUTE_SYSTEM" - chemgraph academy run-compute \ --system "$ALCF_SYSTEM" \ --run-id "$RUN_ID" \ @@ -202,13 +201,6 @@ chemgraph academy run-compute \ --lm-user "$ARGO_USER" ``` -If you reconnect to the login/compute node and re-run only the final -`chemgraph academy run-compute` invocation, the env exports above will not be -in your shell. Re-run the full block, or re-export both variables, before -relaunching. If `CHEMGRAPH_EXECUTION_BACKEND` is unset, the MCP server can fall -back to LocalBackend and produce `BrokenProcessPool` failures under per-rank -memory pressure. - If the wrapper is installed but `chemgraph` is not on `PATH`, use: ```bash @@ -308,32 +300,7 @@ export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 ``` -If MACE results come back as `PicklingError: Can't pickle -run_mace_singleArguments`, the remote ChemGraph checkout does not have the -worker-module fix synced. Sync the latest ChemGraph checkout to the ALCF -filesystem, restart the dashboard with a fresh run id, and rerun from a fresh -compute allocation. - -If MACE results come back as `BrokenProcessPool` failures, confirm the MACE MCP -server initialized Parsl: - -```bash -grep "backend initialised" \ - "$REMOTE_ROOT/runs/$RUN_ID/rank3/mcp_logs/mace.log" -``` - -Expected: - -```text -CGFastMCP backend initialised: ParslBackend -``` - -If the log shows `LocalBackend initialized with 4 workers`, re-run the full -compute block with `CHEMGRAPH_EXECUTION_BACKEND=parsl`. - -If the log shows `Parsl is required for the ParslBackend`, the Parsl package is -missing from the venv: - -```bash -python -m pip install -e ".[academy,parsl]" -``` +If MACE energy evaluations are slow, the first call per worker pays a +one-time foundation-model download into `~/.cache/mace`. Pre-warm by +running the snippet under "About The MACE Path" above on the compute node +before launching the campaign. diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json index 23fb9362..068e110f 100644 --- a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json +++ b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json @@ -1,7 +1,7 @@ { // Campaign files support JSONC-style comments. "run_id": "mace-ensemble-screening-20", - "user_task": "Given 20 staged SMILES candidates, generate 3D XYZ structures, run a MACE ensemble energy screen over generated structures, and rank candidates by calculation readiness and available MACE evidence.", + "user_task": "Given 20 staged SMILES candidates, generate 3D XYZ structures, run a per-structure MACE energy calculation through the run_ase tool, and rank candidates by calculation readiness and available MACE evidence.", "prompt_profile": "prompt_profiles/default.json", "initial_agent": "coordinator-agent", "resources": { @@ -26,40 +26,30 @@ "scope": "shared_run", "description": "Shared run directory where generated XYZ coordinate files should be written." }, - "mace_output_result_file": { - "kind": "file", - "path": "academy_mace_outputs/mace_results.json", + "mace_output_directory": { + "kind": "directory", + "path": "academy_mace_outputs", "scope": "shared_run", - "description": "Shared run file requested for the MACE ensemble result summary." - }, - "mace_model_file": { - "kind": "file", - "path": "models/mace-mpa-0-medium.model", - "scope": "campaign_file", - "description": "Local MACE model file shipped with this campaign example." + "description": "Shared run directory where mace-agent should write one JSON result file per structure (e.g. academy_mace_outputs/.json)." } }, "mcp_servers": [ // MCP server fields: // command: launch command; runtime appends --transport/--host/--port. + // The HPC-specific servers (mace_mcp_hpc, hpc_misc_mcp) are intentionally + // omitted here because they go through chemgraph.execution.ParslBackend, + // which is being reworked in a separate PR. This example exercises the + // in-process MACE path through the general ``run_ase`` tool instead. { "name": "general", "command": "python -m chemgraph.mcp.mcp_tools" - }, - { - "name": "mace", - "command": "python -m chemgraph.mcp.mace_mcp_hpc" - }, - { - "name": "hpc_misc", - "command": "python -m chemgraph.mcp.hpc_misc_mcp" } ], "agents": [ { "name": "coordinator-agent", "role": "MACEReadinessCoordinatorAgent", - "mission": "Coordinate the campaign from the bootstrap task. Send odd-numbered MOL candidates to structure-agent-a and even-numbered MOL candidates to structure-agent-b, including candidate_id, label, SMILES, and output_file. After structure evidence returns, ask mace-agent to run one MACE energy calculation on CPU using the provided structure directory, output result path, and local model file resource, then ask assessment-agent for readiness/ranking evidence before submitting the final result.", + "mission": "Coordinate the campaign from the bootstrap task. Send odd-numbered MOL candidates to structure-agent-a and even-numbered MOL candidates to structure-agent-b, including candidate_id, label, SMILES, and output_file. After structure evidence returns, ask mace-agent to run one MACE energy calculation per generated XYZ file using the run_ase tool with the mace_mp calculator on CPU, then ask assessment-agent for readiness/ranking evidence before submitting the final result.", "allowed_peers": [ "structure-agent-a", "structure-agent-b", @@ -70,8 +60,7 @@ "resources": [ "candidate_dataset", "structure_output_directory", - "mace_output_result_file", - "mace_model_file" + "mace_output_directory" ] }, { @@ -92,18 +81,18 @@ }, { "name": "mace-agent", - "role": "MACEEnsembleAgent", - "mission": "Run MACE only after a concrete request from coordinator-agent. Report started, completed, partial, or failed evidence back to coordinator-agent, including output paths and tool_result_ids; pending work is not a failure.", + "role": "MACEEnergyAgent", + "mission": "Run MACE only after a concrete request from coordinator-agent. For each assigned XYZ file, call the run_ase tool with driver='energy', a calculator block of {'calculator_type': 'mace_mp', 'model': 'medium-mpa-0', 'device': 'cpu'}, the input_structure_file pointing at the XYZ, and an output_results_file under the requested output directory. Report started, completed, partial, or failed evidence back to coordinator-agent, including output paths and tool_result_ids; pending work is not a failure.", "allowed_peers": ["coordinator-agent"], - "mcp_servers": ["mace", "hpc_misc"], - "resources": ["mace_model_file"] + "mcp_servers": ["general"], + "resources": ["mace_output_directory"] }, { "name": "assessment-agent", "role": "ScreeningAssessmentAgent", "mission": "Assess evidence received from coordinator-agent. Summarize structure coverage, MACE coverage, failures, ranking readiness, and pending work without treating pending MACE work as failure.", "allowed_peers": ["coordinator-agent"], - "mcp_servers": ["hpc_misc"], + "mcp_servers": [], "resources": [] } ] From 8043fe970a86ab63b3e8ac07f31c8b4bd3610737 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Sat, 13 Jun 2026 12:12:09 -0500 Subject: [PATCH 092/119] fix static sync location --- src/chemgraph/execution/ensemble_launcher_backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 1ba3424b..20336f3d 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -101,7 +101,7 @@ def get_launcher_config( task_executor_name: Union[str, List] = "async_processpool", child_executor_policy: str = "fixed_leafs_children_policy", policy_config=None, - checkpoint_dir=f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}", + checkpoint_dir=None, mpi_flavour: Literal[ "test", "mpich", "intel", "cray-pals", "openmpi", "srun", "aprun", "jsrun" ] = "mpich", @@ -116,6 +116,8 @@ def get_launcher_config( _require_ensemble_launcher() if policy_config is None: policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) + if checkpoint_dir is None: + checkpoint_dir = f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}" return LauncherConfig( child_executor_name="async_mpi", task_executor_name=task_executor_name, @@ -192,9 +194,7 @@ def initialize( "client_only=True requires a checkpoint_dir pointing " "to a running orchestrator." ) - self._client = ClusterClient( - checkpoint_dir=checkpoint_dir, node_id=node_id - ) + self._client = ClusterClient(checkpoint_dir=checkpoint_dir, node_id=node_id) self._client.start() self._initialized = True logger.info( From 317204ba9bfef4bc4a0be5bf0b7e7146b744509b Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Mon, 15 Jun 2026 09:51:19 -0500 Subject: [PATCH 093/119] chore(campaigns): rename campaign.json to campaign.jsonc to match content The file has always contained JSONC-style // comments and is loaded via _load_jsonc in chemgraph.academy.core.campaign. The .json extension was making IDEs flag the comments as parse errors. Rename to .jsonc so the extension matches the content; the package-data glob in pyproject.toml already includes *.jsonc, so the package install is unaffected. Also updates the manifest in chemgraph.academy.campaigns.__init__, the example-002 notes, and tmpfile names in three academy tests for naming consistency. Verified: - chemgraph.academy.campaigns.resolve_campaign returns the new path. - load_campaign reads it cleanly (5 agents, 1 mcp server). - tests/test_academy_campaign.py, test_academy_compute_launcher.py, test_academy_exchange_registration.py: 17 passed. --- .../example-002-mace-ensemble-screening/notes.md | 2 +- src/chemgraph/academy/campaigns/__init__.py | 2 +- .../{campaign.json => campaign.jsonc} | 0 tests/test_academy_campaign.py | 10 +++++----- tests/test_academy_compute_launcher.py | 2 +- tests/test_academy_exchange_registration.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) rename src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/{campaign.json => campaign.jsonc} (100%) diff --git a/examples/academy/example-002-mace-ensemble-screening/notes.md b/examples/academy/example-002-mace-ensemble-screening/notes.md index bff829a5..87c98e3d 100644 --- a/examples/academy/example-002-mace-ensemble-screening/notes.md +++ b/examples/academy/example-002-mace-ensemble-screening/notes.md @@ -9,7 +9,7 @@ Packaged assets: ```text src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/ - campaign.json + campaign.jsonc lm_config.json prompt_profiles/ data/ diff --git a/src/chemgraph/academy/campaigns/__init__.py b/src/chemgraph/academy/campaigns/__init__.py index 2b01a733..8c3f5cd6 100644 --- a/src/chemgraph/academy/campaigns/__init__.py +++ b/src/chemgraph/academy/campaigns/__init__.py @@ -8,7 +8,7 @@ EXAMPLE_002 = 'example-002-mace-ensemble-screening' CAMPAIGNS = { - 'mace-ensemble-screening-20': f'{EXAMPLE_002}/campaign.json', + 'mace-ensemble-screening-20': f'{EXAMPLE_002}/campaign.jsonc', } LM_CONFIG_TEMPLATES = { diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc similarity index 100% rename from src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.json rename to src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc diff --git a/tests/test_academy_campaign.py b/tests/test_academy_campaign.py index 51c69bed..7588c79a 100644 --- a/tests/test_academy_campaign.py +++ b/tests/test_academy_campaign.py @@ -41,7 +41,7 @@ def test_builtin_mace_campaign_uses_star_coordinator_without_routing_policy() -> def test_removed_structured_orchestration_fields_are_rejected(tmp_path) -> None: - campaign_path = tmp_path / "campaign.json" + campaign_path = tmp_path / "campaign.jsonc" campaign_path.write_text( json.dumps( { @@ -70,7 +70,7 @@ def test_removed_structured_orchestration_fields_are_rejected(tmp_path) -> None: def test_campaign_loader_accepts_jsonc_comments(tmp_path) -> None: - campaign_path = tmp_path / "campaign.json" + campaign_path = tmp_path / "campaign.jsonc" campaign_path.write_text( """ { @@ -132,7 +132,7 @@ def test_mcp_server_spec_validation() -> None: def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: - campaign_path = tmp_path / "campaign.json" + campaign_path = tmp_path / "campaign.jsonc" campaign_path.write_text( json.dumps( { @@ -166,7 +166,7 @@ def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: def test_validate_campaign_rejects_unknown_mcp_server(tmp_path) -> None: - campaign_path = tmp_path / "campaign.json" + campaign_path = tmp_path / "campaign.jsonc" campaign_path.write_text( json.dumps( { @@ -194,7 +194,7 @@ def test_validate_campaign_rejects_unknown_mcp_server(tmp_path) -> None: def test_validate_campaign_rejects_duplicate_mcp_server_names(tmp_path) -> None: - campaign_path = tmp_path / "campaign.json" + campaign_path = tmp_path / "campaign.jsonc" campaign_path.write_text( json.dumps( { diff --git a/tests/test_academy_compute_launcher.py b/tests/test_academy_compute_launcher.py index abce39cd..20b04ea8 100644 --- a/tests/test_academy_compute_launcher.py +++ b/tests/test_academy_compute_launcher.py @@ -8,7 +8,7 @@ def _plan(tmp_path: Path) -> AllocationPlan: lm_config = tmp_path / "lm.json" - campaign = tmp_path / "campaign.json" + campaign = tmp_path / "campaign.jsonc" lm_config.write_text("{}\n", encoding="utf-8") campaign.write_text("{}\n", encoding="utf-8") return AllocationPlan( diff --git a/tests/test_academy_exchange_registration.py b/tests/test_academy_exchange_registration.py index 4780cc21..39aa1509 100644 --- a/tests/test_academy_exchange_registration.py +++ b/tests/test_academy_exchange_registration.py @@ -20,7 +20,7 @@ def _config(tmp_path: Path, exchange_type: str) -> ChemGraphDaemonConfig: run_dir=tmp_path, run_token='token-1', agent_count=1, - campaign_config=tmp_path / 'campaign.json', + campaign_config=tmp_path / 'campaign.jsonc', lm_config=tmp_path / 'lm.json', max_decisions=1, poll_timeout_s=1.0, From 4e6556d49bb5b53550f331013c47e9fa28b01dc8 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Mon, 15 Jun 2026 10:28:27 -0500 Subject: [PATCH 094/119] feat(academy): per-agent allowed_tools whitelist on top of mcp_servers Adds an optional `allowed_tools` field to ChemGraphAgentSpec that filters the tools an agent sees from its declared MCP servers. Empty (the default) keeps todays behavior of exposing every tool the agents servers advertise. Non-empty restricts the agent to the named tools. Why: the MCP-server-per-agent contract gates capability at the server level only. An agent declaring `mcp_servers: ["general"]` sees every tool that server exposes, even when only one or two are relevant to its mission. That weakens per-agent sandboxing (mission prompt becomes the enforcement) and bloats the LangChain tool catalog the LLM has to choose from. Changes: - ChemGraphAgentSpec gains `allowed_tools: tuple[str, ...] = ()`. - Validator rejects duplicate entries and the case where allowed_tools is non-empty but mcp_servers is empty. - MCPServerSupervisor.get_tools accepts allowed_tools: frozenset|None; when set, tools whose name is not in the whitelist are skipped, and whitelist entries that match nothing log a warning (so typos surface without failing the run). - daemon.run threads agent_spec.allowed_tools through. - example-002 campaign demonstrates the field: structure agents see only the SMILES tools, mace-agent sees only run_ase + extract_output_json. - 4 new tests in test_academy_campaign.py for parse + validation. - 4 new tests in test_academy_mcp_supervisor.py for filter behavior. --- .../notes.md | 7 + .../campaign.jsonc | 3 + src/chemgraph/academy/core/campaign.py | 24 ++++ src/chemgraph/academy/runtime/daemon.py | 7 +- .../academy/runtime/mcp_supervisor.py | 28 ++++ tests/test_academy_campaign.py | 129 +++++++++++++++++ tests/test_academy_mcp_supervisor.py | 131 ++++++++++++++++++ 7 files changed, 328 insertions(+), 1 deletion(-) diff --git a/examples/academy/example-002-mace-ensemble-screening/notes.md b/examples/academy/example-002-mace-ensemble-screening/notes.md index 87c98e3d..4bd81cf6 100644 --- a/examples/academy/example-002-mace-ensemble-screening/notes.md +++ b/examples/academy/example-002-mace-ensemble-screening/notes.md @@ -20,3 +20,10 @@ The campaign declares MCP server subprocesses for general ChemGraph tools, MACE screening, and HPC utility inspection. The Academy runtime places one logical agent per MPI rank, launches the declared MCP servers for each agent, and uses Academy exchange handles for peer communication. + +Each agent's `allowed_tools` field acts as a per-agent whitelist drawn from +the union of the tools its `mcp_servers` advertise. In this example the +structure agents see only `molecule_name_to_smiles` + `smiles_to_coordinate_file`, +and the mace-agent sees only `run_ase` + `extract_output_json` — even though +all four come from the same `general` MCP server. Omit `allowed_tools` (or set +it to `[]`) to expose every tool the connected servers advertise. diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc index 068e110f..d0b8640f 100644 --- a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc +++ b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc @@ -69,6 +69,7 @@ "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", "allowed_peers": ["coordinator-agent"], "mcp_servers": ["general"], + "allowed_tools": ["molecule_name_to_smiles", "smiles_to_coordinate_file"], "resources": [] }, { @@ -77,6 +78,7 @@ "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", "allowed_peers": ["coordinator-agent"], "mcp_servers": ["general"], + "allowed_tools": ["molecule_name_to_smiles", "smiles_to_coordinate_file"], "resources": [] }, { @@ -85,6 +87,7 @@ "mission": "Run MACE only after a concrete request from coordinator-agent. For each assigned XYZ file, call the run_ase tool with driver='energy', a calculator block of {'calculator_type': 'mace_mp', 'model': 'medium-mpa-0', 'device': 'cpu'}, the input_structure_file pointing at the XYZ, and an output_results_file under the requested output directory. Report started, completed, partial, or failed evidence back to coordinator-agent, including output paths and tool_result_ids; pending work is not a failure.", "allowed_peers": ["coordinator-agent"], "mcp_servers": ["general"], + "allowed_tools": ["run_ase", "extract_output_json"], "resources": ["mace_output_directory"] }, { diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py index d2fcd14c..3b98fe71 100644 --- a/src/chemgraph/academy/core/campaign.py +++ b/src/chemgraph/academy/core/campaign.py @@ -106,6 +106,19 @@ class ChemGraphAgentSpec: mission: str allowed_peers: tuple[str, ...] mcp_servers: tuple[str, ...] = () + allowed_tools: tuple[str, ...] = () + """Optional per-agent whitelist of MCP tool names. + + Empty (the default) means the agent sees every tool advertised by the + servers listed in :attr:`mcp_servers`. When non-empty, only tools whose + name appears in this tuple are exposed to the agent; everything else + that the servers advertise is filtered out before reaching LangChain. + + The whitelist is flat and server-agnostic: a name matches any tool with + that name across the agent's connected servers. Duplicate tool names + across an agent's servers are still rejected by the supervisor (today's + behavior), so the whitelist does not introduce new ambiguity. + """ resources: tuple[str, ...] = () @@ -243,6 +256,7 @@ def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: mission=item['mission'], allowed_peers=tuple(item.get('allowed_peers', ())), mcp_servers=tuple(item.get('mcp_servers', ())), + allowed_tools=tuple(item.get('allowed_tools', ())), resources=tuple(item.get('resources', ())), ), ) @@ -378,6 +392,16 @@ def validate_campaign(campaign: ChemGraphCampaign, agent_count: int) -> None: raise RuntimeError( f'{agent.name} references unknown MCP servers: {unknown_servers}', ) + if agent.allowed_tools: + if len(set(agent.allowed_tools)) != len(agent.allowed_tools): + raise RuntimeError( + f'{agent.name} has duplicate allowed_tools entries', + ) + if not agent.mcp_servers: + raise RuntimeError( + f'{agent.name} declares allowed_tools but no mcp_servers ' + 'to draw them from', + ) unknown_resources = sorted(set(agent.resources).difference(campaign.resources)) if unknown_resources: raise RuntimeError( diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py index cff65bd8..e6cb05b8 100644 --- a/src/chemgraph/academy/runtime/daemon.py +++ b/src/chemgraph/academy/runtime/daemon.py @@ -59,7 +59,12 @@ async def run_daemon(config: ChemGraphDaemonConfig) -> int: try: await supervisor.start_all() - external_tools = await supervisor.get_tools(agent_spec.mcp_servers) + external_tools = await supervisor.get_tools( + agent_spec.mcp_servers, + allowed_tools=frozenset(agent_spec.allowed_tools) + if agent_spec.allowed_tools + else None, + ) academy_factory = build_exchange_factory(config) if config.rank == 0: diff --git a/src/chemgraph/academy/runtime/mcp_supervisor.py b/src/chemgraph/academy/runtime/mcp_supervisor.py index ea83846a..565e0fbe 100644 --- a/src/chemgraph/academy/runtime/mcp_supervisor.py +++ b/src/chemgraph/academy/runtime/mcp_supervisor.py @@ -90,7 +90,20 @@ async def start_all(self) -> dict[str, str]: async def get_tools( self, server_names: tuple[str, ...] | None = None, + allowed_tools: frozenset[str] | None = None, ) -> list[BaseTool]: + """Return LangChain tools advertised by the requested MCP servers. + + Parameters + ---------- + server_names + Optional subset of supervised servers to query. Defaults to all. + allowed_tools + Optional per-agent tool-name whitelist. When provided, tools + advertised by the connected servers but whose name is not in the + set are filtered out. When ``None`` (or empty), every tool the + servers advertise is returned (legacy behavior). + """ if not self._urls: return [] wanted = tuple(server_names) if server_names else tuple(self._urls) @@ -100,12 +113,14 @@ async def get_tools( f"agent requested unknown MCP servers: {unknown}; " f"available: {sorted(self._urls)}", ) + whitelist = frozenset(allowed_tools) if allowed_tools else None connections = { name: self._urls[name] for name in wanted } tools: list[BaseTool] = [] tool_names: set[str] = set() + matched_whitelist: set[str] = set() for server_name, url in connections.items(): async with streamablehttp_client(url) as (read, write, _): async with ClientSession(read, write) as session: @@ -118,6 +133,10 @@ async def get_tools( f"from server {server_name!r}", ) tool_names.add(mcp_tool.name) + if whitelist is not None: + if mcp_tool.name not in whitelist: + continue + matched_whitelist.add(mcp_tool.name) tools.append( _langchain_tool( server_name=server_name, @@ -128,6 +147,15 @@ async def get_tools( input_schema=mcp_tool.inputSchema, ), ) + if whitelist is not None: + missing = sorted(whitelist - matched_whitelist) + if missing: + logger.warning( + "allowed_tools whitelist references tools not advertised " + "by the connected MCP servers; they will be silently " + "absent from the agent: %s", + missing, + ) return tools async def shutdown(self) -> None: diff --git a/tests/test_academy_campaign.py b/tests/test_academy_campaign.py index 7588c79a..ec102fb7 100644 --- a/tests/test_academy_campaign.py +++ b/tests/test_academy_campaign.py @@ -222,3 +222,132 @@ def test_validate_campaign_rejects_duplicate_mcp_server_names(tmp_path) -> None: campaign = load_campaign(campaign_path) with pytest.raises(RuntimeError, match="MCP server names must be unique"): validate_campaign(campaign, 1) + + +def test_agent_allowed_tools_parses(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "allowed-tools-ok", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m chemgraph.mcp.mcp_tools"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + "allowed_tools": ["run_ase", "extract_output_json"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + validate_campaign(campaign, 1) + + assert campaign.agents[0].allowed_tools == ( + "run_ase", + "extract_output_json", + ) + + +def test_agent_allowed_tools_defaults_to_empty(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "allowed-tools-default", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m chemgraph.mcp.mcp_tools"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + validate_campaign(campaign, 1) + + assert campaign.agents[0].allowed_tools == () + + +def test_validate_campaign_rejects_duplicate_allowed_tools(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "duplicate-allowed-tools", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m chemgraph.mcp.mcp_tools"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + "allowed_tools": ["run_ase", "run_ase"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises(RuntimeError, match="duplicate allowed_tools"): + validate_campaign(campaign, 1) + + +def test_validate_campaign_rejects_allowed_tools_without_servers(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "allowed-tools-no-servers", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": [], + "allowed_tools": ["run_ase"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises( + RuntimeError, + match="allowed_tools but no mcp_servers", + ): + validate_campaign(campaign, 1) diff --git a/tests/test_academy_mcp_supervisor.py b/tests/test_academy_mcp_supervisor.py index 901cc78d..10e0c6fe 100644 --- a/tests/test_academy_mcp_supervisor.py +++ b/tests/test_academy_mcp_supervisor.py @@ -38,6 +38,35 @@ def echo(text: str) -> dict: ) +def _write_multi_tool_server(tmp_path: Path) -> None: + """A server that advertises three tools so allowed_tools can subset it.""" + (tmp_path / "multi_mcp.py").write_text( + """ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("multi") + +@mcp.tool(name="alpha", description="Tool alpha.") +def alpha(text: str) -> dict: + return {"who": "alpha", "text": text} + +@mcp.tool(name="beta", description="Tool beta.") +def beta(text: str) -> dict: + return {"who": "beta", "text": text} + +@mcp.tool(name="gamma", description="Tool gamma.") +def gamma(text: str) -> dict: + return {"who": "gamma", "text": text} + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + run_mcp_server(mcp, default_port=0) +""", + encoding="utf-8", + ) + + @pytest.mark.asyncio async def test_mcp_supervisor_starts_server_and_gets_tools(tmp_path) -> None: _write_tiny_server(tmp_path) @@ -123,3 +152,105 @@ async def test_mcp_supervisor_rejects_unknown_server_request(tmp_path) -> None: await supervisor.get_tools(("missing",)) finally: await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_get_tools_returns_all_when_no_allowed_tools(tmp_path) -> None: + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + tools = await supervisor.get_tools(("multi",)) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha", "beta", "gamma"} + + +@pytest.mark.asyncio +async def test_get_tools_filters_by_allowed_tools(tmp_path) -> None: + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + tools = await supervisor.get_tools( + ("multi",), + allowed_tools=frozenset({"alpha", "gamma"}), + ) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha", "gamma"} + + +@pytest.mark.asyncio +async def test_get_tools_warns_on_whitelist_misses(tmp_path, caplog) -> None: + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + with caplog.at_level("WARNING"): + tools = await supervisor.get_tools( + ("multi",), + allowed_tools=frozenset({"alpha", "does_not_exist"}), + ) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha"} + assert any( + "does_not_exist" in record.message for record in caplog.records + ) + + +@pytest.mark.asyncio +async def test_get_tools_empty_allowed_tools_returns_all(tmp_path) -> None: + """An empty whitelist is treated as None (no filter).""" + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + tools = await supervisor.get_tools( + ("multi",), + allowed_tools=frozenset(), + ) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha", "beta", "gamma"} From 25565d0a55265ea38a1f3f9676cd50a4b2597846 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Mon, 15 Jun 2026 11:14:27 -0500 Subject: [PATCH 095/119] docs(example-002): document http(s)_proxy env vars for compute-node MACE downloads The in-process MACE path uses mace_mp(model="medium-mpa-0") which downloads its foundation model from GitHub on first use. Aurora and Polaris compute nodes can reach external sites only through the ALCF outbound proxy (proxy.alcf.anl.gov:3128); without these env vars the download hangs and the mace-agent reports failure. Add http_proxy / https_proxy / no_proxy to the compute env block, and to the optional MACE pre-warm snippet, so the documented commands work out of the box on both systems. No code changes. --- .../example-002-mace-ensemble-screening/e2e_guide.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md index 44c351f6..08031016 100644 --- a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md +++ b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md @@ -132,9 +132,12 @@ make PREFIX="$REMOTE_ROOT/tools/redis" install The `mace_mp` calculator downloads its foundation model on first use into `~/.cache/mace`, so no manual MACE-model staging is needed for this example. First-call download can take a minute; pre-warm it once on the compute node -if you want to skip that wait at run time: +to skip that wait at run time. The compute node only reaches external sites +through the ALCF outbound proxy, so set the proxy env vars first: ```bash +export http_proxy="http://proxy.alcf.anl.gov:3128" +export https_proxy="http://proxy.alcf.anl.gov:3128" python -c "from mace.calculators import mace_mp; mace_mp(model='medium-mpa-0', device='cpu')" ``` @@ -192,6 +195,13 @@ export NUMEXPR_NUM_THREADS=64 export OMP_NUM_THREADS=1 export MKL_NUM_THREADS=1 +# Aurora/Polaris compute nodes reach external sites (GitHub, S3) only +# through the ALCF outbound proxy. Without these, mace_mp(model="medium-mpa-0") +# hangs trying to fetch the foundation model on first use. +export http_proxy="http://proxy.alcf.anl.gov:3128" +export https_proxy="http://proxy.alcf.anl.gov:3128" +export no_proxy="localhost,127.0.0.1" + export PATH="$REMOTE_ROOT/bin:$REMOTE_ROOT/tools/redis/bin:$PATH" chemgraph academy run-compute \ From 648b536cc8f804e749473b1a31d054155f086b75 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Mon, 15 Jun 2026 11:31:02 -0500 Subject: [PATCH 096/119] fix(tools): create output_results_file parent dir in run_ase_core run_ase_core opens output_results_file for write at the end of the simulation without first ensuring its parent directory exists. Agents and CLI users routinely point the output at a not-yet-created nested subdirectory of a shared run dir; the simulation then runs to completion only to fail with FileNotFoundError: [Errno 2] when it tries to persist results. Compute time wasted, error message blames the wrong layer. Add a single os.makedirs(..., exist_ok=True) on the resolved parent right after the .json extension check. Idempotent, harmless when the directory already exists, and surfaces any permission problem before the calculator gets loaded. Hit this on example-002 polaris run 012: mace-agent received a mace_output_directory resource pointing at academy_mace_outputs/, the agent passed output_results_file as academy_mace_outputs/MOL-002.json, the directory did not exist on disk yet, every run_ase call failed identically, mace-agent retried, kept failing. Test in tests/test_mcp.py mocks load_calculator and runs run_ase_core against a tmp_path / "deeply/nested/output.json" target. --- src/chemgraph/tools/ase_core.py | 9 +++++++ tests/test_mcp.py | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/src/chemgraph/tools/ase_core.py b/src/chemgraph/tools/ase_core.py index 4e3dc915..14b028ed 100644 --- a/src/chemgraph/tools/ase_core.py +++ b/src/chemgraph/tools/ase_core.py @@ -350,6 +350,15 @@ def run_ase_core(params: ASEInputSchema) -> dict: "message": f"Output results file must end with '.json', got: {params.output_results_file}", } + # Make sure the destination directory exists before the simulation runs; + # otherwise the trailing ``open(output_results_file, "w")`` fails with + # FileNotFoundError after the calculation has already burned its + # compute time. Callers (LLM agents, scripts) routinely point at a + # not-yet-created subdirectory of a shared run dir, so create it now. + output_parent = os.path.dirname(os.path.abspath(output_results_file)) + if output_parent: + os.makedirs(output_parent, exist_ok=True) + calc, system_info, calc_model = load_calculator(calculator) if calc is None: diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b4615871..ff3dbfbe 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -76,6 +76,53 @@ def fake_run_mace_core(params): assert result["full_output"] == {"ok": True} +def test_run_ase_core_creates_output_parent_directory(monkeypatch, tmp_path): + """run_ase_core should mkdir the output file's parent before writing. + + Academy agents and CLI users routinely point output_results_file at a + not-yet-existing nested subdirectory of a shared run dir. Without this, + the final ``open(output_results_file, "w")`` fails with + FileNotFoundError after the calculation has already burned its compute + time. + """ + from ase import Atoms + from ase.io import write as ase_write + + from chemgraph.schemas.ase_input import ASEInputSchema + from chemgraph.tools import ase_core + + # Real XYZ that ase.io.read can parse. + input_path = tmp_path / "h2.xyz" + ase_write(input_path, Atoms(numbers=[1, 1], positions=[[0, 0, 0], [0, 0, 0.74]])) + + # Output path under a nested subdirectory that does NOT exist yet. + output_path = tmp_path / "deeply" / "nested" / "output.json" + assert not output_path.parent.exists() + + class _FakeCalc: + # ASE's Atoms.get_potential_energy invokes self._calc.get_potential_energy(atoms). + def get_potential_energy(self, _atoms=None, force_consistent=False): + return -1.234 + + def fake_load_calculator(_calculator): + return _FakeCalc(), {}, None + + monkeypatch.setattr(ase_core, "load_calculator", fake_load_calculator) + + params = ASEInputSchema( + input_structure_file=str(input_path), + output_results_file=str(output_path), + driver="energy", + calculator={"calculator_type": "emt"}, + ) + + result = ase_core.run_ase_core(params) + + assert result["status"] == "success", result + assert output_path.exists() + assert output_path.parent.is_dir() + + @pytest.mark.asyncio async def test_split_cif_dataset(tmp_path): """Test splitting a dataset of CIF files.""" From 04f1dbf7be7c703ad2e941f75afafb1c5638607d Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Mon, 15 Jun 2026 11:31:47 -0500 Subject: [PATCH 097/119] fix(academy): materialise shared_run resource directories at startup resolve_campaign_resources rewrites shared_run resource paths to absolute locations under /shared/ but never actually creates the directories on disk. Tools that get pointed at one of these resources have to guess whether their parent exists; the in-process run_ase tool, for example, did not, and example-002 polaris run 012 saw every mace-agent call fail with FileNotFoundError for a path under the campaign-declared mace_output_directory. Make academy uphold the natural contract: if a campaign declares a shared_run resource, the runtime guarantees the on-disk parent exists before any agent touches it. Specifically, after resolving the path: - kind: directory -> mkdir -p the resolved path itself - kind: file / json -> mkdir -p the resolved parent The file itself stays the responsibility of the agent that writes it. mkdir is idempotent so per-rank repetition is harmless. Test extends test_campaign_resources_resolve_to_shared_run_artifacts with on-disk assertions, and adds test_resolve_campaign_resources_skips_ non_shared_run_paths confirming we do not create absolute / external paths. --- src/chemgraph/academy/core/campaign.py | 26 ++++++++++++++--- tests/test_academy_reasoning_phase2.py | 39 ++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py index 3b98fe71..b87a80da 100644 --- a/src/chemgraph/academy/core/campaign.py +++ b/src/chemgraph/academy/core/campaign.py @@ -165,7 +165,15 @@ def resolve_campaign_resources( *, shared_dir_name: str = 'shared', ) -> ChemGraphCampaign: - """Resolve explicit shared-run resource paths for one concrete run.""" + """Resolve explicit shared-run resource paths for one concrete run. + + Also pre-creates the on-disk directories these resources name so that + tools whose first action is to write under a declared output directory + do not fail with ``FileNotFoundError`` partway through. For ``kind: + directory`` resources the directory itself is created; for ``kind: + file`` and ``kind: json`` resources the file's parent directory is + created (the file itself is the agent's responsibility to write). + """ shared_root = (pathlib.Path(run_dir).resolve() / shared_dir_name) resources: dict[str, ResourceSpec] = {} @@ -177,17 +185,27 @@ def resolve_campaign_resources( resources[name] = spec continue path = pathlib.Path(spec.path) - resolved = path if path.is_absolute() else shared_root / path + resolved = (path if path.is_absolute() else shared_root / path).resolve() + _ensure_resource_dir(resolved, spec.kind) resources[name] = spec.model_copy( update={ - 'path': str(resolved.resolve()), - 'uri': spec.uri or _file_uri(resolved.resolve()), + 'path': str(resolved), + 'uri': spec.uri or _file_uri(resolved), }, ) return dataclasses.replace(campaign, resources=resources) +def _ensure_resource_dir(resolved: pathlib.Path, kind: str) -> None: + """Materialise on-disk directories for a resolved shared_run resource.""" + if kind == 'directory': + resolved.mkdir(parents=True, exist_ok=True) + else: + # 'file' and 'json': create the parent so the agent can write the file. + resolved.parent.mkdir(parents=True, exist_ok=True) + + def _file_uri(path: pathlib.Path) -> str: return path.resolve().as_uri() diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index e4ce9198..bbe13a50 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -306,3 +306,42 @@ def test_campaign_resources_resolve_to_shared_run_artifacts(tmp_path) -> None: assert resolved.resources["mace_output_result_file"].path == str( tmp_path / "run-1" / "shared" / "academy_mace_outputs" / "mace_results.json", ) + + # The directory resource itself is materialised on disk so tools that + # expect to write into it do not hit FileNotFoundError on first use. + assert ( + tmp_path / "run-1" / "shared" / "academy_mace_structures" + ).is_dir() + # File resources get their parent directory materialised (the file + # itself is the agent's responsibility to write). + assert ( + tmp_path / "run-1" / "shared" / "academy_mace_outputs" + ).is_dir() + assert not ( + tmp_path / "run-1" / "shared" / "academy_mace_outputs" / "mace_results.json" + ).exists() + + +def test_resolve_campaign_resources_skips_non_shared_run_paths(tmp_path) -> None: + """Only shared_run resources get on-disk materialisation.""" + spec = dataclasses.replace(_agent_spec(), resources=("local_dataset",)) + campaign = ChemGraphCampaign( + run_id="campaign-2", + user_task="Static dataset.", + initial_agent=spec.name, + prompt_profile=Path("prompt_profiles/default.json"), + agents=(spec,), + resources={ + "local_dataset": ResourceSpec( + kind="json", + path="/should/not/exist/data.json", + scope="absolute", + ), + }, + ) + + resolved = resolve_campaign_resources(campaign, tmp_path / "run-1") + + # The absolute path is preserved verbatim and no directory is created. + assert resolved.resources["local_dataset"].path == "/should/not/exist/data.json" + assert not Path("/should/not/exist").exists() From a2e27a97ec651820eb146daf8c830c5d46e9d8ca Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Mon, 15 Jun 2026 14:05:29 -0500 Subject: [PATCH 098/119] added a one off logging to run_ase_core --- src/chemgraph/tools/ase_core.py | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/chemgraph/tools/ase_core.py b/src/chemgraph/tools/ase_core.py index ad54d0c5..30f0ed90 100644 --- a/src/chemgraph/tools/ase_core.py +++ b/src/chemgraph/tools/ase_core.py @@ -322,13 +322,28 @@ def run_ase_core(params: ASEInputSchema) -> dict: dict Minimal result payload (status, message, key numbers). """ + import logging from ase.io import read from ase.optimize import BFGS, LBFGS, GPMin, FIRE, MDMin + # ---- file logger (cg_logs/) ---- + log_dir = os.path.join(os.getcwd(), "cg_logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, "ase_core.log") + logger = logging.getLogger(f"chemgraph.ase_core.{id(params)}") + logger.setLevel(logging.DEBUG) + logger.propagate = False + _fh = logging.FileHandler(log_file) + _fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + logger.addHandler(_fh) + + logger.info("run_ase_core called with params: %s", params.model_dump_json()) + # ---- unpack params ---- try: calculator = params.calculator.model_dump() except Exception as e: + logger.error("Calculator validation failed: %s", e) return { "status": "failure", "error_type": "ValidationError", @@ -347,7 +362,11 @@ def run_ase_core(params: ASEInputSchema) -> dict: pressure = params.pressure # ---- input validation ---- + logger.info("driver=%s, input=%s, output=%s, optimizer=%s, fmax=%s, steps=%s", + driver, input_structure_file, output_results_file, optimizer, fmax, steps) + if not os.path.isfile(input_structure_file): + logger.error("Input file not found: %s", input_structure_file) return { "status": "failure", "error_type": "FileNotFoundError", @@ -355,15 +374,18 @@ def run_ase_core(params: ASEInputSchema) -> dict: } if not output_results_file.endswith(".json"): + logger.error("Invalid output file extension: %s", output_results_file) return { "status": "failure", "error_type": "ValueError", "message": f"Output results file must end with '.json', got: {params.output_results_file}", } + logger.info("Loading calculator: %s", calculator) calc, system_info, calc_model = load_calculator(calculator) if calc is None: + logger.error("Unsupported calculator: %s", calculator) return { "status": "failure", "error_type": "ValueError", @@ -372,16 +394,19 @@ def run_ase_core(params: ASEInputSchema) -> dict: "MACE (mace_mp, mace_off, mace_anicc), EMT, TBLite (GFN2-xTB, GFN1-xTB), NWChem and Orca" ), } + logger.info("Calculator loaded successfully: %s", type(calc).__name__) try: atoms = read(input_structure_file) except Exception as e: + logger.error("Failed to read input structure: %s", e) return { "status": "failure", "error_type": type(e).__name__, "message": f"Cannot read {input_structure_file} using ASE. Exception from ASE: {e}", } + logger.info("Read %d atoms from %s", len(atoms), input_structure_file) atoms.info.update(system_info) atoms.calc = calc @@ -389,7 +414,9 @@ def run_ase_core(params: ASEInputSchema) -> dict: # Driver: energy / dipole (single-point, no optimization) # ------------------------------------------------------------------ if driver in ("energy", "dipole"): + logger.info("Running single-point %s calculation", driver) energy = atoms.get_potential_energy() + logger.info("Single-point energy: %s eV", energy) final_structure = atoms_to_atomsdata(atoms) dipole: List[Optional[float]] = [None, None, None] @@ -414,6 +441,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: ) with open(output_results_file, "w", encoding="utf-8") as wf: wf.write(simulation_output.model_dump_json(indent=4)) + logger.info("Results saved to %s (wall_time=%.2fs)", output_results_file, wall_time) if driver == "energy": return { @@ -445,13 +473,16 @@ def run_ase_core(params: ASEInputSchema) -> dict: if optimizer_class is None: raise ValueError(f"Unsupported optimizer: {optimizer}") + logger.info("Running optimization with %s (fmax=%s, steps=%s)", optimizer, fmax, steps) if len(atoms) > 1: dyn = optimizer_class(atoms) converged = dyn.run(fmax=fmax, steps=steps) else: converged = True + logger.info("Optimization converged=%s", converged) single_point_energy = float(atoms.get_potential_energy()) + logger.info("Post-optimization energy: %s eV", single_point_energy) final_structure = AtomsData( numbers=atoms.numbers, positions=atoms.positions, @@ -466,6 +497,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # Vibrational / thermo / IR analysis # -------------------------------------------------------------- if driver in {"vib", "thermo", "ir"}: + logger.info("Starting vibrational analysis (driver=%s)", driver) from ase.vibrations import Vibrations from ase import units @@ -481,6 +513,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: vib = Vibrations(atoms, name=vib_name) vib.clean() vib.run() + logger.info("Vibrational analysis complete") vib_data = { "energies": [], @@ -527,6 +560,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # ---- IR ---- if driver == "ir": + logger.info("Running IR calculation") from ase.vibrations import Infrared import matplotlib @@ -558,6 +592,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: fig.savefig(ir_plot_path, format="png", dpi=300) plt.close(fig) + logger.info("IR spectrum plot saved to %s", ir_plot_path) ir_data["IR Plot"] = f"Saved to {os.path.abspath(ir_plot_path)}" ir_data["Normal mode data"] = ( f"Normal modes saved as individual .traj files with prefix {mol_stem}_" @@ -565,6 +600,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # ---- Thermochemistry ---- if driver == "thermo": + logger.info("Computing thermochemistry (T=%s K, P=%s Pa)", temperature, pressure) if len(atoms) == 1: thermo_data = { "enthalpy": single_point_energy, @@ -615,6 +651,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # ---- serialise full output ---- end_time = time.time() wall_time = end_time - start_time + logger.info("Simulation finished (driver=%s, wall_time=%.2fs, converged=%s)", driver, wall_time, converged) simulation_output = ASEOutputSchema( input_structure_file=input_structure_file, @@ -671,6 +708,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: } except Exception as e: + logger.exception("run_ase_core failed with %s: %s", type(e).__name__, e) return { "status": "failure", "error_type": type(e).__name__, From f56713d3b9038e6c99eb8b7999ab204bbdcaa5d8 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Mon, 15 Jun 2026 14:17:48 -0500 Subject: [PATCH 099/119] added ppn to task spec in demo el --- scripts/demo/_demo_chemistry.py | 2 ++ scripts/demo/demo_ensemble_launcher_in_job_direct.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/demo/_demo_chemistry.py b/scripts/demo/_demo_chemistry.py index 4e2d1547..d4ca73cf 100644 --- a/scripts/demo/_demo_chemistry.py +++ b/scripts/demo/_demo_chemistry.py @@ -154,6 +154,7 @@ def submit_and_collect( output_dir: Path | str, inline: bool, timeout: float = 6000.0, + ppn: int = 1, ) -> list[dict]: """Submit one MACE thermo job per molecule, gather and summarise. @@ -177,6 +178,7 @@ def submit_and_collect( task_type="python", callable=_mace_worker, kwargs={"job": job}, + processes_per_node=ppn, ) for name, job in zip(names, jobs) ] diff --git a/scripts/demo/demo_ensemble_launcher_in_job_direct.py b/scripts/demo/demo_ensemble_launcher_in_job_direct.py index d7126924..42c4c73f 100644 --- a/scripts/demo/demo_ensemble_launcher_in_job_direct.py +++ b/scripts/demo/demo_ensemble_launcher_in_job_direct.py @@ -42,6 +42,8 @@ def main() -> None: parser.add_argument("--device", default=None) parser.add_argument("--output-dir", default="demo_el_out") parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument("--ppn", type=int, default=16, + help="Processes (cores) per node for each task") parser.add_argument("--timeout", type=float, default=6000.0) args = parser.parse_args() @@ -69,7 +71,7 @@ def main() -> None: "Install via scripts/hpc_setup/install_remote.sh on HPC." ) - print(f"system={system} device={device} mode=managed") + print(f"system={system} device={device} ppn={args.ppn} mode=managed") from chemgraph.execution.config import get_backend @@ -82,6 +84,7 @@ def main() -> None: output_dir=args.output_dir, inline=False, timeout=args.timeout, + ppn=args.ppn, ) finally: backend.shutdown() From aefed79036c58cb0d3a90cd1a2725deb8fe6f9c7 Mon Sep 17 00:00:00 2001 From: Hari Date: Mon, 15 Jun 2026 19:53:36 +0000 Subject: [PATCH 100/119] added try except block in demo chemistry --- scripts/demo/_demo_chemistry.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/demo/_demo_chemistry.py b/scripts/demo/_demo_chemistry.py index d4ca73cf..065cf966 100644 --- a/scripts/demo/_demo_chemistry.py +++ b/scripts/demo/_demo_chemistry.py @@ -191,12 +191,16 @@ def submit_and_collect( results: list[dict] = [] for name, job, fut in zip(names, jobs, futures): print(f" waiting on {name}...", flush=True) - raw = fut.result(timeout=timeout) - if not isinstance(raw, dict): - raise RuntimeError(f"{name}: non-dict result {type(raw).__name__}: {raw!r}") - if raw.get("status") != "success": - raise RuntimeError(f"{name}: backend returned status={raw.get('status')!r}: {raw}") - results.append(_extract_properties(name, raw, job, inline=inline)) + try: + raw = fut.result(timeout=timeout) + if not isinstance(raw, dict): + raise RuntimeError(f"{name}: non-dict result {type(raw).__name__}: {raw!r}") + if raw.get("status") != "success": + raise RuntimeError(f"{name}: backend returned status={raw.get('status')!r}: {raw}") + results.append(_extract_properties(name, raw, job, inline=inline)) + except Exception as e: + print(f"collecting results for job {name} failed with error: {e}") + results.append(None) return results From 4622823bdcc33e52fce9753305d518ebbea2291c Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Tue, 16 Jun 2026 11:56:31 -0500 Subject: [PATCH 101/119] refactor(academy/dashboard): bundle UAN relay script into chemgraph The dashboard launcher previously required a separate `academy` source checkout on every remote system because it referenced `${academy_repo_root}/examples/09-polaris-lm-swarm/uan_http_relay.py` from the system profile. That dependency was undocumented in the example-002 e2e guide (which only tells users to sync ChemGraph), so a fresh user on Aurora hit "No such file or directory" trying to start the Mac-relay path. Move the relay script into the chemgraph package as a runtime template (stdlib-only, no imports beyond socket/threading). The dashboard launcher now materializes it onto the remote at $REMOTE_ROOT/.chemgraph/uan_http_relay.py before starting the relay, via a one-line ssh stdin pipe. start_relay accepts the resulting path as an argument instead of computing it from profile state. Side cleanup: * SystemProfile no longer has academy_repo_root; both aurora.template.json and polaris.template.json drop the field. * Polaris's relay_host_file used to land inside the academy checkout (`/academy/uan-relay-18186.host`); normalize to the same shape Aurora already used: directly under remote_root. * Dashboard metadata no longer writes academy_repo_root either; nothing downstream consumed it. Result: the second `academy` source checkout is no longer required on remote systems. Users only need ChemGraph synced. The Mac-relay path works the same way on any new host as long as the chemgraph package is installed. 102 academy + synth tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../academy/runtime/dashboard_launcher.py | 32 ++++++- .../runtime/profiles/aurora.template.json | 1 - .../runtime/profiles/polaris.template.json | 3 +- .../academy/runtime/profiles/system.py | 1 - .../runtime/templates/uan_http_relay.py | 96 +++++++++++++++++++ tests/test_academy_dashboard_launcher.py | 1 - 6 files changed, 125 insertions(+), 9 deletions(-) create mode 100644 src/chemgraph/academy/runtime/templates/uan_http_relay.py diff --git a/src/chemgraph/academy/runtime/dashboard_launcher.py b/src/chemgraph/academy/runtime/dashboard_launcher.py index 1869a1e1..0116176f 100644 --- a/src/chemgraph/academy/runtime/dashboard_launcher.py +++ b/src/chemgraph/academy/runtime/dashboard_launcher.py @@ -44,6 +44,29 @@ def parse_args() -> argparse.Namespace: def template(name: str) -> str: return files("chemgraph.academy.runtime.templates").joinpath(name).read_text() + +REMOTE_RELAY_SUBPATH = ".chemgraph/uan_http_relay.py" + + +def stage_relay_script(profile: SystemProfile, host: str, control_path: str) -> str: + """Copy the bundled UAN relay script to the remote host. + + The relay script is shipped inside the chemgraph package so we no longer + require a separate ``academy`` source checkout on the remote system. + We materialize it under ``$REMOTE_ROOT/.chemgraph/uan_http_relay.py`` + on every dashboard launch (idempotent overwrite), then return that + absolute path for the start_relay shell template to reference. + """ + relay_dir = f"{profile.remote_root}/.chemgraph" + relay_path = f"{relay_dir}/uan_http_relay.py" + contents = template("uan_http_relay.py") + cmd = ( + f"mkdir -p {shlex.quote(relay_dir)} && " + f"cat > {shlex.quote(relay_path)}" + ) + ssh(host, cmd, control_path=control_path, input_text=contents) + return relay_path + def ssh(host: str, command: str | list[str] | None, *, control_path: str, input_text: str | None = None, check: bool = True, capture: bool = False, batch_mode: bool = True, extra: list[str] | None = None) -> subprocess.CompletedProcess[str]: cmd = ["ssh"] if batch_mode: @@ -62,8 +85,7 @@ def wrapper(profile: SystemProfile) -> str: .replace("%{venv_python}%", profile.venv_python) ) -def start_relay(profile: SystemProfile, host: str, control_path: str, args: argparse.Namespace, relay_port: int, relay_python: str, log_path: Path) -> subprocess.Popen[str]: - relay_script = f"{profile.academy_repo_root}/examples/09-polaris-lm-swarm/uan_http_relay.py" +def start_relay(profile: SystemProfile, host: str, control_path: str, args: argparse.Namespace, relay_port: int, relay_python: str, log_path: Path, relay_script: str) -> subprocess.Popen[str]: relay_args = ["bash", "-s", "--", profile.remote_root, relay_script, profile.relay_host_file, f"{profile.remote_root}/uan-relay-{relay_port}.pid", f"{profile.remote_root}/uan-relay-{relay_port}.log", str(relay_port), str(args.reverse_port), relay_python] log_path.parent.mkdir(parents=True, exist_ok=True) cmd = ["ssh", "-o", "BatchMode=yes", "-o", f"ControlPath={control_path}", "-o", "ControlMaster=auto", "-o", "ControlPersist=yes", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=4", "-R", f"127.0.0.1:{args.reverse_port}:{args.local_argo_host}:{args.local_argo_port}", host, *relay_args] @@ -157,12 +179,14 @@ def main() -> int: ssh(remote_host, f"mkdir -p {shlex.quote(profile.remote_root + '/bin')} && cat > {shlex.quote(wrapper_path)} && chmod +x {shlex.quote(wrapper_path)}", control_path=control_path, input_text=wrapper(profile)) relay_host = None if args.lm_connect == "mac-argo-relay": + print(f"Staging UAN relay script under {profile.remote_root}/{REMOTE_RELAY_SUBPATH}...", flush=True) + relay_script = stage_relay_script(profile, remote_host, control_path) print(f"Starting {profile.name} UAN relay through {remote_host}...", flush=True) - relay_process = start_relay(profile, remote_host, control_path, args, relay_port, args.relay_python or profile.venv_python, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log")) + relay_process = start_relay(profile, remote_host, control_path, args, relay_port, args.relay_python or profile.venv_python, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log"), relay_script) relay_host = wait_relay(profile, remote_host, control_path, relay_port, relay_process, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log")) lm_base_url = f"http://{relay_host}:{relay_port}/argoapi/v1" if relay_host else str(args.lm_base_url) print(f"Compute-node LM URL: {lm_base_url}", flush=True) - metadata = {"created_at": time.time(), "created_by": "chemgraph academy dashboard", "run_id": args.run_id, "system": profile.name, "campaign": args.campaign, "remote_run_dir": remote_run_dir, "remote_host": remote_host, "lm_connect": args.lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, "academy_repo_root": profile.academy_repo_root, "chemgraph_repo_root": profile.repo_root} + metadata = {"created_at": time.time(), "created_by": "chemgraph academy dashboard", "run_id": args.run_id, "system": profile.name, "campaign": args.campaign, "remote_run_dir": remote_run_dir, "remote_host": remote_host, "lm_connect": args.lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, "chemgraph_repo_root": profile.repo_root} if relay_host: metadata.update({"relay_host": relay_host, "relay_port": relay_port}) print(f"Writing run metadata: {remote_host}:{remote_run_dir}/dashboard_metadata.json", flush=True) diff --git a/src/chemgraph/academy/runtime/profiles/aurora.template.json b/src/chemgraph/academy/runtime/profiles/aurora.template.json index 3d6404de..1e3e40a5 100644 --- a/src/chemgraph/academy/runtime/profiles/aurora.template.json +++ b/src/chemgraph/academy/runtime/profiles/aurora.template.json @@ -2,7 +2,6 @@ "name": "aurora", "remote_host": "${ALCF_USER}@aurora.alcf.anl.gov", "remote_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}", - "academy_repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/academy", "repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", "run_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/runs", "relay_host_file": "/flare/${ALCF_PROJECT}/${ALCF_USER}/uan-relay-18186.host", diff --git a/src/chemgraph/academy/runtime/profiles/polaris.template.json b/src/chemgraph/academy/runtime/profiles/polaris.template.json index c1cf3dc1..7be57c92 100644 --- a/src/chemgraph/academy/runtime/profiles/polaris.template.json +++ b/src/chemgraph/academy/runtime/profiles/polaris.template.json @@ -2,10 +2,9 @@ "name": "polaris", "remote_host": "${ALCF_USER}@polaris.alcf.anl.gov", "remote_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}", - "academy_repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy", "repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", "run_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/runs", - "relay_host_file": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/academy/uan-relay-18186.host", + "relay_host_file": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/uan-relay-18186.host", "relay_port": 18186, "venv_python": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/venvs/academy-swarm/bin/python", "redis_bin_dir": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", diff --git a/src/chemgraph/academy/runtime/profiles/system.py b/src/chemgraph/academy/runtime/profiles/system.py index fcddb3c8..02ed6dc0 100644 --- a/src/chemgraph/academy/runtime/profiles/system.py +++ b/src/chemgraph/academy/runtime/profiles/system.py @@ -20,7 +20,6 @@ class SystemProfile(BaseModel): name: str remote_host: str remote_root: str - academy_repo_root: str repo_root: str run_root: str relay_host_file: str diff --git a/src/chemgraph/academy/runtime/templates/uan_http_relay.py b/src/chemgraph/academy/runtime/templates/uan_http_relay.py new file mode 100644 index 00000000..8ce424fd --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/uan_http_relay.py @@ -0,0 +1,96 @@ +"""Tiny TCP relay used by the dashboard launcher. + +Listens on a UAN-visible port and forwards every accepted connection to a +loopback service on the same UAN host. The dashboard launcher pairs this +with a reverse SSH tunnel (Mac argo-shim -> UAN loopback), so compute +nodes can curl http://:/argoapi/v1 and reach the developer's +local argo-shim. + +This file is materialised onto the remote system at runtime by +``chemgraph.academy.runtime.dashboard_launcher.start_relay``. It was +previously expected to live in a sibling ``academy`` source checkout +under ``examples/09-polaris-lm-swarm/``; bundling it here removes the +need for that second checkout on remote hosts. + +The implementation is intentionally stdlib-only so the script runs under +any Python interpreter without pip-installing anything on the remote. +""" + +from __future__ import annotations + +import argparse +import socket +import threading + + +def pump(src: socket.socket, dst: socket.socket) -> None: + try: + while True: + data = src.recv(65536) + if not data: + break + dst.sendall(data) + except OSError: + pass + finally: + try: + dst.shutdown(socket.SHUT_WR) + except OSError: + pass + + +def handle_client( + client: socket.socket, + target_host: str, + target_port: int, +) -> None: + with client: + try: + upstream = socket.create_connection((target_host, target_port)) + except OSError as e: + print(f'upstream connection failed: {e}', flush=True) + return + with upstream: + left = threading.Thread(target=pump, args=(client, upstream)) + right = threading.Thread(target=pump, args=(upstream, client)) + left.start() + right.start() + left.join() + right.join() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description='Relay a UAN-reachable TCP port to a loopback service.', + ) + parser.add_argument('--listen-host', default='0.0.0.0') + parser.add_argument('--listen-port', type=int, required=True) + parser.add_argument('--target-host', default='127.0.0.1') + parser.add_argument('--target-port', type=int, required=True) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind((args.listen_host, args.listen_port)) + server.listen(128) + print( + f'relay listening on {args.listen_host}:{args.listen_port} ' + f'-> {args.target_host}:{args.target_port}', + flush=True, + ) + while True: + client, addr = server.accept() + print(f'accepted connection from {addr[0]}:{addr[1]}', flush=True) + thread = threading.Thread( + target=handle_client, + args=(client, args.target_host, args.target_port), + daemon=True, + ) + thread.start() + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/tests/test_academy_dashboard_launcher.py b/tests/test_academy_dashboard_launcher.py index 587d2534..7836ef03 100644 --- a/tests/test_academy_dashboard_launcher.py +++ b/tests/test_academy_dashboard_launcher.py @@ -16,7 +16,6 @@ def _profile(tmp_path: Path) -> SystemProfile: name="test-system", remote_host="user@example", remote_root="/remote/root", - academy_repo_root="/remote/root/academy", repo_root="/remote/root/ChemGraph", run_root="/remote/root/runs", relay_host_file="/remote/root/relay.host", From 64d087a0cf478c9c3b547f0da3805b5fb9bcfd8d Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Tue, 16 Jun 2026 18:41:45 -0500 Subject: [PATCH 102/119] added -ppn and --ngpus_per_process to mcp demos --- .../demo_ensemble_launcher_in_job_agent.py | 14 ++++-- src/chemgraph/mcp/mace_mcp_hpc.py | 45 +++++++++++++------ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/scripts/demo/demo_ensemble_launcher_in_job_agent.py b/scripts/demo/demo_ensemble_launcher_in_job_agent.py index a059d070..be280e1a 100644 --- a/scripts/demo/demo_ensemble_launcher_in_job_agent.py +++ b/scripts/demo/demo_ensemble_launcher_in_job_agent.py @@ -39,7 +39,8 @@ def _abort(msg: str) -> None: sys.exit(2) -async def amain(model: str, system: str, device: str, query: str, verbose: int) -> None: +async def amain(model: str, system: str, device: str, query: str, verbose: int, + *, ppn: int = 1, ngpus_per_process: int = 0) -> None: if verbose: logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) @@ -58,7 +59,9 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int) "ChemGraph MACE (EnsembleLauncher)": { "transport": "stdio", "command": python, - "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc", + "--ppn", str(ppn), + "--ngpus-per-process", str(ngpus_per_process)], "env": env, }, } @@ -107,6 +110,10 @@ def main() -> None: parser.add_argument("--model", default="gpt-4o-mini") parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) parser.add_argument("--device", default=None) + parser.add_argument("--ppn", type=int, default=1, + help="Processes per node for MCP backend tasks") + parser.add_argument("--ngpus-per-process", type=int, default=0, + help="GPUs per process for MCP backend tasks") parser.add_argument("--query", default=None) parser.add_argument("-v", "--verbose", action="count", default=0) args = parser.parse_args() @@ -120,7 +127,8 @@ def main() -> None: _abort(f"Unsupported --system: {system!r}") device = args.device or ("xpu" if system == "aurora" else "cuda") query = args.query or agent_prompt(device=device) - asyncio.run(amain(args.model, system, device, query, args.verbose)) + asyncio.run(amain(args.model, system, device, query, args.verbose, + ppn=args.ppn, ngpus_per_process=args.ngpus_per_process)) if __name__ == "__main__": diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index b18816a8..6e6f8776 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -18,6 +18,7 @@ import logging import os +import sys from pathlib import Path from chemgraph.execution.base import TaskSpec @@ -204,10 +205,6 @@ def _mace_transport_hook(task: TaskSpec) -> TaskSpec: # ── Single-structure tool ────────────────────────────────────────────── -@mcp.tool( - name="run_mace_single", - description="Run a single MACE calculation", -) def run_mace_single(params: mace_input_schema) -> dict: """Run a single MACE calculation on the configured backend. @@ -300,16 +297,6 @@ def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: ] -@mcp.schema_fanout_tool( - name="run_mace_ensemble", - description=( - "Run MACE calculations over every structure in a directory. " - "Local mode uses input_structure_directory; remote mode uses " - "remote_structure_directory (pre-stage files first with " - "transfer_files)." - ), - worker=_mace_worker, -) def run_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: return _expand_mace_ensemble(params) @@ -333,8 +320,38 @@ def run_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: if __name__ == "__main__": + import argparse as _ap + from chemgraph.mcp.server_utils import run_mcp_server + _parser = _ap.ArgumentParser(add_help=False) + _parser.add_argument("--ppn", type=int, default=1, + help="Processes per node for backend tasks") + _parser.add_argument("--ngpus-per-process", type=int, default=0, + help="GPUs per process for backend tasks") + _args, _remaining = _parser.parse_known_args() + sys.argv = [sys.argv[0]] + _remaining + + mcp.tool( + name="run_mace_single", + description="Run a single MACE calculation", + processes_per_node=_args.ppn, + gpus_per_task=_args.ngpus_per_process, + )(run_mace_single) + + mcp.schema_fanout_tool( + name="run_mace_ensemble", + description=( + "Run MACE calculations over every structure in a directory. " + "Local mode uses input_structure_directory; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_mace_worker, + processes_per_node=_args.ppn, + gpus_per_task=_args.ngpus_per_process, + )(run_mace_ensemble) + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) try: From 09f972be27d94b0641a4d8d1cf124ee1b54b858b Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Tue, 16 Jun 2026 19:19:47 -0500 Subject: [PATCH 103/119] added a counter in cg mcp to make task_ids unique --- src/chemgraph/mcp/cg_fastmcp.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py index db2c3a23..f9a747a4 100644 --- a/src/chemgraph/mcp/cg_fastmcp.py +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -99,6 +99,7 @@ def __init__(self, **kwargs: Any) -> None: self._backend_kwargs: Optional[dict[str, Any]] = None self._tracker_kwargs: dict[str, Any] = {} self._pre_submit_hook: Optional[Callable] = None + self._task_counter: int = 0 # ── Backend lifecycle ─────────────────────────────────────────────── @@ -423,10 +424,12 @@ async def wrapper(params): from chemgraph.execution.utils import to_picklable self._ensure_backend() + self._task_counter += 1 + batch_counter = self._task_counter pending = [] for i, p in enumerate(params): task = TaskSpec( - task_id=f"{fn.__name__}_{i}", + task_id=f"{fn.__name__}_{batch_counter}_{i}", task_type="python", callable=fn, kwargs={param.name: to_picklable(p)}, @@ -553,12 +556,14 @@ async def wrapper(**kwargs): from chemgraph.execution.utils import to_picklable self._ensure_backend() + self._task_counter += 1 + batch_counter = self._task_counter ensemble_params = kwargs[param.name] items = expander(ensemble_params) pending = [] for i, item in enumerate(items): task = TaskSpec( - task_id=f"{tool_name}_{i}", + task_id=f"{tool_name}_{batch_counter}_{i}", task_type="python", callable=worker, kwargs={worker_param_name: to_picklable(item)}, @@ -602,8 +607,10 @@ def _make_backend_wrapper( @functools.wraps(fn) async def wrapper(**kwargs: Any) -> Any: self._ensure_backend() + self._task_counter += 1 + task_id = f"{fn.__name__}_{self._task_counter}" task = TaskSpec( - task_id=fn.__name__, + task_id=task_id, task_type="python", callable=fn, kwargs=to_picklable(kwargs), @@ -615,7 +622,7 @@ async def wrapper(**kwargs: Any) -> Any: if self._backend.is_async_remote: return await submit_or_gather( self._backend, - [({"task_id": fn.__name__}, fut)], + [({"task_id": task_id}, fut)], self._tracker, fn.__name__, ) From 507d51879f9fa0205597444c7ca9c82cb00456ac Mon Sep 17 00:00:00 2001 From: Hari Date: Wed, 17 Jun 2026 01:51:58 +0000 Subject: [PATCH 104/119] adding some temp cg config dor argo --- .../demo_ensemble_launcher_in_job_agent.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/scripts/demo/demo_ensemble_launcher_in_job_agent.py b/scripts/demo/demo_ensemble_launcher_in_job_agent.py index be280e1a..58b98418 100644 --- a/scripts/demo/demo_ensemble_launcher_in_job_agent.py +++ b/scripts/demo/demo_ensemble_launcher_in_job_agent.py @@ -46,15 +46,11 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int, logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) python = sys.executable - env = { + env = os.environ.copy() + env.update({ "CHEMGRAPH_EXECUTION_BACKEND": "ensemble_launcher", "COMPUTE_SYSTEM": system, - "PATH": os.environ.get("PATH", ""), - "HOME": os.environ.get("HOME", ""), - "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), - "PBS_NODEFILE": os.environ.get("PBS_NODEFILE", ""), - "PBS_O_WORKDIR": os.environ.get("PBS_O_WORKDIR", ""), - } + }) server_configs = { "ChemGraph MACE (EnsembleLauncher)": { "transport": "stdio", @@ -81,13 +77,21 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int, tools = await load_mcp_tools(session) print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + #cg = ChemGraph( + # model_name=model, + # workflow_type="single_agent", + # structured_output=False, + # return_option="state", + # tools=tools, + # ) cg = ChemGraph( - model_name=model, - workflow_type="single_agent", - structured_output=False, - return_option="state", - tools=tools, - ) + model_name="argo:gpt-5.4", + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + base_url="http://127.0.0.1:12985/argoapi/v1" + ) print("Running agent...\n" + "=" * 60) result = await cg.run(query) @@ -123,7 +127,7 @@ def main() -> None: if not args.system: _abort("COMPUTE_SYSTEM env var not set and --system not given.") system = args.system.lower().strip() - if system not in ("polaris", "aurora"): + if system not in ("polaris", "aurora", "crux"): _abort(f"Unsupported --system: {system!r}") device = args.device or ("xpu" if system == "aurora" else "cuda") query = args.query or agent_prompt(device=device) From c8bb17021351d6c79cf4f605ff007e575215b7b2 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 10:40:13 -0500 Subject: [PATCH 105/119] Gate MACE inline-structure transport to no-shared-filesystem backends The MACE MCP server embedded each local structure inline and had the worker re-materialise it to /tmp on every tool call, regardless of backend. This is only needed for Globus Compute, whose workers run on a remote host. For local/parsl/ensemble_launcher backends (shared FS with the server) it was pure overhead -- extra serialization, redundant disk I/O, and a full_output read-back. Add a shares_filesystem capability to ExecutionBackend (True by default; False for Globus Compute, config-overridable via the shares_filesystem kwarg). The MACE transport hook now embeds inline only when the backend does not share the filesystem; the worker already no-ops its inline branch when the key is absent, so shared-FS backends read the input path directly. --- src/chemgraph/execution/base.py | 11 +++++++++++ .../execution/globus_compute_backend.py | 9 +++++++++ src/chemgraph/mcp/mace_mcp_hpc.py | 18 +++++++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py index c182b4cf..2e5e4962 100644 --- a/src/chemgraph/execution/base.py +++ b/src/chemgraph/execution/base.py @@ -111,6 +111,17 @@ def is_async_remote(self) -> bool: retrieval tools instead of blocking until completion.""" return False + @property + def shares_filesystem(self) -> bool: + """Whether workers see the same filesystem as the submitting server. + + When ``True`` (default), a path written by the server is readable by + the worker, so file-transport tricks (inline embedding, ``/tmp`` + re-materialisation) are unnecessary. Globus Compute overrides this to + ``False`` because its workers run on a remote host without a shared + filesystem.""" + return True + @abstractmethod def initialize(self, system: str = "local", **kwargs: Any) -> None: """Prepare the backend for accepting work. diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index 6c810e46..48cf8143 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -52,11 +52,16 @@ def __init__(self) -> None: super().__init__() self._executor = None self._endpoint_id: str | None = None + self._shares_filesystem = False @property def is_async_remote(self) -> bool: return True + @property + def shares_filesystem(self) -> bool: + return self._shares_filesystem + # ── lifecycle ──────────────────────────────────────────────────────── def initialize(self, system: str = "local", **kwargs: Any) -> None: @@ -82,6 +87,10 @@ def initialize(self, system: str = "local", **kwargs: Any) -> None: if amqp_port is not None: executor_kwargs["amqp_port"] = int(amqp_port) + # Opt-in: a Globus Compute endpoint that shares an HPC filesystem with + # the MCP server can skip inline file embedding and read paths directly. + self._shares_filesystem = bool(kwargs.get("shares_filesystem", False)) + self._endpoint_id = endpoint_id self._executor = Executor(**executor_kwargs) self._initialized = True diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index b18816a8..afb41bf0 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -171,9 +171,19 @@ def _normalize_model(job: dict) -> None: job["model"] = "medium-mpa-0" +def _backend_shares_fs() -> bool: + """Whether the active backend shares the server's filesystem. + + When it does, inline embedding (and the worker's ``/tmp`` round-trip) + is unnecessary -- the worker reads ``input_structure_file`` directly. + Defaults to ``True`` (skip embedding) when no backend exists yet.""" + backend = getattr(mcp, "_backend", None) + return getattr(backend, "shares_filesystem", True) + + def _mace_transport_hook(task: TaskSpec) -> TaskSpec: """Route single-tool calls to the dict-based worker and embed - local structures on whichever path is taken.""" + local structures only when the backend has no shared filesystem.""" logger.debug( "mace transport hook: task_id=%s callable=%s", task.task_id, @@ -187,13 +197,15 @@ def _mace_transport_hook(task: TaskSpec) -> TaskSpec: params.model_dump() if hasattr(params, "model_dump") else dict(params) ) _normalize_model(job) - _embed_inline_if_local(job) + if not _backend_shares_fs(): + _embed_inline_if_local(job) task.callable = _mace_worker task.kwargs = {"job": job} elif task.callable is _mace_worker: job = dict(task.kwargs.get("job", {})) _normalize_model(job) - _embed_inline_if_local(job) + if not _backend_shares_fs(): + _embed_inline_if_local(job) task.kwargs = {"job": job} return task From 4aead3e9e1a1494955ab9ff5eca69cafb5ceb2b3 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 10:40:13 -0500 Subject: [PATCH 106/119] Gate MACE inline-structure transport to no-shared-filesystem backends The MACE MCP server embedded each local structure inline and had the worker re-materialise it to /tmp on every tool call, regardless of backend. This is only needed for Globus Compute, whose workers run on a remote host. For local/parsl/ensemble_launcher backends (shared FS with the server) it was pure overhead -- extra serialization, redundant disk I/O, and a full_output read-back. Add a shares_filesystem capability to ExecutionBackend (True by default; False for Globus Compute, config-overridable via the shares_filesystem kwarg). The MACE transport hook now embeds inline only when the backend does not share the filesystem; the worker already no-ops its inline branch when the key is absent, so shared-FS backends read the input path directly. --- src/chemgraph/execution/base.py | 11 +++++++++++ .../execution/globus_compute_backend.py | 9 +++++++++ src/chemgraph/mcp/mace_mcp_hpc.py | 18 +++++++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py index c182b4cf..2e5e4962 100644 --- a/src/chemgraph/execution/base.py +++ b/src/chemgraph/execution/base.py @@ -111,6 +111,17 @@ def is_async_remote(self) -> bool: retrieval tools instead of blocking until completion.""" return False + @property + def shares_filesystem(self) -> bool: + """Whether workers see the same filesystem as the submitting server. + + When ``True`` (default), a path written by the server is readable by + the worker, so file-transport tricks (inline embedding, ``/tmp`` + re-materialisation) are unnecessary. Globus Compute overrides this to + ``False`` because its workers run on a remote host without a shared + filesystem.""" + return True + @abstractmethod def initialize(self, system: str = "local", **kwargs: Any) -> None: """Prepare the backend for accepting work. diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py index 6c810e46..48cf8143 100644 --- a/src/chemgraph/execution/globus_compute_backend.py +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -52,11 +52,16 @@ def __init__(self) -> None: super().__init__() self._executor = None self._endpoint_id: str | None = None + self._shares_filesystem = False @property def is_async_remote(self) -> bool: return True + @property + def shares_filesystem(self) -> bool: + return self._shares_filesystem + # ── lifecycle ──────────────────────────────────────────────────────── def initialize(self, system: str = "local", **kwargs: Any) -> None: @@ -82,6 +87,10 @@ def initialize(self, system: str = "local", **kwargs: Any) -> None: if amqp_port is not None: executor_kwargs["amqp_port"] = int(amqp_port) + # Opt-in: a Globus Compute endpoint that shares an HPC filesystem with + # the MCP server can skip inline file embedding and read paths directly. + self._shares_filesystem = bool(kwargs.get("shares_filesystem", False)) + self._endpoint_id = endpoint_id self._executor = Executor(**executor_kwargs) self._initialized = True diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index 6e6f8776..eee47549 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -172,9 +172,19 @@ def _normalize_model(job: dict) -> None: job["model"] = "medium-mpa-0" +def _backend_shares_fs() -> bool: + """Whether the active backend shares the server's filesystem. + + When it does, inline embedding (and the worker's ``/tmp`` round-trip) + is unnecessary -- the worker reads ``input_structure_file`` directly. + Defaults to ``True`` (skip embedding) when no backend exists yet.""" + backend = getattr(mcp, "_backend", None) + return getattr(backend, "shares_filesystem", True) + + def _mace_transport_hook(task: TaskSpec) -> TaskSpec: """Route single-tool calls to the dict-based worker and embed - local structures on whichever path is taken.""" + local structures only when the backend has no shared filesystem.""" logger.debug( "mace transport hook: task_id=%s callable=%s", task.task_id, @@ -188,13 +198,15 @@ def _mace_transport_hook(task: TaskSpec) -> TaskSpec: params.model_dump() if hasattr(params, "model_dump") else dict(params) ) _normalize_model(job) - _embed_inline_if_local(job) + if not _backend_shares_fs(): + _embed_inline_if_local(job) task.callable = _mace_worker task.kwargs = {"job": job} elif task.callable is _mace_worker: job = dict(task.kwargs.get("job", {})) _normalize_model(job) - _embed_inline_if_local(job) + if not _backend_shares_fs(): + _embed_inline_if_local(job) task.kwargs = {"job": job} return task From 1c054360644a95f304be80573208780daf245be4 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 11:12:10 -0500 Subject: [PATCH 107/119] Drop full_output read-back from MACE worker The worker embedded the entire output JSON into the returned result as full_output when an inline structure was used. Results are already persisted to output_result_file, so this just bloated the tool response. Return only what run_mace_core produces; drop the now-unused json import. --- src/chemgraph/mcp/mace_mcp_hpc.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index eee47549..6e511443 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -84,7 +84,6 @@ def _mace_worker(job: dict) -> dict: attach transport keys ``inline_structure`` / ``remote_structure_file`` before submission. """ - import json import tempfile job = dict(job) @@ -122,15 +121,6 @@ def _mace_worker(job: dict) -> dict: params = mace_input_schema(**job) result = run_mace_core(params) - - # When inline, embed full output so the caller doesn't need to read - # a file on the remote filesystem to recover the results. - if inline is not None and isinstance(result, dict): - out_file = job.get("output_result_file", "") - if os.path.isfile(out_file): - with open(out_file) as fh: - result["full_output"] = json.load(fh) - return result From 60b4225f680b653e5bd67b75024a7bb89348bec0 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 11:12:10 -0500 Subject: [PATCH 108/119] Drop full_output read-back from MACE worker The worker embedded the entire output JSON into the returned result as full_output when an inline structure was used. Results are already persisted to output_result_file, so this just bloated the tool response. Return only what run_mace_core produces; drop the now-unused json import. --- src/chemgraph/mcp/mace_mcp_hpc.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py index afb41bf0..c9c9a860 100644 --- a/src/chemgraph/mcp/mace_mcp_hpc.py +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -83,7 +83,6 @@ def _mace_worker(job: dict) -> dict: attach transport keys ``inline_structure`` / ``remote_structure_file`` before submission. """ - import json import tempfile job = dict(job) @@ -121,15 +120,6 @@ def _mace_worker(job: dict) -> dict: params = mace_input_schema(**job) result = run_mace_core(params) - - # When inline, embed full output so the caller doesn't need to read - # a file on the remote filesystem to recover the results. - if inline is not None and isinstance(result, dict): - out_file = job.get("output_result_file", "") - if os.path.isfile(out_file): - with open(out_file) as fh: - result["full_output"] = json.load(fh) - return result From debdab42986607edc3f64caeada5e11c501c7654 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Wed, 17 Jun 2026 12:19:48 -0500 Subject: [PATCH 109/119] moved el orchestrator to a subprocess --- .../execution/ensemble_launcher_backend.py | 61 +++++++++++-------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 20336f3d..8afc80c7 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -13,6 +13,8 @@ import logging import os +import subprocess +import tempfile import time import uuid from concurrent.futures import Future @@ -163,8 +165,8 @@ def initialize( client_only: bool = False, checkpoint_dir: Optional[str] = None, node_id: str = "global", - system_config=None, - launcher_config=None, + system_config: Optional[SystemConfig] = None, + launcher_config: Optional[LauncherConfig] = None, startup_delay: float = 10.0, **kwargs, ) -> None: @@ -211,27 +213,41 @@ def initialize( "(or set client_only=True with a checkpoint_dir)." ) os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) - self._orchestrator = EnsembleLauncher( - ensemble_file={}, - system_config=system_config, - launcher_config=launcher_config, - ) - self._orchestrator.start() - time.sleep(startup_delay) + with tempfile.TemporaryDirectory() as tmp_dir: + launcher_config_fname = os.path.join(tmp_dir, "launcher_config.json") + with open(launcher_config_fname, "w") as f: + f.write(launcher_config.model_dump_json()) + system_config_fname = os.path.join(tmp_dir, "system_config.json") + with open(system_config_fname, "w") as f: + f.write(system_config.model_dump_json()) + cmd = [ + "el", + "start", + "--system-config-file", + f"{system_config_fname}", + "--launcher-config-file", + f"{launcher_config_fname}", + ] + self._orchestrator = subprocess.Popen( + cmd, + stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stdin=subprocess.DEVNULL, + ) + time.sleep(startup_delay) - self._client = ClusterClient( - checkpoint_dir=launcher_config.checkpoint_dir, - node_id=node_id, - ) - self._client.start() - self._initialized = True + self._client = ClusterClient( + checkpoint_dir=launcher_config.checkpoint_dir, + node_id=node_id, + ) + self._client.start() + self._initialized = True logger.info( "EnsembleLauncherBackend initialized in managed mode " "(system='%s', comm='%s', executor='%s', nodes=%s)", system_config.name, launcher_config.comm_name, launcher_config.task_executor_name, - len(self._orchestrator.nodes), ) def submit(self, task: TaskSpec) -> Future: @@ -293,14 +309,11 @@ def shutdown(self) -> None: orchestrator_ok = True if self._orchestrator is not None: try: - self._orchestrator.stop() - self._orchestrator = None - except Exception: - orchestrator_ok = False - logger.warning( - "Error stopping EnsembleLauncher orchestrator.", - exc_info=True, - ) + self._orchestrator.terminate() + self._orchestrator.wait(timeout=10.0) + finally: + if self._orchestrator.poll() is None: + self._orchestrator.kill() if client_ok and orchestrator_ok: logger.info("EnsembleLauncherBackend shut down.") From b43d390101ad4f230e981029429d5896ae0766e0 Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Wed, 17 Jun 2026 12:30:19 -0500 Subject: [PATCH 110/119] added logging in el backend --- src/chemgraph/execution/ensemble_launcher_backend.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 8afc80c7..fbb7dd3e 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -228,6 +228,7 @@ def initialize( "--launcher-config-file", f"{launcher_config_fname}", ] + logger.info(f"Executing {cmd}") self._orchestrator = subprocess.Popen( cmd, stderr=subprocess.DEVNULL, @@ -236,6 +237,12 @@ def initialize( ) time.sleep(startup_delay) + if self._orchestrator.poll() is not None: + logger.error( + f"Starting el failed with error code: {self._orchestrator.poll()}" + ) + raise RuntimeError() + self._client = ClusterClient( checkpoint_dir=launcher_config.checkpoint_dir, node_id=node_id, From 3feedcd7a937da24494f67c312294339885dc57b Mon Sep 17 00:00:00 2001 From: harikrishna1410 Date: Wed, 17 Jun 2026 18:28:11 +0000 Subject: [PATCH 111/119] added better cleanup of el subprocess --- .../demo_ensemble_launcher_in_job_agent.py | 2 +- scripts/demo/demo_parsl_in_job_agent.py | 20 +++++++++++++------ .../execution/ensemble_launcher_backend.py | 12 ++++++++++- src/chemgraph/tools/parsl_tools.py | 8 ++++++-- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/scripts/demo/demo_ensemble_launcher_in_job_agent.py b/scripts/demo/demo_ensemble_launcher_in_job_agent.py index 58b98418..9b339c8c 100644 --- a/scripts/demo/demo_ensemble_launcher_in_job_agent.py +++ b/scripts/demo/demo_ensemble_launcher_in_job_agent.py @@ -90,7 +90,7 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int, structured_output=False, return_option="state", tools=tools, - base_url="http://127.0.0.1:12985/argoapi/v1" + base_url="http://127.0.0.1:12986/argoapi/v1" ) print("Running agent...\n" + "=" * 60) diff --git a/scripts/demo/demo_parsl_in_job_agent.py b/scripts/demo/demo_parsl_in_job_agent.py index 3d05c650..5d21388b 100644 --- a/scripts/demo/demo_parsl_in_job_agent.py +++ b/scripts/demo/demo_parsl_in_job_agent.py @@ -81,13 +81,21 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int) tools = await load_mcp_tools(session) print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + #cg = ChemGraph( + # model_name=model, + # workflow_type="single_agent", + # structured_output=False, + # return_option="state", + # tools=tools, + #) cg = ChemGraph( - model_name=model, - workflow_type="single_agent", - structured_output=False, - return_option="state", - tools=tools, - ) + model_name="argo:gpt-5.4", + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + base_url="http://127.0.0.1:12986/argoapi/v1" + ) print("Running agent...\n" + "=" * 60) result = await cg.run(query) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index fbb7dd3e..26ba3bd7 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -17,6 +17,7 @@ import tempfile import time import uuid +import json from concurrent.futures import Future from typing import List, Literal, Optional, Union @@ -220,9 +221,13 @@ def initialize( system_config_fname = os.path.join(tmp_dir, "system_config.json") with open(system_config_fname, "w") as f: f.write(system_config.model_dump_json()) + ensemble_fname = os.path.join(tmp_dir,"ensemble_file.json") + with open(ensemble_fname, "w") as f: + json.dump({"ensembles":{}}, f) cmd = [ "el", "start", + ensemble_fname, "--system-config-file", f"{system_config_fname}", "--launcher-config-file", @@ -313,10 +318,15 @@ def shutdown(self) -> None: "Error tearing down EnsembleLauncher client.", exc_info=True ) + p = subprocess.Popen(["el","stop"]) + try: + p.wait(timeout=10.0) + except Exception: + pass + orchestrator_ok = True if self._orchestrator is not None: try: - self._orchestrator.terminate() self._orchestrator.wait(timeout=10.0) finally: if self._orchestrator.poll() is None: diff --git a/src/chemgraph/tools/parsl_tools.py b/src/chemgraph/tools/parsl_tools.py index 86e43a7f..6d8bbda7 100644 --- a/src/chemgraph/tools/parsl_tools.py +++ b/src/chemgraph/tools/parsl_tools.py @@ -85,5 +85,9 @@ def extract_output_json(json_file: str) -> dict: """Load simulation results from a JSON file produced by run_ase.""" import json - with open(json_file, "r") as f: - return json.load(f) + try: + with open(json_file, "r") as f: + ret = json.load(f) + except Exception as e: + ret = {} + return ret From a06952071d21b531ec8d0ece1b233876c7df6a71 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 23:05:57 -0500 Subject: [PATCH 112/119] Guard stdout during EnsembleLauncher teardown Under a stdio MCP server, the server's stdout is the JSON-RPC channel. EnsembleLauncher prints lifecycle notices ("Sent SIGTERM to launcher process ...") to stdout during orchestrator shutdown, which corrupted the protocol stream and crashed the client's message parser with a ValidationError / BrokenResourceError after an otherwise successful run. Add a fd-level stdout->stderr redirect context manager and wrap the client/orchestrator teardown calls in shutdown() with it, so the notices go to stderr instead of the JSON-RPC channel. fd-level dup2 (matching LocalBackend's worker-stdout guard) catches library/subprocess writes, not just Python-level sys.stdout. --- .../execution/ensemble_launcher_backend.py | 77 +++++++++++++------ 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 20336f3d..4262fa92 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -11,8 +11,10 @@ from __future__ import annotations +import contextlib import logging import os +import sys import time import uuid from concurrent.futures import Future @@ -53,6 +55,34 @@ def _require_ensemble_launcher() -> None: ) +@contextlib.contextmanager +def _stdout_to_stderr(): + """Redirect this process's stdout fd to stderr for the duration. + + EnsembleLauncher prints lifecycle notices (e.g. "Sent SIGTERM to + launcher process …") to stdout. Under a stdio MCP server stdout IS the + JSON-RPC channel, so those lines corrupt the protocol stream and crash + the client's message parser. Redirect at the fd level (not + ``contextlib.redirect_stdout``) so library and subprocess writes are + caught too, then restore. + """ + try: + saved_fd = os.dup(sys.stdout.fileno()) + except (OSError, ValueError, AttributeError): + # stdout is not a real fd (e.g. captured in tests/notebooks) -- + # nothing to guard. + yield + return + try: + sys.stdout.flush() + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + yield + finally: + sys.stdout.flush() + os.dup2(saved_fd, sys.stdout.fileno()) + os.close(saved_fd) + + def get_local_system_config(): _require_ensemble_launcher() system_config = SystemConfig( @@ -279,28 +309,31 @@ def submit(self, task: TaskSpec) -> Future: def shutdown(self) -> None: self._initialized = False - client_ok = True - if self._client is not None: - try: - self._client.teardown() - self._client = None - except Exception: - client_ok = False - logger.warning( - "Error tearing down EnsembleLauncher client.", exc_info=True - ) - - orchestrator_ok = True - if self._orchestrator is not None: - try: - self._orchestrator.stop() - self._orchestrator = None - except Exception: - orchestrator_ok = False - logger.warning( - "Error stopping EnsembleLauncher orchestrator.", - exc_info=True, - ) + # EnsembleLauncher prints teardown notices to stdout; guard the + # fd so they don't corrupt a stdio MCP server's JSON-RPC channel. + with _stdout_to_stderr(): + client_ok = True + if self._client is not None: + try: + self._client.teardown() + self._client = None + except Exception: + client_ok = False + logger.warning( + "Error tearing down EnsembleLauncher client.", exc_info=True + ) + + orchestrator_ok = True + if self._orchestrator is not None: + try: + self._orchestrator.stop() + self._orchestrator = None + except Exception: + orchestrator_ok = False + logger.warning( + "Error stopping EnsembleLauncher orchestrator.", + exc_info=True, + ) if client_ok and orchestrator_ok: logger.info("EnsembleLauncherBackend shut down.") From f639aa409d4a3c74a87c12bb7e7fad141e9888b3 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 23:08:44 -0500 Subject: [PATCH 113/119] Guard stdout during EnsembleLauncher teardown Layer an fd-level stdout->stderr redirect around shutdown() on top of the subprocess-based orchestrator rework. The orchestrator subprocess already redirects to DEVNULL, but two teardown paths can still print to the JSON-RPC stdout under a stdio MCP server: the in-process client.teardown() and the `el stop` helper (which inherits the parent's stdout). Wrapping the whole shutdown() in the fd guard covers both, preventing the ValidationError / BrokenResourceError crash after an otherwise successful run. --- .../execution/ensemble_launcher_backend.py | 79 +++++++++++++------ 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py index 26ba3bd7..14ec756f 100644 --- a/src/chemgraph/execution/ensemble_launcher_backend.py +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -11,9 +11,11 @@ from __future__ import annotations +import contextlib import logging import os import subprocess +import sys import tempfile import time import uuid @@ -56,6 +58,35 @@ def _require_ensemble_launcher() -> None: ) +@contextlib.contextmanager +def _stdout_to_stderr(): + """Redirect this process's stdout fd to stderr for the duration. + + EnsembleLauncher (and its ``el stop`` helper) prints lifecycle notices + such as "Sent SIGTERM to launcher process …" to stdout. Under a stdio + MCP server stdout IS the JSON-RPC channel, so those lines corrupt the + protocol stream and crash the client's message parser. Redirect at the + fd level (not ``contextlib.redirect_stdout``) so in-process, + library, and inherited-stdout subprocess writes are all caught, then + restore. + """ + try: + saved_fd = os.dup(sys.stdout.fileno()) + except (OSError, ValueError, AttributeError): + # stdout is not a real fd (e.g. captured in tests/notebooks) -- + # nothing to guard. + yield + return + try: + sys.stdout.flush() + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + yield + finally: + sys.stdout.flush() + os.dup2(saved_fd, sys.stdout.fileno()) + os.close(saved_fd) + + def get_local_system_config(): _require_ensemble_launcher() system_config = SystemConfig( @@ -307,30 +338,34 @@ def submit(self, task: TaskSpec) -> Future: def shutdown(self) -> None: self._initialized = False - client_ok = True - if self._client is not None: - try: - self._client.teardown() - self._client = None - except Exception: - client_ok = False - logger.warning( - "Error tearing down EnsembleLauncher client.", exc_info=True - ) - - p = subprocess.Popen(["el","stop"]) - try: - p.wait(timeout=10.0) - except Exception: - pass + # EnsembleLauncher (in-process teardown and the `el stop` helper) + # prints lifecycle notices to stdout; guard the fd so they don't + # corrupt a stdio MCP server's JSON-RPC channel. + with _stdout_to_stderr(): + client_ok = True + if self._client is not None: + try: + self._client.teardown() + self._client = None + except Exception: + client_ok = False + logger.warning( + "Error tearing down EnsembleLauncher client.", exc_info=True + ) - orchestrator_ok = True - if self._orchestrator is not None: + p = subprocess.Popen(["el", "stop"]) try: - self._orchestrator.wait(timeout=10.0) - finally: - if self._orchestrator.poll() is None: - self._orchestrator.kill() + p.wait(timeout=10.0) + except Exception: + pass + + orchestrator_ok = True + if self._orchestrator is not None: + try: + self._orchestrator.wait(timeout=10.0) + finally: + if self._orchestrator.poll() is None: + self._orchestrator.kill() if client_ok and orchestrator_ok: logger.info("EnsembleLauncherBackend shut down.") From 932e4c8fc4bdfdbc586936e29f3e99af1ce42811 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 23:35:00 -0500 Subject: [PATCH 114/119] Revert hardcoded argo model/base_url in Parsl agent demo demo_parsl_in_job_agent.py forced model_name="argo:gpt-5.4" and a local argoapi base_url, mirroring the temp config that had also landed in demo_ensemble_launcher_in_job_agent.py. Restore model selection from the --model flag (the amain(model=...) parameter) and drop the hardcoded base_url so the demo honours the user's chosen model. --- scripts/demo/demo_parsl_in_job_agent.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/scripts/demo/demo_parsl_in_job_agent.py b/scripts/demo/demo_parsl_in_job_agent.py index 5d21388b..3d05c650 100644 --- a/scripts/demo/demo_parsl_in_job_agent.py +++ b/scripts/demo/demo_parsl_in_job_agent.py @@ -81,21 +81,13 @@ async def amain(model: str, system: str, device: str, query: str, verbose: int) tools = await load_mcp_tools(session) print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") - #cg = ChemGraph( - # model_name=model, - # workflow_type="single_agent", - # structured_output=False, - # return_option="state", - # tools=tools, - #) cg = ChemGraph( - model_name="argo:gpt-5.4", - workflow_type="single_agent", - structured_output=False, - return_option="state", - tools=tools, - base_url="http://127.0.0.1:12986/argoapi/v1" - ) + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) print("Running agent...\n" + "=" * 60) result = await cg.run(query) From f3d036b20ea661a2bafb6527aff25d22a7b0c165 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 23:37:40 -0500 Subject: [PATCH 115/119] Remove machine-specific Crux SystemConfig tests TestELSystemConfigCrux asserted the exact Crux SystemConfig shape (ncpus==128, CPU-only) and the registry membership of "crux". These are tied to one specific machine's hardware layout and don't belong in the portable unit-test suite. The polaris references elsewhere are left as-is since they only pass "polaris" as an arbitrary system string to exercise generic GlobusCompute behaviour. --- tests/test_execution.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/tests/test_execution.py b/tests/test_execution.py index dd52f415..c662547c 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -412,31 +412,6 @@ def test_shell_task_missing_command(self): backend.shutdown() -class TestELSystemConfigCrux: - """EnsembleLauncher SystemConfig builder for Crux (CPU-only).""" - - def test_crux_in_registry(self): - from chemgraph.execution.ensemble_launcher_backend import ( - SYSTEM_CONFIG_REGISTRY, - ) - - assert "crux" in SYSTEM_CONFIG_REGISTRY - - def test_crux_system_config_cpu_only(self): - pytest.importorskip("ensemble_launcher") - from chemgraph.execution.ensemble_launcher_backend import ( - get_crux_system_config, - ) - - cfg = get_crux_system_config() - assert cfg.name == "crux" - assert cfg.ncpus == 128 - assert len(cfg.cpus) == 128 - # CPU-only: ngpus / gpus must not be populated - assert getattr(cfg, "ngpus", None) in (None, 0) - assert not getattr(cfg, "gpus", None) - - # ── GlobusComputeBackend tests ────────────────────────────────────────── From 3434915dff1e6867f2748f81b4c80dc90daeee28 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Wed, 17 Jun 2026 23:39:13 -0500 Subject: [PATCH 116/119] Skip TestELBackend when ensemble_launcher is unavailable EnsembleLauncher is an optional, HPC-only dependency (not on PyPI for Python 3.12), so instantiating EnsembleLauncherBackend() raised ImportError and hard-failed all nine TestELBackend cases in any env without it. Add pytest.importorskip("ensemble_launcher") in setup_class so the class skips cleanly, matching the guard already used by the GlobusCompute tests. --- tests/test_execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_execution.py b/tests/test_execution.py index c662547c..05f3d355 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -204,6 +204,9 @@ def test_shell_task_missing_command(self): class TestELBackend: @classmethod def setup_class(cls): + # EnsembleLauncher is an optional, HPC-only dependency (not on PyPI + # for Python 3.12). Skip the whole class where it isn't installed. + pytest.importorskip("ensemble_launcher") project_root = str(Path(__file__).resolve().parent.parent) existing = os.environ.get("PYTHONPATH", "") os.environ["PYTHONPATH"] = ( From b5cac8b5e7a41fd91ab88ab26a52393869f61a22 Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 18 Jun 2026 10:27:05 -0500 Subject: [PATCH 117/119] fix(academy): honour [academy] optional-dep contract via lazy re-exports The chemgraph.academy package eagerly imported ChemGraphLogicalAgent, which in turn imports academy.agent. Any consumer that touched the package -- including chemgraph.cli.trace's --trace-dir code path and pytest collection -- crashed when the optional [academy] extra was not installed, despite the rest of the package (campaign spec, prompt profile, event log) being pure stdlib + pydantic. Fix in two layers: * src/chemgraph/academy/__init__.py and src/chemgraph/academy/core/__init__.py both grow a __getattr__- based lazy resolver for ChemGraphLogicalAgent. The pure modules stay eager so their public surface (~11 symbols) is reachable on a CPU-only checkout. Accessing the lazy symbol without the extra installed raises ImportError with an actionable `pip install 'chemgraph[academy]'` hint. * tests/test_academy_*.py and tests/test_tool_adapter_validation.py (9 modules total) gain `pytest.importorskip("academy")` at the top, so they skip cleanly when the extra is absent instead of erroring at collection time. Validated: * Without academy-py: chemgraph.academy imports, all eager symbols resolve, cli.trace loads, lazy ChemGraphLogicalAgent access raises with hint, all 9 academy test modules skip cleanly. * With academy-py: 48/48 academy tests pass. Two-tier __init__ was necessary because pulling chemgraph.academy.core.campaign transitively runs core/__init__.py, which previously eager-imported core.agent. Both __init__s now follow the same lazy pattern. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/chemgraph/academy/__init__.py | 71 +++++++++++++++++++-- src/chemgraph/academy/core/__init__.py | 55 +++++++++++++++- tests/test_academy_campaign.py | 7 ++ tests/test_academy_compute_launcher.py | 6 ++ tests/test_academy_dashboard.py | 7 ++ tests/test_academy_dashboard_launcher.py | 3 + tests/test_academy_exchange_registration.py | 5 ++ tests/test_academy_mcp_supervisor.py | 4 ++ tests/test_academy_payloads.py | 7 ++ tests/test_academy_reasoning_phase2.py | 3 + tests/test_tool_adapter_validation.py | 4 ++ 11 files changed, 163 insertions(+), 9 deletions(-) diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py index fd0a6c4c..46c0964b 100644 --- a/src/chemgraph/academy/__init__.py +++ b/src/chemgraph/academy/__init__.py @@ -1,14 +1,33 @@ """Academy Agents integration for ChemGraph. -Provides agent classes and utilities for deploying ChemGraph workflows -across federated HPC infrastructure using the Academy framework. +Public re-exports come in two tiers so the package honours the +``[academy]`` optional-dep contract: -Requires the ``academy`` optional extra. +* **Eager** (pure stdlib + pydantic, always importable): + ``ChemGraphAgentSpec``, ``ChemGraphCampaign``, + ``ChemGraphDaemonConfig``, ``MCPServerSpec``, ``ResourceSpec``, + ``load_campaign``, ``resolve_campaign_resources``, + ``PromptProfile``, ``load_prompt_profile``, ``CampaignEvent``, + ``EventLog``. These let the dashboard, ``--trace-dir``, and the + observability tooling work on a checkout without ``academy-py`` + installed. +* **Lazy** (resolved via ``__getattr__`` on first access; requires + the ``[academy]`` extra): ``ChemGraphLogicalAgent``. Importing it + pulls in ``academy.agent``; without the extra installed, access + raises ``ImportError`` with a hint instead of crashing the package + import. + +This split exists because ``chemgraph.cli.trace`` (single-agent +``--trace-dir`` flow) and the test collector both touch +``chemgraph.academy`` via leaf submodules; eager-importing the +academy-py-dependent ``ChemGraphLogicalAgent`` here broke those code +paths for users without the optional extra. """ from __future__ import annotations -from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from typing import TYPE_CHECKING, Any + from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig @@ -16,10 +35,48 @@ from chemgraph.academy.core.campaign import ResourceSpec from chemgraph.academy.core.campaign import load_campaign from chemgraph.academy.core.campaign import resolve_campaign_resources -from chemgraph.academy.observability.event_log import CampaignEvent -from chemgraph.academy.observability.event_log import EventLog from chemgraph.academy.core.prompt import PromptProfile from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import EventLog + + +if TYPE_CHECKING: + from chemgraph.academy.core.agent import ChemGraphLogicalAgent + + +_LAZY_EXPORTS: dict[str, tuple[str, str]] = { + # public name -> (module path, attribute in that module) + "ChemGraphLogicalAgent": ( + "chemgraph.academy.core.agent", + "ChemGraphLogicalAgent", + ), +} + + +def __getattr__(name: str) -> Any: + """Lazy resolver for academy-py-dependent re-exports. + + Called by Python only when ``name`` is not found among the eager + imports above. On ``ImportError`` we re-raise with an actionable + hint so the operator knows which extra to install. + """ + if name in _LAZY_EXPORTS: + module_path, attr = _LAZY_EXPORTS[name] + try: + from importlib import import_module + module = import_module(module_path) + except ImportError as exc: + raise ImportError( + f"Importing {name!r} from chemgraph.academy requires " + f"the 'academy' optional extra: " + f"`pip install 'chemgraph[academy]'`. " + f"Underlying error: {exc}" + ) from exc + return getattr(module, attr) + raise AttributeError( + f"module 'chemgraph.academy' has no attribute {name!r}" + ) __all__ = [ @@ -27,11 +84,11 @@ "ChemGraphAgentSpec", "ChemGraphCampaign", "ChemGraphDaemonConfig", + "ChemGraphLogicalAgent", "EventLog", "MCPServerSpec", "PromptProfile", "ResourceSpec", - "ChemGraphLogicalAgent", "load_campaign", "load_prompt_profile", "resolve_campaign_resources", diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py index b47c1c18..5f7248dd 100644 --- a/src/chemgraph/academy/core/__init__.py +++ b/src/chemgraph/academy/core/__init__.py @@ -1,6 +1,25 @@ -"""Core ChemGraph Academy campaign contracts and agent logic.""" +"""Core ChemGraph Academy campaign contracts and agent logic. + +Re-exports split into two tiers to keep the ``[academy]`` optional-dep +contract: + +* **Eager** (pure stdlib + pydantic + langchain_core): the campaign + spec types, prompt profile, and reasoning-turn helpers. These are + what the dashboard, ``--trace-dir``, and the test collector touch + on a CPU-only checkout. +* **Lazy** (resolved via ``__getattr__``; requires the ``[academy]`` + extra because it depends on ``academy.agent.Agent``): + ``ChemGraphLogicalAgent``. + +Without this split, importing ``chemgraph.academy.core.campaign`` +would transitively run ``core/__init__.py`` and pull in +``core.agent``, which fails when ``academy-py`` is not installed. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any -from chemgraph.academy.core.agent import ChemGraphLogicalAgent from chemgraph.academy.core.campaign import ChemGraphAgentSpec from chemgraph.academy.core.campaign import ChemGraphCampaign from chemgraph.academy.core.campaign import ChemGraphDaemonConfig @@ -13,6 +32,38 @@ from chemgraph.academy.core.turn import ReasoningTurnResult from chemgraph.academy.core.turn import run_academy_turn + +if TYPE_CHECKING: + from chemgraph.academy.core.agent import ChemGraphLogicalAgent + + +_LAZY_EXPORTS: dict[str, tuple[str, str]] = { + "ChemGraphLogicalAgent": ( + "chemgraph.academy.core.agent", + "ChemGraphLogicalAgent", + ), +} + + +def __getattr__(name: str) -> Any: + if name in _LAZY_EXPORTS: + module_path, attr = _LAZY_EXPORTS[name] + try: + from importlib import import_module + module = import_module(module_path) + except ImportError as exc: + raise ImportError( + f"Importing {name!r} from chemgraph.academy.core requires " + f"the 'academy' optional extra: " + f"`pip install 'chemgraph[academy]'`. " + f"Underlying error: {exc}" + ) from exc + return getattr(module, attr) + raise AttributeError( + f"module 'chemgraph.academy.core' has no attribute {name!r}" + ) + + __all__ = [ "ChemGraphAgentSpec", "ChemGraphCampaign", diff --git a/tests/test_academy_campaign.py b/tests/test_academy_campaign.py index ec102fb7..1fa2956e 100644 --- a/tests/test_academy_campaign.py +++ b/tests/test_academy_campaign.py @@ -4,6 +4,13 @@ import pytest +# Skip the whole module when the optional 'academy' extra is absent. +# Even though this file only touches the pure-stdlib parts of +# chemgraph.academy, the import guard is applied uniformly across the +# academy test suite so pytest collection stays clean on a CPU-only +# checkout without per-test bookkeeping. +pytest.importorskip("academy") + from chemgraph.academy.core.campaign import campaign_bootstrap_text from chemgraph.academy.core.campaign import load_campaign from chemgraph.academy.core.campaign import MCPServerSpec diff --git a/tests/test_academy_compute_launcher.py b/tests/test_academy_compute_launcher.py index 20b04ea8..d73098f3 100644 --- a/tests/test_academy_compute_launcher.py +++ b/tests/test_academy_compute_launcher.py @@ -2,6 +2,12 @@ from pathlib import Path +import pytest + +# Skip when the optional 'academy' extra is absent; the runtime +# subpackage imports academy.* at module level. +pytest.importorskip("academy") + from chemgraph.academy.runtime import compute_launcher from chemgraph.academy.runtime.compute_launcher import AllocationPlan diff --git a/tests/test_academy_dashboard.py b/tests/test_academy_dashboard.py index 6f62f37b..0abec32c 100644 --- a/tests/test_academy_dashboard.py +++ b/tests/test_academy_dashboard.py @@ -2,6 +2,13 @@ import json +import pytest + +# Skip when the optional 'academy' extra is absent. The dashboard +# module itself is pure stdlib, but the import guard is applied +# uniformly across the academy test suite. +pytest.importorskip("academy") + import chemgraph.academy.dashboard as dashboard from chemgraph.academy.observability.event_log import EventLog diff --git a/tests/test_academy_dashboard_launcher.py b/tests/test_academy_dashboard_launcher.py index 7836ef03..51dc1a96 100644 --- a/tests/test_academy_dashboard_launcher.py +++ b/tests/test_academy_dashboard_launcher.py @@ -7,6 +7,9 @@ import pytest +# Skip when the optional 'academy' extra is absent. +pytest.importorskip("academy") + from chemgraph.academy.runtime import dashboard_launcher from chemgraph.academy.runtime.profiles.system import SystemProfile diff --git a/tests/test_academy_exchange_registration.py b/tests/test_academy_exchange_registration.py index 39aa1509..0c1f95cb 100644 --- a/tests/test_academy_exchange_registration.py +++ b/tests/test_academy_exchange_registration.py @@ -3,6 +3,11 @@ from pathlib import Path import pytest + +# Skip when the optional 'academy' extra is absent; this module +# imports academy.exchange.* directly at top level. +pytest.importorskip("academy") + from academy.exchange.hybrid import HybridAgentRegistration from academy.exchange.local import LocalAgentRegistration from academy.exchange.redis import RedisAgentRegistration diff --git a/tests/test_academy_mcp_supervisor.py b/tests/test_academy_mcp_supervisor.py index 10e0c6fe..d42920a4 100644 --- a/tests/test_academy_mcp_supervisor.py +++ b/tests/test_academy_mcp_supervisor.py @@ -6,6 +6,10 @@ import pytest +# Skip when the optional 'academy' extra is absent; mcp_supervisor +# imports httpx (also in the extra) at module level. +pytest.importorskip("academy") + from chemgraph.academy.core.campaign import MCPServerSpec from chemgraph.academy.runtime.mcp_supervisor import MCPServerSupervisor diff --git a/tests/test_academy_payloads.py b/tests/test_academy_payloads.py index 8e9f82c8..b8c1d04c 100644 --- a/tests/test_academy_payloads.py +++ b/tests/test_academy_payloads.py @@ -1,5 +1,12 @@ from __future__ import annotations +import pytest + +# Skip when the optional 'academy' extra is absent. The event_log +# module itself is pure stdlib, but the import guard is applied +# uniformly across the academy test suite. +pytest.importorskip("academy") + from chemgraph.academy.observability.event_log import EventLog, read_events diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py index bbe13a50..1798679c 100644 --- a/tests/test_academy_reasoning_phase2.py +++ b/tests/test_academy_reasoning_phase2.py @@ -8,6 +8,9 @@ import pytest +# Skip when the optional 'academy' extra is absent. +pytest.importorskip("academy") + from chemgraph.academy.core import agent as agent_module from chemgraph.academy.core import turn as turn_module from chemgraph.academy.core.agent import ChemGraphLogicalAgent diff --git a/tests/test_tool_adapter_validation.py b/tests/test_tool_adapter_validation.py index 47dfed6a..8abf5a69 100644 --- a/tests/test_tool_adapter_validation.py +++ b/tests/test_tool_adapter_validation.py @@ -5,6 +5,10 @@ import pytest +# Skip when the optional 'academy' extra is absent; core.tools imports +# academy.agent at module level. +pytest.importorskip("academy") + from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools from chemgraph.academy.core.campaign import ChemGraphAgentSpec From d293cea9e559930cc68c77809c6dc0024470b98b Mon Sep 17 00:00:00 2001 From: Jinchu Li Date: Thu, 18 Jun 2026 10:27:20 -0500 Subject: [PATCH 118/119] fix(agent): restore calculator-context wiring lost in revert f1593ab f1593ab ("revert: restore llm_agent.py to pre-academy shape") was meant to drop temporary event-callback wiring and turn-primitive imports, but its rewrite also dropped the calculator-availability injection that was on the file at the dev-globus fork point. The deletion was a merge-conflict casualty, not an intentional removal: get_calculator_selection_context() in schemas/ase_input.py:145 was left as dead code, and the two regression tests test_single_agent_initialization_injects_calculator_availability (in tests/test_graph_constructors.py and tests/test_graphs.py) have been failing on dev-globus ever since. Restore the 28-line wiring block verbatim from origin/dev:src/chemgraph/agent/llm_agent.py:464-490, plus the matching `from chemgraph.schemas.ase_input import get_available_calculator_names, get_calculator_selection_context, get_default_calculator_name` import at the top. Position in __init__: after the human-supervised system_prompt adjustment (line 324) and before the structured-output gate (line 326). All three instance attributes the block references (self.workflow_type, self.planner_prompt, self.executor_prompt) are set earlier in __init__ (lines 299, 308, 309), so the wiring order is safe. Both calculator-injection regression tests now pass; the 48-test academy suite is unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/chemgraph/agent/llm_agent.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index f736889d..f047cb14 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -24,6 +24,11 @@ ) +from chemgraph.schemas.ase_input import ( + get_available_calculator_names, + get_calculator_selection_context, + get_default_calculator_name, +) from chemgraph.prompt.single_agent_prompt import ( single_agent_prompt, get_single_agent_prompt, @@ -318,6 +323,33 @@ def __init__( if not self.human_supervised and self.system_prompt == single_agent_prompt: self.system_prompt = get_single_agent_prompt(human_supervised=False) + self.available_calculators = get_available_calculator_names() + self.default_calculator = get_default_calculator_name() + self.calculator_selection_context = get_calculator_selection_context() + + def append_calculator_context(prompt: str) -> str: + """Append calculator availability guidance to a prompt once. + + Parameters + ---------- + prompt : str + Prompt text to augment. + + Returns + ------- + str + Prompt with calculator-selection context appended. + """ + if self.calculator_selection_context in prompt: + return prompt + return f"{prompt}{self.calculator_selection_context}" + + if self.workflow_type in {"single_agent", "mock_agent", "single_agent_mcp"}: + self.system_prompt = append_calculator_context(self.system_prompt) + elif self.workflow_type == "multi_agent": + self.planner_prompt = append_calculator_context(self.planner_prompt) + self.executor_prompt = append_calculator_context(self.executor_prompt) + if model_name in supported_argo_models: self.support_structured_output = False else: From 893285d55a6c7f3be236e97b20d84d751a099738 Mon Sep 17 00:00:00 2001 From: Thang Pham Date: Thu, 18 Jun 2026 07:57:35 -0500 Subject: [PATCH 119/119] Remove dead single_agent_architector workflow and fix pre-merge blockers Prepare PR #120 for merge by dropping a non-functional workflow and fixing event/HPC-pickle regressions. The academy optional-dependency and calculator-context fixes originally bundled here are already on dev-globus (b5cac8b, d293cea), so this commit no longer carries them. - Remove single_agent_architector workflow: it was not registered in workflow_map and required a tools module that does not exist. Drop the graph, the CLI entry, and the related test parameters (including a stale graspa_agent case). - events: only emit llm_decision when the model requested tool calls; a plain text answer has no decision to report. - mcp: add explicit pickle-by-reference fix for the gRASPA _ls_remote_files TaskSpec callable; document that decorator-registered worker callables are already fixed by schema_fanout_tool. - tests: update the MACE worker test for the dropped full_output field; note in conftest that academy-only modules self-skip via importorskip. --- src/chemgraph/agent/events.py | 9 +- src/chemgraph/cli/commands.py | 1 - .../graphs/single_agent_architector.py | 162 ------------------ src/chemgraph/mcp/graspa_mcp_hpc.py | 10 ++ src/chemgraph/mcp/xanes_mcp_hpc.py | 4 + tests/conftest.py | 4 + tests/test_graphs.py | 2 - tests/test_mcp.py | 4 +- 8 files changed, 26 insertions(+), 170 deletions(-) delete mode 100644 src/chemgraph/graphs/single_agent_architector.py diff --git a/src/chemgraph/agent/events.py b/src/chemgraph/agent/events.py index 88090877..1c3c2bf7 100644 --- a/src/chemgraph/agent/events.py +++ b/src/chemgraph/agent/events.py @@ -67,10 +67,11 @@ def on_llm_end(self, response, **kwargs) -> None: if isinstance(usage, dict): payload["llm_output"] = usage self._emit("llm_call_finished", payload) - self._emit( - "llm_decision", - {"tool_calls": _response_tool_calls(response) or []}, - ) + # Only surface an llm_decision when the model actually requested tool + # calls; a plain text answer has no decision to report. + tool_calls = _response_tool_calls(response) + if tool_calls: + self._emit("llm_decision", {"tool_calls": tool_calls}) def on_llm_error(self, error, **kwargs) -> None: self._emit("llm_call_failed", {"error": repr(error)}) diff --git a/src/chemgraph/cli/commands.py b/src/chemgraph/cli/commands.py index ad351e61..70934046 100644 --- a/src/chemgraph/cli/commands.py +++ b/src/chemgraph/cli/commands.py @@ -48,7 +48,6 @@ "graspa_mcp", "rag_agent", "single_agent_xanes", - "single_agent_architector", ] # Common aliases so users can type the "obvious" name. diff --git a/src/chemgraph/graphs/single_agent_architector.py b/src/chemgraph/graphs/single_agent_architector.py deleted file mode 100644 index 758ed5e6..00000000 --- a/src/chemgraph/graphs/single_agent_architector.py +++ /dev/null @@ -1,162 +0,0 @@ -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_coordinate_file, -) - -try: - from chemgraph.tools.architector_tools import ( - visualize_molecule, - image_to_connection_points, - build_metal_complex, - ) -except ModuleNotFoundError: - def _missing_architector_tool(*_args, **_kwargs): - raise ImportError( - "single_agent_architector requires chemgraph.tools.architector_tools, " - "which is not available in this installation." - ) - - def visualize_molecule(smiles: str) -> str: - """Visualize a molecule for Architector workflows.""" - return _missing_architector_tool(smiles) - - def image_to_connection_points(image_path: str) -> str: - """Extract connection points from an image for Architector workflows.""" - return _missing_architector_tool(image_path) - - def build_metal_complex(specification: str) -> str: - """Build a metal complex for Architector workflows.""" - return _missing_architector_tool(specification) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - -single_agent_prompt = "" - -def route_tools(state: State): - """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - - Returns - ------- - str - Either 'tools' or 'done' based on the state conditions - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "done" - - -def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - - # Load default tools if no tool is specified. - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - visualize_molecule, - image_to_connection_points, - build_metal_complex - ] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - -def construct_single_agent_architector_graph( - llm: ChatOpenAI, - system_prompt: str = "", - tools: list = None, -): - """Construct a geometry optimization graph. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, by default single_agent_prompt - structured_output : bool, optional - Whether to use structured output, by default False - formatter_prompt : str, optional - The prompt to guide the LLM's formatting behavior, by default formatter_prompt - generate_report: bool, optional - Whether to generate a report, by default False - report_prompt: str, optional - The prompt to guide the LLM's report generation behavior, by default report_prompt - tool: list, optional - The list of tools for the main agent, by default None - Returns - ------- - StateGraph - The constructed single agent graph - """ - try: - logger.info("Constructing single agent graph") - checkpointer = MemorySaver() - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - visualize_molecule, - image_to_connection_points, - build_metal_complex - ] - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent(state, llm, system_prompt=system_prompt, tools=tools), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_edge(START, "ChemGraphAgent") - graph_builder.add_conditional_edges( - "ChemGraphAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "ChemGraphAgent") - graph_builder.add_edge("ChemGraphAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("Graph construction completed") - return graph - except Exception as e: - logger.error(f"Error constructing graph: {str(e)}") - raise diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py index be7737a6..af33f792 100644 --- a/src/chemgraph/mcp/graspa_mcp_hpc.py +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -107,6 +107,10 @@ def _graspa_worker(job: dict) -> dict: } +# Note: ``_graspa_worker`` is registered via ``@mcp.schema_fanout_tool`` below, +# which fixes its module for pickling automatically; no explicit fix is needed. + + # ── Ensemble fanout ──────────────────────────────────────────────────── @@ -117,6 +121,12 @@ def _ls_remote_files(path: str) -> list[str]: ) +# Submitted as a bare ``callable=`` TaskSpec (not via a decorator), so it must +# be fixed explicitly for pickle-by-reference when run as ``__main__``. Mirrors +# the equivalent fix in mace_mcp_hpc.py. +CGFastMCP._fix_module_for_pickle(_ls_remote_files) + + def _expand_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: """Server-side expansion of an ensemble request into per-job dicts. diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py index 8583ae65..0c3008b5 100644 --- a/src/chemgraph/mcp/xanes_mcp_hpc.py +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -156,6 +156,10 @@ def _xanes_ensemble_worker(item: dict) -> dict: } +# Note: ``_xanes_ensemble_worker`` is registered via ``@mcp.schema_fanout_tool`` +# below, which fixes its module for pickling automatically. + + def _expand_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: """Server-side expansion: prepare per-structure run dirs and return one item per structure for the worker to execute.""" diff --git a/tests/conftest.py b/tests/conftest.py index 0de3313d..76b425d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,10 @@ # Configure pytest-asyncio #pytest_plugins = ("pytest_asyncio",) +# Test modules that require the optional ``academy`` extra guard themselves with +# ``pytest.importorskip("academy")`` at module top, so they skip cleanly (rather +# than erroring collection) when the extra is not installed. + @pytest.fixture(autouse=True) def setup_test_env(): diff --git a/tests/test_graphs.py b/tests/test_graphs.py index d0ef2d53..f3a8fca6 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -37,7 +37,6 @@ def get_state(self, config): ("multi_agent", "construct_multi_agent_graph", {}), ("python_relp", "construct_relp_graph", {}), ("graspa", "construct_graspa_graph", {}), - ("graspa_agent", "construct_graspa_graph", {}), ("mock_agent", "construct_mock_agent_graph", {}), ( "single_agent_mcp", @@ -51,7 +50,6 @@ def get_state(self, config): ), ("rag_agent", "construct_rag_agent_graph", {}), ("single_agent_xanes", "construct_single_agent_xanes_graph", {}), - ("single_agent_architector", "construct_single_agent_architector_graph", {}), ], ) def test_graph_constructor_is_called( diff --git a/tests/test_mcp.py b/tests/test_mcp.py index ff3dbfbe..86d32558 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -72,8 +72,10 @@ def fake_run_mace_core(params): } ) + # The worker returns run_mace_core's result verbatim; full_output read-back + # was intentionally dropped. The inline output parent dir is asserted inside + # fake_run_mace_core above. assert result["status"] == "success" - assert result["full_output"] == {"ok": True} def test_run_ase_core_creates_output_parent_directory(monkeypatch, tmp_path):