diff --git a/.gitignore b/.gitignore index 307e2461..91e70e1a 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,8 @@ opencode.json chemgraph_mcp_logs/ vllm/ logs/ +runs/ +**/*.model error_log.txt .env test.csv diff --git a/examples/academy/example-002-mace-ensemble-screening/README.md b/examples/academy/example-002-mace-ensemble-screening/README.md new file mode 100644 index 00000000..bd6f7104 --- /dev/null +++ b/examples/academy/example-002-mace-ensemble-screening/README.md @@ -0,0 +1,35 @@ +# Example 002: MACE Ensemble Screening + +This example demonstrates five persistent ChemGraph Academy logical agents +running under MPI: + +```text +coordinator-agent +structure-agent-a +structure-agent-b +mace-agent +assessment-agent +``` + +The coordinator delegates 20 SMILES candidates, structure agents generate XYZ +files, the MACE agent runs an ensemble energy screen, and the assessment agent +summarizes readiness/ranking evidence. + +The campaign assets are packaged under: + +```text +src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/ +``` + +Run it by campaign name: + +```bash +chemgraph academy run-compute \ + --system aurora \ + --run-id aurora-mace-ensemble-screening-001 \ + --campaign mace-ensemble-screening-20 \ + --lm-user +``` + +See `notes.md` for the high-level architecture notes. The internal E2E user +guide is intentionally not stored in this public example directory. diff --git a/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md new file mode 100644 index 00000000..08031016 --- /dev/null +++ b/examples/academy/example-002-mace-ensemble-screening/e2e_guide.md @@ -0,0 +1,316 @@ +# Example 002 E2E Guide + +This guide runs the `mace-ensemble-screening-20` ChemGraph Academy campaign on +Aurora or Polaris. The campaign starts five persistent logical agents under MPI: + +```text +coordinator-agent +structure-agent-a +structure-agent-b +mace-agent +assessment-agent +``` + +The coordinator delegates 20 SMILES candidates, structure agents generate XYZ +files, the MACE agent runs an ensemble energy screen, and the assessment agent +summarizes readiness/ranking evidence. + +## About The MACE Path + +This example deliberately runs MACE through the general `run_ase` tool +(`chemgraph.mcp.mcp_tools`), which executes MACE in-process inside the MCP +server. It does **not** exercise `chemgraph.mcp.mace_mcp_hpc` or the +Parsl/EnsembleLauncher/Globus Compute backends — those are being reworked in +a separate PR. Once that lands and the WorkerLost subprocess fix is folded +back in, this example can be switched back to the HPC MACE path. + +In-process MACE means each per-structure energy evaluation runs synchronously +in the mace-agent's MCP server process. A 20-structure screen completes in +a few minutes on CPU. + +## Configure Paths + +Set these values in each terminal before copying the commands below: + +```bash +export ALCF_PROJECT= +export ALCF_USER= +export ALCF_LOGIN= +export ARGO_USER= + +export LOCAL_CHEMGRAPH= +``` + +For Aurora: + +```bash +export ALCF_SYSTEM=aurora +export ALCF_HOST=aurora.alcf.anl.gov +export REMOTE_ROOT=/flare/$ALCF_PROJECT/$ALCF_USER +``` + +For Polaris: + +```bash +export ALCF_SYSTEM=polaris +export ALCF_HOST=polaris.alcf.anl.gov +export REMOTE_ROOT=/eagle/$ALCF_PROJECT/$ALCF_USER +``` + +`ALCF_USER` is the shared-filesystem path component. It may differ from the SSH +login and from the Argo user. + +## One-Time Setup + +Sync ChemGraph: + +```bash +cd "$LOCAL_CHEMGRAPH" + +rsync -az --delete --delete-excluded \ + --exclude '.git/' \ + --exclude '__pycache__/' \ + --exclude '.pytest_cache/' \ + --exclude 'runs/' \ + --exclude 'venvs/' \ + --exclude '*.pyc' \ + ./ \ + "$ALCF_LOGIN@$ALCF_HOST:$REMOTE_ROOT/ChemGraph/" +``` + +Install ChemGraph dependencies on the remote system: + +```bash +ssh "$ALCF_LOGIN@$ALCF_HOST" +cd "$REMOTE_ROOT/ChemGraph" + +# Aurora: +module load frameworks + +# Polaris: +# module use /soft/modulefiles +# module load conda +# conda activate base + +source "$REMOTE_ROOT/venvs/academy-swarm/bin/activate" +python -m pip install -e ".[academy]" +``` + +Verify the campaign is visible: + +```bash +PYTHONDONTWRITEBYTECODE=1 PYTHONPATH=src \ +python -m chemgraph.cli.main academy campaigns +``` + +Expected: + +```text +mace-ensemble-screening-20 +``` + +Verify Redis: + +```bash +export PATH="$REMOTE_ROOT/tools/redis/bin:$PATH" +command -v redis-server +redis-server --version +``` + +If Redis is missing, build it once on a login/UAN node: + +```bash +cd "$REMOTE_ROOT" +mkdir -p src tools +cd src +test -d redis || git clone --depth 1 https://github.com/redis/redis.git +cd redis +make -j4 +make PREFIX="$REMOTE_ROOT/tools/redis" install +``` + +The `mace_mp` calculator downloads its foundation model on first use into +`~/.cache/mace`, so no manual MACE-model staging is needed for this example. +First-call download can take a minute; pre-warm it once on the compute node +to skip that wait at run time. The compute node only reaches external sites +through the ALCF outbound proxy, so set the proxy env vars first: + +```bash +export http_proxy="http://proxy.alcf.anl.gov:3128" +export https_proxy="http://proxy.alcf.anl.gov:3128" +python -c "from mace.calculators import mace_mp; mace_mp(model='medium-mpa-0', device='cpu')" +``` + +## Start argo-shim + +On the local machine: + +```bash +CELS_USERNAME="$ARGO_USER" \ +PYTHONPATH= \ +python -m argo_shim --no-auth --no-update-settings --port 18085 +``` + +## Start Dashboard + +Use a fresh run id: + +```bash +cd "$LOCAL_CHEMGRAPH" + +export RUN_ID="${ALCF_SYSTEM}-mace-ensemble-screening-001" + +PYTHONPATH=src python -m chemgraph.cli.main academy dashboard -- \ + --system "$ALCF_SYSTEM" \ + --remote-host "$ALCF_LOGIN@$ALCF_HOST" \ + --campaign mace-ensemble-screening-20 \ + --lm-connect mac-argo-relay \ + "$RUN_ID" +``` + +The dashboard command starts the local dashboard, an rsync mirror, an SSH +control connection, and a relay from compute nodes to local `argo-shim`. + +## Start The Campaign On Compute + +Run inside an interactive allocation: + +```bash +cd "$REMOTE_ROOT/ChemGraph" + +# Aurora: +module load frameworks + +# Polaris: +# module use /soft/modulefiles +# module load conda +# conda activate base + +source "$REMOTE_ROOT/venvs/academy-swarm/bin/activate" + +export RUN_ID="${ALCF_SYSTEM}-mace-ensemble-screening-001" + +export NUMEXPR_MAX_THREADS=256 +export NUMEXPR_NUM_THREADS=64 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 + +# Aurora/Polaris compute nodes reach external sites (GitHub, S3) only +# through the ALCF outbound proxy. Without these, mace_mp(model="medium-mpa-0") +# hangs trying to fetch the foundation model on first use. +export http_proxy="http://proxy.alcf.anl.gov:3128" +export https_proxy="http://proxy.alcf.anl.gov:3128" +export no_proxy="localhost,127.0.0.1" + +export PATH="$REMOTE_ROOT/bin:$REMOTE_ROOT/tools/redis/bin:$PATH" + +chemgraph academy run-compute \ + --system "$ALCF_SYSTEM" \ + --run-id "$RUN_ID" \ + --campaign mace-ensemble-screening-20 \ + --lm-user "$ARGO_USER" +``` + +If the wrapper is installed but `chemgraph` is not on `PATH`, use: + +```bash +chemgraph-academy-run \ + --system "$ALCF_SYSTEM" \ + --run-id "$RUN_ID" \ + --campaign mace-ensemble-screening-20 \ + --lm-user "$ARGO_USER" +``` + +## Reopen A Local Dashboard + +Once the run has been synced locally: + +```bash +cd "$LOCAL_CHEMGRAPH" + +PYTHONPATH=src python -m chemgraph.cli.main academy dashboard -- \ + --system "$ALCF_SYSTEM" \ + --remote-host "$ALCF_LOGIN@$ALCF_HOST" \ + --campaign mace-ensemble-screening-20 \ + "$RUN_ID" \ + --local +``` + +## Dashboard For Traditional ChemGraph Runs + +The dashboard also renders single-agent ChemGraph runs that were not launched +through Academy. Pass `--trace-dir ` to `chemgraph run` to write the +events the dashboard needs (`events.jsonl`, `status.json`, `manifest.json`), +then point the dashboard at that directory. + +On-site at ANL, the simplest path is the built-in Argo support — no shim or +relay needed (set `ARGO_USER` once per shell, or in your shell profile): + +```bash +export ARGO_USER="$ARGO_USER" + +chemgraph run \ + -q "What is the SMILES for water" \ + -m "argo:gpt-5.4" \ + --trace-dir ./run-001 +``` + +Then serve the trace directory: + +```bash +chemgraph dashboard -- --run-dir ./run-001 --port 8765 +# Open http://127.0.0.1:8765 +``` + +The browser shows the same per-agent workflow inspector that Academy displays +for a logical-agent node (query → LLM call → tool calls → output), but at the +top level since the run only has one agent. Use a fresh `--trace-dir` per run +so multiple runs don't pile into one `events.jsonl`. + +`--trace-dir` is currently only effective for the `single_agent` workflow. +Other workflows (`multi_agent`, `python_relp`, `graspa`, `rag_agent`, +`single_agent_xanes`, ...) run normally but don't yet emit dashboard events, +and the CLI prints a yellow warning for those. + +If the browser shows "Waiting for ChemGraph workflow execution events" after a +run completed successfully, the remote checkout is missing the +`llm_decision`-on-every-LLM-call fix. Sync the latest ChemGraph and clear +stale bytecode locally: + +```bash +find src/chemgraph -name __pycache__ -type d -exec rm -rf {} + +``` + +## Troubleshooting + +Check the relay from compute: + +```bash +UAN_RELAY_HOST="$(tr -d '[:space:]' < "$REMOTE_ROOT/uan-relay-18186.host")" +curl --noproxy '*' -I "http://${UAN_RELAY_HOST}:18186/v1/models" +``` + +Expected: + +```text +HTTP/1.1 200 OK +``` + +If the first model response is an Argo access-denied notice for ``, +the compute command was launched without `--lm-user "$ARGO_USER"`. Use a fresh +run id, or restart the dashboard with `--overwrite-run`, then rerun compute +with `--lm-user`. + +If imports are slow or NumExpr complains, set: + +```bash +export NUMEXPR_MAX_THREADS=256 +export NUMEXPR_NUM_THREADS=64 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +``` + +If MACE energy evaluations are slow, the first call per worker pays a +one-time foundation-model download into `~/.cache/mace`. Pre-warm by +running the snippet under "About The MACE Path" above on the compute node +before launching the campaign. diff --git a/examples/academy/example-002-mace-ensemble-screening/notes.md b/examples/academy/example-002-mace-ensemble-screening/notes.md new file mode 100644 index 00000000..4bd81cf6 --- /dev/null +++ b/examples/academy/example-002-mace-ensemble-screening/notes.md @@ -0,0 +1,29 @@ +# Notes + +This root example directory is for user-facing explanation only. The CLI loads +the actual campaign from package data so installed ChemGraph environments can +run the same campaign without relying on a source checkout's root `examples/` +directory. + +Packaged assets: + +```text +src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/ + campaign.jsonc + lm_config.json + prompt_profiles/ + data/ + models/ +``` + +The campaign declares MCP server subprocesses for general ChemGraph tools, MACE +screening, and HPC utility inspection. The Academy runtime places one logical +agent per MPI rank, launches the declared MCP servers for each agent, and uses +Academy exchange handles for peer communication. + +Each agent's `allowed_tools` field acts as a per-agent whitelist drawn from +the union of the tools its `mcp_servers` advertise. In this example the +structure agents see only `molecule_name_to_smiles` + `smiles_to_coordinate_file`, +and the mace-agent sees only `run_ase` + `extract_output_json` — even though +all four come from the same `general` MCP server. Omit `allowed_tools` (or set +it to `[]`) to expose every tool the connected servers advertise. diff --git a/notebooks/3_Demo_using_MCP.ipynb b/notebooks/3_Demo_using_MCP.ipynb index ce37b46d..caf11cb0 100644 --- a/notebooks/3_Demo_using_MCP.ipynb +++ b/notebooks/3_Demo_using_MCP.ipynb @@ -2,190 +2,269 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "3b97dfba-13c9-49a4-bdce-efd5900dcafa", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tpham2/work/projects/ChemGraph/env/chemgraph_env/lib/python3.10/site-packages/google/api_core/_python_version_support.py:266: FutureWarning: You are using a Python version (3.10.19) which Google will stop supporting in new releases of google.api_core once it reaches its end of life (2026-10-04). Please upgrade to the latest Python version, or at least Python 3.11, to continue receiving updates for google.api_core past that date.\n", - " warnings.warn(message, FutureWarning)\n", - "WARNING:root:fairchem is not installed. .\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "2026-01-22 11:50:08,686 - chemgraph.models.openai - INFO - OpenAI API key not found in environment variables.\n" + "Done creating client\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "INFO:chemgraph.models.openai:OpenAI API key not found in environment variables.\n" - ] - }, - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Please enter your OpenAI API key: ········\n" + "2026-05-22 12:34:00,370 - chemgraph.graphs.single_agent - INFO - Constructing single agent graph\n", + "2026-05-22 12:34:00,372 - chemgraph.graphs.single_agent - INFO - Graph construction completed\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "2026-01-22 11:50:10,594 - chemgraph.models.openai - INFO - Loading OpenAI model: gpt-4o-mini\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:chemgraph.models.openai:Loading OpenAI model: gpt-4o-mini\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2026-01-22 11:50:10,710 - chemgraph.models.openai - INFO - Requested model: gpt-4o-mini\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:chemgraph.models.openai:Requested model: gpt-4o-mini\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2026-01-22 11:50:10,711 - chemgraph.models.openai - INFO - OpenAI model loaded successfully\n" + "Done getting tools\n", + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "Run a mace calculations with the same file, use energy for driver and small model. a cif file are located at /Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "Tool Calls:\n", + " run_mace_single (chatcmpl-tool-a42c48d32a55e54d)\n", + " Call ID: chatcmpl-tool-a42c48d32a55e54d\n", + " Args:\n", + " params: {'input_structure_file': '/Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif', 'driver': 'energy', 'model': 'small'}\n", + "=================================\u001b[1m Tool Message \u001b[0m=================================\n", + "Name: run_mace_single\n", + "\n", + "{\n", + " \"status\": \"success\",\n", + " \"message\": \"Simulation completed. Results saved to /Users/hari/projects/ChemGraph/notebooks/output.json\",\n", + " \"single_point_energy\": -295.75144320599975,\n", + " \"unit\": \"eV\"\n", + "}\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "\n", + "The MACE single‑point energy calculation completed successfully.\n", + "\n", + "**Result**\n", + "- **Energy:** -295.75144320599975 eV \n", + "- **Output file:** `/Users/hari/projects/ChemGraph/notebooks/output.json`\n", + "\n", + "If you need any other properties (e.g., forces, charge distribution) or would like to run additional calculations (geometry optimization, vibrational analysis, etc.), just let me know!\n", + "Done\n" ] - }, + } + ], + "source": [ + "import subprocess, time, os\n", + "from langchain_mcp_adapters.client import MultiServerMCPClient\n", + "from chemgraph.agent.llm_agent import ChemGraph\n", + "\n", + "prompt_single = \"Run a mace calculations with the same file, use energy for driver and small model. a cif file are located at /Users/hari/projects/ChemGraph/notebooks/cif_files/calf-20_pacmof.cif\"\n", + "\n", + "os.environ[\"ALCF_ACCESS_TOKEN\"]=\"= '3.11'", "parsl" @@ -92,6 +103,16 @@ where = ["src/"] [tool.setuptools.package-data] "chemgraph.eval" = ["data/*.json"] +"chemgraph.academy.campaigns" = [ + "example-*/*.json", + "example-*/*.jsonc", + "example-*/data/*.json", + "example-*/models/*", + "example-*/prompt_profiles/*.json", +] +"chemgraph.academy.runtime.profiles" = ["*.json"] +"chemgraph.academy.runtime.templates" = ["*"] +"chemgraph.academy.dashboard" = ["static/*"] "ui" = ["assets/*.png"] [tool.ruff] @@ -108,6 +129,10 @@ skip-magic-trailing-comma = false # Ensure Black-style formatting testpaths = ["tests"] markers = [ "llm: marks tests as requiring LLM API access (run with --run-llm)", + "globus_compute: marks tests requiring a live Globus Compute endpoint (run with --run-globus-compute)", + "parsl: marks tests requiring a live Parsl deployment (run with --run-parsl)", + "ensemble_launcher: marks tests requiring a live EnsembleLauncher deployment (run with --run-ensemble-launcher)", + "academy: marks tests requiring Academy agent infrastructure (run with --run-academy)", "asyncio: marks async tests", ] filterwarnings = [ diff --git a/scripts/demo/README.md b/scripts/demo/README.md new file mode 100644 index 00000000..2b73b216 --- /dev/null +++ b/scripts/demo/README.md @@ -0,0 +1,203 @@ +# ChemGraph execution-layer demonstration scripts + +Real-chemistry demos that exercise each `ExecutionBackend` end-to-end. +A 5-molecule library (H2O, CH4, NH3, CO2, ethanol) is screened for +thermochemistry with MACE-MP (`driver="thermo"` → optimize geometry + +vibrational frequencies + ideal-gas thermo at 298.15 K). Each script +writes a CSV of electronic energy, enthalpy, entropy, Gibbs free +energy per molecule and prints a fixed-width summary table. + +These complement `scripts/smoke/`: + +| Directory | Purpose | Pass criterion | +|-----------|---------|---------------| +| `scripts/smoke/` | Regression validators on a trivial water payload | Exit 0 with every `[PASS]` | +| `scripts/demo/` | Realistic chemistry showcases | Useful property table; demos *fail loud* but their value is the output, not a green check | + +## Layout + +``` +scripts/demo/ +├── README.md (this file) +├── _demo_chemistry.py shared helpers (workload, formatting, agent prompt) +├── structures/ 5 .xyz fixtures (~50 lines each) +│ ├── water.xyz methane.xyz ammonia.xyz co2.xyz ethanol.xyz +├── demo_local_direct.py laptop, no LLM, no HPC +├── demo_local_agent.py laptop, LLM, no HPC +├── demo_globus_compute_direct.py laptop, no LLM, live GC endpoint +├── demo_globus_compute_agent.py laptop, LLM, live GC endpoint +├── demo_globus_transfer_direct.py laptop, no LLM, Globus Transfer + GC +├── demo_globus_transfer_agent.py laptop, LLM, Globus Transfer + GC +├── demo_parsl_in_job_direct.py inside qsub -I on Polaris/Aurora, no LLM +├── demo_parsl_in_job_agent.py inside qsub -I, LLM +├── demo_ensemble_launcher_in_job_direct.py inside qsub -I, no LLM +└── demo_ensemble_launcher_in_job_agent.py inside qsub -I, LLM +``` + +Direct demos call `chemgraph.execution.config.get_backend()` and +`backend.submit_batch(...)` directly. Agent demos spawn +`python -m chemgraph.mcp.mace_mcp_hpc` as a stdio subprocess and drive +it with a ChemGraph LLM agent over `langchain-mcp-adapters`. + +## Environment-variable matrix + +| Variable | Required by | Notes | +|----------|-------------|-------| +| `GLOBUS_COMPUTE_ENDPOINT_ID` | `demo_globus_compute_*`, `demo_globus_transfer_*` | UUID from `globus-compute-endpoint start chemgraph-` | +| `GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID` | `demo_globus_transfer_*` | Globus Connect Personal on the laptop | +| `GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID` | `demo_globus_transfer_*` | HPC collection UUID (ALCF data portal) | +| `GLOBUS_TRANSFER_DESTINATION_BASE_PATH` | `demo_globus_transfer_*` | e.g. `/eagle/projects/MyProj/staging` | +| `COMPUTE_SYSTEM` | `demo_parsl_in_job_*`, `demo_ensemble_launcher_in_job_*` | `polaris` or `aurora` | +| `PBS_NODEFILE` | both in-job demos | Set automatically inside `qsub` — demos abort if missing | +| `CG_AMQP_PORT=443` | optional, Aurora | Use when outbound 5671 is blocked | +| LLM API key (e.g. `OPENAI_API_KEY`) | all `*_agent.py` | Match the `--model` flag | + +## Running + +### Laptop, no creds + +```bash +source .cg_env/bin/activate +python scripts/demo/demo_local_direct.py +# ~20s for the 5 molecules on CPU; writes demo_local_out/{demo_local.csv,*_thermo.json} +``` + +Sample output: +``` +=== Local backend thermo screen (cpu) === +molecule energy/eV enthalpy/eV S/(eV/K) G/eV #freqs wall/s conv +--------------------------------------------------------------------------------------------- +water -13.7861 -13.1063 0.001958 -13.6900 9 3.0 True +methane -23.1669 -21.8802 0.001931 -22.4559 15 3.6 True +ammonia -18.9970 -17.9888 0.001996 -18.5839 12 3.3 True +co2 -22.5459 -22.1320 0.002209 -22.7906 9 2.9 True +ethanol -46.2767 -44.0648 0.002820 -44.9056 27 3.3 True +``` + +### Laptop + LLM + +```bash +export OPENAI_API_KEY=... +python scripts/demo/demo_local_agent.py --model gpt-4o-mini +``` + +Agent will call `run_mace_single` 5 times via the MCP subprocess and +respond with a markdown table. + +### Laptop → live Globus Compute endpoint + +```bash +export GLOBUS_COMPUTE_ENDPOINT_ID="" +export COMPUTE_SYSTEM=polaris # for logging +python scripts/demo/demo_globus_compute_direct.py # ~5-15 min first run (model download on remote) +python scripts/demo/demo_globus_compute_agent.py --model gpt-4o-mini +``` + +For Aurora add `--device xpu --amqp-port 443`. + +### Laptop → Globus Transfer + Globus Compute + +```bash +export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging +python scripts/demo/demo_globus_transfer_direct.py +python scripts/demo/demo_globus_transfer_agent.py --model gpt-4o-mini +``` + +The direct demo stages the 5 `.xyz` fixtures, then runs MACE in +*remote-path* mode (worker reads from the staged dir, no inline +embedding). The agent demo asks the LLM to call `transfer_files` and +then `run_mace_ensemble` itself. + +Remote-path mode has one quirk: `_mace_worker` only attaches +`full_output` back to the caller when an `inline_structure` is set +(see `src/chemgraph/mcp/mace_mcp_hpc.py:127-131`). So in +`demo_globus_transfer_direct.py` the printed table will have blank +thermo columns — the full JSON results sit on the HPC under +`/_thermo.json`. Pull them back with a +follow-up Globus Transfer if needed. + +### Inside a PBS allocation on Polaris + +```bash +qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle +# Now on the compute node: +module load conda +conda activate base +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=polaris +cd ~/chemgraph/ChemGraph +python scripts/demo/demo_parsl_in_job_direct.py +python scripts/demo/demo_ensemble_launcher_in_job_direct.py +``` + +### Inside a PBS allocation on Aurora + +```bash +qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare +module load frameworks +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=aurora +cd ~/chemgraph/ChemGraph +python scripts/demo/demo_parsl_in_job_direct.py --device xpu +python scripts/demo/demo_ensemble_launcher_in_job_direct.py --device xpu +``` + +### Inside a PBS allocation on Crux (CPU-only) + +```bash +qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle +cd /lus/eagle/projects/ChemGraph/thang/ChemGraph + +bash scripts/demo/run_crux_demo.sh # Parsl + EL, all 5 molecules +bash scripts/demo/run_crux_demo.sh --molecules water methane +bash scripts/demo/run_crux_demo.sh --parsl-only +bash scripts/demo/run_crux_demo.sh --el-only +``` + +The wrapper activates `.cg_crux_hpc/`, exports `COMPUTE_SYSTEM=crux`, and runs +`demo_parsl_in_job_direct.py` then `demo_ensemble_launcher_in_job_direct.py` +with `--device cpu`. CSVs land in `$PBS_O_WORKDIR/demo_{parsl,el}_out_crux/`. + +Agent variants on either system require an LLM key and follow the +same pattern as `demo_local_agent.py`. + +## Tips + +- `--molecules water methane` to run on a subset (faster iteration). +- `--output-dir /custom/path` to redirect CSV + per-molecule JSON. +- The first run on a fresh endpoint / fresh venv will be slow because + MACE-MP downloads a ~hundred-MB model. Subsequent runs hit the cache + at `~/.cache/mace/`. + +## Known caveats + +- **`langchain-mcp-adapters` must be pinned to `0.1.14`** for the + `*_agent.py` scripts to import. Versions `>=0.2.0` import + `langchain_core.messages.content` (a 1.x API) which doesn't exist in + `langchain-core 0.3.x` — and `langgraph 0.4.7` (pinned in + `pyproject.toml`) constrains us to `langchain-core 0.3.x`. Fix in + `.cg_env`: + ```bash + pip install 'langchain-mcp-adapters==0.1.14' + ``` + This is an **env-only pin** — `pyproject.toml` still lists + `langchain-mcp-adapters` unpinned, so a fresh `pip install -e .` + will regress to `>=0.2`. Re-run the pin command after any clean env + rebuild. The durable fix (one-line edit to `pyproject.toml`) was + deferred per user request. +- `ensemble-launcher` is not on PyPI for Python 3.12; the in-job EL + demos only work on HPC where `scripts/hpc_setup/install_remote.sh` + builds it from source. + +## See also + +- `scripts/smoke/` — pass/fail regression validators (trivial payload). +- `scripts/hpc_setup/{README.md,e2e_test_runbook.md}` — install + ChemGraph + start a Globus Compute endpoint on Polaris/Aurora. +- `scripts/globus_compute_example/` — looser tutorial-style examples, + predecessors of these demos. +- `src/chemgraph/execution/` — the production backends the demos call. +- `src/chemgraph/mcp/mace_mcp_hpc.py` — the MCP server every agent + demo spawns as a subprocess. diff --git a/scripts/demo/_demo_chemistry.py b/scripts/demo/_demo_chemistry.py new file mode 100644 index 00000000..065cf966 --- /dev/null +++ b/scripts/demo/_demo_chemistry.py @@ -0,0 +1,281 @@ +"""Shared chemistry-screening helpers for scripts/demo/*. + +Each demo script in this directory is a thin wrapper around +``submit_and_collect`` -- the actual chemistry workload (a 5-molecule +thermochemistry screen) lives here so we don't duplicate it across +backends. + +Workload +-------- +For each of {water, methane, ammonia, CO2, ethanol} a single MACE +``driver="thermo"`` job is submitted via the configured +``ExecutionBackend``. This drives ``chemgraph.mcp.mace_mcp_hpc._mace_worker`` +under the hood, which itself wraps ``chemgraph.tools.parsl_tools.run_mace_core`` +-> ``chemgraph.tools.ase_core.run_ase_core``. The ``thermo`` driver +optimises the geometry, computes vibrational frequencies, then derives +ideal-gas thermochemistry at the requested temperature/pressure +(``src/chemgraph/tools/ase_core.py:556-602``). + +Two modes +--------- +* ``inline=False`` -- the worker reads ``input_structure_file`` from a + shared filesystem (local, Parsl on a compute node, EL). The demo + reads the on-disk ``output_result_file`` JSON after the future + resolves. +* ``inline=True`` -- the structure is embedded in the payload via + ``atoms_to_atomsdata`` (Globus Compute, where the worker has no + access to the laptop FS). The worker materialises the structure in a + temp dir, runs MACE, then attaches the on-disk JSON back to the + result as ``full_output`` (see ``mace_mcp_hpc.py:127-131``). The demo + reads from ``raw["full_output"]``. +""" + +from __future__ import annotations + +import csv +import json +import os +from pathlib import Path +from typing import Any + +MOLECULE_NAMES: list[str] = ["water", "methane", "ammonia", "co2", "ethanol"] +_HERE = Path(__file__).resolve().parent +_STRUCTURES_DIR = _HERE / "structures" + + +def molecule_xyz_path(name: str) -> Path: + """Absolute path to the .xyz fixture for *name*.""" + p = _STRUCTURES_DIR / f"{name}.xyz" + if not p.is_file(): + raise FileNotFoundError(f"Missing structure fixture: {p}") + return p + + +def structures_dir() -> Path: + """Directory holding the per-molecule .xyz fixtures.""" + return _STRUCTURES_DIR + + +def build_thermo_job( + name: str, + *, + device: str, + output_dir: Path, + inline: bool, + model: str = "medium-mpa-0", + temperature: float = 298.15, + pressure: float = 101325.0, + fmax: float = 0.01, + steps: int = 200, +) -> dict: + """Build the job dict consumed by ``_mace_worker`` for one molecule. + + For ``inline=True`` the structure is embedded and the + ``output_result_file`` is left relative so the worker writes into + its own temp dir (and the on-disk JSON is shipped back to the + caller via the ``full_output`` key). + """ + xyz = molecule_xyz_path(name) + job: dict[str, Any] = { + "input_structure_file": str(xyz), + "driver": "thermo", + "model": model, + "device": device, + "temperature": temperature, + "pressure": pressure, + "fmax": fmax, + "steps": steps, + "optimizer": "lbfgs", + } + if inline: + # Worker resolves the (relative) output path against its own + # tempdir -- see mace_mcp_hpc._mace_worker:117-120. + job["output_result_file"] = f"{name}_thermo.json" + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(xyz)) + job["inline_structure"] = atoms_to_atomsdata(atoms).model_dump() + else: + Path(output_dir).mkdir(parents=True, exist_ok=True) + job["output_result_file"] = str( + (Path(output_dir) / f"{name}_thermo.json").resolve() + ) + return job + + +def _read_full_output(raw: dict, job: dict, *, inline: bool) -> dict: + """Return the full ASEOutputSchema dict for one finished job. + + Inline jobs carry the JSON back inline via ``full_output``. + Non-inline jobs leave it on the shared filesystem at + ``job["output_result_file"]``. + """ + if inline and isinstance(raw.get("full_output"), dict): + return raw["full_output"] + out_file = job.get("output_result_file") + if out_file and os.path.isfile(out_file): + with open(out_file) as fh: + return json.load(fh) + return {} + + +def _extract_properties(name: str, raw: dict, job: dict, *, inline: bool) -> dict: + """Pull the chemistry summary fields out of one job's result.""" + full = _read_full_output(raw, job, inline=inline) + thermo = full.get("thermochemistry") or {} + vib = full.get("vibrational_frequencies") or {} + return { + "molecule": name, + "status": raw.get("status", "?"), + "n_atoms": len(full.get("final_structure", {}).get("numbers", [])) + if isinstance(full.get("final_structure"), dict) + else None, + "energy_eV": full.get("single_point_energy"), + "enthalpy_eV": thermo.get("enthalpy"), + "entropy_eV_per_K": thermo.get("entropy"), + "gibbs_free_energy_eV": thermo.get("gibbs_free_energy"), + "n_frequencies": ( + len(vib.get("frequencies", [])) + if isinstance(vib, dict) and isinstance(vib.get("frequencies"), list) + else None + ), + "converged": full.get("converged"), + "wall_time_s": full.get("wall_time"), + } + + +def submit_and_collect( + backend, + molecule_names: list[str] | None = None, + *, + device: str, + output_dir: Path | str, + inline: bool, + timeout: float = 6000.0, + ppn: int = 1, +) -> list[dict]: + """Submit one MACE thermo job per molecule, gather and summarise. + + Returns a list of per-molecule property dicts in submission order. + Raises if any future fails -- demos should *fail loud*, not swallow. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + names = molecule_names or MOLECULE_NAMES + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + jobs = [ + build_thermo_job(name, device=device, output_dir=output_dir, inline=inline) + for name in names + ] + tasks = [ + TaskSpec( + task_id=f"demo-thermo-{name}", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + processes_per_node=ppn, + ) + for name, job in zip(names, jobs) + ] + print( + f"\nSubmitting {len(tasks)} thermo jobs to backend={type(backend).__name__} " + f"(device={device}, inline={inline})..." + ) + futures = backend.submit_batch(tasks) + + results: list[dict] = [] + for name, job, fut in zip(names, jobs, futures): + print(f" waiting on {name}...", flush=True) + try: + raw = fut.result(timeout=timeout) + if not isinstance(raw, dict): + raise RuntimeError(f"{name}: non-dict result {type(raw).__name__}: {raw!r}") + if raw.get("status") != "success": + raise RuntimeError(f"{name}: backend returned status={raw.get('status')!r}: {raw}") + results.append(_extract_properties(name, raw, job, inline=inline)) + except Exception as e: + print(f"collecting results for job {name} failed with error: {e}") + results.append(None) + return results + + +def write_csv(results: list[dict], csv_path: Path | str) -> Path: + """Write the property table to *csv_path*. Returns the path.""" + csv_path = Path(csv_path) + csv_path.parent.mkdir(parents=True, exist_ok=True) + if not results: + csv_path.write_text("") + return csv_path + fieldnames = list(results[0].keys()) + with open(csv_path, "w", newline="") as fh: + writer = csv.DictWriter(fh, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(results) + return csv_path + + +def print_summary(results: list[dict], title: str = "") -> None: + """Print a fixed-width table of the screening results.""" + if title: + print(f"\n=== {title} ===") + if not results: + print("(no results)") + return + header = ( + f"{'molecule':<10} {'energy/eV':>12} {'enthalpy/eV':>13} " + f"{'S/(eV/K)':>12} {'G/eV':>12} {'#freqs':>7} {'wall/s':>8} {'conv':>5}" + ) + print(header) + print("-" * len(header)) + + def fmt(val, w, p=4): + if val is None: + return f"{'-':>{w}}" + if isinstance(val, float): + return f"{val:>{w}.{p}f}" + return f"{val!s:>{w}}" + + for r in results: + print( + f"{r['molecule']:<10} " + f"{fmt(r.get('energy_eV'), 12)} " + f"{fmt(r.get('enthalpy_eV'), 13)} " + f"{fmt(r.get('entropy_eV_per_K'), 12, 6)} " + f"{fmt(r.get('gibbs_free_energy_eV'), 12)} " + f"{fmt(r.get('n_frequencies'), 7, 0)} " + f"{fmt(r.get('wall_time_s'), 8, 1)} " + f"{fmt(r.get('converged'), 5)}" + ) + print() + + +def agent_prompt(device: str = "cpu") -> str: + """Standard natural-language prompt used by all *_agent.py demos. + + The structure paths reference the demo's own ``structures/`` so the + agent can call ``run_mace_single`` directly without staging. + Replace the file paths if you adapt this for a different layout. + """ + files = ", ".join(str(molecule_xyz_path(n)) for n in MOLECULE_NAMES) + return ( + f"Using the MACE tool with driver='thermo', model='medium-mpa-0', " + f"device='{device}', temperature=298.15 K, pressure=101325 Pa, " + f"compute thermochemistry for each of these five molecules:\n" + f" - water: {molecule_xyz_path('water')}\n" + f" - methane: {molecule_xyz_path('methane')}\n" + f" - ammonia: {molecule_xyz_path('ammonia')}\n" + f" - CO2: {molecule_xyz_path('co2')}\n" + f" - ethanol: {molecule_xyz_path('ethanol')}\n" + f"Call run_mace_single once per molecule (do not batch them yourself). " + f"For each result, retrieve the optimized electronic energy, enthalpy, " + f"entropy and Gibbs free energy by reading the output JSON via " + f"extract_output_json. After all five complete, report a markdown table " + f"with columns: molecule, energy (eV), H (eV), G (eV), and wall-time then a one-line " + f"observation about which molecule has the lowest Gibbs free energy.\n\n" + f"(Structure paths for reference: {files})" + ) diff --git a/scripts/demo/demo_ensemble_launcher_in_job_agent.py b/scripts/demo/demo_ensemble_launcher_in_job_agent.py new file mode 100644 index 00000000..bced3fbf --- /dev/null +++ b/scripts/demo/demo_ensemble_launcher_in_job_agent.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +"""Agent + MCP + EnsembleLauncher demo on an HPC compute node. + +LLM agent on the compute node drives a local ``mace_mcp_hpc`` +subprocess whose backend is ``ensemble_launcher``. Same 5-molecule +thermo screen as the direct demo, but driven natural-language. + +Run inside ``qsub -I`` on Polaris/Aurora. LLM API key required. + +Run:: + + export COMPUTE_SYSTEM=polaris + export OPENAI_API_KEY=... + python scripts/demo/demo_ensemble_launcher_in_job_agent.py --model gpt-4o-mini +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +async def amain(model: str, system: str, device: str, query: str, verbose: int, + *, ppn: int = 1, ngpus_per_process: int = 0) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + python = sys.executable + env = os.environ.copy() + env.update({ + "CHEMGRAPH_EXECUTION_BACKEND": "ensemble_launcher", + "COMPUTE_SYSTEM": system, + }) + server_configs = { + "ChemGraph MACE (EnsembleLauncher)": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc", + "--ppn", str(ppn), + "--ngpus-per-process", str(ngpus_per_process)], + "env": env, + }, + } + + print(f"LLM model: {model}") + print(f"System: {system}") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context( + client.session("ChemGraph MACE (EnsembleLauncher)") + ) + tools = await load_mcp_tools(session) + print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt-4o-mini") + parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) + parser.add_argument("--device", default=None) + parser.add_argument("--ppn", type=int, default=1, + help="Processes per node for MCP backend tasks") + parser.add_argument("--ngpus-per-process", type=int, default=0, + help="GPUs per process for MCP backend tasks") + parser.add_argument("--query", default=None) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + if not os.environ.get("PBS_NODEFILE"): + _abort("PBS_NODEFILE not set. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r}") + device = args.device or ("xpu" if system == "aurora" else "cuda") + query = args.query or agent_prompt(device=device) + asyncio.run(amain(args.model, system, device, query, args.verbose, + ppn=args.ppn, ngpus_per_process=args.ngpus_per_process)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_ensemble_launcher_in_job_direct.py b/scripts/demo/demo_ensemble_launcher_in_job_direct.py new file mode 100644 index 00000000..42c4c73f --- /dev/null +++ b/scripts/demo/demo_ensemble_launcher_in_job_direct.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +"""Direct EnsembleLauncherBackend demo on an HPC compute node. + +5-molecule thermo screen via the EnsembleLauncher orchestrator, +managed mode (the backend starts and tears down the orchestrator +itself). Must run inside ``qsub -I`` on Polaris or Aurora, in a venv +where ``ensemble_launcher`` is installed (built from source by +``scripts/hpc_setup/install_remote.sh``). + +Run:: + + export COMPUTE_SYSTEM=polaris + python scripts/demo/demo_ensemble_launcher_in_job_direct.py + python scripts/demo/demo_ensemble_launcher_in_job_direct.py --device xpu +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) + parser.add_argument("--device", default=None) + parser.add_argument("--output-dir", default="demo_el_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument("--ppn", type=int, default=16, + help="Processes (cores) per node for each task") + parser.add_argument("--timeout", type=float, default=6000.0) + args = parser.parse_args() + + if not os.environ.get("PBS_NODEFILE"): + _abort("PBS_NODEFILE not set. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r}") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" + + try: + import ensemble_launcher # noqa: F401 + except ImportError as exc: + _abort( + f"ensemble_launcher import failed: {exc}. " + "Install via scripts/hpc_setup/install_remote.sh on HPC." + ) + + print(f"system={system} device={device} ppn={args.ppn} mode=managed") + + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="ensemble_launcher", system=system) + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=device, + output_dir=args.output_dir, + inline=False, + timeout=args.timeout, + ppn=args.ppn, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_el.csv") + print_summary( + results, + title=f"EnsembleLauncher thermo screen (system={system}, device={device})", + ) + print(f"CSV: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_compute_agent.py b/scripts/demo/demo_globus_compute_agent.py new file mode 100644 index 00000000..bad23f2f --- /dev/null +++ b/scripts/demo/demo_globus_compute_agent.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +"""Agent + MCP + Globus Compute demo: 5-molecule thermo screen on remote HPC. + +LLM agent on the laptop, MCP server (``mace_mcp_hpc``) as a local +subprocess, work dispatched to a Globus Compute endpoint on Polaris / +Aurora. Mirrors ``scripts/globus_compute_example/run_agent_mcp_remote.py`` +but with a structured 5-molecule chemistry workload instead of a free +prompt. + +Prereqs:: + + export GLOBUS_COMPUTE_ENDPOINT_ID="" + export OPENAI_API_KEY=... # or other model creds + +Run:: + + python scripts/demo/demo_globus_compute_agent.py --model gpt-4o-mini + python scripts/demo/demo_globus_compute_agent.py --device xpu --model argo:gpt-4o +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +async def amain(model: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + endpoint = os.environ["GLOBUS_COMPUTE_ENDPOINT_ID"] + os.environ["CHEMGRAPH_EXECUTION_BACKEND"] = "globus_compute" + + python = sys.executable + server_configs = { + "ChemGraph MACE (Globus Compute)": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": { + "CHEMGRAPH_EXECUTION_BACKEND": "globus_compute", + "GLOBUS_COMPUTE_ENDPOINT_ID": endpoint, + # Forward optional knobs if set. + **({"CG_AMQP_PORT": os.environ["CG_AMQP_PORT"]} if "CG_AMQP_PORT" in os.environ else {}), + **({"COMPUTE_SYSTEM": os.environ["COMPUTE_SYSTEM"]} if "COMPUTE_SYSTEM" in os.environ else {}), + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + }, + }, + } + + print(f"LLM model: {model}") + print(f"GC endpoint: {endpoint[:8]}... ({os.environ.get('COMPUTE_SYSTEM', '?')})") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context( + client.session("ChemGraph MACE (Globus Compute)") + ) + tools = await load_mcp_tools(session) + print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model", + default="gpt-4o-mini", + help="LLM model name (default: gpt-4o-mini)", + ) + parser.add_argument( + "--device", + default=os.environ.get("CG_DEMO_DEVICE", "cuda"), + help="MACE device on the remote endpoint (default: cuda; use xpu on Aurora)", + ) + parser.add_argument("--query", default=None, help="Override the default query") + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + if not os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID"): + print("ERROR: export GLOBUS_COMPUTE_ENDPOINT_ID= first.") + sys.exit(2) + + query = args.query or agent_prompt(device=args.device) + asyncio.run(amain(args.model, args.device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_compute_direct.py b/scripts/demo/demo_globus_compute_direct.py new file mode 100644 index 00000000..48aceeee --- /dev/null +++ b/scripts/demo/demo_globus_compute_direct.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +"""Direct GlobusComputeBackend demo: thermo screen on a remote HPC endpoint. + +Submits 5 MACE ``driver="thermo"`` jobs to a Globus Compute endpoint +(Polaris/Aurora/etc.) and gathers results back to the laptop. The +structures are embedded inline (``inline=True``) so the workers don't +need to read anything from the laptop's filesystem. + +Prereq env vars:: + + export GLOBUS_COMPUTE_ENDPOINT_ID="" # required + export COMPUTE_SYSTEM=polaris # optional, for logging + # export CG_AMQP_PORT=443 # if 5671 blocked (Aurora) + +Run:: + + python scripts/demo/demo_globus_compute_direct.py + python scripts/demo/demo_globus_compute_direct.py --device xpu # Aurora + python scripts/demo/demo_globus_compute_direct.py --molecules water methane +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output-dir", default="demo_globus_compute_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument( + "--device", + default=os.environ.get("CG_DEMO_DEVICE", "cuda"), + help="MACE device on the remote endpoint (default: cuda; use xpu on Aurora)", + ) + parser.add_argument( + "--amqp-port", + type=int, + default=int(os.environ.get("CG_AMQP_PORT", "0")) or None, + help="Override AMQP port (set to 443 if 5671 is blocked, e.g. Aurora)", + ) + parser.add_argument( + "--timeout", + type=float, + default=6000.0, + help="Per-task timeout in seconds (default 6000)", + ) + args = parser.parse_args() + + if not os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID"): + print("ERROR: export GLOBUS_COMPUTE_ENDPOINT_ID= first.") + sys.exit(2) + + from chemgraph.execution.config import get_backend + + backend_kwargs: dict = {} + if args.amqp_port: + backend_kwargs["amqp_port"] = args.amqp_port + + backend = get_backend(backend_name="globus_compute", **backend_kwargs) + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=args.device, + output_dir=args.output_dir, + inline=True, + timeout=args.timeout, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_globus_compute.csv") + print_summary( + results, + title=( + f"Globus Compute thermo screen " + f"(system={os.environ.get('COMPUTE_SYSTEM', '?')}, device={args.device})" + ), + ) + print(f"CSV written to: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_transfer_agent.py b/scripts/demo/demo_globus_transfer_agent.py new file mode 100644 index 00000000..865c32b7 --- /dev/null +++ b/scripts/demo/demo_globus_transfer_agent.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +"""Agent + MCP + Globus Transfer + Globus Compute demo. + +LLM agent on the laptop drives a local ``mace_mcp_hpc`` subprocess. +With both Compute and Transfer env vars set, the MCP server +auto-registers the transfer tools (``mace_mcp_hpc.py:310-313``). The +agent is told to (a) stage the demo's structures to the remote +collection via ``transfer_files``, then (b) call ``run_mace_ensemble`` +with ``remote_structure_directory`` so MACE runs on the pre-staged +files. Finally it reports a Gibbs-energy table. + +Prereqs:: + + export GLOBUS_COMPUTE_ENDPOINT_ID=... + export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID=... + export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID=... + export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging + export OPENAI_API_KEY=... # or any supported model + +Run:: + + python scripts/demo/demo_globus_transfer_agent.py --model gpt-4o-mini +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import MOLECULE_NAMES, structures_dir + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +_TRANSFER_AGENT_PROMPT_TMPL = """\ +The following five molecule structure files live on the local filesystem: +{listing} + +Workflow: +1. Call `transfer_files` with `source_paths` set to that list of absolute + paths (you may pass them as one batch) to stage them on the remote + HPC endpoint. Use `wait=true` so the call blocks until SUCCEEDED. +2. From the transfer result, take the `remote_directory` value. +3. Call `run_mace_ensemble` with: + - remote_structure_directory = + - driver = "thermo" + - model = "medium-mpa-0" + - device = "{device}" + - temperature = 298.15 + - pressure = 101325 + This dispatches one MACE thermo job per file via Globus Compute. +4. If `run_mace_ensemble` returns a `batch_id`, poll `check_job_status` + until completed, then call `get_job_results` to retrieve the per-file + energies and thermochemistry. +5. Report a markdown table with columns: molecule | electronic energy (eV) | + Gibbs free energy (eV). Add a one-line observation about which + molecule has the lowest Gibbs free energy. +""" + + +def _agent_prompt(device: str) -> str: + paths = [str(structures_dir() / f"{n}.xyz") for n in MOLECULE_NAMES] + listing = "\n".join(f" - {p}" for p in paths) + return _TRANSFER_AGENT_PROMPT_TMPL.format(listing=listing, device=device) + + +async def amain(model: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + python = sys.executable + forwarded = { + "CHEMGRAPH_EXECUTION_BACKEND": "globus_compute", + "GLOBUS_COMPUTE_ENDPOINT_ID": os.environ["GLOBUS_COMPUTE_ENDPOINT_ID"], + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID": os.environ["GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID"], + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID": os.environ["GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID"], + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH": os.environ["GLOBUS_TRANSFER_DESTINATION_BASE_PATH"], + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + } + server_configs = { + "ChemGraph MACE+Transfer": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": forwarded, + }, + } + + print(f"LLM model: {model}") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context(client.session("ChemGraph MACE+Transfer")) + tools = await load_mcp_tools(session) + names = [t.name for t in tools] + print(f"Loaded {len(tools)} MCP tools: {names}\n") + if "transfer_files" not in names: + print( + "WARNING: transfer_files not registered. Did you export the " + "GLOBUS_TRANSFER_* env vars? mace_mcp_hpc only registers the " + "transfer tools when a transfer manager is configured." + ) + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt-4o-mini") + parser.add_argument("--device", default=os.environ.get("CG_DEMO_DEVICE", "cuda")) + parser.add_argument("--query", default=None) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + required = ( + "GLOBUS_COMPUTE_ENDPOINT_ID", + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + ) + missing = [v for v in required if not os.environ.get(v)] + if missing: + print(f"ERROR: missing env vars: {', '.join(missing)}") + sys.exit(2) + + query = args.query or _agent_prompt(args.device) + asyncio.run(amain(args.model, args.device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_globus_transfer_direct.py b/scripts/demo/demo_globus_transfer_direct.py new file mode 100644 index 00000000..a0e869fb --- /dev/null +++ b/scripts/demo/demo_globus_transfer_direct.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +"""Direct Globus Transfer + Globus Compute demo. + +Stages the 5 .xyz fixtures to a remote HPC collection via Globus +Transfer, then runs MACE ``driver="thermo"`` on each pre-staged file +through Globus Compute. Workers read the structures from the HPC +filesystem (remote-path mode), not embedded inline -- this exercises +``mace_mcp_hpc._mace_worker``'s ``remote_structure_file`` branch +(`mace_mcp_hpc.py:92-99`). + +Prereq env vars:: + + export GLOBUS_COMPUTE_ENDPOINT_ID=... + export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID=... # laptop GCP collection + export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID=... # HPC collection + export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging + +First run prompts for Globus OAuth; the token caches at +``~/.globus/chemgraph_transfer_tokens.json``. + +Run:: + + python scripts/demo/demo_globus_transfer_direct.py + python scripts/demo/demo_globus_transfer_direct.py --device xpu +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + _extract_properties, + molecule_xyz_path, + print_summary, + write_csv, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output-dir", default="demo_globus_transfer_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument("--device", default=os.environ.get("CG_DEMO_DEVICE", "cuda")) + parser.add_argument( + "--amqp-port", + type=int, + default=int(os.environ.get("CG_AMQP_PORT", "0")) or None, + ) + parser.add_argument( + "--transfer-timeout", + type=float, + default=6000.0, + help="Seconds to wait for the Globus Transfer task (default 6000).", + ) + parser.add_argument( + "--compute-timeout", + type=float, + default=6000.0, + help="Seconds to wait for each MACE thermo task (default 6000).", + ) + args = parser.parse_args() + + required = ( + "GLOBUS_COMPUTE_ENDPOINT_ID", + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + ) + missing = [v for v in required if not os.environ.get(v)] + if missing: + print(f"ERROR: missing env vars: {', '.join(missing)}") + sys.exit(2) + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend, get_transfer_manager + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + # ── 1. Stage all 5 .xyz files to the remote HPC collection ───────── + print("\n[1/3] Submitting Globus Transfer for fixtures...") + tm = get_transfer_manager() + if tm is None: + print("ERROR: get_transfer_manager() returned None.") + sys.exit(2) + + local_paths = [str(molecule_xyz_path(n)) for n in args.molecules] + transfer = tm.transfer_files( + local_paths=local_paths, + label=f"chemgraph-demo-{int(time.time())}", + ) + print(f" task_id = {transfer.task_id}") + print(f" remote_dir = {transfer.remote_directory}") + print(f" waiting up to {args.transfer_timeout}s for SUCCEEDED...") + status = tm.wait_for_transfer( + transfer.task_id, timeout=args.transfer_timeout, poll_interval=5 + ) + if status.get("status") != "SUCCEEDED": + print(f"ERROR: transfer did not succeed: {status}") + sys.exit(1) + print( + f" done: {status['files_transferred']}/{status['files']} files, " + f"{status['bytes_transferred']} bytes" + ) + + # ── 2. Submit one MACE thermo task per pre-staged file ───────────── + print(f"\n[2/3] Dispatching {len(args.molecules)} MACE thermo jobs via Globus Compute...") + backend_kwargs = {} + if args.amqp_port: + backend_kwargs["amqp_port"] = args.amqp_port + backend = get_backend(backend_name="globus_compute", **backend_kwargs) + + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + jobs = [] + tasks = [] + for name in args.molecules: + remote_xyz = f"{transfer.remote_directory}/{name}.xyz" + job = { + # input_structure_file is ignored when remote_structure_file is set + # (mace_mcp_hpc._mace_worker:92-99 overrides it). Pass a sentinel. + "input_structure_file": f"remote::{name}", + "remote_structure_file": remote_xyz, + "output_result_file": f"{name}_thermo.json", + "driver": "thermo", + "model": "medium-mpa-0", + "device": args.device, + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 200, + "optimizer": "lbfgs", + } + jobs.append(job) + tasks.append( + TaskSpec( + task_id=f"demo-tr-{name}", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + + futures = backend.submit_batch(tasks) + + results = [] + try: + for name, job, fut in zip(args.molecules, jobs, futures): + print(f" waiting on {name}...", flush=True) + raw = fut.result(timeout=args.compute_timeout) + if not isinstance(raw, dict) or raw.get("status") != "success": + raise RuntimeError(f"{name}: backend returned {raw!r}") + # Remote-path mode: full_output is NOT attached (only inline triggers + # the JSON round-trip). Convergence + thermo cannot be read here + # without staging the JSON back -- see the note in the summary table. + results.append(_extract_properties(name, raw, job, inline=True)) + finally: + backend.shutdown() + + # ── 3. Report ────────────────────────────────────────────────────── + print(f"\n[3/3] Results (remote-path mode -- full JSON stays on the HPC):") + print_summary( + results, + title=f"Globus Transfer + Compute thermo screen (device={args.device})", + ) + csv_path = write_csv(results, output_dir / "demo_globus_transfer.csv") + print(f"CSV (per-call status; thermo values blank in remote-path mode): {csv_path}") + print( + f"\nNote: workers wrote full JSON results under {transfer.remote_directory} " + f"on the HPC. To pull them back, you can run another Globus Transfer " + f"job in the reverse direction." + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_local_agent.py b/scripts/demo/demo_local_agent.py new file mode 100644 index 00000000..bb435150 --- /dev/null +++ b/scripts/demo/demo_local_agent.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +"""Agent + MCP demo on LocalBackend: LLM screens 5 molecules locally. + +Spawns ``chemgraph.mcp.mace_mcp_hpc`` as a local subprocess wired to +the LocalBackend, then asks the ChemGraph LLM agent to compute +thermochemistry on water / methane / ammonia / CO2 / ethanol via the +MCP ``run_mace_single`` tool and report a markdown table. + +Prereq: an LLM API key for the chosen model (e.g. ``OPENAI_API_KEY``, +``ANTHROPIC_API_KEY``, Argo gateway tokens via ``inference_auth_token.py``, +etc.) and ``langchain-mcp-adapters`` installed (already a dep). + +Run:: + + export OPENAI_API_KEY=... + python scripts/demo/demo_local_agent.py --model gpt-4o-mini + python scripts/demo/demo_local_agent.py --model argo:gpt-4o +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +async def amain(model: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + # Make sure the spawned MCP subprocess uses LocalBackend. + os.environ["CHEMGRAPH_EXECUTION_BACKEND"] = "local" + + python = sys.executable + server_configs = { + "ChemGraph MACE": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": { + "CHEMGRAPH_EXECUTION_BACKEND": "local", + # Forward the user's PATH/HOME so the subprocess can resolve + # the venv's chemgraph + mace_torch installs. + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + }, + }, + } + + print(f"LLM model: {model}") + print(f"MCP server: mace_mcp_hpc (stdio subprocess, CHEMGRAPH_EXECUTION_BACKEND=local)") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context(client.session("ChemGraph MACE")) + tools = await load_mcp_tools(session) + tool_names = [t.name for t in tools] + print(f"Loaded {len(tools)} MCP tools: {tool_names}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model", + default="argo:gpt-4o", + help="LLM model name (default: argo:gpt-4o). Try argo:gpt-4o, claude-sonnet-4-6, gpt-4o.", + ) + parser.add_argument( + "--device", + default="cpu", + help="MACE device passed to the agent prompt (default: cpu)", + ) + parser.add_argument( + "--query", + default=None, + help="Override the natural-language query (default: 5-molecule thermo screen)", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Increase verbosity (-v INFO, -vv DEBUG).", + ) + args = parser.parse_args() + + query = args.query or agent_prompt(device=args.device) + asyncio.run(amain(args.model, args.device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_local_direct.py b/scripts/demo/demo_local_direct.py new file mode 100644 index 00000000..d3148538 --- /dev/null +++ b/scripts/demo/demo_local_direct.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +"""Direct LocalBackend demo: thermochemistry screen of 5 small molecules. + +Runs entirely on the laptop, no LLM, no HPC. Submits 5 MACE +``driver="thermo"`` jobs to a ``LocalBackend`` ProcessPoolExecutor, +gathers the results, prints a property table, and writes a CSV. + +Run:: + + python scripts/demo/demo_local_direct.py + python scripts/demo/demo_local_direct.py --output-dir /tmp/cg_demo +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +# Make _demo_chemistry importable when run from any cwd. +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--output-dir", + default="demo_local_out", + help="Where per-molecule JSON + CSV land (default: ./demo_local_out)", + ) + parser.add_argument( + "--molecules", + nargs="+", + default=MOLECULE_NAMES, + help=f"Subset to run (default: {MOLECULE_NAMES})", + ) + parser.add_argument( + "--device", + default="cpu", + help="MACE device (default: cpu; local Mac/CPU)", + ) + args = parser.parse_args() + + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="local", system="local") + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=args.device, + output_dir=args.output_dir, + inline=False, + timeout=1200, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_local.csv") + print_summary(results, title=f"Local backend thermo screen ({args.device})") + print(f"CSV written to: {csv_path}") + print(f"Per-molecule JSON written under: {Path(args.output_dir).resolve()}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_parsl_in_job_agent.py b/scripts/demo/demo_parsl_in_job_agent.py new file mode 100644 index 00000000..3d05c650 --- /dev/null +++ b/scripts/demo/demo_parsl_in_job_agent.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +"""Agent + MCP + Parsl demo on an HPC compute node. + +LLM agent on the compute node drives a local ``mace_mcp_hpc`` +subprocess whose backend is ``parsl`` configured for Polaris, Aurora, +or Crux. The agent uses ``run_mace_single`` to compute thermochemistry +for each of the 5 molecules and reports a markdown table. + +Must run inside ``qsub -I`` on Polaris/Aurora/Crux. LLM API key required. + +Run:: + + export COMPUTE_SYSTEM=polaris + export OPENAI_API_KEY=... + python scripts/demo/demo_parsl_in_job_agent.py --model gpt-4o-mini + python scripts/demo/demo_parsl_in_job_agent.py --device xpu --model argo:gpt-4o +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import logging +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import agent_prompt + +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_mcp_adapters.tools import load_mcp_tools + +from chemgraph.agent.llm_agent import ChemGraph + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +async def amain(model: str, system: str, device: str, query: str, verbose: int) -> None: + if verbose: + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") + logging.getLogger("chemgraph").setLevel(logging.INFO if verbose == 1 else logging.DEBUG) + + python = sys.executable + env = { + "CHEMGRAPH_EXECUTION_BACKEND": "parsl", + "COMPUTE_SYSTEM": system, + "PATH": os.environ.get("PATH", ""), + "HOME": os.environ.get("HOME", ""), + "VIRTUAL_ENV": os.environ.get("VIRTUAL_ENV", ""), + "CONDA_PREFIX": os.environ.get("CONDA_PREFIX", ""), + "CONDA_DEFAULT_ENV": os.environ.get("CONDA_DEFAULT_ENV", ""), + "CHEMGRAPH_WORKER_INIT": os.environ.get("CHEMGRAPH_WORKER_INIT", ""), + "PBS_NODEFILE": os.environ.get("PBS_NODEFILE", ""), + "PBS_O_WORKDIR": os.environ.get("PBS_O_WORKDIR", ""), + } + server_configs = { + "ChemGraph MACE (Parsl)": { + "transport": "stdio", + "command": python, + "args": ["-u", "-m", "chemgraph.mcp.mace_mcp_hpc"], + "env": env, + }, + } + + print(f"LLM model: {model}") + print(f"System: {system}") + print(f"Device: {device}\n") + print("Query:\n" + "-" * 60) + print(query) + print("-" * 60 + "\n") + + client = MultiServerMCPClient(server_configs) + async with contextlib.AsyncExitStack() as stack: + session = await stack.enter_async_context(client.session("ChemGraph MACE (Parsl)")) + tools = await load_mcp_tools(session) + print(f"Loaded {len(tools)} MCP tools: {[t.name for t in tools]}\n") + + cg = ChemGraph( + model_name=model, + workflow_type="single_agent", + structured_output=False, + return_option="state", + tools=tools, + ) + + print("Running agent...\n" + "=" * 60) + result = await cg.run(query) + print("=" * 60) + + if isinstance(result, dict) and "messages" in result: + for msg in reversed(result["messages"]): + content = getattr(msg, "content", None) + if not content and isinstance(msg, dict): + content = msg.get("content", "") + if content and not getattr(msg, "tool_calls", None): + print(f"\nAgent response:\n{content}") + break + else: + print(f"\nResult:\n{result}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt-4o-mini") + parser.add_argument("--system", default=os.environ.get("COMPUTE_SYSTEM")) + parser.add_argument("--device", default=None) + parser.add_argument("--query", default=None) + parser.add_argument("-v", "--verbose", action="count", default=0) + args = parser.parse_args() + + if not os.environ.get("PBS_NODEFILE"): + _abort("PBS_NODEFILE not set. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r} (expected polaris|aurora|crux)") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" + query = args.query or agent_prompt(device=device) + asyncio.run(amain(args.model, system, device, query, args.verbose)) + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/demo_parsl_in_job_direct.py b/scripts/demo/demo_parsl_in_job_direct.py new file mode 100644 index 00000000..e7df85ea --- /dev/null +++ b/scripts/demo/demo_parsl_in_job_direct.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +"""Direct ParslBackend demo on an HPC compute node: 5-molecule thermo screen. + +Must run inside a PBS interactive allocation on Polaris or Aurora:: + + # Polaris + qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle + # Aurora + qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare + +Inside the allocation:: + + module load conda # or `module load frameworks` on Aurora + source /bin/activate + export COMPUTE_SYSTEM=polaris # or aurora + cd + python scripts/demo/demo_parsl_in_job_direct.py + python scripts/demo/demo_parsl_in_job_direct.py --device xpu # Aurora +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from _demo_chemistry import ( + MOLECULE_NAMES, + print_summary, + submit_and_collect, + write_csv, +) + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--system", + default=os.environ.get("COMPUTE_SYSTEM"), + help="polaris | aurora (default: $COMPUTE_SYSTEM)", + ) + parser.add_argument("--device", default=None, help="cuda (Polaris) | xpu (Aurora)") + parser.add_argument("--output-dir", default="demo_parsl_out") + parser.add_argument("--molecules", nargs="+", default=MOLECULE_NAMES) + parser.add_argument( + "--run-dir", + default=None, + help="Parsl run_dir (default: $PBS_O_WORKDIR/parsl_demo_runs or ./parsl_demo_runs).", + ) + parser.add_argument("--timeout", type=float, default=6000.0) + args = parser.parse_args() + + pbs_nodefile = os.environ.get("PBS_NODEFILE") + if not pbs_nodefile or not Path(pbs_nodefile).is_file(): + _abort("PBS_NODEFILE not set or missing. Run inside `qsub -I`.") + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r}") + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" + + run_dir = args.run_dir or os.environ.get("PBS_O_WORKDIR") + if run_dir: + run_dir = str(Path(run_dir) / "parsl_demo_runs") + else: + run_dir = str(Path.cwd() / "parsl_demo_runs") + Path(run_dir).mkdir(parents=True, exist_ok=True) + + print(f"system={system} device={device} run_dir={run_dir}") + + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="parsl", system=system, run_dir=run_dir) + try: + results = submit_and_collect( + backend, + molecule_names=args.molecules, + device=device, + output_dir=args.output_dir, + inline=False, + timeout=args.timeout, + ) + finally: + backend.shutdown() + + csv_path = write_csv(results, Path(args.output_dir) / "demo_parsl.csv") + print_summary(results, title=f"Parsl thermo screen (system={system}, device={device})") + print(f"CSV: {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/demo/run_crux_demo.sh b/scripts/demo/run_crux_demo.sh new file mode 100755 index 00000000..53ba9acd --- /dev/null +++ b/scripts/demo/run_crux_demo.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash +# Run Parsl + EnsembleLauncher demo (5-molecule thermo screen, MACE on CPU) +# on a Crux compute node. +# +# Must be executed INSIDE an interactive PBS allocation on Crux: +# qsub -I -A -l select=1 -l walltime=01:00:00 -q debug +# cd /lus/eagle/projects/ChemGraph/thang/ChemGraph +# bash scripts/demo/run_crux_demo.sh # both backends +# bash scripts/demo/run_crux_demo.sh --parsl-only +# bash scripts/demo/run_crux_demo.sh --el-only +# bash scripts/demo/run_crux_demo.sh --molecules water methane + +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +abort() { + echo "[ABORT] $*" >&2 + exit 2 +} + +RUN_PARSL=1 +RUN_EL=1 +PASSTHROUGH=() +while (( $# )); do + case "$1" in + --parsl-only) RUN_EL=0; shift ;; + --el-only) RUN_PARSL=0; shift ;; + --molecules) shift; while (( $# )) && [[ "$1" != --* ]]; do PASSTHROUGH+=("$1"); shift; done; PASSTHROUGH=(--molecules "${PASSTHROUGH[@]}") ;; + --timeout) PASSTHROUGH+=("$1" "$2"); shift 2 ;; + -h|--help) sed -n '2,12p' "${BASH_SOURCE[0]}"; exit 0 ;; + *) abort "Unknown argument: $1" ;; + esac +done + +[[ -n "${PBS_NODEFILE:-}" && -f "${PBS_NODEFILE}" ]] \ + || abort "PBS_NODEFILE not set or missing -- run inside 'qsub -I' on Crux." + +VENV_ACTIVATE="$REPO_ROOT/.cg_crux_hpc/bin/activate" +[[ -f "$VENV_ACTIVATE" ]] || abort "Missing venv activate script: $VENV_ACTIVATE" + +if [[ "${VIRTUAL_ENV:-}" != "$REPO_ROOT/.cg_crux_hpc" ]]; then + module load conda 2>/dev/null || true + # shellcheck disable=SC1090 + source "$VENV_ACTIVATE" +fi + +export COMPUTE_SYSTEM=crux +RUN_DIR="${PBS_O_WORKDIR:-$PWD}/parsl_demo_runs_crux" +PARSL_OUT="${PBS_O_WORKDIR:-$PWD}/demo_parsl_out_crux" +EL_OUT="${PBS_O_WORKDIR:-$PWD}/demo_el_out_crux" +mkdir -p "$RUN_DIR" "$PARSL_OUT" "$EL_OUT" + +echo "REPO_ROOT=$REPO_ROOT" +echo "VIRTUAL_ENV=${VIRTUAL_ENV:-}" +echo "PBS_NODEFILE=$PBS_NODEFILE ($(wc -l <"$PBS_NODEFILE") node(s))" +echo "RUN_DIR=$RUN_DIR" +echo "PARSL_OUT=$PARSL_OUT EL_OUT=$EL_OUT" +echo + +parsl_rc=0 +el_rc=0 + +if (( RUN_PARSL )); then + echo "=== Parsl demo (system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/demo/demo_parsl_in_job_direct.py" \ + --system crux --device cpu --run-dir "$RUN_DIR" \ + --output-dir "$PARSL_OUT" "${PASSTHROUGH[@]}" \ + || parsl_rc=$? + echo +fi + +if (( RUN_EL )); then + echo "=== EnsembleLauncher demo (managed, system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/demo/demo_ensemble_launcher_in_job_direct.py" \ + --system crux --device cpu \ + --output-dir "$EL_OUT" "${PASSTHROUGH[@]}" \ + || el_rc=$? + echo +fi + +verdict() { (( $1 == 0 )) && echo PASS || echo "FAIL(rc=$1)"; } +echo "=== Summary ===" +(( RUN_PARSL )) && echo "parsl = $(verdict $parsl_rc) (output: $PARSL_OUT)" +(( RUN_EL )) && echo "el = $(verdict $el_rc) (output: $EL_OUT)" + +(( parsl_rc > el_rc )) && exit "$parsl_rc" || exit "$el_rc" diff --git a/scripts/demo/structures/ammonia.xyz b/scripts/demo/structures/ammonia.xyz new file mode 100644 index 00000000..e4254a0f --- /dev/null +++ b/scripts/demo/structures/ammonia.xyz @@ -0,0 +1,6 @@ +4 +ammonia +N 0.0000000 0.0000000 0.0000000 +H 0.9400000 0.0000000 -0.3300000 +H -0.4700000 0.8140000 -0.3300000 +H -0.4700000 -0.8140000 -0.3300000 diff --git a/scripts/demo/structures/co2.xyz b/scripts/demo/structures/co2.xyz new file mode 100644 index 00000000..0ccb5a2c --- /dev/null +++ b/scripts/demo/structures/co2.xyz @@ -0,0 +1,5 @@ +3 +co2 +C 0.0000000 0.0000000 0.0000000 +O 1.1600000 0.0000000 0.0000000 +O -1.1600000 0.0000000 0.0000000 diff --git a/scripts/demo/structures/ethanol.xyz b/scripts/demo/structures/ethanol.xyz new file mode 100644 index 00000000..594fbd6d --- /dev/null +++ b/scripts/demo/structures/ethanol.xyz @@ -0,0 +1,11 @@ +9 +ethanol +C -0.7480000 0.0150000 -0.0240000 +C 0.6850000 -0.4020000 0.2730000 +O 1.5670000 0.5140000 -0.3270000 +H -0.9270000 0.0430000 -1.1050000 +H -1.4340000 -0.7030000 0.4430000 +H -0.9480000 1.0200000 0.3500000 +H 0.8400000 -0.4500000 1.3580000 +H 0.8640000 -1.3950000 -0.1490000 +H 2.4640000 0.1980000 -0.1140000 diff --git a/scripts/demo/structures/methane.xyz b/scripts/demo/structures/methane.xyz new file mode 100644 index 00000000..89690609 --- /dev/null +++ b/scripts/demo/structures/methane.xyz @@ -0,0 +1,7 @@ +5 +methane +C 0.0000000 0.0000000 0.0000000 +H 0.6290000 0.6290000 0.6290000 +H -0.6290000 -0.6290000 0.6290000 +H -0.6290000 0.6290000 -0.6290000 +H 0.6290000 -0.6290000 -0.6290000 diff --git a/scripts/demo/structures/water.xyz b/scripts/demo/structures/water.xyz new file mode 100644 index 00000000..03120dab --- /dev/null +++ b/scripts/demo/structures/water.xyz @@ -0,0 +1,5 @@ +3 +water +O 0.0000000 0.0000000 0.0000000 +H 0.7570000 0.5860000 0.0000000 +H -0.7570000 0.5860000 0.0000000 diff --git a/scripts/smoke/README.md b/scripts/smoke/README.md new file mode 100644 index 00000000..07e83e12 --- /dev/null +++ b/scripts/smoke/README.md @@ -0,0 +1,141 @@ +# ChemGraph execution-layer smoke tests + +Self-contained scripts that exercise each ExecutionBackend live and emit +`[PASS]` / `[FAIL]` per check. Exit code is `0` only if every check passes +(`2` if required env vars are missing → "skip"). Use them for one-shot +validation after install, after a rebase, or before running real workloads. + +These are *not* pytest tests — they hit live infrastructure (process pools, +PBS allocations, Globus Compute endpoints, Globus Transfer). The mocked +unit suite still lives at `tests/test_execution.py`. + +## Script matrix + +| Script | Runs where | Backends | Live deps | +|--------|------------|----------|-----------| +| [`smoke_local.py`](smoke_local.py) | laptop | `local` | none | +| [`smoke_globus_compute.py`](smoke_globus_compute.py) | laptop | `globus_compute` | live GC endpoint | +| [`smoke_globus_transfer.py`](smoke_globus_transfer.py) | laptop | `GlobusTransferManager` (+ optional `globus_compute` MCP) | Globus collections on both ends | +| [`smoke_parsl_in_job.py`](smoke_parsl_in_job.py) | inside `qsub -I` on Polaris/Aurora | `parsl` | PBS allocation | +| [`smoke_ensemble_launcher_in_job.py`](smoke_ensemble_launcher_in_job.py) | inside `qsub -I` on Polaris/Aurora | `ensemble_launcher` (managed + client-only) | PBS allocation, `ensemble_launcher` built from source | + +`_smoke_utils.py` holds shared helpers (`SmokeReporter`, picklable trivial +callables). `water.xyz` is the shared 3-atom fixture. + +## Environment-variable matrix + +| Variable | Required by | Notes | +|----------|-------------|-------| +| `GLOBUS_COMPUTE_ENDPOINT_ID` | `smoke_globus_compute.py`, `smoke_globus_transfer.py --with-mcp` | UUID printed by `globus-compute-endpoint start chemgraph-` | +| `GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID` | `smoke_globus_transfer.py` | Globus Connect Personal collection on the laptop | +| `GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID` | `smoke_globus_transfer.py` | HPC collection UUID (ALCF data portal) | +| `GLOBUS_TRANSFER_DESTINATION_BASE_PATH` | `smoke_globus_transfer.py` | e.g. `/eagle/projects/MyProj/staging` (Polaris), `/flare/projects/MyProj/staging` (Aurora) | +| `COMPUTE_SYSTEM` | `smoke_parsl_in_job.py`, `smoke_ensemble_launcher_in_job.py` | `polaris` or `aurora` | +| `PBS_NODEFILE` | both in-job scripts | Set automatically by PBS inside `qsub` — the scripts abort if missing | +| `CG_SMOKE_DEVICE` | optional, MACE device override | Defaults: `cuda` (Polaris/Globus), `xpu` (Aurora) | + +## Running + +### Laptop only (no creds) + +```bash +source .cg_env/bin/activate +python scripts/smoke/smoke_local.py # ~5s + first-run MACE model download +python scripts/smoke/smoke_local.py --quick # ~3s, skips MACE +``` + +### Laptop → live Globus Compute endpoint + +```bash +export GLOBUS_COMPUTE_ENDPOINT_ID="" +export COMPUTE_SYSTEM=polaris # or aurora +python scripts/smoke/smoke_globus_compute.py +python scripts/smoke/smoke_globus_compute.py --amqp 443 # Aurora (5671 blocked) +``` + +### Laptop → live Globus Transfer + +```bash +export GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID="" +export GLOBUS_TRANSFER_DESTINATION_BASE_PATH=/eagle/projects/MyProj/staging +python scripts/smoke/smoke_globus_transfer.py # transfer only +python scripts/smoke/smoke_globus_transfer.py --with-mcp # also dispatch MACE ensemble in remote-path mode +``` + +First run triggers an OAuth flow; the token caches at +`~/.globus/chemgraph_transfer_tokens.json` for subsequent runs. + +### Inside a PBS allocation on Polaris + +```bash +qsub -I -A -l select=1 -l walltime=01:00:00 -q debug -l filesystems=home:eagle +# (now on the compute node) +module load conda +conda activate base +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=polaris +cd ~/chemgraph/ChemGraph + +python scripts/smoke/smoke_parsl_in_job.py +python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed +``` + +### Inside a PBS allocation on Aurora + +```bash +qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare +module load frameworks +source ~/chemgraph/venv/bin/activate +export COMPUTE_SYSTEM=aurora +cd ~/chemgraph/ChemGraph + +python scripts/smoke/smoke_parsl_in_job.py --device xpu +python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed --device xpu +``` + +### Inside a PBS allocation on Crux (CPU-only) + +```bash +qsub -I -A -l select=1 -l walltime=00:30:00 -q debug -l filesystems=home:eagle +cd /lus/eagle/projects/ChemGraph/thang/ChemGraph + +bash scripts/smoke/run_crux_smoke.sh # both backends + MACE on CPU +bash scripts/smoke/run_crux_smoke.sh --quick # skip MACE +bash scripts/smoke/run_crux_smoke.sh --parsl-only +bash scripts/smoke/run_crux_smoke.sh --el-only +``` + +The wrapper activates `.cg_crux_hpc/`, exports `COMPUTE_SYSTEM=crux`, and runs +`smoke_parsl_in_job.py` then `smoke_ensemble_launcher_in_job.py` with +`--device cpu`. It exits non-zero if either backend fails. + +### EnsembleLauncher client-only mode + +Exercises `EnsembleLauncherBackend(client_only=True, ...)` introduced in +commit `bc54083c`. Requires two shells on the same compute node: + +```bash +# Shell A — start the orchestrator +cd $PBS_O_WORKDIR +python -m ensemble_launcher \ + --system $COMPUTE_SYSTEM \ + --checkpoint-dir $PBS_O_WORKDIR/el_ckpt \ + --node-id 0 + +# Shell B — connect this client to it +python scripts/smoke/smoke_ensemble_launcher_in_job.py \ + --mode client-only \ + --checkpoint-dir $PBS_O_WORKDIR/el_ckpt \ + --node-id 0 +``` + +The client-only run leaves the orchestrator in Shell A running; stop it +there with `Ctrl-C` when done. + +## See also + +- `scripts/hpc_setup/README.md` — install ChemGraph + Globus Compute endpoint on Polaris/Aurora +- `scripts/hpc_setup/e2e_test_runbook.md` — tier-by-tier manual runbook (these smoke scripts are the automation around Tiers 1, 2, and the gap tests) +- `scripts/globus_compute_example/` — tutorial-style demonstrations (longer-form than the smoke scripts) +- `src/chemgraph/execution/` — the production code paths these scripts call diff --git a/scripts/smoke/_smoke_utils.py b/scripts/smoke/_smoke_utils.py new file mode 100644 index 00000000..2036c689 --- /dev/null +++ b/scripts/smoke/_smoke_utils.py @@ -0,0 +1,125 @@ +"""Shared helpers for the scripts/smoke/* test scripts. + +A tiny PASS/FAIL reporter so every script has the same output shape and +exit code semantics. No external dependencies beyond the stdlib. +""" + +from __future__ import annotations + +import sys +import time +import traceback +from contextlib import contextmanager +from pathlib import Path + + +class SmokeReporter: + def __init__(self, title: str) -> None: + self.title = title + self.passed = 0 + self.failed = 0 + self._t0 = time.monotonic() + print(f"\n=== {title} ===") + + @contextmanager + def check(self, name: str): + start = time.monotonic() + try: + yield + except Exception as exc: + elapsed = time.monotonic() - start + self.failed += 1 + print(f"[FAIL] {name} ({elapsed:.1f}s): {type(exc).__name__}: {exc}") + traceback.print_exc() + else: + elapsed = time.monotonic() - start + self.passed += 1 + print(f"[PASS] {name} ({elapsed:.1f}s)") + + def summary_and_exit(self) -> None: + total = self.passed + self.failed + wall = time.monotonic() - self._t0 + print( + f"\n--- {self.title}: {self.passed}/{total} passed, " + f"{self.failed} failed ({wall:.1f}s total) ---" + ) + sys.exit(0 if self.failed == 0 else 1) + + +def require_env(*names: str) -> dict[str, str]: + """Return a {name: value} dict for the listed env vars, or exit 2 if any + are missing. Use at the top of scripts that need credentials.""" + import os + + missing = [n for n in names if not os.environ.get(n)] + if missing: + print(f"[SKIP] Missing required env vars: {', '.join(missing)}") + print(" Export them and re-run.") + sys.exit(2) + return {n: os.environ[n] for n in names} + + +def water_xyz_path() -> Path: + """Absolute path to the shared water.xyz fixture.""" + return Path(__file__).resolve().parent / "water.xyz" + + +# ── module-level helpers picklable across process / globus boundaries ── + + +def trivial_add(a: int, b: int) -> int: + return a + b + + +def trivial_square(x: int) -> int: + return x * x + + +def trivial_hostname() -> str: + import socket + + return socket.gethostname() + + +def trivial_env_probe() -> dict: + import os + import sys + + info: dict = { + "hostname": __import__("socket").gethostname(), + "python": sys.version.split()[0], + "pid": os.getpid(), + "cwd": os.getcwd(), + } + try: + info["sched_affinity"] = sorted(os.sched_getaffinity(0)) + except (AttributeError, OSError): + info["sched_affinity"] = None + try: + import torch + + info["torch"] = torch.__version__ + info["cuda_devices"] = ( + torch.cuda.device_count() if torch.cuda.is_available() else 0 + ) + info["xpu_devices"] = ( + torch.xpu.device_count() if hasattr(torch, "xpu") and torch.xpu.is_available() else 0 + ) + except Exception as exc: + info["torch_error"] = str(exc) + return info + + +def ensure_on_worker_pythonpath() -> None: + """Add this file's directory to ``PYTHONPATH`` so that worker processes + (Parsl HTEX, EnsembleLauncher, Globus Compute) can ``import _smoke_utils`` + when unpickling tasks. Safe to call from the main process before backend + creation; no-op if already present. + """ + import os + + here = str(Path(__file__).resolve().parent) + existing = os.environ.get("PYTHONPATH", "") + parts = existing.split(os.pathsep) if existing else [] + if here not in parts: + os.environ["PYTHONPATH"] = os.pathsep.join([here, *parts]) if parts else here diff --git a/scripts/smoke/run_crux_smoke.sh b/scripts/smoke/run_crux_smoke.sh new file mode 100755 index 00000000..ab9f8850 --- /dev/null +++ b/scripts/smoke/run_crux_smoke.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# Run Parsl + EnsembleLauncher smoke tests on a Crux compute node (MACE on CPU). +# +# Must be executed INSIDE an interactive PBS allocation on Crux: +# qsub -I -A -l select=1 -l walltime=00:30:00 -q debug +# cd /lus/eagle/projects/ChemGraph/thang/ChemGraph +# bash scripts/smoke/run_crux_smoke.sh # both backends + MACE +# bash scripts/smoke/run_crux_smoke.sh --quick # skip MACE +# bash scripts/smoke/run_crux_smoke.sh --parsl-only +# bash scripts/smoke/run_crux_smoke.sh --el-only + +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +abort() { + echo "[ABORT] $*" >&2 + exit 2 +} + +QUICK="" +RUN_PARSL=1 +RUN_EL=1 +for arg in "$@"; do + case "$arg" in + --quick) QUICK="--quick" ;; + --parsl-only) RUN_EL=0 ;; + --el-only) RUN_PARSL=0 ;; + -h|--help) sed -n '2,11p' "${BASH_SOURCE[0]}"; exit 0 ;; + *) abort "Unknown argument: $arg" ;; + esac +done + +[[ -n "${PBS_NODEFILE:-}" && -f "${PBS_NODEFILE}" ]] \ + || abort "PBS_NODEFILE not set or missing -- run inside 'qsub -I' on Crux." + +VENV_ACTIVATE="$REPO_ROOT/.cg_crux_hpc/bin/activate" +[[ -f "$VENV_ACTIVATE" ]] || abort "Missing venv activate script: $VENV_ACTIVATE" + +if [[ "${VIRTUAL_ENV:-}" != "$REPO_ROOT/.cg_crux_hpc" ]]; then + module load conda 2>/dev/null || true + # shellcheck disable=SC1090 + source "$VENV_ACTIVATE" +fi + +export COMPUTE_SYSTEM=crux +RUN_DIR="${PBS_O_WORKDIR:-$PWD}/parsl_runs_smoke_crux" +mkdir -p "$RUN_DIR" + +echo "REPO_ROOT=$REPO_ROOT" +echo "VIRTUAL_ENV=${VIRTUAL_ENV:-}" +echo "PBS_NODEFILE=$PBS_NODEFILE ($(wc -l <"$PBS_NODEFILE") node(s))" +echo "RUN_DIR=$RUN_DIR" +echo + +parsl_rc=0 +el_rc=0 + +if (( RUN_PARSL )); then + echo "=== Parsl smoke (system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/smoke/smoke_parsl_in_job.py" \ + --system crux --device cpu --run-dir "$RUN_DIR" $QUICK \ + || parsl_rc=$? + echo +fi + +if (( RUN_EL )); then + echo "=== EnsembleLauncher smoke (managed, system=crux, device=cpu) ===" + python "$REPO_ROOT/scripts/smoke/smoke_ensemble_launcher_in_job.py" \ + --mode managed --system crux --device cpu $QUICK \ + || el_rc=$? + echo +fi + +verdict() { (( $1 == 0 )) && echo PASS || echo "FAIL(rc=$1)"; } +echo "=== Summary ===" +(( RUN_PARSL )) && echo "parsl = $(verdict $parsl_rc)" +(( RUN_EL )) && echo "el = $(verdict $el_rc)" + +(( parsl_rc > el_rc )) && exit "$parsl_rc" || exit "$el_rc" diff --git a/scripts/smoke/smoke_ensemble_launcher_in_job.py b/scripts/smoke/smoke_ensemble_launcher_in_job.py new file mode 100644 index 00000000..250e6ba1 --- /dev/null +++ b/scripts/smoke/smoke_ensemble_launcher_in_job.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +"""Smoke test for EnsembleLauncherBackend on an HPC compute node. + +Must run **inside** a PBS interactive allocation on Polaris or Aurora, +in a venv where ``ensemble_launcher`` is installed (it is built from +source by ``scripts/hpc_setup/install_remote.sh`` -- PyPI wheels only +support Python <3.12). + +Two modes +--------- + +``--mode managed`` (default) + The script starts and tears down the EnsembleLauncher orchestrator + in-process via ``get_backend(backend_name="ensemble_launcher", ...)``. + +``--mode client-only`` *(exercises commit bc54083c)* + In a **second shell on the same compute node**, first start the + orchestrator yourself, e.g.:: + + # second shell + python -m ensemble_launcher \\ + --system $COMPUTE_SYSTEM \\ + --checkpoint-dir $PBS_O_WORKDIR/el_ckpt \\ + --node-id 0 + + Then run this script with ``--mode client-only --checkpoint-dir + $PBS_O_WORKDIR/el_ckpt``. It connects to the running orchestrator + via ``ClusterClient`` rather than starting its own. + +Usage +----- +:: + + export COMPUTE_SYSTEM=polaris # or aurora + python scripts/smoke/smoke_ensemble_launcher_in_job.py --mode managed + python scripts/smoke/smoke_ensemble_launcher_in_job.py \\ + --mode client-only --checkpoint-dir $PBS_O_WORKDIR/el_ckpt +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path + +from _smoke_utils import ( + SmokeReporter, + ensure_on_worker_pythonpath, + trivial_add, + trivial_hostname, + trivial_square, + water_xyz_path, +) + +ensure_on_worker_pythonpath() + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def _wait_for_checkpoint(checkpoint_dir: Path, timeout: float) -> None: + """Wait until the orchestrator has written something to checkpoint_dir. + + The exact ready-marker shape depends on the ensemble_launcher + version; we just wait for the directory to be non-empty. + """ + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if checkpoint_dir.is_dir() and any(checkpoint_dir.iterdir()): + return + time.sleep(1.0) + _abort( + f"No checkpoint files appeared under {checkpoint_dir} within {timeout}s. " + "Start the orchestrator in another shell first; see --help." + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--mode", + choices=("managed", "client-only"), + default="managed", + ) + parser.add_argument( + "--system", + default=os.environ.get("COMPUTE_SYSTEM"), + help="polaris | aurora | local (default: $COMPUTE_SYSTEM)", + ) + parser.add_argument( + "--checkpoint-dir", + default=None, + help="(client-only) path the externally-started orchestrator writes to.", + ) + parser.add_argument( + "--node-id", + type=int, + default=0, + help="(client-only) node id assigned by the orchestrator (default 0).", + ) + parser.add_argument( + "--device", + default=None, + help="MACE device: cuda | xpu | cpu (default: cuda on polaris, xpu on aurora)", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference.", + ) + parser.add_argument( + "--wait-timeout", + type=float, + default=60.0, + help="(client-only) seconds to wait for orchestrator checkpoint to appear.", + ) + args = parser.parse_args() + + pbs_nodefile = os.environ.get("PBS_NODEFILE") + if not pbs_nodefile and args.system not in (None, "local"): + _abort( + "PBS_NODEFILE not set. Run inside a PBS allocation, or use --system local." + ) + + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "local", "crux"): + _abort(f"Unsupported --system: {system!r}") + + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" + + try: + import ensemble_launcher # noqa: F401 + except ImportError as exc: + _abort( + f"ensemble_launcher is not importable: {exc}. " + "On HPC, install it via scripts/hpc_setup/install_remote.sh." + ) + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + r = SmokeReporter( + f"smoke_ensemble_launcher_in_job (mode={args.mode}, system={system})" + ) + backend = None + + if args.mode == "managed": + with r.check("get_backend(ensemble_launcher, managed) initialises"): + backend = get_backend(backend_name="ensemble_launcher", system=system) + assert backend is not None + else: + if not args.checkpoint_dir: + _abort("--mode client-only requires --checkpoint-dir.") + ckpt = Path(args.checkpoint_dir).resolve() + with r.check( + f"orchestrator checkpoint dir is populated ({ckpt})" + ): + _wait_for_checkpoint(ckpt, args.wait_timeout) + with r.check("get_backend(ensemble_launcher, client_only=True) connects"): + backend = get_backend( + backend_name="ensemble_launcher", + system=system, + client_only=True, + checkpoint_dir=str(ckpt), + node_id=args.node_id, + ) + assert backend is not None + + if backend is None: + r.summary_and_exit() + return + + with r.check("python TaskSpec returns correct result"): + fut = backend.submit( + TaskSpec( + task_id="el-py", + task_type="python", + callable=trivial_square, + args=(11,), + ) + ) + assert fut.result(timeout=180) == 121 + + with r.check("python TaskSpec ran on a compute node"): + fut = backend.submit( + TaskSpec( + task_id="el-host", + task_type="python", + callable=trivial_hostname, + ) + ) + host = fut.result(timeout=180) + print(f" EL worker hostname = {host!r}") + + with r.check("shell TaskSpec runs"): + fut = backend.submit( + TaskSpec( + task_id="el-sh", + task_type="shell", + command="echo smoke_el_shell_ok", + ) + ) + # EL shell-task return shape depends on the version; just assert + # the future resolves without raising. + fut.result(timeout=180) + + with r.check("submit_batch of 3 python tasks all resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"el-batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, 50), + ) + for i in range(3) + ] + ) + results = [f.result(timeout=240) for f in futures] + assert results == [50, 51, 52], results + + if not args.quick: + with r.check(f"MACE geometry opt on water (device={device}, converged)"): + from ase.io import read as ase_read + + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(water_xyz_path())) + inline = atoms_to_atomsdata(atoms).model_dump() + job = { + "input_structure_file": "ignored_by_inline_path", + "output_result_file": "water_smoke_el.json", + "driver": "opt", + "model": "medium-mpa-0", + "device": device, + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + "inline_structure": inline, + } + fut = backend.submit( + TaskSpec( + task_id="el-mace-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + out = fut.result(timeout=900) + assert out.get("status") == "success", f"opt failed: {out}" + energy = next( + (out[k] for k in ("single_point_energy", "energy", "final_energy") if k in out), + None, + ) + assert energy is not None and energy < 0, f"bad MACE result: {out}" + full = out.get("full_output") or {} + if full: + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + print( + f" water opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + else: + print( + f" water opt energy = {energy:.6f} eV " + "(WARNING: full_output missing; convergence not verified inline)" + ) + + with r.check("backend.shutdown() is clean"): + if args.mode == "managed": + backend.shutdown() + else: + # In client-only mode, shutdown should NOT stop the orchestrator + # the user started -- it should only disconnect this client. + backend.shutdown() + print(" (client-only: orchestrator left running in the other shell)") + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_globus_compute.py b/scripts/smoke/smoke_globus_compute.py new file mode 100644 index 00000000..80226972 --- /dev/null +++ b/scripts/smoke/smoke_globus_compute.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +"""Smoke test for the GlobusComputeBackend. + +Drives the production execution layer against a live Globus Compute +endpoint. Exits 0 on success, nonzero on any failure. + +Prereqs (env vars) +------------------ +- GLOBUS_COMPUTE_ENDPOINT_ID -- UUID printed by ``globus-compute-endpoint start``. +- (optional) COMPUTE_SYSTEM -- "polaris" or "aurora" (used for logging only). + +Run:: + + export GLOBUS_COMPUTE_ENDPOINT_ID="" + python scripts/smoke/smoke_globus_compute.py + python scripts/smoke/smoke_globus_compute.py --quick # skip MACE + python scripts/smoke/smoke_globus_compute.py --amqp 443 # firewalled networks +""" + +from __future__ import annotations + +import argparse +import os + +from _smoke_utils import ( + SmokeReporter, + require_env, + trivial_add, + trivial_env_probe, + trivial_hostname, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference (Globus model download on remote endpoint is slow on first run).", + ) + parser.add_argument( + "--amqp", + type=int, + default=None, + help="AMQP port override. Set to 443 when outbound 5671 is blocked (Aurora).", + ) + args = parser.parse_args() + + require_env("GLOBUS_COMPUTE_ENDPOINT_ID") + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + backend_kwargs: dict = {} + if args.amqp is not None: + backend_kwargs["amqp_port"] = args.amqp + + r = SmokeReporter( + f"smoke_globus_compute (system={os.environ.get('COMPUTE_SYSTEM', '?')}, " + f"endpoint={os.environ['GLOBUS_COMPUTE_ENDPOINT_ID'][:8]}...)" + ) + backend = None + local_hostname = trivial_hostname() + + with r.check("get_backend(globus_compute) initialises"): + backend = get_backend(backend_name="globus_compute", **backend_kwargs) + assert backend is not None + + if backend is None: + r.summary_and_exit() + return + + with r.check("check_endpoint_status() reports online"): + status = backend.check_endpoint_status() + # The SDK returns either a dict like {"status": "online"} or a + # string; both shapes count as healthy if "online" appears in the + # repr. "error" status means we cannot reach the endpoint. + s = status.get("status") + assert s != "error", f"endpoint unreachable: {status}" + s_repr = str(s).lower() + assert "online" in s_repr or "ok" in s_repr or "running" in s_repr, ( + f"endpoint not online: {status}" + ) + print(f" endpoint status: {status}") + + with r.check("python TaskSpec (trivial_add) round-trips through Globus"): + fut = backend.submit( + TaskSpec( + task_id="gc-add", + task_type="python", + callable=trivial_add, + args=(40, 2), + ) + ) + result = fut.result(timeout=300) + assert result == 42, f"expected 42, got {result!r}" + + with r.check("python TaskSpec ran on the HPC node (hostname differs from laptop)"): + fut = backend.submit( + TaskSpec( + task_id="gc-host", + task_type="python", + callable=trivial_hostname, + ) + ) + remote_host = fut.result(timeout=300) + assert isinstance(remote_host, str) and remote_host, "empty hostname" + assert remote_host != local_hostname, ( + f"task ran on the laptop ({remote_host}), not the endpoint!" + ) + print(f" local={local_hostname!r} remote={remote_host!r}") + + with r.check("env probe: torch + accelerators visible on worker"): + fut = backend.submit( + TaskSpec( + task_id="gc-env", + task_type="python", + callable=trivial_env_probe, + ) + ) + info = fut.result(timeout=300) + assert isinstance(info, dict) + print(f" worker env: {info}") + + with r.check("shell TaskSpec returns SDK ShellResult"): + fut = backend.submit( + TaskSpec( + task_id="gc-sh", + task_type="shell", + command="echo smoke_globus_compute_shell_ok && hostname", + ) + ) + sh = fut.result(timeout=300) + # ShellFunction returns a ShellResult object with .stdout + stdout = getattr(sh, "stdout", str(sh)) + assert "smoke_globus_compute_shell_ok" in stdout, f"unexpected stdout: {stdout!r}" + print(f" remote shell stdout (truncated): {stdout[:120]!r}") + + with r.check("submit_batch of 3 python tasks all resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"gc-batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, 10), + ) + for i in range(3) + ] + ) + results = [f.result(timeout=300) for f in futures] + assert results == [10, 11, 12], f"expected [10,11,12], got {results}" + + if not args.quick: + with r.check("MACE geometry opt on water runs on Globus Compute (converged)"): + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + # Worker pulls the structure from its own filesystem. Since + # the laptop's water.xyz is not on the HPC node, embed it + # inline the same way the pre-submit hook would. The + # ``full_output`` key in the result carries the on-disk JSON + # back to us (mace_mcp_hpc._mace_worker, lines 127-131) so + # we can check converged without a follow-up transfer. + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + from _smoke_utils import water_xyz_path + + atoms = ase_read(str(water_xyz_path())) + inline = atoms_to_atomsdata(atoms).model_dump() + + job = { + "input_structure_file": "ignored_by_inline_path", + "output_result_file": "water_smoke_gc.json", + "driver": "opt", + "model": "medium-mpa-0", + "device": os.environ.get("CG_SMOKE_DEVICE", "cuda"), + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + "inline_structure": inline, + } + fut = backend.submit( + TaskSpec( + task_id="gc-mace-water-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + # First MACE run on the endpoint downloads the model + opt loop. + mace_out = fut.result(timeout=6000) + assert isinstance(mace_out, dict), type(mace_out) + assert mace_out.get("status") == "success", f"opt failed: {mace_out}" + energy = next( + (mace_out[k] for k in ("single_point_energy", "energy", "final_energy") if k in mace_out), + None, + ) + assert energy is not None and energy < 0, f"bad energy: {mace_out}" + + full = mace_out.get("full_output") or {} + if full: + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + assert full.get("success") is True, f"opt success=False: {full}" + print( + f" remote opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + else: + # full_output is attached by _mace_worker only when inline_structure + # is set; we always pass inline above so this branch should not hit. + print( + f" remote opt energy = {energy:.6f} eV " + "(WARNING: full_output not returned; convergence not verified)" + ) + + with r.check("backend.shutdown() is clean"): + backend.shutdown() + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_globus_transfer.py b/scripts/smoke/smoke_globus_transfer.py new file mode 100644 index 00000000..5126c60c --- /dev/null +++ b/scripts/smoke/smoke_globus_transfer.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +"""Smoke test for GlobusTransferManager (+ optional MCP integration). + +Exercises the production transfer layer from the laptop. Exits 0 on +success, nonzero on any failure. + +Prereqs (env vars) +------------------ +- GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID -- local Globus collection UUID +- GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID -- HPC collection UUID +- GLOBUS_TRANSFER_DESTINATION_BASE_PATH -- e.g. /eagle/projects/MyProj/staging +- (for --with-mcp): GLOBUS_COMPUTE_ENDPOINT_ID and HPC venv with MACE + +First run triggers a Globus OAuth flow. Token caches at +~/.globus/chemgraph_transfer_tokens.json. + +Run:: + + python scripts/smoke/smoke_globus_transfer.py + python scripts/smoke/smoke_globus_transfer.py --keep-remote # don't delete after + python scripts/smoke/smoke_globus_transfer.py --with-mcp # also exercise MCP ensemble in remote mode +""" + +from __future__ import annotations + +import argparse +import os +import time + +from _smoke_utils import SmokeReporter, require_env, water_xyz_path + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--with-mcp", + action="store_true", + help="Also exercise mace_mcp_hpc.run_mace_ensemble(remote_structure_directory=...).", + ) + parser.add_argument( + "--keep-remote", + action="store_true", + help="Don't attempt to delete the staged remote directory at the end.", + ) + parser.add_argument( + "--timeout", + type=float, + default=6000.0, + help="Per-transfer timeout in seconds (default 6000).", + ) + args = parser.parse_args() + + require_env( + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + ) + if args.with_mcp: + require_env("GLOBUS_COMPUTE_ENDPOINT_ID") + + from chemgraph.execution.config import get_transfer_manager + + r = SmokeReporter("smoke_globus_transfer") + mgr = None + transfer_result = None + + with r.check("get_transfer_manager() returns a configured manager"): + mgr = get_transfer_manager() + assert mgr is not None, ( + "get_transfer_manager returned None -- check env vars are exported." + ) + + if mgr is None: + r.summary_and_exit() + return + + with r.check("transfer_files(water.xyz) submits a Globus Transfer task"): + xyz = water_xyz_path() + assert xyz.is_file(), f"fixture missing: {xyz}" + transfer_result = mgr.transfer_files( + local_paths=[str(xyz)], + label=f"chemgraph-smoke-{int(time.time())}", + ) + assert transfer_result.task_id, "no task_id returned" + print(f" task_id = {transfer_result.task_id}") + print(f" remote_dir = {transfer_result.remote_directory}") + + with r.check(f"wait_for_transfer(timeout={args.timeout}s) reaches SUCCEEDED"): + assert transfer_result is not None + status = mgr.wait_for_transfer( + transfer_result.task_id, + timeout=args.timeout, + poll_interval=5, + ) + assert status.get("status") == "SUCCEEDED", f"final status: {status}" + assert status.get("files_transferred", 0) >= 1, status + print( + f" transferred {status['files_transferred']}/{status['files']} files, " + f"{status['bytes_transferred']} bytes" + ) + + with r.check("check_transfer_status() returns SUCCEEDED for completed task"): + assert transfer_result is not None + status = mgr.check_transfer_status(transfer_result.task_id) + assert status["status"] == "SUCCEEDED", status + + with r.check("list_remote_directory() finds the staged file"): + assert transfer_result is not None + entries = mgr.list_remote_directory(transfer_result.remote_directory) + names = {e["name"] for e in entries} + assert "water.xyz" in names, f"water.xyz not in {names!r}" + size = next((e["size"] for e in entries if e["name"] == "water.xyz"), 0) + print(f" remote water.xyz size = {size} bytes") + + if args.with_mcp: + with r.check("MCP run_mace_ensemble(remote_structure_directory=...) succeeds"): + # Drive the MCP server's tool function directly (in-process) -- + # the heavy work is dispatched to Globus Compute by the + # backend that mcp.init_backend() configured. + from chemgraph.mcp.mace_mcp_hpc import ( + _expand_mace_ensemble, + _mace_worker, + mcp, + ) + from chemgraph.execution.base import TaskSpec + from chemgraph.schemas.mace_parsl_schema import ( + mace_input_schema_ensemble, + ) + + # Init the MCP server's backend (reads CHEMGRAPH_EXECUTION_BACKEND + # / GLOBUS_COMPUTE_ENDPOINT_ID exactly like the prod server does). + os.environ.setdefault("CHEMGRAPH_EXECUTION_BACKEND", "globus_compute") + mcp.init_backend() + try: + params = mace_input_schema_ensemble( + remote_structure_directory=transfer_result.remote_directory, + output_result_file="water_smoke_tr.json", + driver="opt", + model="medium-mpa-0", + device=os.environ.get("CG_SMOKE_DEVICE", "cuda"), + ) + jobs = _expand_mace_ensemble(params) + assert jobs, "no jobs expanded from remote dir" + assert all("remote_structure_file" in j for j in jobs), jobs[0] + # Submit each job through the same backend the MCP server uses. + futures = [ + mcp._backend.submit( + TaskSpec( + task_id=f"tr-mace-opt-{i}", + task_type="python", + callable=_mace_worker, + kwargs={"job": j}, + ) + ) + for i, j in enumerate(jobs) + ] + results = [f.result(timeout=6000) for f in futures] + assert all(isinstance(r, dict) for r in results), results + assert all(r.get("status") == "success" for r in results), [ + r.get("status") for r in results + ] + energies = [ + next( + (r[k] for k in ("single_point_energy", "energy", "final_energy") if k in r), + None, + ) + for r in results + ] + assert all(e is not None and e < 0 for e in energies), results + # Remote-path mode does NOT attach full_output (only the + # inline-structure path does -- see mace_mcp_hpc._mace_worker + # lines 127-131). Convergence can be verified after the fact + # by reading the per-structure JSON on the remote filesystem + # (e.g. via Globus Transfer back to the laptop) -- out of + # scope for this smoke test. + print(f" remote MACE opt energies (eV): {energies}") + finally: + mcp.shutdown_backend() + + if not args.keep_remote and transfer_result is not None: + print( + f"\nNOTE: staged directory left in place at {transfer_result.remote_directory}\n" + " (the manager does not implement remote deletion). " + "Clean it up manually if needed." + ) + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_local.py b/scripts/smoke/smoke_local.py new file mode 100644 index 00000000..6058638e --- /dev/null +++ b/scripts/smoke/smoke_local.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +"""Smoke test for the LocalBackend. + +Drives the production execution layer end-to-end on the laptop with no +HPC and no credentials. Exits 0 on success, nonzero on any failure. + +Checks +------ +1. ``get_backend(backend_name="local")`` initialises cleanly. +2. Python TaskSpec round-trip (callable returns correct result). +3. Shell TaskSpec round-trip (exit code 0). +4. ``submit_batch`` of three tasks all resolve. +5. ``JobTracker`` register_batch / get_status / get_results round-trip. +6. MACE worker path: build a job dict for ``water.xyz`` and submit it to + the local backend exactly as ``mace_mcp_hpc._mace_transport_hook`` would. + +Run:: + + python scripts/smoke/smoke_local.py + python scripts/smoke/smoke_local.py --quick # skip the MACE check +""" + +from __future__ import annotations + +import argparse + +from _smoke_utils import ( + SmokeReporter, + trivial_add, + trivial_square, + water_xyz_path, +) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference (saves ~30s on first run downloading the model).", + ) + args = parser.parse_args() + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + r = SmokeReporter("smoke_local") + backend = None + + with r.check("get_backend(local) initialises"): + backend = get_backend(backend_name="local", system="local") + assert backend is not None, "backend is None" + + if backend is None: + r.summary_and_exit() + return + + with r.check("python TaskSpec returns correct result"): + fut = backend.submit( + TaskSpec( + task_id="py-1", + task_type="python", + callable=trivial_square, + args=(7,), + ) + ) + result = fut.result(timeout=30) + assert result == 49, f"expected 49, got {result!r}" + + with r.check("shell TaskSpec exits 0"): + fut = backend.submit( + TaskSpec( + task_id="sh-1", + task_type="shell", + command="echo smoke_local_shell_ok", + ) + ) + rc = fut.result(timeout=30) + assert rc == 0, f"expected exit 0, got {rc!r}" + + with r.check("submit_batch of 3 python tasks resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, i + 1), + ) + for i in range(3) + ] + ) + results = [f.result(timeout=30) for f in futures] + assert results == [1, 3, 5], f"expected [1,3,5], got {results}" + + with r.check("JobTracker register_batch / get_results round-trip"): + from chemgraph.execution.job_tracker import JobTracker + + tracker = JobTracker() + fut = backend.submit( + TaskSpec( + task_id="tracked-1", + task_type="python", + callable=trivial_square, + args=(6,), + ) + ) + batch_id = tracker.register_batch( + tool_name="smoke_local", + pending_tasks=[({"task_id": "tracked-1"}, fut)], + ) + # Block on the future then ask the tracker for results. + fut.result(timeout=30) + out = tracker.get_results(batch_id) + assert out["status"] == "completed", f"status={out.get('status')}" + assert out["results"][0]["result"] == 36, out["results"] + + if not args.quick: + with r.check("MACE geometry opt: water.xyz on local backend (converged)"): + import json + + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + + xyz = water_xyz_path() + assert xyz.is_file(), f"fixture missing: {xyz}" + out_json = xyz.parent / "water_smoke_output.json" + job = { + "input_structure_file": str(xyz), + "output_result_file": str(out_json), + "driver": "opt", + "model": "medium-mpa-0", + "device": "cpu", + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + } + # Submit through the backend (not in-process) to prove the + # submission pipeline serializes the worker callable and the + # arg dict correctly. + fut = backend.submit( + TaskSpec( + task_id="mace-water-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + # First MACE run downloads the model; allow generous timeout. + mace_out = fut.result(timeout=600) + assert isinstance(mace_out, dict), f"non-dict result: {type(mace_out)}" + assert mace_out.get("status") == "success", f"opt failed: {mace_out}" + energy = next( + (mace_out[k] for k in ("single_point_energy", "energy", "final_energy") if k in mace_out), + None, + ) + assert energy is not None, f"no energy in result keys={list(mace_out)}" + assert energy < 0, f"water energy should be negative, got {energy}" + + assert out_json.is_file(), f"opt output JSON not written: {out_json}" + with open(out_json) as fh: + full = json.load(fh) + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + assert full.get("success") is True, f"opt success=False: {full}" + print( + f" water opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + + with r.check("backend.shutdown() is clean"): + backend.shutdown() + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/smoke_parsl_in_job.py b/scripts/smoke/smoke_parsl_in_job.py new file mode 100644 index 00000000..ee426162 --- /dev/null +++ b/scripts/smoke/smoke_parsl_in_job.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python +"""Smoke test for ParslBackend on an HPC compute node. + +Must run **inside** a PBS interactive allocation on Polaris or Aurora:: + + # Polaris + qsub -I -A -l select=1 -l walltime=01:00:00 -q debug + # Aurora + qsub -I -A -l select=1,walltime=01:00:00 -q debug -l filesystems=home:flare + +Inside the allocation:: + + module load conda # or `module load frameworks` on Aurora + source /bin/activate + export COMPUTE_SYSTEM=polaris # or aurora + python scripts/smoke/smoke_parsl_in_job.py + python scripts/smoke/smoke_parsl_in_job.py --quick + python scripts/smoke/smoke_parsl_in_job.py --device xpu # Aurora + +The script fails fast with a clear message if PBS_NODEFILE is missing. +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +from _smoke_utils import ( + SmokeReporter, + ensure_on_worker_pythonpath, + trivial_add, + trivial_env_probe, + trivial_hostname, + trivial_square, + water_xyz_path, +) + +ensure_on_worker_pythonpath() + + +def _abort(msg: str) -> None: + print(f"[ABORT] {msg}") + sys.exit(2) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--system", + default=os.environ.get("COMPUTE_SYSTEM"), + help="polaris | aurora (default: COMPUTE_SYSTEM env var)", + ) + parser.add_argument( + "--device", + default=None, + help="MACE device: cuda (Polaris default), xpu (Aurora), or cpu.", + ) + parser.add_argument( + "--run-dir", + default=None, + help="Parsl run_dir (default: $PBS_O_WORKDIR/parsl_runs or ./parsl_runs).", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Skip MACE inference.", + ) + args = parser.parse_args() + + pbs_nodefile = os.environ.get("PBS_NODEFILE") + if not pbs_nodefile or not Path(pbs_nodefile).is_file(): + _abort( + "PBS_NODEFILE not set or missing. This script must run inside a " + "PBS interactive allocation (qsub -I ...)." + ) + + if not args.system: + _abort("COMPUTE_SYSTEM env var not set and --system not given.") + system = args.system.lower().strip() + if system not in ("polaris", "aurora", "crux"): + _abort(f"Unsupported --system: {system!r} (expected polaris|aurora|crux)") + + if args.device: + device = args.device + elif system == "aurora": + device = "xpu" + elif system == "crux": + device = "cpu" + else: + device = "cuda" + nodes = Path(pbs_nodefile).read_text().splitlines() + + run_dir = args.run_dir or os.environ.get("PBS_O_WORKDIR") + if run_dir: + run_dir = str(Path(run_dir) / "parsl_runs_smoke") + else: + run_dir = str(Path.cwd() / "parsl_runs_smoke") + Path(run_dir).mkdir(parents=True, exist_ok=True) + + print(f"system={system} device={device} nodes={len(nodes)} run_dir={run_dir}") + + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.config import get_backend + + r = SmokeReporter(f"smoke_parsl_in_job (system={system}, nodes={len(nodes)})") + backend = None + + with r.check("get_backend(parsl) initialises with HPC config"): + backend = get_backend( + backend_name="parsl", + system=system, + run_dir=run_dir, + ) + assert backend is not None + + if backend is None: + r.summary_and_exit() + return + + with r.check("python TaskSpec returns correct result"): + fut = backend.submit( + TaskSpec( + task_id="p-py", + task_type="python", + callable=trivial_square, + args=(9,), + ) + ) + assert fut.result(timeout=120) == 81 + + with r.check("python TaskSpec ran on a compute node (hostname != login)"): + fut = backend.submit( + TaskSpec( + task_id="p-host", + task_type="python", + callable=trivial_hostname, + ) + ) + host = fut.result(timeout=120) + print(f" parsl worker hostname = {host!r}") + assert isinstance(host, str) and host + + with r.check("worker env: torch + accelerators visible"): + fut = backend.submit( + TaskSpec( + task_id="p-env", + task_type="python", + callable=trivial_env_probe, + ) + ) + info = fut.result(timeout=120) + print(f" worker env: {info}") + # Polaris should show cuda; Aurora should show xpu; Crux is CPU-only. + if system == "polaris": + assert info.get("cuda_devices", 0) >= 1, info + elif system == "aurora": + assert info.get("xpu_devices", 0) >= 1, info + # Crux: CPU-only; no accelerator assertion. + + with r.check("shell TaskSpec exits 0"): + fut = backend.submit( + TaskSpec( + task_id="p-sh", + task_type="shell", + command="echo smoke_parsl_shell_ok && hostname", + ) + ) + rc = fut.result(timeout=120) + assert rc == 0, f"exit code = {rc}" + + with r.check("submit_batch of 4 python tasks all resolve"): + futures = backend.submit_batch( + [ + TaskSpec( + task_id=f"p-batch-{i}", + task_type="python", + callable=trivial_add, + args=(i, 100), + ) + for i in range(4) + ] + ) + results = [f.result(timeout=180) for f in futures] + assert results == [100, 101, 102, 103], results + + if not args.quick: + with r.check(f"MACE geometry opt on water (device={device}, converged)"): + from ase.io import read as ase_read + + from chemgraph.mcp.mace_mcp_hpc import _mace_worker + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(str(water_xyz_path())) + inline = atoms_to_atomsdata(atoms).model_dump() + job = { + "input_structure_file": "ignored_by_inline_path", + "output_result_file": "water_smoke_parsl.json", + "driver": "opt", + "model": "medium-mpa-0", + "device": device, + "temperature": 298.15, + "pressure": 101325.0, + "fmax": 0.01, + "steps": 100, + "optimizer": "lbfgs", + "inline_structure": inline, + } + fut = backend.submit( + TaskSpec( + task_id="p-mace-opt", + task_type="python", + callable=_mace_worker, + kwargs={"job": job}, + ) + ) + out = fut.result(timeout=900) + assert out.get("status") == "success", f"opt failed: {out}" + energy = next( + (out[k] for k in ("single_point_energy", "energy", "final_energy") if k in out), + None, + ) + assert energy is not None and energy < 0, f"bad MACE result: {out}" + full = out.get("full_output") or {} + if full: + assert full.get("converged") is True, f"opt did not converge: {full.get('converged')!r}" + print( + f" water opt energy = {energy:.6f} eV " + f"(converged={full.get('converged')}, wall={full.get('wall_time')}s)" + ) + else: + print( + f" water opt energy = {energy:.6f} eV " + "(WARNING: full_output missing; convergence not verified inline)" + ) + + with r.check("backend.shutdown() is clean"): + backend.shutdown() + + r.summary_and_exit() + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke/water.xyz b/scripts/smoke/water.xyz new file mode 100644 index 00000000..baec6e18 --- /dev/null +++ b/scripts/smoke/water.xyz @@ -0,0 +1,5 @@ +3 +water molecule +O 0.0000000 0.0000000 0.0000000 +H 0.7570000 0.5860000 0.0000000 +H -0.7570000 0.5860000 0.0000000 diff --git a/src/chemgraph/academy/__init__.py b/src/chemgraph/academy/__init__.py new file mode 100644 index 00000000..46c0964b --- /dev/null +++ b/src/chemgraph/academy/__init__.py @@ -0,0 +1,95 @@ +"""Academy Agents integration for ChemGraph. + +Public re-exports come in two tiers so the package honours the +``[academy]`` optional-dep contract: + +* **Eager** (pure stdlib + pydantic, always importable): + ``ChemGraphAgentSpec``, ``ChemGraphCampaign``, + ``ChemGraphDaemonConfig``, ``MCPServerSpec``, ``ResourceSpec``, + ``load_campaign``, ``resolve_campaign_resources``, + ``PromptProfile``, ``load_prompt_profile``, ``CampaignEvent``, + ``EventLog``. These let the dashboard, ``--trace-dir``, and the + observability tooling work on a checkout without ``academy-py`` + installed. +* **Lazy** (resolved via ``__getattr__`` on first access; requires + the ``[academy]`` extra): ``ChemGraphLogicalAgent``. Importing it + pulls in ``academy.agent``; without the extra installed, access + raises ``ImportError`` with a hint instead of crashing the package + import. + +This split exists because ``chemgraph.cli.trace`` (single-agent +``--trace-dir`` flow) and the test collector both touch +``chemgraph.academy`` via leaf submodules; eager-importing the +academy-py-dependent ``ChemGraphLogicalAgent`` here broke those code +paths for users without the optional extra. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import MCPServerSpec +from chemgraph.academy.core.campaign import ResourceSpec +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import EventLog + + +if TYPE_CHECKING: + from chemgraph.academy.core.agent import ChemGraphLogicalAgent + + +_LAZY_EXPORTS: dict[str, tuple[str, str]] = { + # public name -> (module path, attribute in that module) + "ChemGraphLogicalAgent": ( + "chemgraph.academy.core.agent", + "ChemGraphLogicalAgent", + ), +} + + +def __getattr__(name: str) -> Any: + """Lazy resolver for academy-py-dependent re-exports. + + Called by Python only when ``name`` is not found among the eager + imports above. On ``ImportError`` we re-raise with an actionable + hint so the operator knows which extra to install. + """ + if name in _LAZY_EXPORTS: + module_path, attr = _LAZY_EXPORTS[name] + try: + from importlib import import_module + module = import_module(module_path) + except ImportError as exc: + raise ImportError( + f"Importing {name!r} from chemgraph.academy requires " + f"the 'academy' optional extra: " + f"`pip install 'chemgraph[academy]'`. " + f"Underlying error: {exc}" + ) from exc + return getattr(module, attr) + raise AttributeError( + f"module 'chemgraph.academy' has no attribute {name!r}" + ) + + +__all__ = [ + "CampaignEvent", + "ChemGraphAgentSpec", + "ChemGraphCampaign", + "ChemGraphDaemonConfig", + "ChemGraphLogicalAgent", + "EventLog", + "MCPServerSpec", + "PromptProfile", + "ResourceSpec", + "load_campaign", + "load_prompt_profile", + "resolve_campaign_resources", +] diff --git a/src/chemgraph/academy/campaigns/__init__.py b/src/chemgraph/academy/campaigns/__init__.py new file mode 100644 index 00000000..8c3f5cd6 --- /dev/null +++ b/src/chemgraph/academy/campaigns/__init__.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import dataclasses +from importlib import resources +from pathlib import Path + + +EXAMPLE_002 = 'example-002-mace-ensemble-screening' + +CAMPAIGNS = { + 'mace-ensemble-screening-20': f'{EXAMPLE_002}/campaign.jsonc', +} + +LM_CONFIG_TEMPLATES = { + 'argo-gpt54-mace-template': f'{EXAMPLE_002}/lm_config.json', +} + + +@dataclasses.dataclass(frozen=True) +class CampaignLaunchDefaults: + """Runtime defaults for a packaged ChemGraph Academy campaign.""" + + lm_config_template: str + agent_count: int + agents_per_node: int + max_decisions: int + + +CAMPAIGN_LAUNCH_DEFAULTS = { + 'mace-ensemble-screening-20': CampaignLaunchDefaults( + lm_config_template='argo-gpt54-mace-template', + agent_count=5, + agents_per_node=1, + max_decisions=24, + ), +} + + +def _resolve_campaign_asset( + path_or_name: str | Path, + known_assets: dict[str, str], +) -> Path: + value = str(path_or_name) + path = Path(value) + if path.exists(): + return path.resolve() + relative = known_assets.get(value) + if relative is None: + return path + return Path(str(resources.files(__package__).joinpath(relative))) + + +def resolve_campaign(path_or_name: str | Path) -> Path: + return _resolve_campaign_asset(path_or_name, CAMPAIGNS) + + +def resolve_lm_config_template(path_or_name: str | Path) -> Path: + return _resolve_campaign_asset(path_or_name, LM_CONFIG_TEMPLATES) + + +def list_campaigns() -> list[str]: + return sorted(CAMPAIGNS) + + +def campaign_launch_defaults(campaign: str) -> CampaignLaunchDefaults: + try: + return CAMPAIGN_LAUNCH_DEFAULTS[campaign] + except KeyError as exc: + raise KeyError( + f'No launch defaults for campaign {campaign!r}', + ) from exc diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc new file mode 100644 index 00000000..d0b8640f --- /dev/null +++ b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/campaign.jsonc @@ -0,0 +1,102 @@ +{ + // Campaign files support JSONC-style comments. + "run_id": "mace-ensemble-screening-20", + "user_task": "Given 20 staged SMILES candidates, generate 3D XYZ structures, run a per-structure MACE energy calculation through the run_ase tool, and rank candidates by calculation readiness and available MACE evidence.", + "prompt_profile": "prompt_profiles/default.json", + "initial_agent": "coordinator-agent", + "resources": { + // Resource fields: + // kind: "json" | "file" | "directory" + // scope: "campaign_file" | "shared_run" | "absolute" | "external" + // campaign_file: relative paths resolve next to this campaign file. + // shared_run: relative paths resolve under /shared/. + // absolute: path must already be absolute. + // external: runtime leaves path/uri unchanged. + // expose_content: only meaningful for kind="json"; true includes parsed JSON in the bootstrap task. + "candidate_dataset": { + "kind": "json", + "path": "data/mace_screening_20_smiles.json", + "scope": "campaign_file", + "description": "The full input candidate list. Coordinator-agent may inspect it and delegate records by peer message.", + "expose_content": true + }, + "structure_output_directory": { + "kind": "directory", + "path": "academy_mace_structures", + "scope": "shared_run", + "description": "Shared run directory where generated XYZ coordinate files should be written." + }, + "mace_output_directory": { + "kind": "directory", + "path": "academy_mace_outputs", + "scope": "shared_run", + "description": "Shared run directory where mace-agent should write one JSON result file per structure (e.g. academy_mace_outputs/.json)." + } + }, + "mcp_servers": [ + // MCP server fields: + // command: launch command; runtime appends --transport/--host/--port. + // The HPC-specific servers (mace_mcp_hpc, hpc_misc_mcp) are intentionally + // omitted here because they go through chemgraph.execution.ParslBackend, + // which is being reworked in a separate PR. This example exercises the + // in-process MACE path through the general ``run_ase`` tool instead. + { + "name": "general", + "command": "python -m chemgraph.mcp.mcp_tools" + } + ], + "agents": [ + { + "name": "coordinator-agent", + "role": "MACEReadinessCoordinatorAgent", + "mission": "Coordinate the campaign from the bootstrap task. Send odd-numbered MOL candidates to structure-agent-a and even-numbered MOL candidates to structure-agent-b, including candidate_id, label, SMILES, and output_file. After structure evidence returns, ask mace-agent to run one MACE energy calculation per generated XYZ file using the run_ase tool with the mace_mp calculator on CPU, then ask assessment-agent for readiness/ranking evidence before submitting the final result.", + "allowed_peers": [ + "structure-agent-a", + "structure-agent-b", + "mace-agent", + "assessment-agent" + ], + "mcp_servers": [], + "resources": [ + "candidate_dataset", + "structure_output_directory", + "mace_output_directory" + ] + }, + { + "name": "structure-agent-a", + "role": "MolecularStructureWorkerAgent", + "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", + "allowed_peers": ["coordinator-agent"], + "mcp_servers": ["general"], + "allowed_tools": ["molecule_name_to_smiles", "smiles_to_coordinate_file"], + "resources": [] + }, + { + "name": "structure-agent-b", + "role": "MolecularStructureWorkerAgent", + "mission": "Process only candidates assigned by coordinator-agent. Generate XYZ coordinate files, then report concise artifact evidence and failures back to coordinator-agent.", + "allowed_peers": ["coordinator-agent"], + "mcp_servers": ["general"], + "allowed_tools": ["molecule_name_to_smiles", "smiles_to_coordinate_file"], + "resources": [] + }, + { + "name": "mace-agent", + "role": "MACEEnergyAgent", + "mission": "Run MACE only after a concrete request from coordinator-agent. For each assigned XYZ file, call the run_ase tool with driver='energy', a calculator block of {'calculator_type': 'mace_mp', 'model': 'medium-mpa-0', 'device': 'cpu'}, the input_structure_file pointing at the XYZ, and an output_results_file under the requested output directory. Report started, completed, partial, or failed evidence back to coordinator-agent, including output paths and tool_result_ids; pending work is not a failure.", + "allowed_peers": ["coordinator-agent"], + "mcp_servers": ["general"], + "allowed_tools": ["run_ase", "extract_output_json"], + "resources": ["mace_output_directory"] + }, + { + "name": "assessment-agent", + "role": "ScreeningAssessmentAgent", + "mission": "Assess evidence received from coordinator-agent. Summarize structure coverage, MACE coverage, failures, ranking readiness, and pending work without treating pending MACE work as failure.", + "allowed_peers": ["coordinator-agent"], + "mcp_servers": [], + "resources": [] + } + ] +} diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json new file mode 100644 index 00000000..90bce655 --- /dev/null +++ b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/data/mace_screening_20_smiles.json @@ -0,0 +1,106 @@ +{ + "dataset_id": "mace-screening-20-smiles-v1", + "description": "Twenty small-molecule SMILES for a ChemGraph-native Academy MACE ensemble screening demo.", + "candidates": [ + { + "candidate_id": "MOL-001", + "label": "water", + "smiles": "O" + }, + { + "candidate_id": "MOL-002", + "label": "methane", + "smiles": "C" + }, + { + "candidate_id": "MOL-003", + "label": "ammonia", + "smiles": "N" + }, + { + "candidate_id": "MOL-004", + "label": "carbon_dioxide", + "smiles": "O=C=O" + }, + { + "candidate_id": "MOL-005", + "label": "methanol", + "smiles": "CO" + }, + { + "candidate_id": "MOL-006", + "label": "ethanol", + "smiles": "CCO" + }, + { + "candidate_id": "MOL-007", + "label": "acetone", + "smiles": "CC(=O)C" + }, + { + "candidate_id": "MOL-008", + "label": "acetic_acid", + "smiles": "CC(=O)O" + }, + { + "candidate_id": "MOL-009", + "label": "benzene", + "smiles": "c1ccccc1" + }, + { + "candidate_id": "MOL-010", + "label": "toluene", + "smiles": "Cc1ccccc1" + }, + { + "candidate_id": "MOL-011", + "label": "phenol", + "smiles": "Oc1ccccc1" + }, + { + "candidate_id": "MOL-012", + "label": "aniline", + "smiles": "Nc1ccccc1" + }, + { + "candidate_id": "MOL-013", + "label": "pyridine", + "smiles": "n1ccccc1" + }, + { + "candidate_id": "MOL-014", + "label": "furan", + "smiles": "c1ccoc1" + }, + { + "candidate_id": "MOL-015", + "label": "formaldehyde", + "smiles": "C=O" + }, + { + "candidate_id": "MOL-016", + "label": "formic_acid", + "smiles": "C(=O)O" + }, + { + "candidate_id": "MOL-017", + "label": "glycine", + "smiles": "NCC(=O)O" + }, + { + "candidate_id": "MOL-018", + "label": "alanine", + "smiles": "CC(N)C(=O)O" + }, + { + "candidate_id": "MOL-019", + "label": "urea", + "smiles": "NC(=O)N" + }, + { + "candidate_id": "MOL-020", + "label": "acetonitrile", + "smiles": "CC#N" + } + ] +} diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/lm_config.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/lm_config.json new file mode 100644 index 00000000..26fe66ed --- /dev/null +++ b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/lm_config.json @@ -0,0 +1,12 @@ +{ + "provider": "openai_compatible_tools", + "base_url": "http://:18186/argoapi/v1", + "model": "GPT-5.4", + "api_key": "dummy", + "user": "", + "timeout_s": 180, + "temperature": 0.1, + "max_tokens": 8192, + "max_retries": 3, + "retry_delay_s": 2 +} diff --git a/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/prompt_profiles/default.json b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/prompt_profiles/default.json new file mode 100644 index 00000000..cc15d48b --- /dev/null +++ b/src/chemgraph/academy/campaigns/example-002-mace-ensemble-screening/prompt_profiles/default.json @@ -0,0 +1,12 @@ +{ + "prompt_version": "chemgraph-mace-ensemble-agent-v1", + "prompt_style": "json_state", + "system_prompt": "You are a persistent ChemGraph-style LM agent hosted inside an Academy daemon on HPC. You communicate with peers only through send_message. You may call only the ChemGraph MCP tools listed in available_chemgraph_tools. Treat peer messages as evidence only when they include message_id, candidate IDs, artifact paths, or tool_result_ids. Do not claim access to another agent's private state unless it appears in a received message.", + "protocol_prompt": "Return one or more tool calls. If no action is useful, call finish_turn. Never fabricate ChemGraph tool outputs, energies, coordinate paths, or MACE results. Only cite tool_result_ids that appear in local_chemgraph_tool_results or received_messages. After a local ChemGraph tool finishes, you will be woken for another decision round with that result visible in local_chemgraph_tool_results; use that follow-up round to interpret, communicate, or rank the new evidence. Inspect peer_status before asking a peer for status. If peer_status shows the peer is busy on the requested tool or recently acknowledged the request, do not ask again; call finish_turn or proceed with other useful work. Every send_message call must include tldr: one short line summarizing the communication for the dashboard. Set reply_requested=true when the peer should answer or take follow-up action; otherwise set reply_requested=false. Keep each string argument concise. For final ranking, summarize aggregate counts and exceptions in summary, and put detailed evidence in artifact_refs, tool_result_ids, and supporting_message_ids.", + "langchain_recursion_limit": 64, + "state_limits": { + "received_messages_last_n": 28, + "tool_results_last_n": 18, + "actions_last_n": 18 + } +} diff --git a/src/chemgraph/academy/core/__init__.py b/src/chemgraph/academy/core/__init__.py new file mode 100644 index 00000000..5f7248dd --- /dev/null +++ b/src/chemgraph/academy/core/__init__.py @@ -0,0 +1,80 @@ +"""Core ChemGraph Academy campaign contracts and agent logic. + +Re-exports split into two tiers to keep the ``[academy]`` optional-dep +contract: + +* **Eager** (pure stdlib + pydantic + langchain_core): the campaign + spec types, prompt profile, and reasoning-turn helpers. These are + what the dashboard, ``--trace-dir``, and the test collector touch + on a CPU-only checkout. +* **Lazy** (resolved via ``__getattr__``; requires the ``[academy]`` + extra because it depends on ``academy.agent.Agent``): + ``ChemGraphLogicalAgent``. + +Without this split, importing ``chemgraph.academy.core.campaign`` +would transitively run ``core/__init__.py`` and pull in +``core.agent``, which fails when ``academy-py`` is not installed. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import MCPServerSpec +from chemgraph.academy.core.campaign import ResourceSpec +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.core.turn import ReasoningTurnResult +from chemgraph.academy.core.turn import run_academy_turn + + +if TYPE_CHECKING: + from chemgraph.academy.core.agent import ChemGraphLogicalAgent + + +_LAZY_EXPORTS: dict[str, tuple[str, str]] = { + "ChemGraphLogicalAgent": ( + "chemgraph.academy.core.agent", + "ChemGraphLogicalAgent", + ), +} + + +def __getattr__(name: str) -> Any: + if name in _LAZY_EXPORTS: + module_path, attr = _LAZY_EXPORTS[name] + try: + from importlib import import_module + module = import_module(module_path) + except ImportError as exc: + raise ImportError( + f"Importing {name!r} from chemgraph.academy.core requires " + f"the 'academy' optional extra: " + f"`pip install 'chemgraph[academy]'`. " + f"Underlying error: {exc}" + ) from exc + return getattr(module, attr) + raise AttributeError( + f"module 'chemgraph.academy.core' has no attribute {name!r}" + ) + + +__all__ = [ + "ChemGraphAgentSpec", + "ChemGraphCampaign", + "ChemGraphDaemonConfig", + "ChemGraphLogicalAgent", + "MCPServerSpec", + "PromptProfile", + "ReasoningTurnResult", + "ResourceSpec", + "load_campaign", + "load_prompt_profile", + "resolve_campaign_resources", + "run_academy_turn", +] diff --git a/src/chemgraph/academy/core/agent.py b/src/chemgraph/academy/core/agent.py new file mode 100644 index 00000000..6f2c81ca --- /dev/null +++ b/src/chemgraph/academy/core/agent.py @@ -0,0 +1,263 @@ +"""Persistent logical Academy agent for ChemGraph campaigns.""" + +from __future__ import annotations + +import asyncio +import time +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import Any + +from academy.agent import Agent, action +from academy.agent import loop +from academy.handle import Handle +from academy.identifier import AgentId +from langchain_core.tools import BaseTool + +from chemgraph.academy.core.peer_protocol import validate_message +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.run_artifacts import write_status_snapshot +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.turn import run_academy_turn +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.models.settings import LLMSettings + + +class ChemGraphLogicalAgent(Agent): + """Persistent Academy logical agent for one ChemGraph campaign role.""" + + def __init__( + self, + spec: ChemGraphAgentSpec, + *, + campaign: ChemGraphCampaign, + llm_settings: LLMSettings, + prompt_profile: PromptProfile, + run_dir: Path, + max_decisions: int, + external_tools: Sequence[BaseTool] = (), + peer_agent_ids: Mapping[str, AgentId[Any]] | None = None, + placement: dict[str, Any] | None = None, + poll_timeout_s: float = 2.0, + idle_timeout_s: float = 120.0, + status_interval_s: float = 5.0, + ) -> None: + super().__init__() + self.spec = spec + self.campaign = campaign + self.llm_settings = llm_settings + self.prompt_profile = prompt_profile + self.run_dir = run_dir + self.max_decisions = max_decisions + self.external_tools = list(external_tools) + self.peer_agent_ids = dict(peer_agent_ids or {}) + self.placement = placement or {} + self.poll_timeout_s = poll_timeout_s + self.idle_timeout_s = idle_timeout_s + self.status_interval_s = status_interval_s + + self.peer_names = tuple(spec.allowed_peers) + self.peer_handles: dict[str, Handle[Any]] = {} + self.received_message_history: list[dict[str, Any]] = [] + self.outbox: list[dict[str, Any]] = [] + self.tool_results: list[dict[str, Any]] = [] + self.final_result: dict[str, Any] | None = None + self.round_index = 0 + self.finished = False + self.last_error: str | None = None + self._wake_event: asyncio.Event | None = None + + async def agent_on_startup(self) -> None: + self._wake_event = asyncio.Event() + self.peer_handles = { + name: Handle(agent_id) + for name, agent_id in self.peer_agent_ids.items() + if name in self.peer_names + } + self._trace( + 'agent_started', + { + 'role': self.spec.role, + 'tool_names': [tool.name for tool in self.external_tools], + 'allowed_peers': list(self.spec.allowed_peers), + 'placement': self.placement, + **self.placement, + }, + ) + + @action + async def receive_message(self, message: dict[str, Any]) -> None: + validate_message(message) + self.received_message_history.append(message) + self._trace('message_received', message) + if self._wake_event is not None: + self._wake_event.set() + + @action + async def get_status(self) -> dict[str, Any]: + return await self.report_state() + + @loop + async def deliberate(self, shutdown: asyncio.Event) -> None: + if self._wake_event is None: + raise RuntimeError('agent startup did not initialize wake state') + + decisions_completed = 0 + last_activity = time.monotonic() + last_status = 0.0 + + while not shutdown.is_set(): + if self._wake_event.is_set(): + self._wake_event.clear() + decisions_completed, self_wake = await self.run_decision_turn( + decisions_completed, + ) + last_activity = time.monotonic() + if self_wake: + self._wake_event.set() + await self.write_runtime_status() + if decisions_completed >= self.max_decisions: + self._trace( + 'max_decisions_reached', + {'decisions_completed': decisions_completed}, + ) + break + continue + + now = time.monotonic() + if now - last_status >= self.status_interval_s: + await self.write_runtime_status() + last_status = now + + if now - last_activity >= self.idle_timeout_s: + self._trace( + 'idle_timeout', + { + 'idle_timeout_s': self.idle_timeout_s, + 'decisions_completed': decisions_completed, + }, + ) + break + + try: + await asyncio.wait_for( + self._wake_event.wait(), + timeout=self.poll_timeout_s, + ) + except asyncio.TimeoutError: + pass + + self.finished = True + self._trace( + 'daemon_stopped', + { + 'decisions_completed': decisions_completed, + 'shutdown_requested': shutdown.is_set(), + }, + ) + await self.write_runtime_status() + self.agent_shutdown() + + async def write_runtime_status(self) -> None: + write_status_snapshot( + run_dir=self.run_dir, + campaign=self.campaign, + agent_state=await self.report_state(), + placement=self.placement, + ) + + async def run_decision_turn(self, decisions_completed: int) -> tuple[int, bool]: + self.round_index += 1 + try: + self_wake = await self._reasoning_round() + except Exception as exc: + self.last_error = repr(exc) + self._trace('agent_error', {'error': self.last_error}) + raise + return decisions_completed + 1, self_wake + + async def report_state(self) -> dict[str, Any]: + return { + 'agent_name': self.spec.name, + 'role': self.spec.role, + 'status_updated_at': time.time(), + 'round': self.round_index, + 'finished': self.finished, + 'last_error': self.last_error, + } + + async def _reasoning_round(self) -> bool: + self._trace('round_started', {'round': self.round_index}) + tools = await build_chemgraph_reasoning_tools( + spec=self.spec, + run_dir=self.run_dir, + external_tools=self.external_tools, + peer_names=self.peer_names, + peer_handles=self.peer_handles, + outbox=self.outbox, + tool_results=self.tool_results, + get_round_index=lambda: self.round_index, + set_final_result=self._set_final_result, + trace=self._trace, + ) + result = await run_academy_turn( + campaign=self.campaign, + spec=self.spec, + llm_settings=self.llm_settings, + prompt_profile=self.prompt_profile, + run_dir=self.run_dir, + max_decisions=self.max_decisions, + tools=tools, + received_message_history=self.received_message_history, + outbox=self.outbox, + tool_results=self.tool_results, + get_final_result=lambda: self.final_result, + get_round_index=lambda: self.round_index, + trace=self._trace, + peer_names=self.peer_names, + ) + self._trace( + 'agent_decision', + { + 'mode': 'mpi_daemon', + 'wake_reason': f'daemon round {self.round_index}', + 'rationale': 'LM returned the listed tool calls for this daemon turn.', + 'round': self.round_index, + 'tool_names': list(result.executed_tool_names), + 'action_tools_called': list(result.action_tools_called), + 'science_tools_called': list(result.science_tools_called), + 'thread_id': result.thread_id, + 'engine': 'chemgraph_single_agent', + 'actions': [ + {'action': name} + for name in result.executed_tool_names + ], + }, + ) + self._trace('round_finished', {'round': self.round_index}) + if result.requested_self_wake: + self._trace( + 'self_wake_scheduled', + { + 'round': self.round_index, + 'reason': ( + 'local ChemGraph tool result is now available in ' + 'local_chemgraph_tool_results' + ), + }, + ) + return result.requested_self_wake + + def _set_final_result(self, result: dict[str, Any]) -> None: + self.final_result = result + + def _trace(self, event: str, payload: dict[str, Any]) -> None: + EventLog(self.run_dir / 'events.jsonl').emit( + event, # type: ignore[arg-type] + run_id=self.run_dir.name, + agent_id=self.spec.name, + role=self.spec.role, + payload=payload, + ) diff --git a/src/chemgraph/academy/core/campaign.py b/src/chemgraph/academy/core/campaign.py new file mode 100644 index 00000000..b87a80da --- /dev/null +++ b/src/chemgraph/academy/core/campaign.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import dataclasses +import json +import pathlib +from collections.abc import Mapping +from typing import Any + +from chemgraph.academy.campaigns import resolve_campaign +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +_REMOVED_CAMPAIGN_FIELDS = frozenset( + { + 'completion_criteria', + 'parameters', + 'routing_policy', + 'workflow_stages', + }, +) +_RESOURCE_KINDS = frozenset({'directory', 'file', 'json'}) +_RESOURCE_SCOPES = frozenset( + { + 'absolute', + 'campaign_file', + 'external', + 'shared_run', + }, +) + + +class MCPServerSpec(BaseModel): + """Campaign-declared MCP server subprocess available to agents.""" + + model_config = ConfigDict(extra='forbid') + + name: str = Field(min_length=1) + command: str = Field( + min_length=1, + description=( + "Shell command to launch the MCP server. Tokens after the first " + "are arguments. Do not include --transport/--host/--port; the " + "supervisor adds them." + ), + ) + env: dict[str, str] = Field(default_factory=dict) + + @field_validator('name', 'command') + @classmethod + def _non_empty(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError('field must be non-empty') + return value + + +class ResourceSpec(BaseModel): + """Campaign-declared resource or artifact handle. + + The runtime resolves only these explicit ``path`` fields. It never scans + arbitrary campaign metadata looking for strings that might be paths. + """ + + model_config = ConfigDict(extra='forbid') + + kind: str + path: str | None = None + uri: str | None = None + scope: str = 'campaign_file' + description: str = '' + expose_content: bool = False + + @field_validator('kind') + @classmethod + def _known_resource_kind(cls, value: str) -> str: + value = value.strip() + if value not in _RESOURCE_KINDS: + raise ValueError( + f'resource kind must be one of {sorted(_RESOURCE_KINDS)}', + ) + return value + + @field_validator('scope') + @classmethod + def _known_resource_scope(cls, value: str) -> str: + value = value.strip() + if value not in _RESOURCE_SCOPES: + raise ValueError( + f'resource scope must be one of {sorted(_RESOURCE_SCOPES)}', + ) + return value + + @field_validator('path', 'uri', 'description') + @classmethod + def _strip_optional_resource_field(cls, value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + +@dataclasses.dataclass(frozen=True) +class ChemGraphAgentSpec: + name: str + role: str + mission: str + allowed_peers: tuple[str, ...] + mcp_servers: tuple[str, ...] = () + allowed_tools: tuple[str, ...] = () + """Optional per-agent whitelist of MCP tool names. + + Empty (the default) means the agent sees every tool advertised by the + servers listed in :attr:`mcp_servers`. When non-empty, only tools whose + name appears in this tuple are exposed to the agent; everything else + that the servers advertise is filtered out before reaching LangChain. + + The whitelist is flat and server-agnostic: a name matches any tool with + that name across the agent's connected servers. Duplicate tool names + across an agent's servers are still rejected by the supervisor (today's + behavior), so the whitelist does not introduce new ambiguity. + """ + resources: tuple[str, ...] = () + + +@dataclasses.dataclass(frozen=True) +class ChemGraphCampaign: + run_id: str + user_task: str + initial_agent: str + prompt_profile: pathlib.Path + agents: tuple[ChemGraphAgentSpec, ...] + mcp_servers: tuple[MCPServerSpec, ...] = () + resources: Mapping[str, ResourceSpec] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass(frozen=True) +class ChemGraphDaemonConfig: + run_dir: pathlib.Path + run_token: str + agent_count: int + campaign_config: pathlib.Path + lm_config: pathlib.Path + max_decisions: int + poll_timeout_s: float + idle_timeout_s: float + startup_timeout_s: float + completion_timeout_s: float + status_interval_s: float + redis_host: str + redis_port: int + redis_namespace: str + rank: int + local_rank: int | None + chemgraph_repo_root: pathlib.Path + exchange_type: str = 'redis' + + +def namespace_for_run(run_dir: pathlib.Path) -> str: + return f'academy-chemgraph-swarm:{run_dir.name}' + + +def resolve_campaign_resources( + campaign: ChemGraphCampaign, + run_dir: str | pathlib.Path, + *, + shared_dir_name: str = 'shared', +) -> ChemGraphCampaign: + """Resolve explicit shared-run resource paths for one concrete run. + + Also pre-creates the on-disk directories these resources name so that + tools whose first action is to write under a declared output directory + do not fail with ``FileNotFoundError`` partway through. For ``kind: + directory`` resources the directory itself is created; for ``kind: + file`` and ``kind: json`` resources the file's parent directory is + created (the file itself is the agent's responsibility to write). + """ + shared_root = (pathlib.Path(run_dir).resolve() / shared_dir_name) + resources: dict[str, ResourceSpec] = {} + + for name, spec in campaign.resources.items(): + if spec.path is None: + resources[name] = spec + continue + if spec.scope != 'shared_run': + resources[name] = spec + continue + path = pathlib.Path(spec.path) + resolved = (path if path.is_absolute() else shared_root / path).resolve() + _ensure_resource_dir(resolved, spec.kind) + resources[name] = spec.model_copy( + update={ + 'path': str(resolved), + 'uri': spec.uri or _file_uri(resolved), + }, + ) + + return dataclasses.replace(campaign, resources=resources) + + +def _ensure_resource_dir(resolved: pathlib.Path, kind: str) -> None: + """Materialise on-disk directories for a resolved shared_run resource.""" + if kind == 'directory': + resolved.mkdir(parents=True, exist_ok=True) + else: + # 'file' and 'json': create the parent so the agent can write the file. + resolved.parent.mkdir(parents=True, exist_ok=True) + + +def _file_uri(path: pathlib.Path) -> str: + return path.resolve().as_uri() + + +def _resolve_resource_spec( + raw: Mapping[str, Any], + *, + campaign_path: pathlib.Path, +) -> ResourceSpec: + spec = ResourceSpec.model_validate(raw) + if spec.path is None: + return spec + if spec.scope == 'campaign_file': + path = pathlib.Path(spec.path) + resolved = path if path.is_absolute() else campaign_path.parent / path + resolved = resolved.resolve() + return spec.model_copy( + update={ + 'path': str(resolved), + 'uri': spec.uri or _file_uri(resolved), + }, + ) + if spec.scope == 'absolute': + path = pathlib.Path(spec.path) + if not path.is_absolute(): + raise RuntimeError( + f'absolute resource path must be absolute: {spec.path}', + ) + resolved = path.resolve() + return spec.model_copy( + update={ + 'path': str(resolved), + 'uri': spec.uri or _file_uri(resolved), + }, + ) + if spec.scope in {'shared_run', 'external'}: + return spec + + raise RuntimeError(f'unsupported resource scope {spec.scope!r}') + + +def load_campaign(path: str | pathlib.Path) -> ChemGraphCampaign: + path = resolve_campaign(path) + data = _load_jsonc(path) + _reject_removed_campaign_fields(data, campaign_path=path) + prompt_profile = _resolve_campaign_relative_path( + data.get('prompt_profile'), + campaign_path=path, + field_name='prompt_profile', + ) + + mcp_servers = tuple( + MCPServerSpec.model_validate(raw) + for raw in data.get('mcp_servers', ()) + ) + resources = { + name: _resolve_resource_spec(raw, campaign_path=path) + for name, raw in dict(data.get('resources', {})).items() + } + agents = [] + for item in data['agents']: + agents.append( + ChemGraphAgentSpec( + name=item['name'], + role=item['role'], + mission=item['mission'], + allowed_peers=tuple(item.get('allowed_peers', ())), + mcp_servers=tuple(item.get('mcp_servers', ())), + allowed_tools=tuple(item.get('allowed_tools', ())), + resources=tuple(item.get('resources', ())), + ), + ) + return ChemGraphCampaign( + run_id=data.get('run_id', path.stem), + user_task=data['user_task'], + initial_agent=data.get('initial_agent', agents[0].name), + prompt_profile=prompt_profile, + agents=tuple(agents), + mcp_servers=mcp_servers, + resources=resources, + ) + + +def _load_jsonc(path: pathlib.Path) -> dict[str, Any]: + """Load a campaign file that may contain JSONC-style comments.""" + data = json.loads(_strip_json_comments(path.read_text(encoding='utf-8'))) + if not isinstance(data, dict): + raise RuntimeError(f'campaign {path} must contain a JSON object') + return data + + +def _strip_json_comments(text: str) -> str: + """Remove // and /* */ comments without touching JSON string values.""" + out: list[str] = [] + in_string = False + escape = False + i = 0 + + while i < len(text): + char = text[i] + nxt = text[i + 1] if i + 1 < len(text) else '' + + if in_string: + out.append(char) + if escape: + escape = False + elif char == '\\': + escape = True + elif char == '"': + in_string = False + i += 1 + continue + + if char == '"': + in_string = True + out.append(char) + i += 1 + continue + + if char == '/' and nxt == '/': + i += 2 + while i < len(text) and text[i] not in '\r\n': + i += 1 + continue + + if char == '/' and nxt == '*': + i += 2 + while i < len(text): + if text[i] in '\r\n': + out.append(text[i]) + i += 1 + continue + if text[i] == '*' and i + 1 < len(text) and text[i + 1] == '/': + i += 2 + break + i += 1 + continue + + out.append(char) + i += 1 + + return ''.join(out) + + +def _reject_removed_campaign_fields( + data: Mapping[str, Any], + *, + campaign_path: pathlib.Path, +) -> None: + removed = sorted(_REMOVED_CAMPAIGN_FIELDS.intersection(data)) + if not removed: + return + raise RuntimeError( + f'campaign {campaign_path} uses removed structured orchestration ' + f'field(s): {removed}. Put simple natural-language coordination hints ' + 'in agent mission fields and enforce the communication graph with ' + 'allowed_peers.', + ) + + +def _resolve_campaign_relative_path( + raw: Any, + *, + campaign_path: pathlib.Path, + field_name: str, +) -> pathlib.Path: + if not isinstance(raw, str) or not raw.strip(): + raise RuntimeError(f'campaign requires non-empty {field_name!r}') + path = pathlib.Path(raw.strip()) + if not path.is_absolute(): + path = campaign_path.parent / path + return path.resolve() + + +def validate_campaign(campaign: ChemGraphCampaign, agent_count: int) -> None: + if len(campaign.agents) != agent_count: + raise RuntimeError( + f'campaign defines {len(campaign.agents)} agents but ' + f'agent_count={agent_count}', + ) + names = [agent.name for agent in campaign.agents] + if len(set(names)) != len(names): + raise RuntimeError('campaign agent names must be unique') + if campaign.initial_agent not in names: + raise RuntimeError( + f'initial_agent {campaign.initial_agent!r} is not an agent', + ) + server_names = [server.name for server in campaign.mcp_servers] + if len(set(server_names)) != len(server_names): + raise RuntimeError('campaign MCP server names must be unique') + declared_servers = set(server_names) + for agent in campaign.agents: + unknown = sorted(set(agent.allowed_peers).difference(names)) + if unknown: + raise RuntimeError( + f'{agent.name} has unknown allowed peers: {unknown}', + ) + if agent.name in agent.allowed_peers: + raise RuntimeError(f'{agent.name} must not list itself as a peer') + unknown_servers = sorted(set(agent.mcp_servers).difference(declared_servers)) + if unknown_servers: + raise RuntimeError( + f'{agent.name} references unknown MCP servers: {unknown_servers}', + ) + if agent.allowed_tools: + if len(set(agent.allowed_tools)) != len(agent.allowed_tools): + raise RuntimeError( + f'{agent.name} has duplicate allowed_tools entries', + ) + if not agent.mcp_servers: + raise RuntimeError( + f'{agent.name} declares allowed_tools but no mcp_servers ' + 'to draw them from', + ) + unknown_resources = sorted(set(agent.resources).difference(campaign.resources)) + if unknown_resources: + raise RuntimeError( + f'{agent.name} references unknown resources: {unknown_resources}', + ) + + +def selected_agent(campaign: ChemGraphCampaign, rank: int) -> ChemGraphAgentSpec: + if rank < 0 or rank >= len(campaign.agents): + raise RuntimeError( + f'MPI rank {rank} has no agent. Launch exactly ' + f'{len(campaign.agents)} ranks for this campaign.', + ) + return campaign.agents[rank] + + +def campaign_bootstrap_text(campaign: ChemGraphCampaign) -> str: + initial_agent = next( + (agent for agent in campaign.agents if agent.name == campaign.initial_agent), + None, + ) + initial_resources = initial_agent.resources if initial_agent is not None else () + payload: dict[str, Any] = { + 'user_task': campaign.user_task, + 'resources': _resources_payload(campaign, initial_resources), + 'resource_data': _resource_data_payload(campaign, initial_resources), + } + return json.dumps(payload, sort_keys=True) + + +def _resources_payload( + campaign: ChemGraphCampaign, + resource_names: tuple[str, ...] | list[str], +) -> dict[str, dict[str, Any]]: + payload: dict[str, dict[str, Any]] = {} + for name in resource_names: + spec = campaign.resources.get(name) + if spec is None: + continue + payload[name] = spec.model_dump(exclude_none=True) + return payload + + +def _resource_data_payload( + campaign: ChemGraphCampaign, + resource_names: tuple[str, ...] | list[str], +) -> dict[str, Any]: + payload: dict[str, Any] = {} + for name in resource_names: + spec = campaign.resources.get(name) + if spec is None or not spec.expose_content: + continue + if spec.kind != 'json' or spec.path is None: + continue + path = pathlib.Path(spec.path) + if not path.exists(): + raise FileNotFoundError(f'campaign resource does not exist: {path}') + payload[name] = json.loads(path.read_text(encoding='utf-8')) + return payload + + +def visible_resources_payload( + campaign: ChemGraphCampaign, + agent: ChemGraphAgentSpec, +) -> dict[str, dict[str, Any]]: + return _resources_payload(campaign, agent.resources) diff --git a/src/chemgraph/academy/core/peer_protocol.py b/src/chemgraph/academy/core/peer_protocol.py new file mode 100644 index 00000000..67182203 --- /dev/null +++ b/src/chemgraph/academy/core/peer_protocol.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import uuid +import time +from typing import Any + + +REQUIRED_MESSAGE_KEYS = { + 'message_id', + 'sender', + 'recipient', + 'content', +} + + +def validate_message(message: dict[str, Any]) -> None: + """Validate the generic Academy message envelope.""" + if missing := REQUIRED_MESSAGE_KEYS.difference(message): + raise ValueError(f'message missing keys: {sorted(missing)}') + + +def build_message( + *, + sender: str, + recipient: str, + content: str, + round_index: int | None = None, + kind: str = 'message', + tldr: str | None = None, + artifact_refs: list[str] | None = None, + tool_result_ids: list[str] | None = None, + reply_requested: bool = False, + reason: str | None = None, + confidence: float | None = None, +) -> dict[str, Any]: + """Create the structured message payload sent through Academy handles.""" + payload: dict[str, Any] = { + 'message_id': f'msg-{uuid.uuid4()}', + 'timestamp': time.time(), + 'sender': sender, + 'recipient': recipient, + 'kind': kind, + 'content': content, + 'reply_requested': reply_requested, + 'artifact_refs': artifact_refs or [], + 'tool_result_ids': tool_result_ids or [], + } + if round_index is not None: + payload['round'] = round_index + if tldr is not None: + payload['tldr'] = tldr + if reason is not None: + payload['reason'] = reason + if confidence is not None: + payload['confidence'] = confidence + return payload diff --git a/src/chemgraph/academy/core/prompt.py b/src/chemgraph/academy/core/prompt.py new file mode 100644 index 00000000..8268bf87 --- /dev/null +++ b/src/chemgraph/academy/core/prompt.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + + +class PromptStateLimits(BaseModel): + """Visibility limits for state included in each logical-agent prompt.""" + + model_config = ConfigDict(extra='forbid') + + received_messages_last_n: int = Field(ge=0) + tool_results_last_n: int = Field(ge=0) + actions_last_n: int = Field(ge=0) + + +class PromptProfile(BaseModel): + """Prompt/rendering profile shared by logical agents in a campaign run.""" + + model_config = ConfigDict(extra='forbid') + + prompt_version: str + prompt_style: Literal['json_state'] + system_prompt: str + protocol_prompt: str + langchain_recursion_limit: int = Field(ge=4) + state_limits: PromptStateLimits + + +def load_prompt_profile(path: str | Path) -> PromptProfile: + data = json.loads(Path(path).read_text(encoding='utf-8')) + return PromptProfile.model_validate(data) diff --git a/src/chemgraph/academy/core/tools.py b/src/chemgraph/academy/core/tools.py new file mode 100644 index 00000000..a636b4a5 --- /dev/null +++ b/src/chemgraph/academy/core/tools.py @@ -0,0 +1,296 @@ +"""Build Academy action tools and attach configured science tools.""" + +from __future__ import annotations + +import pathlib +import time +import asyncio +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +from academy.handle import Handle +from langchain_core.tools import BaseTool, StructuredTool +from pydantic import BaseModel, ConfigDict, Field, ValidationError + +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.peer_protocol import build_message +from chemgraph.academy.observability.run_files import append_jsonl + + +TraceFn = Callable[[str, dict[str, Any]], None] +SetFinalResultFn = Callable[[dict[str, Any]], None] +_BACKGROUND_DELIVERIES: set[asyncio.Task[Any]] = set() + + +class SendMessageArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + recipient: str = Field(min_length=1, description="Allowed peer agent name.") + tldr: str = Field(min_length=1, max_length=160, description="One-line dashboard edge label.") + content: str = Field(min_length=1, max_length=1800, description="Full peer message content.") + artifact_refs: list[str] = Field(default_factory=list, description="Artifact path strings.") + tool_result_ids: list[str] = Field(default_factory=list, description="ChemGraph tool_result_id strings.") + reply_requested: bool = Field( + default=False, + description="True when this asks the peer to reply or act.", + ) + reason: str = Field(min_length=1, max_length=600, description="Why this peer needs the message now.") + confidence: float = Field(ge=0, le=1, description="Numeric confidence from 0 to 1.") + + +class SubmitResultArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + summary: str = Field(min_length=1, max_length=1200) + artifact_refs: list[str] = Field(default_factory=list) + tool_result_ids: list[str] = Field(default_factory=list) + supporting_message_ids: list[str] = Field(default_factory=list) + confidence: float = Field(ge=0, le=1) + reason: str = Field(min_length=1, max_length=600) + + +class FinishTurnArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + + reason: str = Field(min_length=1, max_length=600) + + +def _stable_validation_errors(exc: ValidationError) -> list[dict[str, str]]: + return [ + { + "field": ".".join(str(part) for part in error.get("loc", ())), + "message": str(error.get("msg", "invalid value")), + } + for error in exc.errors() + ] + + +def _invalid_args_response( + tool_name: str, + exc: ValidationError, + trace: TraceFn, +) -> dict[str, Any]: + payload = { + "tool_name": tool_name, + "status": "failed", + "error": "invalid_tool_arguments", + "error_type": "invalid_tool_arguments", + "errors": _stable_validation_errors(exc), + } + trace("tool_call_failed", payload) + return {**payload, "status": "error"} + + +def _disallowed_recipient_response( + tool_name: str, + recipient: str, + allowed: tuple[str, ...], + trace: TraceFn, +) -> dict[str, Any]: + payload = { + "tool_name": tool_name, + "status": "failed", + "error": "disallowed_recipient", + "error_type": "disallowed_recipient", + "recipient": recipient, + "allowed_peers": list(allowed), + } + trace("tool_call_failed", payload) + return {**payload, "status": "error"} + + +async def build_chemgraph_reasoning_tools( + *, + spec: ChemGraphAgentSpec, + run_dir: pathlib.Path, + external_tools: Sequence[BaseTool] = (), + peer_names: tuple[str, ...], + peer_handles: Mapping[str, Handle[Any]], + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + get_round_index: Callable[[], int], + set_final_result: SetFinalResultFn, + trace: TraceFn, +) -> list[BaseTool]: + """Build explicit tools for one ChemGraph-backed reasoning turn.""" + + async def _send_message_impl( + *, + recipient: str, + tldr: str, + content: str, + artifact_refs: list[str], + tool_result_ids: list[str], + reply_requested: bool, + reason: str, + confidence: float, + ) -> dict[str, Any]: + if recipient not in peer_names: + raise ValueError( + f"{spec.name} tried to message disallowed peer {recipient}", + ) + kind = "question" if reply_requested else "message" + message = build_message( + sender=spec.name, + recipient=recipient, + content=content, + round_index=get_round_index(), + kind=kind, + tldr=tldr, + artifact_refs=artifact_refs, + tool_result_ids=tool_result_ids, + reply_requested=reply_requested, + reason=reason, + confidence=confidence, + ) + outbox.append(message) + append_jsonl(run_dir / "messages.jsonl", message) + trace("message_sent", message) + if recipient not in peer_handles: + raise RuntimeError(f"No Academy handle for allowed peer {recipient}") + task = asyncio.create_task( + _deliver_message( + recipient=recipient, + message=message, + handle=peer_handles[recipient], + trace=trace, + ), + ) + _BACKGROUND_DELIVERIES.add(task) + task.add_done_callback(_BACKGROUND_DELIVERIES.discard) + return { + "status": "sent", + "delivery": "queued", + "message_id": message["message_id"], + "recipient": recipient, + } + + async def _deliver_message( + *, + recipient: str, + message: dict[str, Any], + handle: Handle[Any], + trace: TraceFn, + ) -> None: + try: + await handle.action("receive_message", message) + except Exception as exc: # noqa: BLE001 - preserve async delivery failure. + trace( + "message_delivery_failed", + { + "recipient": recipient, + "message_id": message["message_id"], + "error": repr(exc), + }, + ) + return + trace( + "message_delivered", + { + "recipient": recipient, + "message_id": message["message_id"], + }, + ) + + def _validation_error_handler(tool_name: str) -> Callable[[ValidationError], dict[str, Any]]: + def handle(exc: ValidationError) -> dict[str, Any]: + return _invalid_args_response(tool_name, exc, trace) + + return handle + + async def send_message(**kwargs: Any) -> dict[str, Any]: + try: + args = SendMessageArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("send_message", exc, trace) + if args.recipient not in peer_names: + return _disallowed_recipient_response( + "send_message", + args.recipient, + peer_names, + trace, + ) + return await _send_message_impl( + recipient=args.recipient, + tldr=args.tldr, + content=args.content, + artifact_refs=args.artifact_refs, + tool_result_ids=args.tool_result_ids, + reply_requested=args.reply_requested, + reason=args.reason, + confidence=args.confidence, + ) + + async def submit_result(**kwargs: Any) -> dict[str, Any]: + try: + args = SubmitResultArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("submit_result", exc, trace) + result = { + "timestamp": time.time(), + "round": get_round_index(), + "hypothesis": args.summary, + "summary": args.summary, + "artifact_refs": args.artifact_refs, + "tool_result_ids": args.tool_result_ids, + "supporting_message_ids": args.supporting_message_ids, + "supporting_tool_result_ids": args.tool_result_ids, + "confidence": args.confidence, + "reason": args.reason, + } + set_final_result(result) + trace("belief_updated", result) + return {"status": "submitted", "confidence": result["confidence"]} + + async def finish_turn(**kwargs: Any) -> dict[str, Any]: + try: + args = FinishTurnArgs.model_validate(kwargs) + except ValidationError as exc: + return _invalid_args_response("finish_turn", exc, trace) + trace("turn_finished_without_external_action", {"reason": args.reason}) + return {"status": "finished", "reason": args.reason} + + tools: list[BaseTool] = [ + StructuredTool.from_function( + coroutine=send_message, + name="send_message", + description=( + "Send tool-backed evidence, reasoning, or a request to one " + "allowed peer. Always provide recipient, tldr, content, " + "artifact_refs as an array of strings or [], tool_result_ids " + "as an array of strings or [], reply_requested as true when " + "the peer should respond, a non-empty reason, and numeric " + "confidence from 0 to 1." + ), + args_schema=SendMessageArgs, + handle_validation_error=_validation_error_handler("send_message"), + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + StructuredTool.from_function( + coroutine=submit_result, + name="submit_result", + description=( + "Submit this agent's current final answer or report. Cite peer " + "message IDs and ChemGraph tool result IDs." + ), + args_schema=SubmitResultArgs, + handle_validation_error=_validation_error_handler("submit_result"), + return_direct=True, + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + StructuredTool.from_function( + coroutine=finish_turn, + name="finish_turn", + description=( + "End this decision turn when no tool, message, or report action " + "is currently useful." + ), + args_schema=FinishTurnArgs, + handle_validation_error=_validation_error_handler("finish_turn"), + return_direct=True, + metadata={"chemgraph_academy_tool_kind": "action_tool"}, + ), + ] + tools.extend(external_tools) + + return tools diff --git a/src/chemgraph/academy/core/turn.py b/src/chemgraph/academy/core/turn.py new file mode 100644 index 00000000..3b833849 --- /dev/null +++ b/src/chemgraph/academy/core/turn.py @@ -0,0 +1,139 @@ +"""Run one Academy logical-agent wakeup through ChemGraph.""" + +from __future__ import annotations +import json +import time +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from langchain_core.tools import BaseTool +from chemgraph.academy.core.campaign import ChemGraphAgentSpec, ChemGraphCampaign +from chemgraph.academy.core.campaign import visible_resources_payload +from chemgraph.academy.core.prompt import PromptProfile +from chemgraph.academy.observability.run_files import read_json_file +from chemgraph.agent.turn import run_turn +from chemgraph.models.settings import LLMSettings + +TraceFn = Callable[[str, dict[str, Any]], None] +ACTION_TOOL_NAMES = frozenset({"send_message", "ask_peer", "submit_result", "finish_turn"}) +TERMINAL_TOOL_NAMES = ("finish_turn", "submit_result") + +@dataclass(frozen=True) +class ReasoningTurnResult: + final_text: str + executed_tool_names: tuple[str, ...] + action_tools_called: tuple[str, ...] + science_tools_called: tuple[str, ...] + requested_finish: bool + requested_self_wake: bool + thread_id: str + +async def run_academy_turn( + *, + campaign: ChemGraphCampaign, + spec: ChemGraphAgentSpec, + llm_settings: LLMSettings, + prompt_profile: PromptProfile, + run_dir: Path, + max_decisions: int, + tools: list[BaseTool], + received_message_history: list[dict[str, Any]], + outbox: list[dict[str, Any]], + tool_results: list[dict[str, Any]], + get_final_result: Callable[[], dict[str, Any] | None], + get_round_index: Callable[[], int], + trace: TraceFn, + peer_names: tuple[str, ...] = (), +) -> ReasoningTurnResult: + round_index = get_round_index() + thread_id = f"{spec.name}-round-{round_index}" + trace("chemgraph_reasoning_turn_started", {"round": round_index, "thread_id": thread_id, "tool_names": [t.name for t in tools]}) + + def on_event(event: str, payload: dict) -> None: + trace(event, {"round": round_index, **payload}) + + available_tool_names = tuple( + tool.name for tool in tools if tool.name not in ACTION_TOOL_NAMES + ) + result = await run_turn( + query=json.dumps(_state(campaign, spec, prompt_profile, run_dir, max_decisions, round_index, received_message_history, outbox, tool_results, get_final_result, peer_names, available_tool_names), sort_keys=True), + tools=tools, + model_name=llm_settings.model, + base_url=llm_settings.base_url, + api_key=llm_settings.api_key, + argo_user=llm_settings.user, + system_prompt=prompt_profile.system_prompt, + recursion_limit=prompt_profile.langchain_recursion_limit, + thread_id=thread_id, + terminal_tool_names=TERMINAL_TOOL_NAMES, + on_event=on_event, + ) + if not result.executed_tool_names: + raise RuntimeError("ChemGraph reasoning turn returned without calling an Academy action or science tool; call finish_turn when no external action is useful.") + action_tools = tuple(n for n in result.executed_tool_names if n in ACTION_TOOL_NAMES) + science_tools = tuple(n for n in result.executed_tool_names if n not in ACTION_TOOL_NAMES) + out = ReasoningTurnResult( + final_text=result.final_text, + executed_tool_names=result.executed_tool_names, + action_tools_called=action_tools, + science_tools_called=science_tools, + requested_finish=result.terminal_tool in TERMINAL_TOOL_NAMES, + requested_self_wake=bool(science_tools), + thread_id=result.thread_id, + ) + trace("chemgraph_reasoning_turn_finished", {"round": round_index, "thread_id": out.thread_id, "action_tools_called": list(action_tools), "science_tools_called": list(science_tools), "requested_finish": out.requested_finish, "requested_self_wake": out.requested_self_wake}) + return out + +def _state(campaign, spec, profile, run_dir, max_decisions, round_index, messages, outbox, results, get_final_result, peer_names, available_tool_names) -> dict[str, Any]: + limits = profile.state_limits + return { + "campaign": campaign.run_id, + "user_task": campaign.user_task, + "agent_name": spec.name, + "role": spec.role, + "mission": spec.mission, + "round": round_index, + "max_decisions": max_decisions, + "resources": visible_resources_payload(campaign, spec), + "allowed_peers": list(spec.allowed_peers), + "peer_status": build_peer_status(run_dir=run_dir, peer_names=peer_names), + "available_chemgraph_tools": list(available_tool_names), + "received_messages": _tail(messages, limits.received_messages_last_n), + "local_chemgraph_tool_results": _tail(results, limits.tool_results_last_n), + "recent_actions": build_recent_actions(outbox=outbox, tool_results=results, limit=limits.actions_last_n), + "current_final_result": get_final_result(), + "required_protocol": profile.protocol_prompt, + } + +def build_peer_status(*, run_dir: Path, peer_names: tuple[str, ...]) -> dict[str, dict[str, Any]]: + return {peer: _status(run_dir, peer, now=time.time()) for peer in peer_names} + +def build_recent_actions(*, outbox: list[dict[str, Any]], tool_results: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: + if limit <= 0: + return [] + actions = [{"type": "send_message", "recipient": m.get("recipient"), "reply_requested": bool(m.get("reply_requested")), "tldr": m.get("tldr") or _preview(m.get("content")), "message_id": m.get("message_id"), "timestamp": m.get("timestamp")} for m in outbox[-limit:]] + actions += [{"type": "tool_call", "tool_name": r.get("tool_name"), "tool_result_id": r.get("tool_result_id"), "status": r.get("status"), "timestamp": r.get("timestamp")} for r in tool_results[-limit:]] + return sorted(actions, key=lambda i: float(i.get("timestamp") or 0.0))[-limit:] + +def _status(run_dir: Path, peer: str, *, now: float) -> dict[str, Any]: + data = read_json_file(run_dir / "agent_status" / f"{peer}.json", default={}) + timestamp = _float(data.get("status_updated_at")) + state = "unknown" if not data else "error" if data.get("last_error") else "finished" if data.get("finished") else "idle" + return {"state": state, "round": data.get("round"), "finished": bool(data.get("finished")) if data else False, "last_error": data.get("last_error"), "seconds_since_update": None if timestamp is None else max(0.0, round(now - timestamp, 3))} + + +def _tail(items: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: + return items[-limit:] if limit else [] + + +def _float(value: Any) -> float | None: + try: + return None if value is None or isinstance(value, bool) else float(value) + except (TypeError, ValueError): + return None + + +def _preview(value: Any, *, max_chars: int = 160) -> str: + text = "" if value is None else str(value) + return text if len(text) <= max_chars else text[: max_chars - 1] + "..." diff --git a/src/chemgraph/academy/dashboard/__init__.py b/src/chemgraph/academy/dashboard/__init__.py new file mode 100644 index 00000000..27eb75b0 --- /dev/null +++ b/src/chemgraph/academy/dashboard/__init__.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from chemgraph.academy.dashboard.server import DashboardHandler +from chemgraph.academy.dashboard.server import events_payload +from chemgraph.academy.dashboard.server import main +from chemgraph.academy.dashboard.server import parse_args +from chemgraph.academy.dashboard.server import serve_dashboard +from chemgraph.academy.dashboard.server import snapshot +from chemgraph.academy.dashboard.server import status_payload + +__all__ = [ + 'DashboardHandler', + 'events_payload', + 'main', + 'parse_args', + 'serve_dashboard', + 'snapshot', + 'status_payload', +] diff --git a/src/chemgraph/academy/dashboard/__main__.py b/src/chemgraph/academy/dashboard/__main__.py new file mode 100644 index 00000000..a9021b2a --- /dev/null +++ b/src/chemgraph/academy/dashboard/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from chemgraph.academy.dashboard.server import main + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/src/chemgraph/academy/dashboard/server.py b/src/chemgraph/academy/dashboard/server.py new file mode 100644 index 00000000..3c50741f --- /dev/null +++ b/src/chemgraph/academy/dashboard/server.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import argparse +import socket +import json +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from importlib.resources import files +from pathlib import Path +from typing import Any + +from chemgraph.academy.observability.event_log import read_events +from chemgraph.academy.observability.run_files import read_json_file +from chemgraph.academy.observability.run_artifacts import write_run_artifacts + +_STATIC_CACHE: dict[str, bytes] = {} + + +def _static_file(name: str, content_type: str) -> tuple[bytes, str]: + if name not in _STATIC_CACHE: + resource = files('chemgraph.academy.dashboard').joinpath( + 'static', + name, + ) + _STATIC_CACHE[name] = resource.read_bytes() + return _STATIC_CACHE[name], content_type + + +class DashboardHandler(BaseHTTPRequestHandler): + run_dir: Path + + def do_GET(self) -> None: + path = self.path.split('?', 1)[0] + if path in {'/', '/index.html'}: + body, content_type = _static_file('index.html', 'text/html; charset=utf-8') + self._send_bytes(200, body, content_type) + return + if path == '/static/app.js': + body, content_type = _static_file( + 'app.js', + 'application/javascript; charset=utf-8', + ) + self._send_bytes(200, body, content_type) + return + if path == '/api/status': + self._send_json(200, status_payload(self)) + return + if path == '/api/events': + self._send_json(200, events_payload(self.run_dir)) + return + if path == '/api/snapshot': + self._send_json(200, snapshot(self)) + return + self._send_json(404, {'error': 'not found'}) + + def log_message(self, format: str, *args: Any) -> None: + return + + def _send_json(self, status: int, payload: dict[str, Any]) -> None: + body = json.dumps(payload, indent=2, sort_keys=True).encode('utf-8') + self._send_bytes(status, body, 'application/json') + + def _send(self, status: int, body: str, content_type: str) -> None: + self._send_bytes(status, body.encode('utf-8'), content_type) + + def _send_bytes(self, status: int, body: bytes, content_type: str) -> None: + try: + self.send_response(status) + self.send_header('Content-Type', content_type) + self.send_header('Content-Length', str(len(body))) + self.end_headers() + self.wfile.write(body) + except (BrokenPipeError, ConnectionResetError, socket.timeout): + return + + +def snapshot(handler: DashboardHandler) -> dict[str, Any]: + data = status_payload(handler) + data.update(events_payload(handler.run_dir)) + return data + + +def status_payload(handler: DashboardHandler) -> dict[str, Any]: + run_dir = handler.run_dir + status_path = run_dir / "status.json" + status: dict[str, Any] = {} + if status_path.exists(): + try: + status = json.loads(status_path.read_text(encoding="utf-8")) + except json.JSONDecodeError: + status = {} + artifacts = write_run_artifacts(run_dir) + manifest = read_json_file(run_dir / "manifest.json", default={}) + updated = status.get("updated") or status.get("timestamp") + schema = ( + status.get("mode") + or (manifest.get("mode") if isinstance(manifest, dict) else None) + or "canonical_events" + ) + return { + "run_dir": str(run_dir), + "updated": updated, + "schema": schema, + "status": status, + "placement": artifacts["placement"], + "summary": artifacts["summary"], + } + + +def events_payload(run_dir: Path) -> dict[str, Any]: + events = [ + event.model_dump(mode="json") for event in read_events(run_dir / "events.jsonl") + ] + return { + "run_dir": str(run_dir), + "events": events, + } + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--run-dir", required=True) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8765) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + return serve_dashboard( + run_dir=Path(args.run_dir).resolve(), + host=args.host, + port=args.port, + ) + + +def serve_dashboard(*, run_dir: Path, host: str, port: int) -> int: + DashboardHandler.run_dir = run_dir + server = ThreadingHTTPServer((host, port), DashboardHandler) + print(f"Serving {run_dir} at http://{host}:{port}", flush=True) + try: + server.serve_forever() + except KeyboardInterrupt: + print("\nStopping dashboard.", flush=True) + finally: + server.server_close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/academy/dashboard/static/app.js b/src/chemgraph/academy/dashboard/static/app.js new file mode 100644 index 00000000..1796c0e8 --- /dev/null +++ b/src/chemgraph/academy/dashboard/static/app.js @@ -0,0 +1,3072 @@ + let snapshot = null; + let snapshotIdentity = null; + let selectedAgent = null; + let selectedEdgeKey = null; + let selectedActivityEventKey = null; + let timelineIndex = null; + let followLatest = true; + let graphMode = 'recent'; + let isReplaying = false; + let replayTimer = null; + let replayStartedAtMs = 0; + let replayStartTimestamp = null; + let replayStartIndex = 0; + let graphView = null; + let graphPanDrag = null; + let embeddedWorkflowView = null; + let embeddedWorkflowPanDrag = null; + let selectedEmbeddedWorkflowEventKey = null; + let embeddedWorkflowAnchorKey = null; + let workflowPanelOpen = true; + let workflowPanelFrame = {x: 28, y: 76, width: 1120, height: 640}; + let workflowPanelDrag = null; + let workflowPanelResizeDrag = null; + let detailResizeDrag = null; + let lastRenderedDetailIdentity = null; + let lastEmbeddedWorkflowInspectorIdentity = null; + const recentMessageWindow = 4; + const actionToolNames = new Set(['send_message', 'submit_result', 'finish_turn']); + const renderedHtmlCache = new WeakMap(); + + const esc = (s) => String(s ?? '').replace(/[&<>"']/g, c => ({'&':'&','<':'<','>':'>','"':'"',"'":'''}[c])); + const trunc = (s, n=180) => { + s = String(s ?? ''); + return s.length > n ? s.slice(0, n - 1) + '…' : s; + }; + const formatTime = (timestamp) => timestamp ? new Date(timestamp * 1000).toLocaleTimeString() : '-'; + const eventTimestamp = (event) => { + const value = Number(event?.timestamp); + return Number.isFinite(value) ? value : null; + }; + const replaySpeed = () => Number(document.getElementById('replaySpeed')?.value || 25); + const eventHoldSeconds = () => Number(document.getElementById('eventHold')?.value || 8); + + function estimateLabelWidth(text) { + return Math.min(460, Math.max(20, String(text || '').length * 6.2 + 14)); + } + + function labelBox(x, y, text) { + const width = estimateLabelWidth(text); + const height = 18; + return { + x1: x - width / 2, + y1: y - height + 5, + x2: x + width / 2, + y2: y + 7, + }; + } + + function boxesOverlap(a, b, pad = 5) { + return !(a.x2 + pad < b.x1 || b.x2 + pad < a.x1 || a.y2 + pad < b.y1 || b.y2 + pad < a.y1); + } + + function placeEdgeLabel(baseX, baseY, rawLabel, occupiedBoxes, force = false) { + const text = trunc(String(rawLabel || ''), force ? 96 : 72); + if (!text) return null; + const candidates = [ + [0, 0], [0, -18], [0, 18], [30, -12], [-30, 12], + [44, 18], [-44, -18], [0, -34], [0, 34], + ]; + for (const [dx, dy] of candidates) { + const x = baseX + dx; + const y = baseY + dy; + const box = labelBox(x, y, text); + const collision = occupiedBoxes.some(other => boxesOverlap(box, other)); + if (force || !collision) { + occupiedBoxes.push(box); + return {x, y, text}; + } + } + return null; + } + + async function load() { + const [statusRes, eventsRes] = await Promise.all([ + fetch('/api/status'), + fetch('/api/events'), + ]); + const statusData = await statusRes.json(); + const eventsData = await eventsRes.json(); + const nextSnapshot = {...statusData, events: eventsData.events || []}; + const nextIdentity = identityForSnapshot(nextSnapshot); + const previousEventCount = snapshot?.events?.length || 0; + const nextEventCount = nextSnapshot.events.length; + if ( + snapshotIdentity !== null + && ( + nextIdentity !== snapshotIdentity + || nextEventCount < previousEventCount + ) + ) { + resetInteractionState(); + } + snapshotIdentity = nextIdentity; + snapshot = nextSnapshot; + const latest = allEvents().length - 1; + if (followLatest || timelineIndex === null) { + timelineIndex = latest; + } else { + timelineIndex = Math.min(timelineIndex, latest); + } + render(); + } + + function identityForSnapshot(data) { + return [ + data?.schema || '', + data?.run_dir || '', + ].join('|'); + } + + function resetInteractionState() { + stopReplay(false); + selectedAgent = null; + selectedEdgeKey = null; + selectedActivityEventKey = null; + timelineIndex = null; + followLatest = true; + graphView = null; + graphPanDrag = null; + embeddedWorkflowView = null; + embeddedWorkflowPanDrag = null; + lastRenderedDetailIdentity = null; + lastEmbeddedWorkflowInspectorIdentity = null; + selectedEmbeddedWorkflowEventKey = null; + embeddedWorkflowAnchorKey = null; + } + + function allEvents() { + return snapshot?.events || []; + } + + function isWorkflowMode() { + return snapshot?.schema === 'chemgraph_workflow'; + } + + function currentEventIndex() { + const events = allEvents(); + if (!events.length) return -1; + if (timelineIndex === null) return events.length - 1; + return Math.max(0, Math.min(timelineIndex, events.length - 1)); + } + + function visibleEvents() { + const index = currentEventIndex(); + return index < 0 ? [] : allEvents().slice(0, index + 1); + } + + function currentEvent() { + const index = currentEventIndex(); + return index < 0 ? null : allEvents()[index]; + } + + function eventKey(event) { + if (!event) return ''; + if (event.event_id) return String(event.event_id); + const payload = JSON.stringify(event.payload || {}); + return [ + event.timestamp ?? '', + event.event || '', + event.agent_id || '', + event.correlation_id || '', + payload.slice(0, 220), + ].join('|'); + } + + function firstTimestamp() { + const first = allEvents().find(event => eventTimestamp(event) !== null); + return eventTimestamp(first); + } + + function currentTimestamp() { + return eventTimestamp(currentEvent()); + } + + function eventIndexAtTimestamp(timestamp) { + const events = allEvents(); + if (!events.length) return -1; + if (timestamp === null || timestamp === undefined || Number.isNaN(timestamp)) { + return Math.min(events.length - 1, replayStartIndex + 1); + } + let index = 0; + for (let i = 0; i < events.length; i += 1) { + const ts = eventTimestamp(events[i]); + if (ts === null) { + index = i; + continue; + } + if (ts <= timestamp) index = i; + else break; + } + return index; + } + + function activeWindowEvents(multiplier = 1) { + const events = visibleEvents(); + const now = currentTimestamp(); + if (now === null) return events.slice(-Math.max(1, recentMessageWindow)); + const hold = eventHoldSeconds() * multiplier; + return events.filter(event => { + const ts = eventTimestamp(event); + return ts !== null && ts <= now && now - ts <= hold; + }); + } + + function eventsOf(type) { + return visibleEvents().filter(e => e.event === type); + } + + function graphMessageEvents() { + const sent = eventsOf('message_sent'); + if (graphMode === 'cumulative') return sent; + if (graphMode === 'current') { + const event = currentEvent(); + const activeMessages = activeWindowEvents(1).filter(e => e.event === 'message_sent'); + if (!event) return activeMessages; + if (activeMessages.length) return activeMessages; + if (event.event === 'message_sent') return [event]; + if (event.event === 'message_received') { + const messageId = event.payload?.message_id; + return sent.filter(item => item.payload?.message_id === messageId).slice(-1); + } + return []; + } + const now = currentTimestamp(); + if (now !== null) { + const windowSeconds = Math.max(eventHoldSeconds() * 2, 8); + const windowed = sent.filter(event => { + const ts = eventTimestamp(event); + return ts !== null && ts <= now && now - ts <= windowSeconds; + }); + if (windowed.length) return windowed; + } + return sent.slice(-recentMessageWindow); + } + + function graphModeLabel() { + if (graphMode === 'current') return `showing active events for ${eventHoldSeconds()}s`; + if (graphMode === 'cumulative') return 'showing all prior communication'; + return `showing recent communication window`; + } + + function latestEventOf(type, agentId = null) { + const matches = visibleEvents().filter(e => e.event === type && (!agentId || e.agent_id === agentId)); + return matches.length ? matches[matches.length - 1] : null; + } + + function agents() { + if (!snapshot) return []; + const specs = snapshot.status?.agents || []; + const currentEvents = visibleEvents(); + const visiblePlacements = {}; + currentEvents.forEach(event => { + if (event.event !== 'agent_started' || !event.agent_id) return; + const placement = event.payload?.placement; + if (placement) visiblePlacements[event.agent_id] = placement; + }); + const finalPlacements = snapshot.placement?.agents || {}; + return specs.map(spec => { + const agentId = spec.agent_id || spec.agent_name || spec.name; + return { + ...spec, + agent_id: agentId, + agent_name: spec.agent_name || agentId, + ...agentStateAt(agentId, currentEvents), + placement: visiblePlacements[agentId] || finalPlacements[agentId] || spec.placement || {}, + }; + }).filter(agent => agent.agent_id); + } + + function agentStateAt(agentId, events) { + const state = { + started: false, + last_error: null, + decision_count: 0, + received_message_count: 0, + outbox_count: 0, + tool_started_count: 0, + tool_finished_count: 0, + }; + events.forEach(event => { + if (event.agent_id !== agentId) return; + if (event.event === 'agent_started') state.started = true; + if (event.event === 'agent_error') state.last_error = event.payload?.error || 'agent_error'; + if (event.event === 'agent_decision') state.decision_count += 1; + if (event.event === 'message_received') state.received_message_count += 1; + if (event.event === 'message_sent') state.outbox_count += 1; + if (event.event === 'tool_call_started') state.tool_started_count += 1; + if (event.event === 'tool_call_finished' || event.event === 'tool_call_failed') state.tool_finished_count += 1; + }); + return state; + } + + function agentHost(agent) { + return agent?.placement?.short_hostname || agent?.placement?.hostname || (agent?.started ? 'unknown host' : 'pending'); + } + + function hostColor(index) { + const colors = ['#dbeafe', '#dcfce7', '#fef3c7', '#fce7f3', '#e0e7ff', '#ccfbf1', '#fee2e2', '#ede9fe']; + return colors[index % colors.length]; + } + + function hostStroke(index) { + const colors = ['#2563eb', '#16a34a', '#d97706', '#db2777', '#4f46e5', '#0f766e', '#dc2626', '#7c3aed']; + return colors[index % colors.length]; + } + + function render() { + const detailScroll = captureDetailScrollSnapshot(); + document.getElementById('updated').textContent = snapshot.updated ? new Date(snapshot.updated * 1000).toLocaleTimeString() : ''; + document.getElementById('runPath').textContent = snapshot.run_dir || ''; + document.getElementById('graphTitle').textContent = isWorkflowMode() ? 'ChemGraph Workflow' : 'Agent Graph'; + renderTimeline(); + renderMetrics(); + renderGraph(); + renderAgentPicker(); + renderDetail(); + renderEmbeddedWorkflowPanel(); + restoreDetailScrollSnapshot(detailScroll); + lastRenderedDetailIdentity = currentDetailIdentity(); + } + + function renderTimeline() { + const events = allEvents(); + const slider = document.getElementById('timeSlider'); + const index = currentEventIndex(); + slider.max = String(Math.max(0, events.length - 1)); + slider.value = String(Math.max(0, index)); + slider.disabled = events.length === 0; + const event = index >= 0 ? events[index] : null; + const mode = isReplaying ? 'replay' : followLatest ? 'latest' : `event ${index + 1}`; + document.getElementById('timeLabel').textContent = `${mode} / ${events.length}`; + document.getElementById('timeEvent').textContent = event + ? `${formatTime(event.timestamp)} ${event.event}${event.agent_id ? ` · ${event.agent_id}` : ''} · ${graphModeLabel()}` + : ''; + document.getElementById('playReplay').textContent = isReplaying ? 'Pause' : 'Replay'; + document.querySelectorAll('#graphMode button').forEach(button => { + button.classList.toggle('active', button.dataset.mode === graphMode); + }); + } + + function renderMetrics() { + if (isWorkflowMode()) { + renderWorkflowMetrics(); + return; + } + const events = visibleEvents(); + const counts = {}; + events.forEach(event => { counts[event.event] = (counts[event.event] || 0) + 1; }); + const currentAgents = agents(); + const startedAgents = currentAgents.filter(agent => agent.started); + const hostByAgent = new Map(currentAgents.map(agent => [agent.agent_id, agentHost(agent)])); + const hosts = new Set(startedAgents.map(agentHost).filter(host => host && host !== 'pending')); + const finish = latestEventOf('campaign_finished')?.payload || {}; + const messageEvents = events.filter(event => event.event === 'message_sent'); + const crossNodeMessages = messageEvents.filter(event => { + const p = event.payload || {}; + const senderHost = hostByAgent.get(p.sender); + const recipientHost = hostByAgent.get(p.recipient); + return senderHost && recipientHost && senderHost !== recipientHost; + }); + const maceResults = events.filter(event => ( + ['tool_call_finished', 'chemgraph_job_result'].includes(event.event) + && event.payload?.tool_name === 'run_mace_ensemble' + )); + const values = [ + ['Finish', finish.reason || 'running'], + ['Decisions', counts.agent_decision || 0], + ['Agents / Hosts', `${startedAgents.length} / ${hosts.size}`], + ['Errors', counts.agent_error || 0], + ['Messages', messageEvents.length], + ['Cross-node', crossNodeMessages.length], + ['Tool calls', counts.tool_call_started || 0], + ['Workflows', counts.workflow_started || 0], + ]; + document.getElementById('metrics').innerHTML = values.map(([k,v]) => ` +
${esc(k)}
${esc(v)}
+ `).join(''); + document.getElementById('proof').innerHTML = crossNodeMessages.length + ? `cross-node messages=${crossNodeMessages.length}` + : ''; + } + + function renderWorkflowMetrics() { + const events = visibleEvents(); + const counts = {}; + events.forEach(event => { counts[event.event] = (counts[event.event] || 0) + 1; }); + const status = snapshot.status || {}; + const finish = events.filter(event => event.event === 'workflow_finished').slice(-1)[0]?.payload || {}; + const toolResults = events.filter(event => event.event === 'tool_call_finished' && event.payload?.runtime); + const tokenEvents = workflowTokenEvents(events); + const tokenTotals = summedTokenCounts(events); + const values = [ + ['Status', finish.status || status.status || 'running'], + ['Workflow', status.workflow_type || finish.workflow_type || '-'], + ['Events', events.length], + ['LM calls', tokenEvents.length || (counts.llm_decision || 0)], + ['LM tokens', tokenTotals ? formatTokenCount(tokenTotals.total) : '-'], + ['Tool results', toolResults.length], + ['Errors', finish.status === 'failed' ? 1 : 0], + ['Model', status.model_name || '-'], + ['Span', trunc(status.workflow_span_id || finish.span_id || '-', 18)], + ]; + document.getElementById('metrics').innerHTML = values.map(([k,v]) => ` +
${esc(k)}
${esc(v)}
+ `).join(''); + document.getElementById('proof').innerHTML = 'local ChemGraph workflow'; + } + + function workflowGraphEvents() { + return visibleEvents().filter(event => isWorkflowEvent(event)); + } + + function activeEmbeddedWorkflowContext() { + if (isWorkflowMode()) return null; + const activity = selectedActivityEvent(); + if (activity) { + const events = workflowEventsForSelection(activity); + if (events.length) { + return { + events, + anchorKey: eventKey(activity), + title: isWorkflowEvent(activity) + ? `ChemGraph: ${workflowAgentId(activity) || activity.agent_id || 'workflow'}` + : `ChemGraph: ${activity.agent_id || 'agent'}`, + meta: embeddedWorkflowMeta(events), + }; + } + } + return null; + } + + function embeddedWorkflowMeta(events) { + const flow = workflowFlowGraph(events); + const tokenEventCount = workflowTokenEvents(events).length; + const llmCount = tokenEventCount || events.filter(event => event.event === 'llm_decision').length; + const toolCount = flow.nodes.filter(node => node.type === 'tool').length; + const tokenTotals = summedTokenCounts(events); + const first = events[0]; + const p = first?.payload || {}; + return [ + p.thread_id || (p.round !== undefined ? `round ${p.round}` : ''), + `${llmCount} LM`, + tokenTotals ? `${formatTokenCount(tokenTotals.total)} tok` : '', + `${toolCount} tools/actions`, + `${events.length} events`, + ].filter(Boolean).join(' · '); + } + + function renderEmbeddedWorkflowPanel() { + const context = activeEmbeddedWorkflowContext(); + const panel = document.getElementById('workflowFloatingPanel'); + const tab = document.getElementById('workflowFloatingTab'); + if (!context || !context.events.length) { + panel.classList.add('hidden'); + tab.classList.add('hidden'); + return; + } + if (!workflowPanelOpen) { + panel.classList.add('hidden'); + tab.classList.remove('hidden'); + return; + } + tab.classList.add('hidden'); + panel.classList.remove('hidden'); + applyWorkflowPanelFrame(); + if (embeddedWorkflowAnchorKey !== context.anchorKey) { + embeddedWorkflowAnchorKey = context.anchorKey; + selectedEmbeddedWorkflowEventKey = null; + embeddedWorkflowView = null; + } + if ( + selectedEmbeddedWorkflowEventKey + && !context.events.some(event => eventKey(event) === selectedEmbeddedWorkflowEventKey) + ) { + selectedEmbeddedWorkflowEventKey = null; + } + document.getElementById('workflowFloatingTitle').textContent = context.title; + document.getElementById('workflowFloatingMeta').textContent = context.meta; + renderEmbeddedWorkflowGraph(context.events); + renderEmbeddedWorkflowInspector(context.events); + } + + function currentEmbeddedWorkflowInspectorIdentity(events) { + const selected = selectedEmbeddedWorkflowEvent(events); + return [ + embeddedWorkflowAnchorKey || '', + selected ? eventKey(selected) : 'summary', + ].join('|'); + } + + function applyWorkflowPanelFrame() { + const panel = document.getElementById('workflowFloatingPanel'); + const maxWidth = Math.max(520, window.innerWidth - 24); + const maxHeight = Math.max(320, window.innerHeight - 24); + workflowPanelFrame.width = Math.min(Math.max(workflowPanelFrame.width, 520), maxWidth); + workflowPanelFrame.height = Math.min(Math.max(workflowPanelFrame.height, 320), maxHeight); + workflowPanelFrame.x = Math.min(Math.max(workflowPanelFrame.x, 8), window.innerWidth - 80); + workflowPanelFrame.y = Math.min(Math.max(workflowPanelFrame.y, 8), window.innerHeight - 56); + panel.style.left = `${workflowPanelFrame.x}px`; + panel.style.top = `${workflowPanelFrame.y}px`; + panel.style.width = `${workflowPanelFrame.width}px`; + panel.style.height = `${workflowPanelFrame.height}px`; + } + + function renderEmbeddedWorkflowGraph(events) { + const svg = document.getElementById('embeddedWorkflowGraph'); + const empty = document.getElementById('embeddedWorkflowEmpty'); + const flow = workflowFlowGraph(events); + const nodes = flow.nodes; + const edges = flow.edges; + if (!nodes.length) { + svg.innerHTML = ''; + empty.textContent = 'No ChemGraph workflow nodes visible for this selection.'; + empty.classList.remove('hidden'); + return; + } + empty.classList.add('hidden'); + + const nodeW = 190; + const nodeH = 66; + const columnGap = 112; + const toolLaneGap = 82; + const toolStartY = 260; + const maxColumn = Math.max(...nodes.map(node => node.column || 0)); + const maxToolLanes = Math.max(1, ...nodes.filter(node => node.type === 'tool').map(node => node.laneCount || 1)); + const width = Math.max(1120, 120 + (maxColumn + 1) * (nodeW + columnGap)); + const height = Math.max(540, toolStartY + maxToolLanes * toolLaneGap + 100); + const yByType = {input: 236, lm: 128, output: 236}; + const positions = new Map(); + nodes.forEach(node => { + const column = node.column || 0; + const y = node.type === 'tool' + ? toolStartY + (node.laneIndex || 0) * toolLaneGap + : (yByType[node.type] || 236); + positions.set(node.id, { + x: 96 + nodeW / 2 + column * (nodeW + columnGap), + y, + }); + }); + + const selectedEvent = selectedEmbeddedWorkflowEvent(events); + const selectedNodeId = selectedEvent ? workflowFlowNodeId(selectedEvent) : null; + const current = currentEvent(); + const currentNodeId = current ? workflowFlowNodeId(current) : null; + const selectedEdgeIds = new Set(); + if (selectedNodeId) { + edges.forEach(edge => { + if (edge.from === selectedNodeId || edge.to === selectedNodeId) { + selectedEdgeIds.add(`${edge.from}->${edge.to}`); + } + }); + } + const nodeById = new Map(nodes.map(node => [node.id, node])); + const edgeSvg = edges.map(edge => { + const prev = nodeById.get(edge.from); + const node = nodeById.get(edge.to); + const source = positions.get(edge.from); + const target = positions.get(edge.to); + if (!prev || !node || !source || !target) return ''; + const startX = source.x + nodeW / 2; + const endX = target.x - nodeW / 2; + const midX = (startX + endX) / 2; + const controlY = Math.min(source.y, target.y) - 54; + const path = `M ${startX.toFixed(1)} ${source.y.toFixed(1)} Q ${midX.toFixed(1)} ${controlY.toFixed(1)} ${endX.toFixed(1)} ${target.y.toFixed(1)}`; + const cls = [ + 'workflow-edge', + node.id === currentNodeId || prev.id === currentNodeId ? 'current' : '', + selectedEdgeIds.has(`${edge.from}->${edge.to}`) ? 'related' : '', + ].filter(Boolean).join(' '); + return ` + + ${esc(prev.title)} -> ${esc(node.title)} + + `; + }).join(''); + + const nodeSvg = nodes.map(node => { + const pos = positions.get(node.id); + const classes = [ + 'workflow-node', + node.type, + node.toolClass || '', + node.failed ? 'error' : '', + node.id === currentNodeId ? 'current' : '', + node.id === selectedNodeId ? 'selected' : '', + ].filter(Boolean).join(' '); + return ` + + + ${esc(trunc(node.title, 25))} + ${esc(trunc(node.meta, 34))} + ${esc(formatTime(node.event.timestamp))} + ${esc(formatWorkflowEvent(node.event))} + + `; + }).join(''); + + ensureEmbeddedWorkflowView(width, height); + svg.innerHTML = ` + + + + + + + ChemGraph turn · ${nodes.length} node(s) · ${events.length} visible event(s) + + ${edgeSvg} + ${nodeSvg} + `; + updateEmbeddedWorkflowViewBox(); + svg.querySelectorAll('[data-embedded-workflow-event-key]').forEach(node => { + node.addEventListener('click', event => { + selectedEmbeddedWorkflowEventKey = node.dataset.embeddedWorkflowEventKey; + renderEmbeddedWorkflowPanel(); + event.stopPropagation(); + }); + }); + } + + function selectedEmbeddedWorkflowEvent(events) { + if (!selectedEmbeddedWorkflowEventKey) return null; + return events.find(event => eventKey(event) === selectedEmbeddedWorkflowEventKey) || null; + } + + function renderEmbeddedWorkflowInspector(events) { + const title = document.getElementById('embeddedWorkflowInspectorTitle'); + const meta = document.getElementById('embeddedWorkflowInspectorMeta'); + const body = document.getElementById('embeddedWorkflowInspectorBody'); + const identity = currentEmbeddedWorkflowInspectorIdentity(events); + const previousIdentity = lastEmbeddedWorkflowInspectorIdentity; + const previousScrollTop = body.scrollTop; + const previousScrollLeft = body.scrollLeft; + const event = selectedEmbeddedWorkflowEvent(events); + if (!event) { + const tokenTotals = summedTokenCounts(events); + const flow = workflowFlowGraph(events); + const html = detailRich( + detailSection('Turn Summary', detailKvGrid([ + ['Visible events', events.length], + ['Nodes', flow.nodes.length], + ['LM calls', workflowTokenEvents(events).length || events.filter(item => item.event === 'llm_decision').length], + ['Input tokens', tokenTotals?.input ?? '-'], + ['Output tokens', tokenTotals?.output ?? '-'], + ['Total tokens', tokenTotals?.total ?? '-'], + ]), 'info'), + detailSection( + 'Inspect', + paragraphsHtml('Click a ChemGraph node in this panel to inspect its LM tokens, tool arguments, output, and payload without changing the outer dashboard selection.'), + ), + ); + title.textContent = 'ChemGraph Inspector'; + meta.textContent = 'Select an LM, tool, action, or output node.'; + setStableHtml(body, html, identity === previousIdentity); + if (identity === previousIdentity) { + body.scrollTop = previousScrollTop; + body.scrollLeft = previousScrollLeft; + } + lastEmbeddedWorkflowInspectorIdentity = identity; + return; + } + const html = detailRich( + chemgraphNodeDetailHtml(event), + payloadDetailHtml(event.payload || {}), + chemgraphNodeContextHtml(event), + ); + title.textContent = chemgraphNodeDetailTitle(event); + meta.textContent = `${formatTime(event.timestamp)} · ${event.event}`; + setStableHtml(body, html, identity === previousIdentity); + if (identity === previousIdentity) { + body.scrollTop = previousScrollTop; + body.scrollLeft = previousScrollLeft; + } + lastEmbeddedWorkflowInspectorIdentity = identity; + } + + function ensureEmbeddedWorkflowView(width, height) { + const padX = Math.max(180, width * 0.08); + const padY = Math.max(100, height * 0.12); + const bounds = { + x: -padX, + y: -padY, + width: width + padX * 2, + height: height + padY * 2, + }; + if (!embeddedWorkflowView) { + embeddedWorkflowView = { + x: bounds.x, + y: bounds.y, + width: bounds.width, + height: bounds.height, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + return; + } + if ( + embeddedWorkflowView.layoutWidth !== width + || embeddedWorkflowView.layoutHeight !== height + ) { + const nextView = preserveViewForLayoutChange( + embeddedWorkflowView, + bounds, + width, + height, + ); + embeddedWorkflowView = { + ...embeddedWorkflowView, + ...nextView, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + clampEmbeddedWorkflowView(); + } + } + + function updateEmbeddedWorkflowViewBox() { + const svg = document.getElementById('embeddedWorkflowGraph'); + if (!embeddedWorkflowView) return; + svg.setAttribute( + 'viewBox', + `${embeddedWorkflowView.x.toFixed(1)} ${embeddedWorkflowView.y.toFixed(1)} ${embeddedWorkflowView.width.toFixed(1)} ${embeddedWorkflowView.height.toFixed(1)}` + ); + } + + function clampEmbeddedWorkflowView() { + if (!embeddedWorkflowView) return; + const boundsX = embeddedWorkflowView.boundsX ?? 0; + const boundsY = embeddedWorkflowView.boundsY ?? 0; + const boundsWidth = embeddedWorkflowView.boundsWidth ?? embeddedWorkflowView.layoutWidth; + const boundsHeight = embeddedWorkflowView.boundsHeight ?? embeddedWorkflowView.layoutHeight; + embeddedWorkflowView.width = Math.min(boundsWidth, Math.max(embeddedWorkflowView.layoutWidth / 12, embeddedWorkflowView.width)); + embeddedWorkflowView.height = Math.min(boundsHeight, Math.max(embeddedWorkflowView.layoutHeight / 12, embeddedWorkflowView.height)); + embeddedWorkflowView.x = Math.min(Math.max(boundsX, embeddedWorkflowView.x), boundsX + boundsWidth - embeddedWorkflowView.width); + embeddedWorkflowView.y = Math.min(Math.max(boundsY, embeddedWorkflowView.y), boundsY + boundsHeight - embeddedWorkflowView.height); + } + + function zoomEmbeddedWorkflow(factor) { + if (!embeddedWorkflowView) return; + const centerX = embeddedWorkflowView.x + embeddedWorkflowView.width / 2; + const centerY = embeddedWorkflowView.y + embeddedWorkflowView.height / 2; + embeddedWorkflowView.width *= factor; + embeddedWorkflowView.height *= factor; + embeddedWorkflowView.x = centerX - embeddedWorkflowView.width / 2; + embeddedWorkflowView.y = centerY - embeddedWorkflowView.height / 2; + clampEmbeddedWorkflowView(); + updateEmbeddedWorkflowViewBox(); + } + + function resetEmbeddedWorkflowView() { + embeddedWorkflowView = null; + renderEmbeddedWorkflowPanel(); + } + + function renderWorkflowGraph() { + const svg = document.getElementById('graph'); + const events = workflowGraphEvents(); + document.getElementById('hostLegend').innerHTML = ` + query + LM + action tool + science tool + output + failure + `; + if (!events.length) { + svg.setAttribute('viewBox', '0 0 1000 260'); + svg.innerHTML = 'No ChemGraph workflow events yet.'; + return; + } + + const flow = workflowFlowGraph(events); + const nodes = flow.nodes; + const edges = flow.edges; + if (!nodes.length) { + svg.setAttribute('viewBox', '0 0 1000 260'); + svg.innerHTML = 'Waiting for ChemGraph workflow execution events.'; + return; + } + + const nodeW = 184; + const nodeH = 64; + const columnGap = 96; + const toolLaneGap = 76; + const toolStartY = 230; + const maxColumn = Math.max(...nodes.map(node => node.column || 0)); + const maxToolLanes = Math.max(1, ...nodes.filter(node => node.type === 'tool').map(node => node.laneCount || 1)); + const width = Math.max(1040, 120 + (maxColumn + 1) * (nodeW + columnGap)); + const height = Math.max(500, toolStartY + maxToolLanes * toolLaneGap + 80); + const yByType = {input: 220, lm: 126, output: 220}; + const positions = new Map(); + nodes.forEach(node => { + const column = node.column || 0; + const y = node.type === 'tool' + ? toolStartY + (node.laneIndex || 0) * toolLaneGap + : (yByType[node.type] || 220); + positions.set(node.id, { + x: 96 + nodeW / 2 + column * (nodeW + columnGap), + y, + }); + }); + const current = currentEvent(); + const selectedEvent = selectedActivityEvent(); + const currentNodeId = current ? workflowFlowNodeId(current) : null; + const selectedNodeId = selectedEvent ? workflowFlowNodeId(selectedEvent) : null; + const selectedEdgeIds = new Set(); + if (selectedNodeId) { + edges.forEach(edge => { + if (edge.from === selectedNodeId || edge.to === selectedNodeId) { + selectedEdgeIds.add(`${edge.from}->${edge.to}`); + } + }); + } + + const nodeById = new Map(nodes.map(node => [node.id, node])); + const edgeSvg = edges.map(edge => { + const prev = nodeById.get(edge.from); + const node = nodeById.get(edge.to); + const source = positions.get(edge.from); + const target = positions.get(edge.to); + if (!prev || !node || !source || !target) return ''; + const startX = source.x + nodeW / 2; + const endX = target.x - nodeW / 2; + const midX = (startX + endX) / 2; + const controlY = Math.min(source.y, target.y) - 46; + const path = `M ${startX.toFixed(1)} ${source.y.toFixed(1)} Q ${midX.toFixed(1)} ${controlY.toFixed(1)} ${endX.toFixed(1)} ${target.y.toFixed(1)}`; + const cls = [ + 'workflow-edge', + node.id === currentNodeId || prev.id === currentNodeId ? 'current' : '', + selectedEdgeIds.has(`${edge.from}->${edge.to}`) ? 'related' : '', + ].filter(Boolean).join(' '); + return ` + + ${esc(prev.title)} -> ${esc(node.title)} + + `; + }).join(''); + + const nodeSvg = nodes.map(node => { + const pos = positions.get(node.id); + const classes = [ + 'workflow-node', + node.type, + node.toolClass || '', + node.failed ? 'error' : '', + node.id === currentNodeId ? 'current' : '', + node.id === selectedNodeId ? 'selected' : '', + ].filter(Boolean).join(' '); + return ` + + + ${esc(trunc(node.title, 24))} + ${esc(trunc(node.meta, 32))} + ${esc(formatTime(node.event.timestamp))} + ${esc(formatWorkflowEvent(node.event))} + + `; + }).join(''); + + svg.style.minHeight = `${height}px`; + ensureGraphView(width, height); + svg.innerHTML = ` + + + + + + + ${esc(snapshot.status?.workflow_type || 'ChemGraph workflow')} · ${nodes.length} flow node(s) · ${events.length} visible event(s) + + ${edgeSvg} + ${nodeSvg} + `; + updateGraphViewBox(); + svg.querySelectorAll('[data-activity-event-key]').forEach(activityEl => { + activityEl.addEventListener('click', event => { + selectedActivityEventKey = activityEl.dataset.activityEventKey; + selectedAgent = null; + selectedEdgeKey = null; + event.stopPropagation(); + render(); + }); + }); + } + + function workflowFlowGraph(events) { + const nodes = []; + const edges = []; + const toolNodeIndexes = new Map(); + const hasGraphWork = events.some(event => ( + event.event === 'llm_decision' + || event.event === 'workflow_output' + || event.event === 'run_finished' + || event.event.startsWith('tool_call_') + )); + const runStart = events.find(event => event.event === 'run_started') + || events.find(event => event.event === 'workflow_started'); + let lastColumn = -1; + let lastNodeIds = []; + let currentLmId = null; + let currentToolBatchIds = []; + let currentToolBatchColumn = null; + let lmTurn = 0; + + function addEdge(from, to) { + if (!from || !to || from === to) return; + if (!edges.some(edge => edge.from === from && edge.to === to)) { + edges.push({from, to}); + } + } + + function addEdges(fromIds, to) { + Array.from(new Set(fromIds.filter(Boolean))).forEach(from => addEdge(from, to)); + } + + function updateCurrentToolBatchLanes() { + currentToolBatchIds.forEach((id, laneIndex) => { + const index = toolNodeIndexes.get(id); + if (index === undefined) return; + nodes[index].laneIndex = laneIndex; + nodes[index].laneCount = currentToolBatchIds.length; + }); + } + + if (runStart && hasGraphWork) { + const node = workflowFlowNode(runStart, 0); + node.column = 0; + nodes.push(node); + lastColumn = 0; + lastNodeIds = [node.id]; + } + + events.forEach(event => { + if (event.event === 'llm_decision') { + lmTurn += 1; + const afterToolBatch = currentToolBatchIds.length > 0; + const node = workflowFlowNode(event, nodes.length, {lmTurn}); + node.column = afterToolBatch + ? currentToolBatchColumn + 1 + : lastColumn + 1; + nodes.push(node); + addEdges(afterToolBatch ? currentToolBatchIds : lastNodeIds, node.id); + lastColumn = node.column; + lastNodeIds = [node.id]; + currentLmId = node.id; + currentToolBatchIds = []; + currentToolBatchColumn = null; + return; + } + if (event.event.startsWith('tool_call_')) { + if (currentToolBatchColumn === null) { + currentToolBatchColumn = lastColumn + 1; + } + const node = workflowFlowNode(event, nodes.length); + if (toolNodeIndexes.has(node.id)) { + const index = toolNodeIndexes.get(node.id); + node.column = nodes[index].column; + node.laneIndex = nodes[index].laneIndex; + node.laneCount = nodes[index].laneCount; + nodes[index] = node; + } else { + node.column = currentToolBatchColumn; + node.laneIndex = currentToolBatchIds.length; + node.laneCount = currentToolBatchIds.length + 1; + toolNodeIndexes.set(node.id, nodes.length); + currentToolBatchIds.push(node.id); + nodes.push(node); + addEdges(currentLmId ? [currentLmId] : lastNodeIds, node.id); + updateCurrentToolBatchLanes(); + } + } + }); + + const output = events.filter(event => event.event === 'workflow_output').slice(-1)[0] + || ( + nodes.length + ? events.filter(event => event.event === 'workflow_finished' || event.event === 'run_finished').slice(-1)[0] + : null + ); + if (output) { + const afterToolBatch = currentToolBatchIds.length > 0; + const node = workflowFlowNode(output, nodes.length); + node.column = afterToolBatch + ? currentToolBatchColumn + 1 + : lastColumn + 1; + nodes.push(node); + addEdges(afterToolBatch ? currentToolBatchIds : lastNodeIds, node.id); + } + return {nodes, edges}; + } + + function workflowFlowNodeId(event) { + const p = event?.payload || {}; + if (!event) return null; + if (event.event === 'run_started' || event.event === 'workflow_started') return 'workflow-query'; + if (event.event === 'workflow_output' || event.event === 'workflow_finished' || event.event === 'run_finished') return 'workflow-output'; + if (event.event.startsWith('tool_call_')) { + return p.tool_call_id ? `workflow-tool-${p.tool_call_id}` : (p.span_id || eventKey(event)); + } + return p.span_id || eventKey(event); + } + + function toolCallSummary(toolCalls) { + const counts = new Map(); + toolCalls + .map(call => call?.name || call?.id || 'tool') + .filter(Boolean) + .forEach(name => counts.set(name, (counts.get(name) || 0) + 1)); + const parts = Array.from(counts.entries()).map(([name, count]) => count > 1 ? `${name} x${count}` : name); + return parts.join(', '); + } + + function numberOrNull(value) { + const number = Number(value); + return Number.isFinite(number) ? number : null; + } + + function tokenField(raw, names) { + if (!raw || typeof raw !== 'object') return null; + for (const name of names) { + if (raw[name] !== undefined && raw[name] !== null) { + const value = numberOrNull(raw[name]); + if (value !== null) return value; + } + } + return null; + } + + function llmTokenCounts(event) { + const raw = event?.payload?.token_counts; + if (!raw || typeof raw !== 'object') return null; + const input = tokenField(raw, ['input_tokens', 'prompt_tokens']); + const output = tokenField(raw, ['output_tokens', 'completion_tokens']); + let total = tokenField(raw, ['total_tokens']); + if (total === null && (input !== null || output !== null)) { + total = (input || 0) + (output || 0); + } + if (input === null && output === null && total === null) return null; + return { + input, + output, + total, + source: raw.source || 'unknown', + estimateScope: raw.estimate_scope || '', + rawUsage: raw.raw_usage, + raw, + }; + } + + function workflowTokenEvents(events) { + return events + .map(event => ({event, counts: llmTokenCounts(event)})) + .filter(item => item.counts); + } + + function summedTokenCounts(events) { + const items = workflowTokenEvents(events); + if (!items.length) return null; + let input = 0; + let output = 0; + let total = 0; + let sawInput = false; + let sawOutput = false; + let sawTotal = false; + const sources = new Set(); + items.forEach(({counts}) => { + if (counts.input !== null) { + input += counts.input; + sawInput = true; + } + if (counts.output !== null) { + output += counts.output; + sawOutput = true; + } + if (counts.total !== null) { + total += counts.total; + sawTotal = true; + } + if (counts.source) sources.add(counts.source); + }); + if (!sawTotal && (sawInput || sawOutput)) total = input + output; + return { + input: sawInput ? input : null, + output: sawOutput ? output : null, + total: sawTotal || sawInput || sawOutput ? total : null, + source: Array.from(sources).join('+') || 'unknown', + }; + } + + function formatTokenCount(value) { + const number = numberOrNull(value); + if (number === null) return '-'; + if (Math.abs(number) >= 1000000) return `${(number / 1000000).toFixed(1)}M`; + if (Math.abs(number) >= 10000) return `${Math.round(number / 1000)}k`; + if (Math.abs(number) >= 1000) return `${(number / 1000).toFixed(1)}k`; + return String(Math.round(number)); + } + + function tokenSourceLabel(source) { + if (source === 'local_estimate') return 'estimate'; + if (source === 'provider') return 'provider'; + return source || 'unknown'; + } + + function tokenSummary(event, {compact = false} = {}) { + const counts = llmTokenCounts(event); + if (!counts) return ''; + const source = tokenSourceLabel(counts.source); + if (compact) { + const sourceSuffix = source === 'estimate' ? ' · est' : source === 'provider' ? '' : ` · ${source}`; + return `tok ${formatTokenCount(counts.input)} in / ${formatTokenCount(counts.output)} out${sourceSuffix}`; + } + return [ + `input ${formatTokenCount(counts.input)}`, + `output ${formatTokenCount(counts.output)}`, + `total ${formatTokenCount(counts.total)}`, + source, + ].filter(Boolean).join(' · '); + } + + function promptDisclosureHtml(event) { + const messages = event?.payload?.prompt_messages; + if (!Array.isArray(messages) || !messages.length) return ''; + return ` +
+ Show full prompt (${messages.length} messages) +
${esc(formatJson(messages))}
+
+ `; + } + + function promptDetailSection(event) { + if (!llmTokenCounts(event)) return ''; + const disclosure = promptDisclosureHtml(event); + return detailSection( + 'Prompt', + disclosure || paragraphsHtml('Full prompt was not captured for this event. Rerun with the updated ChemGraph observability code to populate this field.'), + disclosure ? 'info' : 'warn', + ); + } + + function tokenDetailSection(event, title = 'Token Counts') { + const counts = llmTokenCounts(event); + if (!counts) return ''; + const body = [ + detailKvGrid([ + ['Input tokens', counts.input === null ? '-' : Math.round(counts.input)], + ['Output tokens', counts.output === null ? '-' : Math.round(counts.output)], + ['Total tokens', counts.total === null ? '-' : Math.round(counts.total)], + ['Source', tokenSourceLabel(counts.source)], + ['Estimate scope', counts.estimateScope || ''], + ]), + ].filter(Boolean).join(''); + return [ + detailSection(title, body, counts.source === 'provider' ? 'ok' : 'info'), + counts.rawUsage ? detailSection('Provider Usage', detailValueHtml(counts.rawUsage)) : '', + ].filter(Boolean).join(''); + } + + function workflowToolKind(event) { + const name = toolDisplayName(event); + return actionToolNames.has(name) ? 'action-tool' : 'science-tool'; + } + + function workflowToolKindLabel(event) { + return workflowToolKind(event) === 'action-tool' ? 'action' : 'science'; + } + + function workflowFlowNode(event, index, options = {}) { + const p = event.payload || {}; + const key = eventKey(event); + if (event.event === 'run_started' || event.event === 'workflow_started') { + return { + id: workflowFlowNodeId(event), + type: 'input', + title: p.nested ? 'Prompt' : 'Query', + meta: p.query || p.thread_id || (p.round !== undefined ? `round ${p.round}` : snapshot.status?.query || 'user request'), + event, + eventKey: key, + failed: false, + }; + } + if (event.event === 'llm_decision') { + const calls = Array.isArray(p.tool_calls) ? p.tool_calls : []; + const callNames = toolCallSummary(calls); + const tokens = tokenSummary(event, {compact: true}); + return { + id: workflowFlowNodeId(event), + type: 'lm', + title: `LM turn ${options.lmTurn || index}`, + meta: [tokens, callNames ? `calls: ${callNames}` : 'response'].filter(Boolean).join(' · '), + event, + eventKey: key, + failed: false, + }; + } + if (event.event.startsWith('tool_call_')) { + const failed = event.event === 'tool_call_failed' || p.status === 'failed'; + const kind = workflowToolKind(event); + return { + id: workflowFlowNodeId(event), + type: 'tool', + toolClass: kind, + title: toolDisplayName(event), + meta: `${workflowToolKindLabel(event)} · ${failed ? 'failed' : event.event === 'tool_call_started' ? 'running' : 'finished'}`, + event, + eventKey: key, + failed, + }; + } + return { + id: workflowFlowNodeId(event), + type: 'output', + title: 'Output', + meta: tokenSummary(event, {compact: true}) || p.content_preview || p.status || snapshot.status?.status || 'finished', + event, + eventKey: key, + failed: p.status === 'failed', + }; + } + + function renderGraph() { + if (isWorkflowMode()) { + renderWorkflowGraph(); + return; + } + const svg = document.getElementById('graph'); + const currentAgents = agents(); + if (!currentAgents.length) { + svg.setAttribute('viewBox', '0 0 1000 260'); + svg.innerHTML = 'No agents yet.'; + document.getElementById('hostLegend').innerHTML = 'Waiting for placement.'; + return; + } + + const byHost = new Map(); + currentAgents.forEach(agent => { + const host = agentHost(agent); + if (!byHost.has(host)) byHost.set(host, []); + byHost.get(host).push(agent); + }); + const hosts = Array.from(byHost.keys()).sort(); + const maxPerHost = Math.max(...hosts.map(host => byHost.get(host).length)); + const radial = currentAgents.length >= 2; + const width = radial + ? Math.max(1200, currentAgents.length * 220, hosts.length * 220) + : 1000; + const height = radial + ? Math.max(760, currentAgents.length * 135) + : Math.max(340, 110 + maxPerHost * 88); + const marginX = 42; + const top = 58; + const bottom = 34; + const laneGap = radial ? 8 : 14; + const laneWidth = (width - marginX * 2 - laneGap * Math.max(0, hosts.length - 1)) / hosts.length; + const nodeW = radial ? 132 : Math.min(154, Math.max(82, laneWidth - 18)); + const nodeH = radial ? 58 : 58; + const positions = new Map(); + svg.style.minWidth = ''; + svg.style.minHeight = `${height}px`; + + const hostIndex = new Map(hosts.map((host, index) => [host, index])); + const legendPrefix = ` + ChemGraph turn chip + ${radial ? 'radial host layout' : ''} + `; + const legend = legendPrefix + hosts.map((host, index) => ` + + + ${esc(host)} (${byHost.get(host).length}) + + `).join(''); + document.getElementById('hostLegend').innerHTML = legend; + + let bands = ''; + if (radial) { + const centerX = width / 2; + const centerY = height / 2; + const radiusScale = currentAgents.length < 5 ? 0.31 : 0.36; + const radiusX = width * radiusScale; + const radiusY = height * (currentAgents.length < 5 ? 0.30 : 0.34); + const hostCenters = new Map(); + bands = ` + + + ${currentAgents.length} daemon agents across ${hosts.length} host(s) + `; + hosts.forEach((host, index) => { + const angle = -Math.PI / 2 + (2 * Math.PI * index) / hosts.length; + const x = centerX + Math.cos(angle) * radiusX; + const y = centerY + Math.sin(angle) * radiusY; + hostCenters.set(host, {x, y, angle}); + const fill = hostColor(index); + const stroke = hostStroke(index); + bands += ` + + ${esc(trunc(host, 28))} + `; + }); + hosts.forEach((host, hIndex) => { + const list = byHost.get(host).slice().sort((a, b) => a.agent_id.localeCompare(b.agent_id)); + const center = hostCenters.get(host); + const spread = list.length > 1 ? Math.max(nodeH + 22, 84) : 0; + list.forEach((agent, index) => { + const offset = (index - (list.length - 1) / 2) * spread; + const tangentX = -Math.sin(center.angle); + const tangentY = Math.cos(center.angle); + const x = center.x + tangentX * offset; + const y = center.y + tangentY * offset; + positions.set(agent.agent_id, {x, y, host, hostIndex: hIndex, agent}); + }); + }); + } else { + bands = hosts.map((host, index) => { + const x = marginX + index * (laneWidth + laneGap); + const fill = hostColor(index); + const stroke = hostStroke(index); + const label = trunc(host, 34); + return ` + + ${esc(label)} + `; + }).join(''); + + hosts.forEach((host, hIndex) => { + const list = byHost.get(host).sort((a, b) => a.agent_id.localeCompare(b.agent_id)); + const x = marginX + hIndex * (laneWidth + laneGap) + laneWidth / 2; + list.forEach((agent, index) => { + const usable = height - top - bottom; + const y = top + ((index + 1) * usable) / (list.length + 1); + positions.set(agent.agent_id, {x, y, host, hostIndex: hIndex, agent}); + }); + }); + } + + const sent = graphMessageEvents(); + const recentIds = new Set(sent.slice(-10).map(e => e.payload?.message_id).filter(Boolean)); + const edgeMap = new Map(); + sent.forEach((event, index) => { + const p = event.payload || {}; + if (!positions.has(p.sender) || !positions.has(p.recipient)) return; + const key = `${p.sender}->${p.recipient}`; + const prev = edgeMap.get(key) || { + key, + sender: p.sender, + recipient: p.recipient, + count: 0, + latestIndex: -1, + latestMessageId: null, + latestTldr: '', + latestContent: '', + messages: [], + }; + prev.count += 1; + prev.latestIndex = index; + prev.latestMessageId = p.message_id; + prev.latestTldr = p.tldr || ''; + prev.latestContent = p.content || ''; + prev.messages.push(event); + edgeMap.set(key, prev); + }); + const edges = Array.from(edgeMap.values()).sort((a, b) => a.latestIndex - b.latestIndex); + if (selectedEdgeKey && !edgeMap.has(selectedEdgeKey)) { + selectedEdgeKey = null; + } + const allVisibleMessages = eventsOf('message_sent').length; + const showEdgeLabels = Boolean(selectedEdgeKey) || (graphMode !== 'cumulative' && edges.length <= 8); + + const labelBoxes = []; + const edgeSvg = edges.map((edge, index) => { + const source = positions.get(edge.sender); + const target = positions.get(edge.recipient); + const cross = source.host !== target.host; + const selectedByAgent = selectedAgent && (edge.sender === selectedAgent || edge.recipient === selectedAgent); + const selectedByEdge = selectedEdgeKey === edge.key; + const dimmed = (selectedAgent && !selectedByAgent) || (selectedEdgeKey && !selectedByEdge); + const recent = recentIds.has(edge.latestMessageId); + const start = edgeEndpoint(source, target, nodeW, nodeH, true); + const end = edgeEndpoint(source, target, nodeW, nodeH, false); + const hasReverse = edgeMap.has(`${edge.recipient}->${edge.sender}`); + const route = curvedRoute(start, end, edge, index, radial, width / 2, height / 2, hasReverse); + const path = route.path; + const cls = ['edge', cross ? 'cross-node' : '', recent ? 'recent' : '', selectedByEdge ? 'selected' : '', dimmed ? 'dimmed' : ''].filter(Boolean).join(' '); + const marker = cross ? 'url(#arrowCross)' : 'url(#arrow)'; + const labelX = route.labelX; + const labelY = route.labelY; + const edgeSummary = edge.latestTldr || ''; + const shouldShowSummary = edgeSummary && !dimmed && (selectedByEdge || recent || showEdgeLabels); + const rawLabel = shouldShowSummary ? edgeSummary : (selectedByEdge || showEdgeLabels ? edge.count : ''); + const placedLabel = placeEdgeLabel(labelX, labelY, rawLabel, labelBoxes, selectedByEdge); + const titleText = edgeSummary || edge.latestContent; + return ` + + ${esc(edge.sender)} -> ${esc(edge.recipient)} (${edge.count}) ${cross ? 'cross-node' : 'same-node'} ${esc(trunc(titleText, 180))} + + + ${placedLabel ? `${esc(placedLabel.text)}` : ''} + `; + }).join(''); + + const edgeHint = !edges.length ? ` + + ${allVisibleMessages ? 'No routes visible in this graph mode. Try Recent or All.' : 'No messages visible at this time point.'} + + ` : ''; + const bubbleSvg = renderActivityBubbles(currentAgents, positions, nodeW, nodeH, width, height); + + const nodeSvg = currentAgents.map(agent => { + const pos = positions.get(agent.agent_id); + const current = currentEvent(); + const hIndex = hostIndex.get(pos.host) || 0; + const x = pos.x - nodeW / 2; + const y = pos.y - nodeH / 2; + const classes = [ + 'agent-node', + agent.agent_id === selectedAgent ? 'selected' : '', + current?.agent_id === agent.agent_id ? 'current' : '', + agent.last_error ? 'error' : '', + !agent.started ? 'pending' : '', + selectedEdgeKey && !selectedEdgeKey.split('->').includes(agent.agent_id) ? 'dimmed' : '', + ].filter(Boolean).join(' '); + const status = !agent.started ? 'pending' : agent.last_error ? 'error' : `${agent.decision_count || 0} decisions`; + return ` + + + ${esc(trunc(agent.agent_id, 20))} + ${esc(trunc(agent.role || '', 24))} + ${esc(status)} + ${esc(agent.agent_id)} host: ${esc(pos.host)} role: ${esc(agent.role || '')} + + `; + }).join(''); + + ensureGraphView(width, height); + svg.innerHTML = ` + + + + + + + + + ${bands} + ${edgeSvg} + ${edgeHint} + ${nodeSvg} + ${bubbleSvg} + `; + updateGraphViewBox(); + svg.querySelectorAll('[data-edge-key]').forEach(edgeEl => { + edgeEl.addEventListener('click', () => { + selectedEdgeKey = edgeEl.dataset.edgeKey; + selectedAgent = null; + selectedActivityEventKey = null; + render(); + }); + }); + svg.querySelectorAll('.agent-node').forEach(node => { + node.addEventListener('click', () => { + selectedAgent = node.dataset.agent; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + }); + node.addEventListener('keydown', event => { + if (event.key === 'Enter' || event.key === ' ') { + selectedAgent = node.dataset.agent; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + } + }); + }); + svg.querySelectorAll('[data-activity-event-key]').forEach(activityEl => { + activityEl.addEventListener('click', event => { + selectedActivityEventKey = activityEl.dataset.activityEventKey; + selectedAgent = null; + selectedEdgeKey = null; + event.stopPropagation(); + render(); + }); + }); + } + + function edgeEndpoint(source, target, nodeW, nodeH, isStart) { + const from = isStart ? source : target; + const to = isStart ? target : source; + const dx = to.x - from.x; + const dy = to.y - from.y; + if (Math.abs(dx) > Math.abs(dy)) { + return { + x: from.x + Math.sign(dx || 1) * nodeW / 2, + y: from.y + (dy / Math.max(Math.abs(dx), 1)) * nodeH * 0.18, + }; + } + return { + x: from.x + (dx / Math.max(Math.abs(dy), 1)) * nodeW * 0.18, + y: from.y + Math.sign(dy || 1) * nodeH / 2, + }; + } + + function stableHash(text) { + let hash = 0; + for (const ch of String(text || '')) { + hash = ((hash << 5) - hash + ch.charCodeAt(0)) | 0; + } + return Math.abs(hash); + } + + function curvedRoute(start, end, edge, index, radial, centerX, centerY, hasReverse) { + const dx = end.x - start.x; + const dy = end.y - start.y; + const distance = Math.max(Math.hypot(dx, dy), 1); + const midX = (start.x + end.x) / 2; + const midY = (start.y + end.y) / 2; + const perpX = -dy / distance; + const perpY = dx / distance; + const direction = edge.sender < edge.recipient ? 1 : -1; + const jitter = ((stableHash(edge.key) % 5) - 2) * (radial ? 11 : 8); + let curve = Math.min(radial ? 230 : 130, Math.max(radial ? 78 : 42, distance * (radial ? 0.24 : 0.17))); + if (!hasReverse) curve *= 0.62; + curve = curve * direction + jitter; + + let controlX = midX + perpX * curve; + let controlY = midY + perpY * curve; + if (radial) { + const outX = midX - centerX; + const outY = midY - centerY; + const outDistance = Math.max(Math.hypot(outX, outY), 1); + const outward = Math.min(130, Math.max(36, distance * 0.16)); + controlX += (outX / outDistance) * outward; + controlY += (outY / outDistance) * outward; + } else { + controlX += ((index % 3) - 1) * 18; + } + + return { + path: `M ${start.x.toFixed(1)} ${start.y.toFixed(1)} Q ${controlX.toFixed(1)} ${controlY.toFixed(1)} ${end.x.toFixed(1)} ${end.y.toFixed(1)}`, + labelX: (midX * 0.65 + controlX * 0.35), + labelY: (midY * 0.65 + controlY * 0.35) - 4, + }; + } + + function renderActivityBubbles(currentAgents, positions, nodeW, nodeH, layoutWidth, layoutHeight) { + const grouped = new Map(); + function addGrouped(agentId, item) { + if (!agentId || !positions.has(agentId)) return; + if (!grouped.has(agentId)) grouped.set(agentId, []); + grouped.get(agentId).push(item); + } + activeWindowEvents(1).forEach(event => { + if (isWorkflowEvent(event)) return; + if (event.event.startsWith('tool_call_') || event.event === 'chemgraph_job_result') return; + const bubble = bubbleInfo(event); + if (!bubble) return; + addGrouped(event.agent_id, { + event, + bubble, + selected: eventKey(event) === selectedActivityEventKey, + sortTime: eventTimestamp(event) ?? 0, + }); + }); + currentAgents.forEach(agent => { + reasoningTurnsForAgent(agent.agent_id).slice(-3).forEach(turn => { + const representative = turn.representative; + if (!representative) return; + const selected = turn.events.some(event => eventKey(event) === selectedActivityEventKey); + addGrouped(agent.agent_id, { + event: representative, + bubble: chemgraphTurnBubbleInfo(turn), + selected, + sortTime: turn.lastTimestamp ?? turn.firstTimestamp ?? 0, + }); + }); + }); + const rows = []; + grouped.forEach((items, agentId) => { + const pos = positions.get(agentId); + items + .slice() + .sort((a, b) => (a.sortTime || 0) - (b.sortTime || 0)) + .slice(-4) + .forEach((item, index) => { + const key = eventKey(item.event); + const selected = item.selected || key === selectedActivityEventKey; + const width = Math.max(64, Math.min(156, item.bubble.label.length * 7 + 18)); + let x = pos.x + nodeW / 2 + 8; + if (x + width > layoutWidth - 10) x = pos.x - nodeW / 2 - width - 8; + const y = Math.max(12, Math.min(layoutHeight - 24, pos.y - nodeH / 2 + index * 24)); + rows.push(` + + + ${esc(item.bubble.label)} + ${esc(item.bubble.title)} + + `); + }); + }); + return rows.join(''); + } + + function chemgraphTurnBubbleInfo(turn) { + const round = turn.round === null || turn.round === undefined ? '-' : turn.round; + const toolCount = (turn.scienceToolCount || 0) + (turn.actionToolCount || 0); + const status = turn.status || 'running'; + const failed = String(status).toLowerCase() === 'failed'; + const label = `CG r${round} ${toolCount}t`; + const title = [ + `Open ChemGraph turn for ${workflowAgentId(turn.representative) || '-'}`, + `Round: ${round}`, + `Status: ${status}`, + `LM calls: ${turn.lmCount || 0}`, + `Science tools: ${turn.scienceToolCount || 0}`, + `Message actions: ${turn.actionToolCount || 0}`, + `Span: ${turn.spanId || '-'}`, + ].join('\n'); + return { + className: failed ? 'bubble-error' : 'bubble-chemgraph', + label, + title, + }; + } + + function toolDisplayName(event) { + const p = event.payload || {}; + return ( + p.tool_name + || p.tool + || p.name + || p.result?.tool_name + || p.result?.name + || p.tool_result?.tool_name + || 'tool' + ); + } + + function bubbleInfo(event) { + const p = event.payload || {}; + if (event.event === 'agent_decision') { + const actions = Array.isArray(p.actions) ? p.actions.length : 0; + return { + className: 'bubble-decision', + label: actions ? `decide ${actions}` : 'decide', + title: `${event.agent_id} decision\n${trunc(p.rationale || p.wake_reason || '', 220)}`, + }; + } + if (event.event === 'belief_updated') { + return { + className: 'bubble-belief', + label: 'belief', + title: `${event.agent_id} belief\n${formatBelief(p)}`, + }; + } + if (event.event === 'agent_error') { + return { + className: 'bubble-error', + label: 'error', + title: `${event.agent_id} error\n${p.error || formatJson(p)}`, + }; + } + return null; + } + + function ensureGraphView(width, height) { + const padX = width >= 1180 ? Math.max(420, width * 0.28) : 140; + const padY = width >= 1180 ? Math.max(220, height * 0.22) : 120; + const bounds = { + x: -padX, + y: -padY, + width: width + padX * 2, + height: height + padY * 2, + }; + if (!graphView) { + graphView = { + x: bounds.x, + y: bounds.y, + width: bounds.width, + height: bounds.height, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + return; + } + if ( + graphView.layoutWidth !== width + || graphView.layoutHeight !== height + ) { + const nextView = preserveViewForLayoutChange( + graphView, + bounds, + width, + height, + ); + graphView = { + ...graphView, + ...nextView, + layoutWidth: width, + layoutHeight: height, + boundsX: bounds.x, + boundsY: bounds.y, + boundsWidth: bounds.width, + boundsHeight: bounds.height, + }; + clampGraphView(); + } + } + + function preserveViewForLayoutChange(view, nextBounds, nextLayoutWidth, nextLayoutHeight) { + const previousBoundsWidth = view.boundsWidth || view.layoutWidth || nextLayoutWidth; + const previousBoundsHeight = view.boundsHeight || view.layoutHeight || nextLayoutHeight; + const zoomX = previousBoundsWidth / Math.max(view.width || previousBoundsWidth, 1); + const zoomY = previousBoundsHeight / Math.max(view.height || previousBoundsHeight, 1); + const centerXRatio = ( + (view.x || 0) + (view.width || previousBoundsWidth) / 2 - (view.boundsX || 0) + ) / Math.max(previousBoundsWidth, 1); + const centerYRatio = ( + (view.y || 0) + (view.height || previousBoundsHeight) / 2 - (view.boundsY || 0) + ) / Math.max(previousBoundsHeight, 1); + const width = nextBounds.width / Math.max(zoomX, 1e-6); + const height = nextBounds.height / Math.max(zoomY, 1e-6); + const centerX = nextBounds.x + centerXRatio * nextBounds.width; + const centerY = nextBounds.y + centerYRatio * nextBounds.height; + return { + x: centerX - width / 2, + y: centerY - height / 2, + width, + height, + }; + } + + function updateGraphViewBox() { + const svg = document.getElementById('graph'); + if (!graphView) return; + svg.setAttribute( + 'viewBox', + `${graphView.x.toFixed(1)} ${graphView.y.toFixed(1)} ${graphView.width.toFixed(1)} ${graphView.height.toFixed(1)}` + ); + } + + function clampGraphView() { + if (!graphView) return; + const boundsX = graphView.boundsX ?? 0; + const boundsY = graphView.boundsY ?? 0; + const boundsWidth = graphView.boundsWidth ?? graphView.layoutWidth; + const boundsHeight = graphView.boundsHeight ?? graphView.layoutHeight; + graphView.width = Math.min(boundsWidth, Math.max(graphView.layoutWidth / 10, graphView.width)); + graphView.height = Math.min(boundsHeight, Math.max(graphView.layoutHeight / 10, graphView.height)); + graphView.x = Math.min(Math.max(boundsX, graphView.x), boundsX + boundsWidth - graphView.width); + graphView.y = Math.min(Math.max(boundsY, graphView.y), boundsY + boundsHeight - graphView.height); + } + + function zoomGraph(factor) { + if (!graphView) return; + const centerX = graphView.x + graphView.width / 2; + const centerY = graphView.y + graphView.height / 2; + graphView.width *= factor; + graphView.height *= factor; + graphView.x = centerX - graphView.width / 2; + graphView.y = centerY - graphView.height / 2; + clampGraphView(); + updateGraphViewBox(); + } + + function resetGraphView() { + graphView = null; + renderGraph(); + } + + function renderAgentPicker() { + const picker = document.getElementById('agentSelect'); + if (isWorkflowMode()) { + picker.innerHTML = ''; + picker.value = ''; + picker.disabled = true; + return; + } + picker.disabled = false; + const options = [''].concat( + agents().map(agent => ``) + ).join(''); + picker.innerHTML = options; + picker.value = selectedAgent || ''; + } + + function selectedState() { + return agents().find(agent => agent.agent_id === selectedAgent) || null; + } + + function selectedActivityEvent() { + if (!selectedActivityEventKey) return null; + return allEvents().find(event => eventKey(event) === selectedActivityEventKey) || null; + } + + function currentDetailIdentity() { + if (selectedActivityEventKey) return `activity:${selectedActivityEventKey}`; + if (selectedEdgeKey) return `edge:${selectedEdgeKey}`; + if (selectedAgent) return `agent:${selectedAgent}`; + if (isWorkflowMode()) { + const event = currentEvent(); + return event ? `workflow-event:${eventKey(event)}` : 'workflow-empty'; + } + const event = currentEvent(); + return event ? `timeline-event:${eventKey(event)}` : 'empty'; + } + + function captureDetailScrollSnapshot() { + const blockIds = ['detailPrimary', 'detailSecondary', 'detailTertiary']; + const blocks = {}; + blockIds.forEach(id => { + const el = document.getElementById(id); + if (!el) return; + blocks[id] = { + scrollTop: el.scrollTop, + scrollLeft: el.scrollLeft, + }; + }); + return { + identity: lastRenderedDetailIdentity, + blocks, + }; + } + + function restoreDetailScrollSnapshot(snapshot) { + if (!snapshot || snapshot.identity !== currentDetailIdentity()) return; + Object.entries(snapshot.blocks || {}).forEach(([id, pos]) => { + const el = document.getElementById(id); + if (!el) return; + el.scrollTop = pos.scrollTop || 0; + el.scrollLeft = pos.scrollLeft || 0; + }); + } + + function renderDetail() { + const activityEvent = selectedActivityEvent(); + if (activityEvent) { + renderTimelineEventDetail( + activityEvent, + allEvents().findIndex(event => eventKey(event) === selectedActivityEventKey), + ); + return; + } + if (isWorkflowMode()) { + const event = currentEvent(); + if (event) { + renderTimelineEventDetail(event); + } else { + renderEmptyDetail(); + } + return; + } + if (selectedEdgeKey) { + renderEdgeDetail(selectedEdgeKey); + return; + } + const agent = selectedState(); + if (!agent) { + renderEmptyDetail(); + return; + } + selectedAgent = agent.agent_id; + document.getElementById('agentSelect').value = selectedAgent; + document.getElementById('detailTitle').textContent = agent.agent_id; + document.getElementById('detailCards').innerHTML = detailCards([ + ['Role', agent.role || '-'], + ['Host', agentHost(agent)], + ['Decisions', agent.decision_count || 0], + ['Received / Sent', `${agent.received_message_count || 0} / ${agent.outbox_count || 0}`], + ['Tools', `${agent.tool_finished_count || 0} / ${agent.tool_started_count || 0}`], + ['State', agent.last_error ? 'error' : agent.started ? 'active' : 'pending'], + ]); + const current = currentEvent(); + if (current?.event === 'agent_decision' && current.agent_id === agent.agent_id) { + setDetailBlock('detailPrimaryTitle', 'Current Decision', 'detailPrimary', formatDecisionEvent(current)); + setDetailBlock('detailSecondaryTitle', 'Wake Context', 'detailSecondary', formatWakeEvents(current)); + const turnEvents = workflowEventsForSelection(current); + if (turnEvents.length) { + setDetailHtmlBlock( + 'detailTertiaryTitle', + 'ChemGraph Turn', + 'detailTertiary', + detailRich(detailSection( + 'ChemGraph Panel', + paragraphsHtml('Open in the floating ChemGraph panel. Click inner graph nodes there to inspect LM, tool, action, and output details inside the panel.'), + 'info', + )), + ); + } else { + const received = eventsOf('message_received').filter(e => e.agent_id === agent.agent_id).slice(-4); + setDetailBlock('detailTertiaryTitle', 'Recent Received Messages', 'detailTertiary', received.length + ? received.map(formatMessageEvent).join('\n\n') + : 'No received messages at this point in the timeline.'); + } + return; + } + const beliefEvents = eventsOf('belief_updated').filter(e => e.agent_id === agent.agent_id); + const latestBelief = beliefEvents.length ? beliefEvents[beliefEvents.length - 1].payload : null; + setDetailBlock('detailPrimaryTitle', 'Current Belief', 'detailPrimary', latestBelief + ? formatBelief(latestBelief) + : 'No belief recorded at this point in the timeline.'); + const received = eventsOf('message_received').filter(e => e.agent_id === agent.agent_id).slice(-6); + setDetailHtmlBlock('detailSecondaryTitle', 'Received Messages', 'detailSecondary', received.length + ? messageHistoryHtml(received) + : '
No received messages at this point in the timeline.
'); + const turns = reasoningTurnsForAgent(agent.agent_id); + setDetailHtmlBlock( + 'detailTertiaryTitle', + 'ChemGraph Turn Entries', + 'detailTertiary', + reasoningTurnListHtml(turns) + (turns.length + ? detailRich(detailSection( + 'Open Turn', + paragraphsHtml('Click a ChemGraph turn chip attached to this agent in the graph, or click a turn row above. The inner ChemGraph graph opens in the floating panel.'), + 'info', + )) + : '
No ChemGraph turns for this agent at this point in the timeline.
'), + ); + } + + function renderEdgeDetail(edgeKey) { + const [sender, recipient] = edgeKey.split('->'); + const currentAgents = agents(); + const senderAgent = currentAgents.find(agent => agent.agent_id === sender); + const recipientAgent = currentAgents.find(agent => agent.agent_id === recipient); + const senderHost = agentHost(senderAgent); + const recipientHost = agentHost(recipientAgent); + const messages = eventsOf('message_sent').filter(e => { + const p = e.payload || {}; + return p.sender === sender && p.recipient === recipient; + }); + const latest = messages.length ? messages[messages.length - 1] : null; + const latestPayload = latest?.payload || {}; + const route = senderHost && recipientHost && senderHost !== recipientHost ? 'cross-node' : 'same-node'; + document.getElementById('detailTitle').textContent = `${sender} -> ${recipient}`; + document.getElementById('detailCards').innerHTML = detailCards([ + ['Route', route], + ['Messages', messages.length], + ['From host', senderHost], + ['To host', recipientHost], + ]); + setDetailHtmlBlock('detailPrimaryTitle', 'Latest Message', 'detailPrimary', latest + ? messageDetailHtml(latest) + : '
No message visible at this point in the timeline.
'); + const history = messages.slice(-8); + setDetailHtmlBlock('detailSecondaryTitle', 'Message History', 'detailSecondary', history.length + ? messageHistoryHtml(history) + : '
No messages visible at this point in the timeline.
'); + const messageIds = new Set(messages.map(e => e.payload?.message_id).filter(Boolean)); + const relatedBeliefs = eventsOf('belief_updated').filter(e => { + const refs = e.payload?.supporting_message_ids || []; + return refs.some(ref => messageIds.has(ref)); + }); + setDetailBlock('detailTertiaryTitle', 'Beliefs Citing This Edge', 'detailTertiary', relatedBeliefs.length + ? relatedBeliefs.slice(-6).map(e => `${formatTime(e.timestamp)} ${e.agent_id}\n${formatBelief(e.payload)}`).join('\n\n') + : 'No belief cites this relationship at this point in the timeline.'); + } + + function renderEmptyDetail() { + const event = currentEvent(); + if (event) { + renderTimelineEventDetail(event); + return; + } + document.getElementById('detailTitle').textContent = 'Timeline Event'; + document.getElementById('detailCards').innerHTML = ''; + setDetailBlock('detailPrimaryTitle', 'State', 'detailPrimary', 'No events yet.'); + setDetailBlock('detailSecondaryTitle', 'Evidence', 'detailSecondary', ''); + setDetailBlock('detailTertiaryTitle', 'History', 'detailTertiary', ''); + } + + function renderTimelineEventDetail(event, indexOverride = null) { + const index = indexOverride ?? currentEventIndex(); + const isToolEvent = ['tool_call_started', 'tool_call_finished', 'tool_call_failed', 'chemgraph_job_result'].includes(event.event); + const isNestedWorkflowEvent = isWorkflowEvent(event); + document.getElementById('detailTitle').textContent = isNestedWorkflowEvent + ? chemgraphNodeDetailTitle(event) + : isToolEvent + ? `Tool: ${toolDisplayName(event)}` + : `Timeline Event ${index + 1}`; + const cards = [ + ['Event', event.event], + ['Time', formatTime(event.timestamp)], + ['Agent', event.agent_id || '-'], + ['Role', event.role || '-'], + ]; + if (isToolEvent) cards.push(['Tool', toolDisplayName(event)]); + if (isNestedWorkflowEvent) cards.push(['Runtime', event.payload?.runtime || '-']); + document.getElementById('detailCards').innerHTML = detailCards(cards); + if (isNestedWorkflowEvent) { + setDetailHtmlBlock('detailPrimaryTitle', chemgraphNodeDetailTitle(event), 'detailPrimary', chemgraphNodeDetailHtml(event)); + setDetailHtmlBlock('detailSecondaryTitle', 'Node Payload', 'detailSecondary', payloadDetailHtml(event.payload || {})); + setDetailHtmlBlock('detailTertiaryTitle', 'ChemGraph Context', 'detailTertiary', chemgraphNodeContextHtml(event)); + return; + } + if (event.event === 'agent_decision') { + setDetailBlock('detailPrimaryTitle', 'Agent Decision', 'detailPrimary', formatDecisionEvent(event)); + setDetailBlock('detailSecondaryTitle', 'Wake Context', 'detailSecondary', formatWakeEvents(event)); + const turnEvents = workflowEventsForSelection(event); + if (turnEvents.length) { + setDetailHtmlBlock('detailTertiaryTitle', 'ChemGraph Turn', 'detailTertiary', detailRich(detailSection( + 'ChemGraph Panel', + paragraphsHtml('The ChemGraph turn for this decision is shown in the floating panel. Click an inner node to inspect it inside the panel.'), + 'info', + ))); + } else { + setDetailBlock('detailTertiaryTitle', 'Raw Action Count', 'detailTertiary', `${event.payload?.actions?.length || 0} action(s) returned by LM.`); + } + return; + } + if (event.event === 'message_sent' || event.event === 'message_received') { + setDetailHtmlBlock('detailPrimaryTitle', 'Message', 'detailPrimary', messageDetailHtml(event)); + setDetailHtmlBlock('detailSecondaryTitle', 'Route', 'detailSecondary', routeDetailHtml(event)); + setDetailHtmlBlock('detailTertiaryTitle', 'Payload', 'detailTertiary', payloadDetailHtml(event.payload || {})); + return; + } + if (event.event === 'belief_updated') { + setDetailBlock('detailPrimaryTitle', 'Belief Update', 'detailPrimary', formatBelief(event.payload || {})); + setDetailBlock('detailSecondaryTitle', 'Supporting Messages', 'detailSecondary', (event.payload?.supporting_message_ids || []).join('\n') || 'No message refs.'); + setDetailBlock('detailTertiaryTitle', 'Supporting Artifacts', 'detailTertiary', (event.payload?.supporting_artifact_ids || []).join('\n') || 'No artifact refs.'); + return; + } + if (['tool_call_started', 'tool_call_finished', 'tool_call_failed', 'chemgraph_job_result'].includes(event.event)) { + setDetailHtmlBlock('detailPrimaryTitle', 'Tool Event', 'detailPrimary', toolDetailHtml(event)); + setDetailHtmlBlock('detailSecondaryTitle', 'Tool Payload', 'detailSecondary', payloadDetailHtml(event.payload || {})); + const nested = nestedWorkflowEventsForTool(event); + setDetailHtmlBlock('detailTertiaryTitle', nested.length ? 'ChemGraph Tool Trace' : 'Correlation', 'detailTertiary', nested.length + ? workflowHistoryHtml(nested) + : payloadDetailHtml({correlation_id: event.correlation_id || 'No correlation id.'})); + return; + } + setDetailHtmlBlock('detailPrimaryTitle', 'Event Payload', 'detailPrimary', payloadDetailHtml(event.payload || {})); + setDetailBlock('detailSecondaryTitle', 'Correlation', 'detailSecondary', event.correlation_id || 'No correlation id.'); + setDetailBlock('detailTertiaryTitle', 'Selection', 'detailTertiary', 'Click a node or edge to inspect derived agent or communication state.'); + } + + function detailCards(items) { + return items.map(([label, value]) => ` +
+
${esc(label)}
+
${esc(value)}
+
+ `).join(''); + } + + function setDetailBlock(titleId, title, bodyId, body) { + document.getElementById(titleId).textContent = title; + const el = document.getElementById(bodyId); + const text = body || ''; + if (el.textContent !== text) { + el.textContent = text; + renderedHtmlCache.delete(el); + } + } + + function setDetailHtmlBlock(titleId, title, bodyId, bodyHtml) { + document.getElementById(titleId).textContent = title; + setStableHtml(document.getElementById(bodyId), bodyHtml || '', true); + } + + function setStableHtml(el, html, preserveScroll = true) { + if (!el) return; + const next = html || ''; + if (renderedHtmlCache.get(el) === next) return; + const scrollTop = el.scrollTop; + const scrollLeft = el.scrollLeft; + el.innerHTML = next; + renderedHtmlCache.set(el, next); + if (preserveScroll) { + el.scrollTop = scrollTop; + el.scrollLeft = scrollLeft; + } + } + + function selectActivityEventKey(key) { + if (!key) return; + selectedActivityEventKey = key; + selectedAgent = null; + selectedEdgeKey = null; + const index = allEvents().findIndex(event => eventKey(event) === key); + if (index >= 0) { + followLatest = false; + timelineIndex = index; + const event = allEvents()[index]; + if (workflowEventsForSelection(event).length) { + workflowPanelOpen = true; + } + } + render(); + } + + function handleDetailPaneClick(event) { + const target = event.target.closest('[data-detail-activity-key]'); + if (!target) return; + event.preventDefault(); + event.stopPropagation(); + selectActivityEventKey(target.dataset.detailActivityKey); + } + + function detailRich(...parts) { + return `
${parts.filter(Boolean).join('')}
`; + } + + function detailSection(title, body, tone = '') { + return ` +
+
${esc(title)}
+ ${body || '
None
'} +
+ `; + } + + function detailKvGrid(rows) { + const visibleRows = rows.filter(([_, value]) => !isEmptyDetailValue(value)); + if (!visibleRows.length) return '
None
'; + return ` +
+ ${visibleRows.map(([label, value, kind]) => ` +
${esc(label)}
+
${detailValueHtml(value, kind)}
+ `).join('')} +
+ `; + } + + function isEmptyDetailValue(value) { + return value === undefined + || value === null + || value === '' + || (Array.isArray(value) && value.length === 0); + } + + function detailValueHtml(value, kind = '') { + if (value === undefined || value === null || value === '') return '-'; + if (Array.isArray(value)) { + if (!value.length) return 'none'; + if (value.every(item => ['string', 'number', 'boolean'].includes(typeof item))) { + return detailChips(value, kind); + } + return collapsedJsonHtml(`Array (${value.length})`, value); + } + if (typeof value === 'object') { + return collapsedJsonHtml(`Object (${Object.keys(value).length})`, value); + } + const text = String(value); + if (kind === 'text') return paragraphsHtml(text); + return esc(text); + } + + function detailChips(values, kind = '') { + const list = Array.isArray(values) ? values : [values]; + if (!list.length) return 'none'; + return ` +
+ ${list.map(value => `${esc(String(value))}`).join('')} +
+ `; + } + + function paragraphsHtml(text) { + const value = String(text || '').trim(); + if (!value) return '
None
'; + const paragraphs = value.split(/\n{2,}/).map(part => part.trim()).filter(Boolean); + return `
${paragraphs.map(part => `

${esc(part)}

`).join('')}
`; + } + + function collapsedJsonHtml(summary, value) { + return ` +
+ ${esc(summary)} +
${esc(formatJson(value))}
+
+ `; + } + + function rawJsonDetails(value) { + return ` +
+ Raw JSON +
${esc(formatJson(value))}
+
+ `; + } + + function statusTone(status) { + const value = String(status || '').toLowerCase(); + if (['ok', 'success', 'completed', 'finished'].includes(value)) return 'ok'; + if (['failed', 'failure', 'error'].includes(value)) return 'error'; + if (['running', 'pending', 'submitted'].includes(value)) return 'warn'; + return 'info'; + } + + function messageDetailHtml(event, {includeRaw = true} = {}) { + const p = event.payload || {}; + const sender = p.sender || event.agent_id || '-'; + const recipient = p.recipient || '-'; + const refs = [ + ...(Array.isArray(p.evidence_refs) ? p.evidence_refs : []), + ...(Array.isArray(p.tool_result_ids) ? p.tool_result_ids : []), + ...(Array.isArray(p.supporting_message_ids) ? p.supporting_message_ids : []), + ]; + return detailRich( + detailSection('Route', detailKvGrid([ + ['Direction', `${sender} -> ${recipient}`], + ['Time', formatTime(event.timestamp)], + ['Message id', p.message_id || '-', 'mono'], + ['Event', event.event], + ]), 'info'), + p.tldr ? detailSection('TLDR', paragraphsHtml(p.tldr), 'info') : '', + p.content ? detailSection('Content', paragraphsHtml(p.content)) : '', + p.reason ? detailSection('Reason', paragraphsHtml(p.reason), 'warn') : '', + refs.length ? detailSection('References', detailChips(refs, 'action')) : '', + includeRaw ? rawJsonDetails(p) : '', + ); + } + + function messageHistoryHtml(messages) { + if (!messages.length) return detailRich('
No messages visible.
'); + return detailRich(messages.map(event => { + const p = event.payload || {}; + const title = `${p.sender || event.agent_id || '-'} -> ${p.recipient || '-'}`; + const body = [ + detailKvGrid([ + ['Time', formatTime(event.timestamp)], + ['Message id', p.message_id || '-', 'mono'], + ]), + p.tldr ? paragraphsHtml(p.tldr) : paragraphsHtml(trunc(p.content || '', 360)), + ].join(''); + return detailSection(title, body, ''); + }).join('')); + } + + function routeDetailHtml(event) { + const p = event.payload || {}; + const currentAgents = agents(); + const sender = currentAgents.find(agent => agent.agent_id === p.sender); + const recipient = currentAgents.find(agent => agent.agent_id === p.recipient); + const senderHost = agentHost(sender); + const recipientHost = agentHost(recipient); + const route = senderHost && recipientHost && senderHost !== recipientHost ? 'cross-node' : 'same-node'; + return detailRich( + detailSection('Route', detailKvGrid([ + ['Type', route], + ['Sender host', senderHost], + ['Recipient host', recipientHost], + ['Message id', p.message_id || '-', 'mono'], + ]), route === 'cross-node' ? 'ok' : 'info'), + ); + } + + function issueListHtml(issues) { + if (!Array.isArray(issues) || !issues.length) { + return '
No issues recorded.
'; + } + return issues.map((issue, index) => detailSection( + `${index + 1}. ${issue.field || 'field'}`, + detailKvGrid([ + ['Expected', issue.expected || '-'], + ['Received type', issue.received_type || '-'], + ['Received', issue.received ?? '', 'mono'], + ['Defaulted to', issue.defaulted_to ?? '', 'mono'], + ['Normalized to', issue.normalized_to ?? '', 'mono'], + ['Dropped items', issue.dropped_items ?? ''], + ['Allowed peers', issue.allowed_peers || [], 'action'], + ]), + )).join(''); + } + + function workflowDetailHtml(event) { + const p = event.payload || {}; + const calls = Array.isArray(p.tool_calls) + ? p.tool_calls.map(call => call.name || call.id || 'tool').filter(Boolean) + : []; + return detailRich( + detailSection('Workflow', detailKvGrid([ + ['Status', p.status || event.event], + ['Runtime', p.runtime || '-'], + ['Workflow', p.workflow_type || '-'], + ['Node', p.workflow_node || '-'], + ['Round', p.round ?? '-'], + ['Thread', p.thread_id || '-', 'mono'], + ['Time', formatTime(event.timestamp)], + ]), statusTone(p.status)), + detailSection('Span', detailKvGrid([ + ['Span id', p.span_id || '-', 'mono'], + ['Parent span', p.parent_span_id || '-', 'mono'], + ['Correlation', event.correlation_id || '-', 'mono'], + ['Log dir', p.log_dir || '-', 'mono'], + ])), + calls.length ? detailSection('Tool Calls', detailChips(calls, 'science')) : '', + p.content_preview ? detailSection('Preview', paragraphsHtml(p.content_preview)) : '', + p.error ? detailSection('Error', paragraphsHtml(p.error), 'error') : '', + ); + } + + function chemgraphNodeDetailTitle(event) { + if (event.event === 'llm_decision') return 'LM Decision Node'; + if (event.event.startsWith('tool_call_')) return `Tool Node: ${toolDisplayName(event)}`; + if (event.event === 'workflow_started' || event.event === 'run_started') return 'Wake Input Node'; + if (event.event === 'workflow_finished' || event.event === 'workflow_output' || event.event === 'run_finished') return 'Output Node'; + return `ChemGraph Node: ${event.event}`; + } + + function chemgraphNodeDetailHtml(event) { + const p = event.payload || {}; + if (event.event === 'llm_decision') { + const calls = Array.isArray(p.tool_calls) ? p.tool_calls : []; + return detailRich( + detailSection('LM Decision', detailKvGrid([ + ['Agent', event.agent_id || p.agent_id || '-'], + ['Role', event.role || p.role || '-'], + ['Time', formatTime(event.timestamp)], + ['Message index', p.message_index ?? '-'], + ['Tool calls', calls.length], + ]), 'info'), + tokenDetailSection(event), + promptDetailSection(event), + calls.length ? detailSection('Requested Tool Calls', detailChips(calls.map(call => call.name || call.id || 'tool'), 'science')) : '', + calls.length ? detailSection('Call Arguments', calls.map((call, index) => detailSection( + `${index + 1}. ${call.name || call.id || 'tool'}`, + detailValueHtml(call.args || call.arguments || {}), + )).join('')) : '', + p.content_preview ? detailSection('Response Preview', paragraphsHtml(p.content_preview)) : '', + ); + } + if (event.event.startsWith('tool_call_')) { + return toolDetailHtml(event); + } + if (event.event === 'workflow_started' || event.event === 'run_started') { + return detailRich( + detailSection('Wake Input', detailKvGrid([ + ['Agent', event.agent_id || p.agent_id || '-'], + ['Role', event.role || p.role || '-'], + ['Round', p.round ?? '-'], + ['Thread', p.thread_id || '-', 'mono'], + ['Time', formatTime(event.timestamp)], + ['Tool count', Array.isArray(p.tool_names) ? p.tool_names.length : '-'], + ]), 'info'), + Array.isArray(p.tool_names) && p.tool_names.length + ? detailSection('Available Tools', detailChips(p.tool_names, 'science')) + : '', + p.query ? detailSection('Query', paragraphsHtml(p.query)) : '', + ); + } + if (event.event === 'workflow_finished' || event.event === 'workflow_output' || event.event === 'run_finished') { + return detailRich( + detailSection('Output', detailKvGrid([ + ['Status', p.status || event.event], + ['Agent', event.agent_id || p.agent_id || '-'], + ['Round', p.round ?? '-'], + ['Time', formatTime(event.timestamp)], + ]), statusTone(p.status)), + tokenDetailSection(event), + promptDetailSection(event), + p.content_preview ? detailSection('Preview', paragraphsHtml(p.content_preview)) : '', + p.error ? detailSection('Error', paragraphsHtml(p.error), 'error') : '', + ); + } + return workflowDetailHtml(event); + } + + function chemgraphNodeContextHtml(event) { + const p = event.payload || {}; + const turnEvents = workflowEventsForSelection(event); + const flow = workflowFlowGraph(turnEvents); + return detailRich( + detailSection('Turn Context', detailKvGrid([ + ['Agent', workflowAgentId(event) || '-'], + ['Runtime', p.runtime || '-'], + ['Round', p.round ?? '-'], + ['Thread', p.thread_id || '-', 'mono'], + ['Workflow span', workflowRootSpanId(event) || '-', 'mono'], + ['Selected span', p.span_id || event.correlation_id || '-', 'mono'], + ['Visible nodes', flow.nodes.length], + ['Visible events', turnEvents.length], + ]), 'info'), + ); + } + + function workflowHistoryHtml(events) { + if (!events.length) return detailRich('
No related workflow events visible.
'); + return detailRich(events.map(event => { + const p = event.payload || {}; + return detailSection(event.event, detailKvGrid([ + ['Time', formatTime(event.timestamp)], + ['Status', p.status || '-'], + ['Round', p.round ?? '-'], + ['Tool', p.tool_name || '-'], + ['Span', p.span_id || event.correlation_id || '-', 'mono'], + ]), statusTone(p.status)); + }).join('')); + } + + function toolDetailHtml(event) { + const p = event.payload || {}; + const status = p.status || p.result?.status || event.event; + return detailRich( + detailSection('Tool', detailKvGrid([ + ['Tool name', toolDisplayName(event)], + ['Status', status], + ['Event', event.event], + ['Time', formatTime(event.timestamp)], + ['Agent', event.agent_id || '-'], + ['Call id', p.tool_result_id || p.correlation_id || event.correlation_id || '-', 'mono'], + ]), statusTone(status)), + p.arguments ? detailSection('Arguments', detailValueHtml(p.arguments)) : '', + p.error ? detailSection('Error', paragraphsHtml(p.error), 'error') : '', + p.result ? detailSection('Result', detailValueHtml(p.result)) : '', + ); + } + + function payloadDetailHtml(payload) { + if (!payload || typeof payload !== 'object' || Array.isArray(payload)) { + return detailRich(detailSection('Value', detailValueHtml(payload)), rawJsonDetails(payload)); + } + const priority = [ + 'status', + 'tool_name', + 'tool_result_id', + 'message_id', + 'sender', + 'recipient', + 'tldr', + 'content_preview', + 'error', + 'reason', + 'runtime', + 'workflow_type', + 'workflow_node', + 'round', + 'thread_id', + 'span_id', + 'parent_span_id', + 'correlation_id', + 'log_dir', + 'run_dir', + 'model_name', + ]; + const prioritySet = new Set(priority); + const priorityRows = priority + .filter(key => Object.prototype.hasOwnProperty.call(payload, key)) + .map(key => [fieldLabel(key), payload[key], fieldKind(key)]); + const otherRows = Object.keys(payload) + .filter(key => !prioritySet.has(key)) + .sort() + .map(key => [fieldLabel(key), payload[key], fieldKind(key)]); + return detailRich( + priorityRows.length ? detailSection('Key Fields', detailKvGrid(priorityRows), statusTone(payload.status)) : '', + otherRows.length ? detailSection('Additional Fields', detailKvGrid(otherRows)) : '', + rawJsonDetails(payload), + ); + } + + function fieldLabel(key) { + return String(key || '').replaceAll('_', ' '); + } + + function fieldKind(key) { + if (/(^|_)(id|dir|path|file|span|thread|correlation)($|_)/.test(String(key))) return 'mono'; + if (['content', 'content_preview', 'reason', 'error', 'tldr'].includes(String(key))) return 'text'; + return ''; + } + + function formatBelief(payload) { + const lines = [ + `Hypothesis: ${payload.hypothesis || '-'}`, + `Confidence: ${payload.confidence ?? '-'}`, + ]; + if (payload.reason) lines.push(`Reason: ${payload.reason}`); + if (payload.supporting_message_ids?.length) lines.push(`Messages: ${payload.supporting_message_ids.join(', ')}`); + if (payload.supporting_artifact_ids?.length) lines.push(`Artifacts: ${payload.supporting_artifact_ids.join(', ')}`); + return lines.join('\n'); + } + + function formatMessageEvent(event) { + const p = event.payload || {}; + const lines = [ + `${formatTime(event.timestamp)} ${p.sender || event.agent_id || '-'} -> ${p.recipient || '-'}`, + ]; + if (p.tldr) lines.push(`TLDR: ${p.tldr}`); + if (p.content) lines.push(p.content); + if (p.evidence_refs?.length) lines.push(`Evidence: ${p.evidence_refs.join(', ')}`); + if (p.reason) lines.push(`Reason: ${p.reason}`); + return lines.filter(Boolean).join('\n'); + } + + function formatMessageRoute(event) { + const p = event.payload || {}; + const currentAgents = agents(); + const sender = currentAgents.find(agent => agent.agent_id === p.sender); + const recipient = currentAgents.find(agent => agent.agent_id === p.recipient); + const senderHost = agentHost(sender); + const recipientHost = agentHost(recipient); + const route = senderHost && recipientHost && senderHost !== recipientHost ? 'cross-node' : 'same-node'; + return [ + `Route: ${route}`, + `Sender host: ${senderHost}`, + `Recipient host: ${recipientHost}`, + `Message id: ${p.message_id || '-'}`, + ].join('\n'); + } + + function formatDecisionEvent(event) { + const p = event.payload || {}; + const lines = [ + `${formatTime(event.timestamp)} ${event.agent_id || '-'} decision`, + ]; + if (p.mode) lines.push(`Mode: ${p.mode}`); + if (p.wake_reason) lines.push(`Wake reason: ${p.wake_reason}`); + if (p.rationale) lines.push(`\nRationale:\n${p.rationale}`); + const actions = Array.isArray(p.actions) ? p.actions : []; + if (actions.length) { + lines.push('\nActions:'); + actions.forEach((action, index) => { + lines.push(formatDecisionToolCall(action, index + 1)); + }); + } else { + lines.push('\nActions: none'); + } + const ignored = Array.isArray(p.ignored_actions) ? p.ignored_actions : []; + if (ignored.length) { + lines.push(`\nIgnored actions: ${ignored.length}`); + } + return lines.join('\n'); + } + + function formatDecisionToolCall(action, number) { + const parts = [`${number}. ${action.action || 'unknown_action'}`]; + if (action.recipient) parts.push(`recipient=${action.recipient}`); + if (action.tool_name) parts.push(`tool=${action.tool_name}`); + if (action.confidence !== null && action.confidence !== undefined) parts.push(`confidence=${action.confidence}`); + let text = parts.join(' | '); + if (action.question) text += `\n Question: ${action.question}`; + if (action.content) text += `\n Content: ${action.content}`; + if (action.hypothesis) text += `\n Hypothesis: ${action.hypothesis}`; + if (action.reason) text += `\n Reason: ${action.reason}`; + if (action.evidence_refs?.length) text += `\n Evidence: ${action.evidence_refs.join(', ')}`; + if (action.supporting_message_ids?.length) text += `\n Messages: ${action.supporting_message_ids.join(', ')}`; + if (action.supporting_artifact_ids?.length) text += `\n Artifacts: ${action.supporting_artifact_ids.join(', ')}`; + if (action.arguments && Object.keys(action.arguments).length) { + text += `\n Arguments: ${formatJson(action.arguments)}`; + } + return text; + } + + function formatWakeEvents(event) { + const wakeEvents = event.payload?.wake_events || []; + if (!Array.isArray(wakeEvents) || !wakeEvents.length) { + return 'No wake events recorded for this decision.'; + } + return wakeEvents.map((wake, index) => { + const payload = wake.payload || {}; + const summary = [ + `${index + 1}. ${formatTime(wake.timestamp)} ${wake.event}${wake.agent_id ? ` · ${wake.agent_id}` : ''}`, + ]; + if (payload.message_id) summary.push(` Message: ${payload.message_id}`); + if (payload.sender || payload.recipient) summary.push(` Route: ${payload.sender || '-'} -> ${payload.recipient || '-'}`); + if (payload.tool_name) summary.push(` Tool: ${payload.tool_name}`); + if (payload.content) summary.push(` Content: ${trunc(payload.content, 500)}`); + if (payload.result) summary.push(` Result: ${formatJson(payload.result)}`); + return summary.join('\n'); + }).join('\n\n'); + } + + function formatToolEvent(event) { + const p = event.payload || {}; + const status = p.status || p.result?.status || event.event; + const toolName = toolDisplayName(event); + const lines = [ + `${formatTime(event.timestamp)} ${event.event}`, + `Tool: ${toolName}`, + `Status: ${status}`, + ]; + if (event.correlation_id) lines.push(`Call: ${event.correlation_id}`); + const results = p.results || p.result?.results; + if (Array.isArray(results)) { + lines.push(`Results: ${results.map(item => `${item.index ?? '-'}:${item.status || item.error_type || '-'}`).join(', ')}`); + } + if (p.error) lines.push(`Error: ${p.error}`); + return lines.join('\n'); + } + + function isWorkflowEvent(event) { + const p = event.payload || {}; + const workflowNames = [ + 'run_started', + 'run_finished', + 'workflow_started', + 'workflow_node_started', + 'workflow_node_finished', + 'workflow_output', + 'workflow_finished', + 'llm_decision', + 'tool_call_started', + 'tool_call_finished', + 'tool_call_failed', + ]; + if (!workflowNames.includes(event.event)) return false; + return Boolean( + p.nested + || p.runtime + || p.workflow_type + || p.workflow_span_id + || p.span_id + || p.parent_span_id + || p.thread_id + || (event.agent_id && p.round !== undefined && p.round !== null) + ); + } + + function workflowSpanId(event) { + const p = event.payload || {}; + return p.span_id || event.correlation_id || null; + } + + function workflowParentSpanId(event) { + const p = event.payload || {}; + return p.parent_span_id || null; + } + + function workflowAgentId(event) { + const p = event.payload || {}; + return event.agent_id || p.agent_id || p.agent_name || p.agent || null; + } + + function workflowRootSpanId(event) { + const p = event.payload || {}; + if (p.workflow_span_id) return p.workflow_span_id; + if (p.thread_id) return p.thread_id; + if ((event.event === 'workflow_started' || event.event === 'workflow_finished') && p.span_id) { + return p.span_id; + } + if (p.parent_span_id || p.span_id || event.correlation_id) { + return p.parent_span_id || p.span_id || event.correlation_id; + } + const agentId = workflowAgentId(event); + if (agentId && p.round !== undefined && p.round !== null) { + return `${agentId}-round-${p.round}`; + } + return null; + } + + function workflowEventsForSpan(spanId) { + if (!spanId) return []; + return visibleEvents().filter(candidate => ( + isWorkflowEvent(candidate) && workflowRootSpanId(candidate) === spanId + )); + } + + function workflowEventsForSelection(event) { + if (!event) return []; + const p = event.payload || {}; + const spanId = p.workflow_span_id || (isWorkflowEvent(event) ? workflowRootSpanId(event) : null); + if (spanId) return workflowEventsForSpan(spanId); + const agentId = event.agent_id || p.agent_id || p.agent_name; + const round = p.round; + if (!agentId || round === undefined || round === null) return []; + return visibleEvents().filter(candidate => ( + isWorkflowEvent(candidate) + && workflowAgentId(candidate) === agentId + && candidate.payload?.round === round + )); + } + + function reasoningTurnsForAgent(agentId) { + if (!agentId) return []; + const bySpan = new Map(); + visibleEvents().forEach(event => { + if (!isWorkflowEvent(event) || workflowAgentId(event) !== agentId) return; + const spanId = workflowRootSpanId(event); + if (!spanId) return; + if (!bySpan.has(spanId)) { + bySpan.set(spanId, { + spanId, + events: [], + round: null, + threadId: '', + started: null, + finished: null, + firstTimestamp: null, + lastTimestamp: null, + }); + } + const turn = bySpan.get(spanId); + turn.events.push(event); + const p = event.payload || {}; + if (p.round !== undefined && p.round !== null) turn.round = p.round; + if (p.thread_id) turn.threadId = p.thread_id; + if (event.event === 'workflow_started') turn.started = event; + if (event.event === 'workflow_finished') turn.finished = event; + const ts = eventTimestamp(event); + if (ts !== null) { + turn.firstTimestamp = turn.firstTimestamp === null ? ts : Math.min(turn.firstTimestamp, ts); + turn.lastTimestamp = turn.lastTimestamp === null ? ts : Math.max(turn.lastTimestamp, ts); + } + }); + return Array.from(bySpan.values()) + .map(turn => { + const toolNodes = workflowFlowGraph(turn.events).nodes.filter(node => node.type === 'tool'); + const actionTools = toolNodes.filter(node => node.toolClass === 'action-tool'); + const scienceTools = toolNodes.filter(node => node.toolClass === 'science-tool'); + return { + ...turn, + status: turn.finished?.payload?.status || 'running', + representative: turn.started || turn.events[0], + lmCount: workflowTokenEvents(turn.events).length || turn.events.filter(event => event.event === 'llm_decision').length, + tokenTotals: summedTokenCounts(turn.events), + actionToolCount: actionTools.length, + scienceToolCount: scienceTools.length, + }; + }) + .sort((a, b) => (a.firstTimestamp ?? 0) - (b.firstTimestamp ?? 0)); + } + + function reasoningTurnListHtml(turns) { + if (!turns.length) return ''; + const rows = turns.slice(-6).reverse().map(turn => { + const key = eventKey(turn.representative); + const round = turn.round === null || turn.round === undefined ? '-' : turn.round; + const title = `Round ${round}`; + const meta = [ + `${turn.status}`, + `${turn.lmCount} LM`, + turn.tokenTotals ? `${formatTokenCount(turn.tokenTotals.total)} tok` : '', + `${turn.actionToolCount} action`, + `${turn.scienceToolCount} science`, + formatTime(turn.firstTimestamp), + ].filter(Boolean).join(' · '); + return ` + + `; + }).join(''); + return `
${rows}
`; + } + + function nestedWorkflowEventsForTool(event) { + const p = event.payload || {}; + const callId = p.tool_result_id || p.correlation_id || event.correlation_id; + if (!callId) return []; + return allEvents().filter(candidate => { + if (!isWorkflowEvent(candidate)) return false; + return workflowParentSpanId(candidate) === callId || candidate.payload?.parent_tool_name === toolDisplayName(event); + }); + } + + function relatedWorkflowEvents(event) { + const spanId = workflowSpanId(event); + const parentSpanId = workflowParentSpanId(event); + if (!spanId && !parentSpanId) return []; + return allEvents().filter(candidate => { + if (!isWorkflowEvent(candidate)) return false; + const candidateSpan = workflowSpanId(candidate); + const candidateParent = workflowParentSpanId(candidate); + return candidateSpan === spanId + || candidateParent === spanId + || (parentSpanId && (candidateSpan === parentSpanId || candidateParent === parentSpanId)); + }); + } + + function formatWorkflowEvent(event) { + const p = event.payload || {}; + const lines = [ + `${formatTime(event.timestamp)} ${event.event}`, + `Runtime: ${p.runtime || '-'}`, + `Span: ${p.span_id || '-'}`, + ]; + if (p.parent_span_id) lines.push(`Parent span: ${p.parent_span_id}`); + if (p.workflow_type) lines.push(`Workflow: ${p.workflow_type}`); + if (p.workflow_node) lines.push(`Node: ${p.workflow_node}`); + if (p.phase) lines.push(`Phase: ${p.phase}`); + if (p.status) lines.push(`Status: ${p.status}`); + if (p.model_name) lines.push(`Model: ${p.model_name}`); + if (p.log_dir) lines.push(`Log dir: ${p.log_dir}`); + if (Array.isArray(p.tool_calls) && p.tool_calls.length) { + lines.push(`Tool calls: ${p.tool_calls.map(call => call.name || call.id || 'tool').join(', ')}`); + } + if (p.tool_name) lines.push(`Tool: ${p.tool_name}`); + if (p.content_preview) lines.push(`Preview: ${p.content_preview}`); + if (p.error) lines.push(`Error: ${p.error}`); + return lines.join('\n'); + } + + function formatWorkflowHistory(event) { + const history = relatedWorkflowEvents(event); + return history.length + ? history.map(formatWorkflowEvent).join('\n\n') + : 'No related workflow events visible.'; + } + + function formatJson(value) { + try { + return JSON.stringify(value, null, 2); + } catch { + return String(value); + } + } + + function toggleReplay() { + if (isReplaying) { + stopReplay(true); + } else { + startReplay(); + } + } + + function startReplay() { + const events = allEvents(); + if (!events.length) return; + followLatest = false; + isReplaying = true; + graphMode = 'current'; + selectedEdgeKey = null; + replayStartIndex = Math.max(0, currentEventIndex()); + replayStartTimestamp = eventTimestamp(events[replayStartIndex]) ?? firstTimestamp(); + replayStartedAtMs = Date.now(); + if (replayTimer) window.clearInterval(replayTimer); + replayTimer = window.setInterval(tickReplay, 120); + tickReplay(); + } + + function stopReplay(renderNow = false) { + isReplaying = false; + if (replayTimer) { + window.clearInterval(replayTimer); + replayTimer = null; + } + if (renderNow) render(); + } + + function tickReplay() { + const events = allEvents(); + if (!events.length) { + stopReplay(true); + return; + } + if (replayStartTimestamp === null) { + timelineIndex = Math.min(events.length - 1, currentEventIndex() + 1); + } else { + const elapsedSeconds = ((Date.now() - replayStartedAtMs) / 1000) * replaySpeed(); + timelineIndex = Math.max( + replayStartIndex, + eventIndexAtTimestamp(replayStartTimestamp + elapsedSeconds) + ); + } + render(); + if (timelineIndex >= events.length - 1) stopReplay(false); + } + + function startGraphPan(event) { + if (!graphView || event.button !== 0) return; + if (event.target.closest('.agent-node') || event.target.closest('[data-edge-key]') || event.target.closest('.activity-bubble')) return; + graphPanDrag = { + clientX: event.clientX, + clientY: event.clientY, + startX: graphView.x, + startY: graphView.y, + }; + document.getElementById('graph').classList.add('panning'); + event.preventDefault(); + } + + function moveGraphPan(event) { + if (!graphPanDrag || !graphView) return; + const svg = document.getElementById('graph'); + const rect = svg.getBoundingClientRect(); + const dx = (event.clientX - graphPanDrag.clientX) * graphView.width / Math.max(rect.width, 1); + const dy = (event.clientY - graphPanDrag.clientY) * graphView.height / Math.max(rect.height, 1); + graphView.x = graphPanDrag.startX - dx; + graphView.y = graphPanDrag.startY - dy; + clampGraphView(); + updateGraphViewBox(); + } + + function stopGraphPan() { + graphPanDrag = null; + document.getElementById('graph').classList.remove('panning'); + } + + function startEmbeddedWorkflowPan(event) { + if (!embeddedWorkflowView || event.button !== 0) return; + if (event.target.closest('.workflow-node')) return; + embeddedWorkflowPanDrag = { + clientX: event.clientX, + clientY: event.clientY, + startX: embeddedWorkflowView.x, + startY: embeddedWorkflowView.y, + }; + document.getElementById('embeddedWorkflowGraph').classList.add('panning'); + event.preventDefault(); + } + + function moveEmbeddedWorkflowPan(event) { + if (!embeddedWorkflowPanDrag || !embeddedWorkflowView) return; + const svg = document.getElementById('embeddedWorkflowGraph'); + const rect = svg.getBoundingClientRect(); + const dx = (event.clientX - embeddedWorkflowPanDrag.clientX) * embeddedWorkflowView.width / Math.max(rect.width, 1); + const dy = (event.clientY - embeddedWorkflowPanDrag.clientY) * embeddedWorkflowView.height / Math.max(rect.height, 1); + embeddedWorkflowView.x = embeddedWorkflowPanDrag.startX - dx; + embeddedWorkflowView.y = embeddedWorkflowPanDrag.startY - dy; + clampEmbeddedWorkflowView(); + updateEmbeddedWorkflowViewBox(); + } + + function stopEmbeddedWorkflowPan() { + embeddedWorkflowPanDrag = null; + document.getElementById('embeddedWorkflowGraph').classList.remove('panning'); + } + + function startWorkflowPanelDrag(event) { + if (event.button !== 0 || event.target.closest('button')) return; + workflowPanelDrag = { + clientX: event.clientX, + clientY: event.clientY, + startX: workflowPanelFrame.x, + startY: workflowPanelFrame.y, + }; + event.preventDefault(); + } + + function moveWorkflowPanelDrag(event) { + if (!workflowPanelDrag) return; + workflowPanelFrame.x = workflowPanelDrag.startX + event.clientX - workflowPanelDrag.clientX; + workflowPanelFrame.y = workflowPanelDrag.startY + event.clientY - workflowPanelDrag.clientY; + applyWorkflowPanelFrame(); + } + + function stopWorkflowPanelDrag() { + workflowPanelDrag = null; + } + + function startWorkflowPanelResize(event) { + if (event.button !== 0) return; + workflowPanelResizeDrag = { + clientX: event.clientX, + clientY: event.clientY, + startWidth: workflowPanelFrame.width, + startHeight: workflowPanelFrame.height, + }; + event.preventDefault(); + event.stopPropagation(); + } + + function moveWorkflowPanelResize(event) { + if (!workflowPanelResizeDrag) return; + workflowPanelFrame.width = workflowPanelResizeDrag.startWidth + event.clientX - workflowPanelResizeDrag.clientX; + workflowPanelFrame.height = workflowPanelResizeDrag.startHeight + event.clientY - workflowPanelResizeDrag.clientY; + applyWorkflowPanelFrame(); + } + + function stopWorkflowPanelResize() { + workflowPanelResizeDrag = null; + } + + function startDetailResize(event) { + if (event.button !== 0) return; + const currentWidth = parseFloat(getComputedStyle(document.documentElement).getPropertyValue('--detail-width')) || 430; + detailResizeDrag = { + clientX: event.clientX, + startWidth: currentWidth, + }; + document.getElementById('detailResizer').classList.add('active'); + event.preventDefault(); + } + + function moveDetailResize(event) { + if (!detailResizeDrag) return; + const width = Math.min(760, Math.max(320, detailResizeDrag.startWidth - (event.clientX - detailResizeDrag.clientX))); + document.documentElement.style.setProperty('--detail-width', `${width}px`); + } + + function stopDetailResize() { + detailResizeDrag = null; + document.getElementById('detailResizer').classList.remove('active'); + } + + document.getElementById('refresh').addEventListener('click', load); + document.getElementById('playReplay').addEventListener('click', toggleReplay); + document.getElementById('timeSlider').addEventListener('input', e => { + stopReplay(false); + followLatest = false; + timelineIndex = Number(e.target.value); + render(); + }); + document.getElementById('latest').addEventListener('click', () => { + stopReplay(false); + followLatest = true; + timelineIndex = allEvents().length - 1; + render(); + }); + document.getElementById('eventHold').addEventListener('change', render); + document.getElementById('zoomIn').addEventListener('click', () => zoomGraph(0.86)); + document.getElementById('zoomOut').addEventListener('click', () => zoomGraph(1.16)); + document.getElementById('zoomReset').addEventListener('click', resetGraphView); + document.getElementById('graph').addEventListener('mousedown', startGraphPan); + document.getElementById('graph').addEventListener('wheel', event => { + if (!graphView) return; + event.preventDefault(); + zoomGraph(event.deltaY < 0 ? 0.94 : 1.06); + }, {passive: false}); + document.getElementById('embeddedWorkflowGraph').addEventListener('mousedown', startEmbeddedWorkflowPan); + document.getElementById('embeddedWorkflowGraph').addEventListener('wheel', event => { + if (!embeddedWorkflowView) return; + event.preventDefault(); + zoomEmbeddedWorkflow(event.deltaY < 0 ? 0.94 : 1.06); + }, {passive: false}); + document.getElementById('workflowZoomIn').addEventListener('click', () => zoomEmbeddedWorkflow(0.86)); + document.getElementById('workflowZoomOut').addEventListener('click', () => zoomEmbeddedWorkflow(1.16)); + document.getElementById('workflowZoomReset').addEventListener('click', resetEmbeddedWorkflowView); + document.getElementById('workflowPanelClose').addEventListener('click', () => { + workflowPanelOpen = false; + renderEmbeddedWorkflowPanel(); + }); + document.getElementById('workflowFloatingTab').addEventListener('click', () => { + workflowPanelOpen = true; + renderEmbeddedWorkflowPanel(); + }); + document.getElementById('workflowFloatingHead').addEventListener('mousedown', startWorkflowPanelDrag); + document.getElementById('workflowPanelResize').addEventListener('mousedown', startWorkflowPanelResize); + document.getElementById('detailResizer').addEventListener('mousedown', startDetailResize); + window.addEventListener('mousemove', moveGraphPan); + window.addEventListener('mousemove', moveEmbeddedWorkflowPan); + window.addEventListener('mousemove', moveWorkflowPanelDrag); + window.addEventListener('mousemove', moveWorkflowPanelResize); + window.addEventListener('mousemove', moveDetailResize); + window.addEventListener('mouseup', stopGraphPan); + window.addEventListener('mouseup', stopEmbeddedWorkflowPan); + window.addEventListener('mouseup', stopWorkflowPanelDrag); + window.addEventListener('mouseup', stopWorkflowPanelResize); + window.addEventListener('mouseup', stopDetailResize); + document.getElementById('detailPrimary').addEventListener('click', handleDetailPaneClick); + document.getElementById('detailSecondary').addEventListener('click', handleDetailPaneClick); + document.getElementById('detailTertiary').addEventListener('click', handleDetailPaneClick); + document.querySelectorAll('#graphMode button').forEach(button => { + button.addEventListener('click', () => { + graphMode = button.dataset.mode; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + }); + }); + document.getElementById('agentSelect').addEventListener('change', e => { + selectedAgent = e.target.value || null; + selectedEdgeKey = null; + selectedActivityEventKey = null; + render(); + }); + load(); + setInterval(load, 2000); diff --git a/src/chemgraph/academy/dashboard/static/index.html b/src/chemgraph/academy/dashboard/static/index.html new file mode 100644 index 00000000..f26c106f --- /dev/null +++ b/src/chemgraph/academy/dashboard/static/index.html @@ -0,0 +1,703 @@ + + + + + + ChemGraph Academy Dashboard + + + +
+

ChemGraph Academy Dashboard

+
+ + +
+
+
+
+
+
+
Run State
+
+
+
+
+
+
+
+
+
Agent Graph
+
+
+
+
+ + + + + +
+
+ + + +
+
+ Graph + + + +
+
+
+ + +
+
+
+
+ +
+
+
+
+
+
+
+
+
Selection
+ +
+
+
+

State

+
+

Evidence

+
+

History

+
+
+
+
+
+ + + + + diff --git a/src/chemgraph/academy/observability/__init__.py b/src/chemgraph/academy/observability/__init__.py new file mode 100644 index 00000000..f6031201 --- /dev/null +++ b/src/chemgraph/academy/observability/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.event_log import read_events + +__all__ = [ + 'CampaignEvent', + 'EventLog', + 'read_events', +] diff --git a/src/chemgraph/academy/observability/event_log.py b/src/chemgraph/academy/observability/event_log.py new file mode 100644 index 00000000..c42a41cc --- /dev/null +++ b/src/chemgraph/academy/observability/event_log.py @@ -0,0 +1,148 @@ +"""Shared event log for Academy-ChemGraph campaign runs. + +The dynamic campaign layer treats agent messages and ChemGraph job updates as +one append-only event stream. The dashboard and HPC run scripts can consume +this file without knowing which science use case created the event. +""" + +from __future__ import annotations + +import json +import time +import uuid +from pathlib import Path +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field + + +EventKind = Literal[ + "campaign_started", + "campaign_planned", + "campaign_finished", + "agent_started", + "agent_stopped", + "agent_decision", + "agent_error", + "message_sent", + "message_received", + "message_delivered", + "message_delivery_failed", + "belief_updated", + "tool_call_started", + "tool_call_finished", + "tool_call_failed", + "chemgraph_batch_submitted", + "chemgraph_job_status", + "chemgraph_job_result", + "chemgraph_transfer_submitted", + "chemgraph_transfer_done", + "round_started", + "round_finished", + "self_wake_scheduled", + "idle_timeout", + "max_decisions_reached", + "daemon_started", + "daemon_stopped", + "bootstrap_message_dispatched", + "llm_tool_calls", + "turn_finished_without_external_action", + "chemgraph_reasoning_turn_started", + "chemgraph_reasoning_turn_finished", + "run_started", + "run_finished", + "workflow_started", + "workflow_finished", + "workflow_node_started", + "workflow_node_finished", + "llm_call_started", + "llm_call_finished", + "llm_call_failed", + "llm_decision", + "workflow_output", +] + +__all__ = [ + 'CampaignEvent', + 'EventKind', + 'EventLog', + 'read_events', +] + + +class CampaignEvent(BaseModel): + """One durable event emitted by a campaign runtime.""" + + model_config = ConfigDict(extra="forbid") + + event_id: str = Field(default_factory=lambda: f"evt-{uuid.uuid4()}") + timestamp: float = Field(default_factory=time.time) + event: EventKind + run_id: str | None = None + agent_id: str | None = None + role: str | None = None + correlation_id: str | None = None + payload: dict[str, Any] = Field(default_factory=dict) + + +class EventLog: + """Append/read helper for campaign JSONL event logs.""" + + def __init__(self, path: str | Path) -> None: + self.path = Path(path) + + def append(self, event: CampaignEvent) -> CampaignEvent: + """Append *event* and return it.""" + self.path.parent.mkdir(parents=True, exist_ok=True) + with self.path.open("a", encoding="utf-8") as handle: + handle.write(event.model_dump_json()) + handle.write("\n") + return event + + def emit( + self, + event: EventKind, + *, + run_id: str | None = None, + agent_id: str | None = None, + role: str | None = None, + correlation_id: str | None = None, + payload: dict[str, Any] | None = None, + ) -> CampaignEvent: + """Build and append a :class:`CampaignEvent`.""" + return self.append( + CampaignEvent( + event=event, + run_id=run_id, + agent_id=agent_id, + role=role, + correlation_id=correlation_id, + payload=payload or {}, + ) + ) + + def read(self) -> list[CampaignEvent]: + """Read all valid JSONL events from the log.""" + return read_events(self.path) + + +def read_events(path: str | Path) -> list[CampaignEvent]: + """Read valid campaign events from *path*. + + Partially written or malformed lines are skipped so live dashboards can + poll while another process is appending. + """ + event_path = Path(path) + if not event_path.exists(): + return [] + events: list[CampaignEvent] = [] + with event_path.open(encoding="utf-8") as handle: + for line in handle: + if not line.strip(): + continue + try: + payload = json.loads(line) + events.append(CampaignEvent.model_validate(payload)) + except (json.JSONDecodeError, ValueError): + continue + return events diff --git a/src/chemgraph/academy/observability/run_artifacts.py b/src/chemgraph/academy/observability/run_artifacts.py new file mode 100644 index 00000000..11fa8b4a --- /dev/null +++ b/src/chemgraph/academy/observability/run_artifacts.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +import asyncio +import json +import pathlib +import shutil +import time +from collections import Counter +from typing import Any + +from chemgraph.academy.observability.event_log import CampaignEvent +from chemgraph.academy.observability.event_log import read_events +from chemgraph.academy.observability.run_files import append_jsonl +from chemgraph.academy.observability.run_files import write_json +from chemgraph.academy.observability.run_files import write_json_atomic +from chemgraph.academy.core.campaign import ChemGraphAgentSpec +from chemgraph.academy.core.campaign import ChemGraphCampaign +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.runtime.mpi import append_system_trace +from chemgraph.models.settings import LLMSettings + + +def write_run_artifacts(run_dir: str | pathlib.Path) -> dict[str, Any]: + """Write placement and summary artifacts.""" + root = pathlib.Path(run_dir) + events = read_events(root / "events.jsonl") + placement = build_placement(events, root / "status.json") + summary = summarize_events(events) + + write_json(root / "placement.json", placement) + write_json(root / "summary.json", summary) + return { + "placement": placement, + "summary": summary, + } + + +def build_placement( + events: list[CampaignEvent], + status_path: str | pathlib.Path | None = None, +) -> dict[str, Any]: + """Build agent placement proof from events and latest status.""" + agents: dict[str, dict[str, Any]] = {} + for event in events: + if event.event != "agent_started" or not event.agent_id: + continue + placement = event.payload.get("placement") + if isinstance(placement, dict): + agents[event.agent_id] = { + "agent_id": event.agent_id, + "role": event.role, + **placement, + } + + if status_path is not None: + path = pathlib.Path(status_path) + if path.exists(): + try: + status = json.loads(path.read_text(encoding="utf-8")) + states = status.get("agent_states", {}) + if isinstance(states, dict): + for agent_id, state in states.items(): + if not isinstance(state, dict): + continue + placement = state.get("placement") + if isinstance(placement, dict): + agents.setdefault( + agent_id, + { + "agent_id": agent_id, + "role": state.get("role"), + **placement, + }, + ) + except json.JSONDecodeError: + pass + + hostnames = sorted( + { + str(record.get("hostname")) + for record in agents.values() + if record.get("hostname") + }, + ) + return { + "agent_count": len(agents), + "hostnames": hostnames, + "distinct_hostname_count": len(hostnames), + "agents": dict(sorted(agents.items())), + } + + +def summarize_events(events: list[CampaignEvent]) -> dict[str, Any]: + """Return compact run summary from campaign events.""" + counts = Counter(event.event for event in events) + final_reports = _final_reports(events) + return { + "event_count": len(events), + "event_counts": dict(sorted(counts.items())), + "finish": _last_payload( + events, + {"campaign_finished", "workflow_finished", "run_finished"}, + ), + "agent_errors": _payloads_of(events, "agent_error"), + "message_count": counts.get("message_sent", 0), + "final_reports": final_reports, + "tool_results": _tool_result_summaries(events), + } + + +def _last_payload( + events: list[CampaignEvent], + kinds: set[str], +) -> dict[str, Any] | None: + payloads = [event.payload for event in events if event.event in kinds] + return payloads[-1] if payloads else None + + +def _payloads_of(events: list[CampaignEvent], kind: str) -> list[dict[str, Any]]: + return [ + { + "agent_id": event.agent_id, + "role": event.role, + **event.payload, + } + for event in events + if event.event == kind + ] + + +def _final_reports(events: list[CampaignEvent]) -> list[dict[str, Any]]: + reports = [] + for event in events: + payload = event.payload + if event.event == "belief_updated": + reports.append( + { + "agent_id": event.agent_id, + "summary": payload.get("summary") or payload.get("hypothesis"), + "confidence": payload.get("confidence"), + "supporting_message_ids": payload.get("supporting_message_ids", []), + "supporting_tool_result_ids": payload.get( + "supporting_tool_result_ids", + [], + ), + }, + ) + return reports[-10:] + + +def _tool_result_summaries(events: list[CampaignEvent]) -> list[dict[str, Any]]: + results = [] + for event in events: + if event.event != "tool_call_finished": + continue + payload = event.payload + results.append( + { + "timestamp": event.timestamp, + "agent_id": event.agent_id, + "tool_name": payload.get("tool_name"), + "tool_result_id": payload.get("tool_result_id"), + "status": payload.get("status"), + "content_preview": payload.get("content_preview"), + }, + ) + return results + + +def default_agent_state(spec: ChemGraphAgentSpec) -> dict[str, Any]: + return { + 'agent_name': spec.name, + 'role': spec.role, + 'status_updated_at': None, + 'round': 0, + 'finished': False, + 'last_error': None, + } + + +def write_status_snapshot( + *, + run_dir: pathlib.Path, + campaign: ChemGraphCampaign, + agent_state: dict[str, Any], + placement: dict[str, Any], +) -> None: + state_dir = run_dir / 'agent_status' + state_dir.mkdir(parents=True, exist_ok=True) + payload = dict(agent_state) + payload['placement'] = placement + write_json_atomic(state_dir / f'{agent_state["agent_name"]}.json', payload) + + states_by_agent: dict[str, dict[str, Any]] = {} + for path in state_dir.glob('*.json'): + try: + item = json.loads(path.read_text(encoding='utf-8')) + except json.JSONDecodeError: + continue + if isinstance(item, dict) and isinstance(item.get('agent_name'), str): + states_by_agent[item['agent_name']] = item + + agents = [] + placements = {} + for spec in campaign.agents: + state = states_by_agent.get(spec.name) or default_agent_state(spec) + agents.append(state) + if isinstance(state.get('placement'), dict): + placements[spec.name] = state['placement'] + + distinct_hosts = sorted( + { + item.get('short_hostname') or item.get('hostname') + for item in placements.values() + if item.get('short_hostname') or item.get('hostname') + }, + ) + placement_doc = { + 'agents': placements, + 'distinct_hostnames': distinct_hosts, + 'distinct_hostname_count': len(distinct_hosts), + } + write_json_atomic(run_dir / 'placement.json', placement_doc) + + converged = bool(agents) and all( + bool(item.get('finished')) for item in agents + ) + status = { + 'timestamp': time.time(), + 'mode': 'mpi_daemon', + 'campaign_kind': 'chemgraph_agent_swarm', + 'campaign': campaign.run_id, + 'agents': sorted(agents, key=lambda item: item['agent_name']), + 'placement': placement_doc, + 'converged': converged, + } + write_json_atomic(run_dir / 'status.json', status) + append_jsonl(run_dir / 'status_history.jsonl', status) + + +async def wait_for_agent_statuses_finished( + *, + run_dir: pathlib.Path, + campaign: ChemGraphCampaign, + timeout_s: float, +) -> bool: + deadline = time.monotonic() + timeout_s + state_dir = run_dir / 'agent_status' + expected = {spec.name for spec in campaign.agents} + while True: + finished = set() + for path in state_dir.glob('*.json'): + try: + item = json.loads(path.read_text(encoding='utf-8')) + except (OSError, json.JSONDecodeError): + continue + if item.get('finished') is True and item.get('agent_name') in expected: + finished.add(item['agent_name']) + if finished == expected: + return True + if time.monotonic() > deadline: + return False + await asyncio.sleep(0.5) + + +def clear_run_outputs(run_dir: pathlib.Path) -> None: + for name in ( + 'academy_registrations.json', + 'messages.jsonl', + 'events.jsonl', + 'placement.json', + 'status.json', + 'status_history.jsonl', + 'tool_results.jsonl', + ): + path = run_dir / name + if path.exists(): + path.unlink() + for dirname in ('agent_status', 'artifacts', 'shared'): + path = run_dir / dirname + if path.exists(): + shutil.rmtree(path) + + +def initialize_run_files( + *, + run_dir: pathlib.Path, + campaign: ChemGraphCampaign, + config: ChemGraphDaemonConfig, + llm_settings: LLMSettings, +) -> None: + run_dir.mkdir(parents=True, exist_ok=True) + clear_run_outputs(run_dir) + write_json( + run_dir / 'manifest.json', + { + 'run_dir': str(run_dir), + 'run_token': config.run_token, + 'mode': 'chemgraph_mpi_daemon', + 'agent_runtime': 'academy_runtime', + 'agent_count': config.agent_count, + 'max_decisions_per_agent': config.max_decisions, + 'campaign_config': ( + str(config.campaign_config) + if config.campaign_config is not None + else None + ), + 'prompt_profile': str(campaign.prompt_profile), + 'chemgraph_repo_root': str(config.chemgraph_repo_root), + 'communication_transport': f'academy_{config.exchange_type}_actions', + 'exchange_type': config.exchange_type, + 'redis_host': config.redis_host, + 'redis_port': config.redis_port, + 'redis_namespace': config.redis_namespace, + 'llm_model': llm_settings.model, + 'llm_base_url': llm_settings.base_url, + 'llm_provider': llm_settings.provider, + 'llm_user': llm_settings.user, + }, + ) + append_system_trace( + run_dir, + 'campaign_started', + { + 'mode': 'chemgraph_mpi_daemon', + 'agent_count': config.agent_count, + 'campaign': campaign.run_id, + }, + ) + append_system_trace( + run_dir, + 'campaign_planned', + { + 'agents': [spec.name for spec in campaign.agents], + 'roles': {spec.name: spec.role for spec in campaign.agents}, + 'mcp_servers': { + spec.name: list(spec.mcp_servers) + for spec in campaign.agents + }, + }, + ) diff --git a/src/chemgraph/academy/observability/run_files.py b/src/chemgraph/academy/observability/run_files.py new file mode 100644 index 00000000..bc128904 --- /dev/null +++ b/src/chemgraph/academy/observability/run_files.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import json +import os +import uuid +from pathlib import Path +from typing import Any + +__all__ = [ + 'append_jsonl', + 'read_json_file', + 'read_jsonl', + 'write_json', + 'write_json_atomic', +] + + +def write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open('w', encoding='utf-8') as fp: + json.dump(payload, fp, indent=2, sort_keys=True) + fp.write('\n') + + +def write_json_atomic(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_name(f'.{path.name}.{os.getpid()}.{uuid.uuid4()}.tmp') + tmp.write_text( + json.dumps(payload, indent=2, sort_keys=True) + '\n', + encoding='utf-8', + ) + tmp.replace(path) + + +def append_jsonl(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open('a', encoding='utf-8') as fp: + fp.write(json.dumps(payload, sort_keys=True)) + fp.write('\n') + + +def read_jsonl(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + return [] + rows = [] + with path.open(encoding='utf-8') as fp: + for line in fp: + if not line.strip(): + continue + try: + item = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(item, dict): + rows.append(item) + return rows + + +def read_json_file(path: Path, *, default: dict[str, Any]) -> dict[str, Any]: + if not path.exists(): + return default + try: + payload = json.loads(path.read_text(encoding='utf-8')) + except json.JSONDecodeError: + return default + return payload if isinstance(payload, dict) else default diff --git a/src/chemgraph/academy/runtime/__init__.py b/src/chemgraph/academy/runtime/__init__.py new file mode 100644 index 00000000..ccc9bab8 --- /dev/null +++ b/src/chemgraph/academy/runtime/__init__.py @@ -0,0 +1 @@ +"""Runtime launch and MPI support for ChemGraph Academy campaigns.""" diff --git a/src/chemgraph/academy/runtime/compute_launcher.py b/src/chemgraph/academy/runtime/compute_launcher.py new file mode 100644 index 00000000..3ba9ad41 --- /dev/null +++ b/src/chemgraph/academy/runtime/compute_launcher.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import argparse +import dataclasses +import json +import os +import shutil +import socket +import subprocess +import sys +import time +from pathlib import Path +from typing import Any + +from chemgraph.academy.campaigns import campaign_launch_defaults +from chemgraph.academy.campaigns import resolve_campaign +from chemgraph.academy.campaigns import resolve_lm_config_template +from chemgraph.academy.runtime.profiles import list_builtin_system_profiles +from chemgraph.academy.runtime.profiles import load_system_profile +from chemgraph.academy.runtime.profiles.system import SystemProfile + + +DASHBOARD_METADATA_FILE = "dashboard_metadata.json" + + +@dataclasses.dataclass(frozen=True) +class AllocationPlan: + """Resolved parameters needed to launch one MPI-backed campaign.""" + + run_dir: Path + run_token: str + agent_count: int + agents_per_node: int + campaign_config: Path + lm_config: Path + max_decisions: int + poll_timeout_s: float + idle_timeout_s: float + startup_timeout_s: float + completion_timeout_s: float + status_interval_s: float + redis_host: str + redis_port: int + redis_bind: str + redis_protected_mode: str + redis_namespace: str + start_redis: bool + mpiexec: str + chemgraph_repo_root: Path + exchange_type: str = "redis" + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run a ChemGraph Academy campaign inside the current " + "HPC compute allocation." + ), + ) + parser.add_argument( + "--system", + required=True, + help=( + "Built-in system profile or profile JSON path. Built-ins: " + + ", ".join(list_builtin_system_profiles()) + ), + ) + parser.add_argument("--run-id", required=True) + parser.add_argument("--campaign", required=True) + parser.add_argument("--run-dir") + parser.add_argument("--lm-base-url") + parser.add_argument("--relay-host") + parser.add_argument("--lm-model") + parser.add_argument("--lm-user") + parser.add_argument("--max-tokens", type=int) + parser.add_argument("--agent-count", type=int) + parser.add_argument("--agents-per-node", type=int) + parser.add_argument("--max-decisions", type=int) + parser.add_argument("--redis-port", type=int) + parser.add_argument( + "--exchange-type", + choices=("redis", "local", "hybrid"), + default="redis", + ) + parser.add_argument("--no-start-redis", action="store_true") + return parser.parse_args(argv) + + +def _prepend_path(name: str, entries: list[str]) -> None: + existing = os.environ.get(name, "") + values = [entry for entry in entries if entry] + if existing: + values.append(existing) + os.environ[name] = os.pathsep.join(values) + + +def _prepare_environment(profile: SystemProfile) -> None: + for name in profile.unset_env: + os.environ.pop(name, None) + _prepend_path("PATH", profile.path_entries) + _prepend_path("PYTHONPATH", profile.pythonpath_entries) + for name, value in profile.env.items(): + os.environ.setdefault(name, value) + os.environ["no_proxy"] = profile.no_proxy + os.environ["NO_PROXY"] = profile.no_proxy + + +def _load_dashboard_metadata(run_dir: Path) -> dict[str, Any]: + path = run_dir / DASHBOARD_METADATA_FILE + if not path.exists(): + return {} + data = json.loads(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise RuntimeError(f"{path} must contain a JSON object") + return data + + +def _relay_host_from_profile(profile: SystemProfile) -> str: + path = Path(profile.relay_host_file) + if not path.exists(): + raise RuntimeError( + "Could not determine UAN relay host. Start the Mac dashboard " + f"first, or pass --lm-base-url. Missing: {path}", + ) + host = path.read_text(encoding="utf-8").strip() + if not host: + raise RuntimeError(f"Relay host file is empty: {path}") + return host + + +def _resolve_lm_base_url( + *, + args: argparse.Namespace, + profile: SystemProfile, + metadata: dict[str, Any], +) -> str: + if args.lm_base_url: + return args.lm_base_url + value = metadata.get("lm_base_url") + if isinstance(value, str) and value.strip(): + return value.strip() + relay_host = args.relay_host or metadata.get("relay_host") + if not isinstance(relay_host, str) or not relay_host.strip(): + relay_host = _relay_host_from_profile(profile) + return f"http://{relay_host.strip()}:{profile.relay_port}/argoapi/v1" + + +def _write_lm_config( + *, + run_dir: Path, + template_name: str, + base_url: str, + lm_model: str | None, + lm_user: str | None, + max_tokens: int | None, +) -> Path: + template_path = resolve_lm_config_template(template_name) + data = json.loads(template_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + raise RuntimeError(f"LM template must contain a JSON object: {template_path}") + data["base_url"] = base_url + if lm_model: + data["model"] = lm_model + if lm_user: + data["user"] = lm_user + if max_tokens is not None: + data["max_tokens"] = max_tokens + + path = run_dir / "lm_config.json" + path.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8") + return path + + +def _export_workflow_lm_environment(lm_config: Path) -> None: + data = json.loads(lm_config.read_text(encoding="utf-8")) + values = { + "CHEMGRAPH_WORKFLOW_BASE_URL": data.get("base_url"), + "CHEMGRAPH_WORKFLOW_MODEL": data.get("model"), + "CHEMGRAPH_WORKFLOW_API_KEY": data.get("api_key"), + "CHEMGRAPH_WORKFLOW_ARGO_USER": data.get("user"), + "ARGO_USER": data.get("user"), + } + for name, value in values.items(): + if isinstance(value, str) and value: + os.environ.setdefault(name, value) + + +def _run_token() -> str: + return f"{int(time.time())}-{os.getpid()}" + + +def prepare_compute_launch(args: argparse.Namespace) -> AllocationPlan: + """Resolve a system profile and dashboard metadata into an allocation plan.""" + profile = load_system_profile(args.system) + _prepare_environment(profile) + + defaults = campaign_launch_defaults(args.campaign) + run_dir = Path(args.run_dir or Path(profile.run_root) / args.run_id).resolve() + run_dir.mkdir(parents=True, exist_ok=True) + metadata = _load_dashboard_metadata(run_dir) + metadata_campaign = metadata.get("campaign") + if metadata_campaign and metadata_campaign != args.campaign: + raise RuntimeError( + f"Run metadata campaign {metadata_campaign!r} does not match " + f"--campaign {args.campaign!r}", + ) + + lm_base_url = _resolve_lm_base_url( + args=args, + profile=profile, + metadata=metadata, + ) + lm_config = _write_lm_config( + run_dir=run_dir, + template_name=defaults.lm_config_template, + base_url=lm_base_url, + lm_model=args.lm_model, + lm_user=args.lm_user, + max_tokens=args.max_tokens, + ) + _export_workflow_lm_environment(lm_config) + agent_count = args.agent_count or defaults.agent_count + agents_per_node = args.agents_per_node or defaults.agents_per_node + max_decisions = args.max_decisions or defaults.max_decisions + redis_port = args.redis_port or profile.redis_port + + campaign_config = resolve_campaign(args.campaign) + if not campaign_config.exists(): + campaign_config = Path(args.campaign).resolve() + + return AllocationPlan( + run_dir=run_dir, + run_token=_run_token(), + agent_count=agent_count, + agents_per_node=agents_per_node, + campaign_config=campaign_config, + lm_config=lm_config, + max_decisions=max_decisions, + poll_timeout_s=2.0, + idle_timeout_s=600.0, + startup_timeout_s=120.0, + completion_timeout_s=60.0, + status_interval_s=5.0, + redis_host=socket.getfqdn(), + redis_port=redis_port, + redis_bind=profile.redis_bind, + redis_protected_mode=profile.redis_protected_mode, + redis_namespace=f"academy-chemgraph-swarm:{args.run_id}", + start_redis=not args.no_start_redis, + mpiexec=profile.mpiexec, + chemgraph_repo_root=Path(profile.repo_root).resolve(), + exchange_type=args.exchange_type, + ) + + +def wait_redis(host: str, port: int, run_dir: Path) -> None: + import redis + + deadline = time.time() + 30 + while True: + try: + redis.Redis(host=host, port=port).ping() + return + except Exception: + if time.time() > deadline: + log = run_dir / "redis.log" + if log.exists(): + print(log.read_text(errors="replace")[-4000:], file=sys.stderr) + raise + time.sleep(1) + + +def run_allocation(plan: AllocationPlan) -> int: + """Start Redis if requested and run per-rank daemons under mpiexec.""" + plan.run_dir.mkdir(parents=True, exist_ok=True) + redis_proc: subprocess.Popen[bytes] | None = None + uses_redis = plan.exchange_type in {"redis", "hybrid"} + if plan.start_redis and uses_redis: + redis_server = shutil.which("redis-server") + if redis_server is None: + raise RuntimeError("redis-server is required unless --no-start-redis is set") + redis_log = (plan.run_dir / "redis.log").open("ab") + redis_proc = subprocess.Popen( + [ + redis_server, + "--bind", + plan.redis_bind, + "--port", + str(plan.redis_port), + "--protected-mode", + plan.redis_protected_mode, + "--save", + "", + "--appendonly", + "no", + "--daemonize", + "no", + ], + stdout=redis_log, + stderr=subprocess.STDOUT, + ) + (plan.run_dir / "redis.pid").write_text( + f"{redis_proc.pid}\n", + encoding="utf-8", + ) + try: + if uses_redis: + wait_redis(plan.redis_host, plan.redis_port, plan.run_dir) + daemon_args = [ + "--run-dir", str(plan.run_dir), + "--run-token", plan.run_token, + "--agent-count", str(plan.agent_count), + "--campaign-config", str(plan.campaign_config), + "--lm-config", str(plan.lm_config), + "--max-decisions", str(plan.max_decisions), + "--poll-timeout-s", str(plan.poll_timeout_s), + "--idle-timeout-s", str(plan.idle_timeout_s), + "--startup-timeout-s", str(plan.startup_timeout_s), + "--completion-timeout-s", str(plan.completion_timeout_s), + "--status-interval-s", str(plan.status_interval_s), + "--redis-host", plan.redis_host, + "--redis-port", str(plan.redis_port), + "--redis-namespace", plan.redis_namespace, + "--exchange-type", plan.exchange_type, + "--chemgraph-repo-root", str(plan.chemgraph_repo_root), + ] + cmd = [ + plan.mpiexec, + "-n", str(plan.agent_count), + "--ppn", str(plan.agents_per_node), + sys.executable, "-m", "chemgraph.cli.main", "academy", "mpi-daemon", "--", + *daemon_args, + ] + (plan.run_dir / "launch_command.txt").write_text( + " ".join(cmd) + "\n", + encoding="utf-8", + ) + return subprocess.call(cmd) + finally: + if redis_proc is not None: + redis_proc.terminate() + try: + redis_proc.wait(timeout=10) + except subprocess.TimeoutExpired: + redis_proc.kill() + redis_proc.wait() + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + plan = prepare_compute_launch(args) + print(f"ChemGraph Academy run: {args.run_id}") + print(f" system: {load_system_profile(args.system).name}") + print(f" campaign: {args.campaign}") + print(f" run dir: {plan.run_dir}") + print(f" LM config: {plan.lm_config}") + print(f" agents: {plan.agent_count}, agents_per_node: {plan.agents_per_node}") + return run_allocation(plan) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/daemon.py b/src/chemgraph/academy/runtime/daemon.py new file mode 100644 index 00000000..e6cb05b8 --- /dev/null +++ b/src/chemgraph/academy/runtime/daemon.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import argparse +import asyncio +import pathlib +import signal + +from academy.handle import Handle +from academy.runtime import Runtime +from academy.runtime import RuntimeConfig + +from chemgraph.academy.core.peer_protocol import build_message +from chemgraph.academy.runtime.exchange import build_exchange_factory +from chemgraph.academy.runtime.registration import load_academy_registrations +from chemgraph.academy.runtime.registration import wait_academy_registrations +from chemgraph.academy.runtime.registration import write_academy_registrations +from chemgraph.academy.observability.run_artifacts import initialize_run_files +from chemgraph.academy.observability.run_artifacts import ( + wait_for_agent_statuses_finished, +) +from chemgraph.academy.observability.run_artifacts import write_status_snapshot +from chemgraph.academy.core.campaign import campaign_bootstrap_text +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import namespace_for_run +from chemgraph.academy.core.campaign import resolve_campaign_resources +from chemgraph.academy.core.campaign import selected_agent +from chemgraph.academy.core.campaign import validate_campaign +from chemgraph.academy.campaigns import resolve_campaign +from chemgraph.academy.runtime.mpi import append_system_trace +from chemgraph.academy.runtime.mpi import local_rank_from_env +from chemgraph.academy.runtime.mpi import placement_payload +from chemgraph.academy.runtime.mpi import rank_from_env +from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from chemgraph.academy.core.prompt import load_prompt_profile +from chemgraph.academy.runtime.mcp_supervisor import MCPServerSupervisor +from chemgraph.models.settings import load_lm_settings + + +async def run_daemon(config: ChemGraphDaemonConfig) -> int: + config.run_dir.mkdir(parents=True, exist_ok=True) + llm_settings = load_lm_settings(config.lm_config) + campaign = resolve_campaign_resources( + load_campaign(config.campaign_config), + config.run_dir, + ) + prompt_profile = load_prompt_profile(campaign.prompt_profile) + validate_campaign(campaign, config.agent_count) + agent_spec = selected_agent(campaign, config.rank) + placement = placement_payload(config, agent_spec.name) + supervisor = MCPServerSupervisor( + specs=[ + spec + for spec in campaign.mcp_servers + if spec.name in agent_spec.mcp_servers + ], + run_dir=config.run_dir / f'rank{config.rank}', + ) + + try: + await supervisor.start_all() + external_tools = await supervisor.get_tools( + agent_spec.mcp_servers, + allowed_tools=frozenset(agent_spec.allowed_tools) + if agent_spec.allowed_tools + else None, + ) + + academy_factory = build_exchange_factory(config) + if config.rank == 0: + initialize_run_files( + run_dir=config.run_dir, + campaign=campaign, + config=config, + llm_settings=llm_settings, + ) + registrar = await academy_factory.create_user_client( + name=f'{config.run_dir.name}-registrar', + start_listener=False, + ) + try: + registered = await registrar.register_agents( + [ + (ChemGraphLogicalAgent, spec.name) + for spec in campaign.agents + ], + ) + finally: + await registrar.close() + registrations = dict( + zip( + (spec.name for spec in campaign.agents), + registered, + strict=True, + ), + ) + write_academy_registrations( + run_dir=config.run_dir, + run_token=config.run_token, + registrations=registrations, + ) + else: + registrations = await wait_academy_registrations( + config.run_dir, + run_token=config.run_token, + timeout_s=config.startup_timeout_s, + ) + + if config.rank == 0: + registrations = load_academy_registrations( + config.run_dir, + run_token=config.run_token, + ) + registration = registrations[agent_spec.name] + peer_agent_ids = { + peer: registrations[peer].agent_id + for peer in agent_spec.allowed_peers + if peer in registrations + } + + agent = ChemGraphLogicalAgent( + agent_spec, + campaign=campaign, + llm_settings=llm_settings, + prompt_profile=prompt_profile, + run_dir=config.run_dir, + max_decisions=config.max_decisions, + external_tools=external_tools, + peer_agent_ids=peer_agent_ids, + placement=placement, + poll_timeout_s=config.poll_timeout_s, + idle_timeout_s=config.idle_timeout_s, + status_interval_s=config.status_interval_s, + ) + runtime_config = RuntimeConfig( + terminate_on_success=False, + terminate_on_error=False, + ) + runtime = Runtime( + agent, + exchange_factory=academy_factory, + registration=registration, + config=runtime_config, + ) + async with runtime: + await agent.write_runtime_status() + + if config.rank == 0: + bootstrap = build_message( + sender='campaign', + recipient=campaign.initial_agent, + content=campaign_bootstrap_text(campaign), + kind='message', + tldr='Campaign bootstrap', + reason='Initial campaign task dispatch.', + confidence=1.0, + ) + initial_handle: Handle[Any] = Handle( + registrations[campaign.initial_agent].agent_id, + ) + await initial_handle.action( + 'receive_message', + bootstrap, + ) + append_system_trace( + config.run_dir, + 'bootstrap_message_dispatched', + { + 'agent': campaign.initial_agent, + 'message_id': bootstrap['message_id'], + 'via': 'academy_action', + }, + ) + + await runtime.wait_shutdown() + + if config.rank == 0: + all_done = await wait_for_agent_statuses_finished( + run_dir=config.run_dir, + campaign=campaign, + timeout_s=config.completion_timeout_s, + ) + append_system_trace( + config.run_dir, + 'campaign_finished', + {'all_agents_done': all_done}, + ) + write_status_snapshot( + run_dir=config.run_dir, + campaign=campaign, + agent_state=await agent.report_state(), + placement=placement, + ) + return 0 + finally: + await supervisor.shutdown() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description='Run one persistent ChemGraph-style agent daemon.', + ) + parser.add_argument('--run-dir', required=True) + parser.add_argument('--run-token', required=True) + parser.add_argument('--agent-count', type=int, default=5) + parser.add_argument('--campaign-config', required=True) + parser.add_argument('--lm-config', required=True) + parser.add_argument('--max-decisions', type=int, default=6) + parser.add_argument('--poll-timeout-s', type=float, default=2) + parser.add_argument('--idle-timeout-s', type=float, default=600) + parser.add_argument('--startup-timeout-s', type=float, default=120) + parser.add_argument('--completion-timeout-s', type=float, default=60) + parser.add_argument('--status-interval-s', type=float, default=5) + parser.add_argument('--redis-host', default='127.0.0.1') + parser.add_argument('--redis-port', type=int, required=True) + parser.add_argument('--redis-namespace') + parser.add_argument( + '--exchange-type', + choices=('redis', 'local', 'hybrid'), + default='redis', + ) + parser.add_argument('--chemgraph-repo-root') + return parser.parse_args() + + +def config_from_args(args: argparse.Namespace) -> ChemGraphDaemonConfig: + run_dir = pathlib.Path(args.run_dir).resolve() + resolved_campaign = resolve_campaign(args.campaign_config) + campaign_config = ( + resolved_campaign.resolve() + if resolved_campaign.exists() + else pathlib.Path(args.campaign_config).resolve() + ) + return ChemGraphDaemonConfig( + run_dir=run_dir, + run_token=args.run_token, + agent_count=args.agent_count, + campaign_config=campaign_config, + lm_config=pathlib.Path(args.lm_config).resolve(), + max_decisions=args.max_decisions, + poll_timeout_s=args.poll_timeout_s, + idle_timeout_s=args.idle_timeout_s, + startup_timeout_s=args.startup_timeout_s, + completion_timeout_s=args.completion_timeout_s, + status_interval_s=args.status_interval_s, + redis_host=args.redis_host, + redis_port=args.redis_port, + redis_namespace=args.redis_namespace or namespace_for_run(run_dir), + exchange_type=args.exchange_type, + rank=rank_from_env(), + local_rank=local_rank_from_env(), + chemgraph_repo_root=( + pathlib.Path(args.chemgraph_repo_root).resolve() + if args.chemgraph_repo_root + else pathlib.Path.cwd().resolve() + ), + ) + + +async def _main_async() -> int: + task = asyncio.create_task(run_daemon(config_from_args(parse_args()))) + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, task.cancel) + except (NotImplementedError, RuntimeError): + pass + try: + return await task + except asyncio.CancelledError: + return 130 + + +def main() -> int: + return asyncio.run(_main_async()) + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/dashboard_launcher.py b/src/chemgraph/academy/runtime/dashboard_launcher.py new file mode 100644 index 00000000..0116176f --- /dev/null +++ b/src/chemgraph/academy/runtime/dashboard_launcher.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import argparse +import json +import os, shlex, shutil, signal, subprocess, threading +import time +import urllib.error +import urllib.request +from importlib.resources import files +from pathlib import Path + +from chemgraph.academy.dashboard import serve_dashboard +from chemgraph.academy.campaigns import campaign_launch_defaults +from chemgraph.academy.runtime.profiles import list_builtin_system_profiles +from chemgraph.academy.runtime.profiles import load_system_profile +from chemgraph.academy.runtime.profiles.system import SystemProfile + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(prog="chemgraph academy dashboard") + a = p.add_argument + a("run_id") + a("--system", default="aurora", help="Built-ins: " + ", ".join(list_builtin_system_profiles())) + a("--campaign", default="mace-ensemble-screening-20") + a("--lm-connect", choices=("mac-argo-relay", "direct"), default="mac-argo-relay") + a("--lm-base-url") + a("--remote-host") + a("--ssh-control-path") + a("--keep-ssh-master", action="store_true") + a("--local-argo-host", default="127.0.0.1") + a("--local-argo-port", type=int, default=18085) + a("--reverse-port", type=int, default=18185) + a("--relay-port", type=int) + a("--relay-python") + a("--rsync-interval-s", type=float, default=2.0) + a("--local-mirror-root", default=str(Path.home() / "projects/chemgraph-academy/remote-runs")) + a("--local-run-dir") + a("--dashboard-host", default="127.0.0.1") + a("--dashboard-port", type=int, default=8765) + a("--local", action="store_true", help="Only serve an already mirrored local run.") + a("--no-dashboard", action="store_true") + a("--overwrite-run", action="store_true") + return p.parse_args() + +def template(name: str) -> str: + return files("chemgraph.academy.runtime.templates").joinpath(name).read_text() + + +REMOTE_RELAY_SUBPATH = ".chemgraph/uan_http_relay.py" + + +def stage_relay_script(profile: SystemProfile, host: str, control_path: str) -> str: + """Copy the bundled UAN relay script to the remote host. + + The relay script is shipped inside the chemgraph package so we no longer + require a separate ``academy`` source checkout on the remote system. + We materialize it under ``$REMOTE_ROOT/.chemgraph/uan_http_relay.py`` + on every dashboard launch (idempotent overwrite), then return that + absolute path for the start_relay shell template to reference. + """ + relay_dir = f"{profile.remote_root}/.chemgraph" + relay_path = f"{relay_dir}/uan_http_relay.py" + contents = template("uan_http_relay.py") + cmd = ( + f"mkdir -p {shlex.quote(relay_dir)} && " + f"cat > {shlex.quote(relay_path)}" + ) + ssh(host, cmd, control_path=control_path, input_text=contents) + return relay_path + +def ssh(host: str, command: str | list[str] | None, *, control_path: str, input_text: str | None = None, check: bool = True, capture: bool = False, batch_mode: bool = True, extra: list[str] | None = None) -> subprocess.CompletedProcess[str]: + cmd = ["ssh"] + if batch_mode: + cmd += ["-o", "BatchMode=yes"] + cmd += ["-o", f"ControlPath={control_path}", "-o", "ControlMaster=auto", "-o", "ControlPersist=yes", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=4"] + cmd += extra or [] + cmd.append(host) + cmd += command if isinstance(command, list) else ([command] if command else []) + return subprocess.run(cmd, input=input_text, text=True, check=check, stdout=subprocess.PIPE if capture else None, stderr=subprocess.PIPE if capture else None) + +def wrapper(profile: SystemProfile) -> str: + return ( + template("compute_wrapper.sh.tmpl") + .replace("%{path_prefix}%", ":".join([profile.redis_bin_dir, f"{profile.remote_root}/bin"])) + .replace("%{pythonpath}%", ":".join(profile.pythonpath_entries)) + .replace("%{venv_python}%", profile.venv_python) + ) + +def start_relay(profile: SystemProfile, host: str, control_path: str, args: argparse.Namespace, relay_port: int, relay_python: str, log_path: Path, relay_script: str) -> subprocess.Popen[str]: + relay_args = ["bash", "-s", "--", profile.remote_root, relay_script, profile.relay_host_file, f"{profile.remote_root}/uan-relay-{relay_port}.pid", f"{profile.remote_root}/uan-relay-{relay_port}.log", str(relay_port), str(args.reverse_port), relay_python] + log_path.parent.mkdir(parents=True, exist_ok=True) + cmd = ["ssh", "-o", "BatchMode=yes", "-o", f"ControlPath={control_path}", "-o", "ControlMaster=auto", "-o", "ControlPersist=yes", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=4", "-R", f"127.0.0.1:{args.reverse_port}:{args.local_argo_host}:{args.local_argo_port}", host, *relay_args] + process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=log_path.open("w", encoding="utf-8"), stderr=subprocess.STDOUT, text=True) + assert process.stdin is not None + process.stdin.write(template("start_relay.sh")) + process.stdin.close() + return process + +def wait_relay(profile: SystemProfile, host: str, control_path: str, relay_port: int, process: subprocess.Popen[str], log_path: Path) -> str: + print("Waiting for relay readiness...", flush=True) + check = f"host=$(cat {shlex.quote(profile.relay_host_file)} 2>/dev/null || true); test -n \"$host\" && curl -fsS \"http://${{host}}:{relay_port}/v1/models\" >/dev/null" + for _ in range(60): + if ssh(host, check, control_path=control_path, check=False).returncode == 0: + relay_host = ssh(host, ["cat", profile.relay_host_file], control_path=control_path, capture=True).stdout.strip() + print(f"{profile.name} relay host: {relay_host}", flush=True) + return relay_host + if process.poll() is not None: + raise RuntimeError("Relay SSH session exited before readiness. Local relay log:\n" + log_path.read_text(encoding="utf-8", errors="replace")) + time.sleep(1) + raise RuntimeError("Relay readiness timed out. Local relay log:\n" + log_path.read_text(encoding="utf-8", errors="replace")) + +def start_rsync(host: str, control_path: str, remote_run_dir: str, local_run_dir: Path, interval_s: float, stop: threading.Event) -> None: + local_run_dir.mkdir(parents=True, exist_ok=True) + rsync_args = [host, control_path, remote_run_dir, str(local_run_dir), str(interval_s), str(local_run_dir / "rsync.log")] + + def loop() -> None: + process = subprocess.Popen(["bash", "-s", "--", *rsync_args], stdin=subprocess.PIPE, text=True, start_new_session=True) + assert process.stdin is not None + process.stdin.write(template("rsync_loop.sh")) + process.stdin.close() + stop.wait() + if process.poll() is None: + os.killpg(process.pid, signal.SIGTERM) + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + os.killpg(process.pid, signal.SIGKILL) + + threading.Thread(target=loop, name="chemgraph-academy-rsync", daemon=True).start() + +def compute_lines(profile: SystemProfile, wrapper_path: str, run_id: str, campaign: str) -> list[str]: + lines = [" module use /soft/modulefiles", " module load conda", " conda activate base"] if profile.name == "polaris" else [" module load frameworks"] + return lines + [f" source {profile.remote_root}/venvs/academy-swarm/bin/activate", f" export PATH={profile.remote_root}/bin:$PATH", " chemgraph academy run-compute \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}", "", "If PATH is not configured, use:", f" {wrapper_path} \\", f" --system {profile.name} \\", f" --run-id {run_id} \\", f" --campaign {campaign}"] + +def main() -> int: + args = parse_args() + profile = load_system_profile(args.system) + campaign_launch_defaults(args.campaign) + local_run_dir = Path(args.local_run_dir or Path(args.local_mirror_root) / args.run_id).expanduser() + local_run_dir.mkdir(parents=True, exist_ok=True) + if args.local: + if args.overwrite_run: + raise RuntimeError("--overwrite-run cannot be used with --local") + return 0 if args.no_dashboard else serve_dashboard(run_dir=local_run_dir, host=args.dashboard_host, port=args.dashboard_port) + if args.lm_connect == "direct" and not args.lm_base_url: + raise RuntimeError("--lm-connect direct requires --lm-base-url") + if args.lm_connect == "mac-argo-relay": + try: + with urllib.request.urlopen(f"http://{args.local_argo_host}:{args.local_argo_port}/v1/models", timeout=5) as response: + if int(response.status) >= 300: + raise OSError + except (OSError, urllib.error.URLError, urllib.error.HTTPError) as exc: + raise RuntimeError("Local argo-shim is not reachable. Start it before using --lm-connect mac-argo-relay.") from exc + + remote_host = args.remote_host or profile.remote_host + control_path = args.ssh_control_path or str(Path.home() / f".ssh/{profile.name}-dashboard-%r@%h:%p") + relay_port = args.relay_port or profile.relay_port + remote_run_dir = f"{profile.run_root}/{args.run_id}" + relay_process: subprocess.Popen[str] | None = None + stop = threading.Event() + started_master = False + try: + Path(control_path).expanduser().parent.mkdir(parents=True, exist_ok=True) + if ssh(remote_host, None, control_path=control_path, extra=["-O", "check"], check=False, batch_mode=False).returncode != 0: + print(f"Starting SSH ControlMaster for {remote_host}...", flush=True) + ssh(remote_host, None, control_path=control_path, extra=["-M", "-N", "-f", "-o", "ControlMaster=yes"], batch_mode=False) + started_master = True + if args.overwrite_run: + if not args.run_id or "/" in args.run_id or args.run_id in {".", ".."}: + raise RuntimeError(f"Refusing to overwrite unsafe run id: {args.run_id!r}") + print("Deleting existing run artifacts because --overwrite-run was set:", flush=True) + print(f" remote: {remote_host}:{remote_run_dir}", flush=True) + print(f" local: {local_run_dir}", flush=True) + delete = f"set -euo pipefail; run_root={shlex.quote(profile.run_root)}; run_id={shlex.quote(args.run_id)}; case \"$run_id\" in \"\"|.|..|*/*) echo \"unsafe run id\" >&2; exit 2;; esac; run_dir=\"$run_root/$run_id\"; trash_root=\"$run_root/.deleted-runs\"; if [ -e \"$run_dir\" ]; then mkdir -p \"$trash_root\"; trash_dir=\"$trash_root/${{run_id}}.$(date +%Y%m%d%H%M%S).$$\"; mv -- \"$run_dir\" \"$trash_dir\"; for delay in 0 1 2 5 10; do sleep \"$delay\"; if rm -rf -- \"$trash_dir\" 2>/dev/null; then break; fi; done; fi; mkdir -p \"$run_dir\"" + ssh(remote_host, delete, control_path=control_path) + if local_run_dir.exists(): + shutil.rmtree(local_run_dir) + wrapper_path = f"{profile.remote_root}/bin/chemgraph-academy-run" + print(f"Installing compute wrapper at {wrapper_path}...", flush=True) + ssh(remote_host, f"mkdir -p {shlex.quote(profile.remote_root + '/bin')} && cat > {shlex.quote(wrapper_path)} && chmod +x {shlex.quote(wrapper_path)}", control_path=control_path, input_text=wrapper(profile)) + relay_host = None + if args.lm_connect == "mac-argo-relay": + print(f"Staging UAN relay script under {profile.remote_root}/{REMOTE_RELAY_SUBPATH}...", flush=True) + relay_script = stage_relay_script(profile, remote_host, control_path) + print(f"Starting {profile.name} UAN relay through {remote_host}...", flush=True) + relay_process = start_relay(profile, remote_host, control_path, args, relay_port, args.relay_python or profile.venv_python, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log"), relay_script) + relay_host = wait_relay(profile, remote_host, control_path, relay_port, relay_process, Path(f"/tmp/chemgraph-academy-{args.run_id}-relay.log")) + lm_base_url = f"http://{relay_host}:{relay_port}/argoapi/v1" if relay_host else str(args.lm_base_url) + print(f"Compute-node LM URL: {lm_base_url}", flush=True) + metadata = {"created_at": time.time(), "created_by": "chemgraph academy dashboard", "run_id": args.run_id, "system": profile.name, "campaign": args.campaign, "remote_run_dir": remote_run_dir, "remote_host": remote_host, "lm_connect": args.lm_connect, "lm_base_url": lm_base_url, "workspace_root": profile.remote_root, "chemgraph_repo_root": profile.repo_root} + if relay_host: + metadata.update({"relay_host": relay_host, "relay_port": relay_port}) + print(f"Writing run metadata: {remote_host}:{remote_run_dir}/dashboard_metadata.json", flush=True) + ssh(remote_host, f"mkdir -p {shlex.quote(remote_run_dir)} && cat > {shlex.quote(remote_run_dir + '/dashboard_metadata.json')}", control_path=control_path, input_text=json.dumps(metadata, indent=2) + "\n") + print("Starting rsync mirror:", flush=True) + print(f" {remote_host}:{remote_run_dir}/", flush=True) + print(f" {local_run_dir}/", flush=True) + start_rsync(remote_host, control_path, remote_run_dir, local_run_dir, args.rsync_interval_s, stop) + print("\nDashboard launcher is ready.\n", flush=True) + print(f"On the {profile.name} compute node, use:", flush=True) + print("\n".join(compute_lines(profile, wrapper_path, args.run_id, args.campaign)), flush=True) + if args.no_dashboard: + return 0 + print(f"\nStarting dashboard at http://{args.dashboard_host}:{args.dashboard_port}", flush=True) + print("Ctrl-C stops the local dashboard, rsync loop, and relay tunnel.", flush=True) + return serve_dashboard(run_dir=local_run_dir, host=args.dashboard_host, port=args.dashboard_port) + finally: + stop.set() + if relay_process is not None and relay_process.poll() is None: + relay_process.terminate() + try: + relay_process.wait(timeout=5) + except subprocess.TimeoutExpired: + relay_process.kill() + if started_master and not args.keep_ssh_master: + ssh(remote_host, None, control_path=control_path, extra=["-O", "exit"], check=False, batch_mode=False) + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chemgraph/academy/runtime/exchange.py b/src/chemgraph/academy/runtime/exchange.py new file mode 100644 index 00000000..6a8b2b2d --- /dev/null +++ b/src/chemgraph/academy/runtime/exchange.py @@ -0,0 +1,39 @@ +"""Build the Academy exchange factory matching a daemon config.""" + +from __future__ import annotations + +from typing import Any + +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig + + +def build_exchange_factory(config: ChemGraphDaemonConfig) -> Any: + """Return the Academy exchange factory matching ``config.exchange_type``.""" + exchange_type = config.exchange_type + + if exchange_type == 'redis': + from academy.exchange.redis import RedisExchangeFactory + + return RedisExchangeFactory( + hostname=config.redis_host, + port=config.redis_port, + ) + + if exchange_type == 'local': + from academy.exchange.local import LocalExchangeFactory + + return LocalExchangeFactory() + + if exchange_type == 'hybrid': + from academy.exchange.hybrid import HybridExchangeFactory + + return HybridExchangeFactory( + redis_host=config.redis_host, + redis_port=config.redis_port, + namespace=config.redis_namespace, + ) + + raise ValueError( + f"Unsupported exchange type {exchange_type!r}; expected one of " + "'redis', 'local', 'hybrid'.", + ) diff --git a/src/chemgraph/academy/runtime/mcp_supervisor.py b/src/chemgraph/academy/runtime/mcp_supervisor.py new file mode 100644 index 00000000..565e0fbe --- /dev/null +++ b/src/chemgraph/academy/runtime/mcp_supervisor.py @@ -0,0 +1,302 @@ +"""Spawn per-rank MCP server subprocesses, wait for readiness, connect.""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import os +import shlex +import socket +import subprocess +import time +from pathlib import Path +from typing import Any + +import httpx +from langchain_core.tools import BaseTool +from langchain_core.tools import StructuredTool +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import CallToolResult + +from chemgraph.academy.core.campaign import MCPServerSpec + +logger = logging.getLogger(__name__) + +_READINESS_TIMEOUT_S = 30.0 +_READINESS_POLL_INTERVAL_S = 0.25 +_SHUTDOWN_TIMEOUT_S = 5.0 + + +def _pick_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +class MCPServerSupervisor: + """Per-rank MCP subprocess lifecycle and client wiring.""" + + def __init__(self, specs: list[MCPServerSpec], run_dir: Path) -> None: + self._specs = list(specs) + self._run_dir = Path(run_dir) + self._log_dir = self._run_dir / "mcp_logs" + self._processes: dict[str, subprocess.Popen[bytes]] = {} + self._log_handles: dict[str, object] = {} + self._urls: dict[str, str] = {} + + @property + def urls(self) -> dict[str, str]: + return dict(self._urls) + + async def start_all(self) -> dict[str, str]: + if not self._specs: + return {} + self._log_dir.mkdir(parents=True, exist_ok=True) + for spec in self._specs: + port = _pick_free_port() + url = f"http://127.0.0.1:{port}/mcp/" + cmd = shlex.split(spec.command) + [ + "--transport", + "streamable_http", + "--host", + "127.0.0.1", + "--port", + str(port), + ] + env = {**os.environ, **spec.env} + log_path = self._log_dir / f"{spec.name}.log" + log_handle = log_path.open("ab") + logger.info( + "spawning MCP server %s on port %d: %s", + spec.name, + port, + " ".join(cmd), + ) + proc = subprocess.Popen( + cmd, + stdout=log_handle, + stderr=subprocess.STDOUT, + env=env, + start_new_session=True, + ) + self._processes[spec.name] = proc + self._log_handles[spec.name] = log_handle + self._urls[spec.name] = url + await self._await_all_ready() + return dict(self._urls) + + async def get_tools( + self, + server_names: tuple[str, ...] | None = None, + allowed_tools: frozenset[str] | None = None, + ) -> list[BaseTool]: + """Return LangChain tools advertised by the requested MCP servers. + + Parameters + ---------- + server_names + Optional subset of supervised servers to query. Defaults to all. + allowed_tools + Optional per-agent tool-name whitelist. When provided, tools + advertised by the connected servers but whose name is not in the + set are filtered out. When ``None`` (or empty), every tool the + servers advertise is returned (legacy behavior). + """ + if not self._urls: + return [] + wanted = tuple(server_names) if server_names else tuple(self._urls) + unknown = sorted(set(wanted) - set(self._urls)) + if unknown: + raise RuntimeError( + f"agent requested unknown MCP servers: {unknown}; " + f"available: {sorted(self._urls)}", + ) + whitelist = frozenset(allowed_tools) if allowed_tools else None + connections = { + name: self._urls[name] + for name in wanted + } + tools: list[BaseTool] = [] + tool_names: set[str] = set() + matched_whitelist: set[str] = set() + for server_name, url in connections.items(): + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + listed = await session.list_tools() + for mcp_tool in listed.tools: + if mcp_tool.name in tool_names: + raise RuntimeError( + f"duplicate MCP tool name {mcp_tool.name!r} " + f"from server {server_name!r}", + ) + tool_names.add(mcp_tool.name) + if whitelist is not None: + if mcp_tool.name not in whitelist: + continue + matched_whitelist.add(mcp_tool.name) + tools.append( + _langchain_tool( + server_name=server_name, + server_url=url, + tool_name=mcp_tool.name, + description=mcp_tool.description + or f"MCP tool {mcp_tool.name}.", + input_schema=mcp_tool.inputSchema, + ), + ) + if whitelist is not None: + missing = sorted(whitelist - matched_whitelist) + if missing: + logger.warning( + "allowed_tools whitelist references tools not advertised " + "by the connected MCP servers; they will be silently " + "absent from the agent: %s", + missing, + ) + return tools + + async def shutdown(self) -> None: + for name, proc in list(self._processes.items()): + if proc.poll() is not None: + continue + with contextlib.suppress(ProcessLookupError): + proc.terminate() + + deadline = time.monotonic() + _SHUTDOWN_TIMEOUT_S + for name, proc in list(self._processes.items()): + remaining = max(0.0, deadline - time.monotonic()) + try: + proc.wait(timeout=remaining) + except subprocess.TimeoutExpired: + logger.warning("MCP server %s did not exit; killing", name) + with contextlib.suppress(ProcessLookupError): + proc.kill() + with contextlib.suppress(subprocess.TimeoutExpired): + proc.wait(timeout=2) + + for handle in self._log_handles.values(): + with contextlib.suppress(Exception): + handle.close() + self._processes.clear() + self._log_handles.clear() + self._urls.clear() + + async def _await_all_ready(self) -> None: + deadline = time.monotonic() + _READINESS_TIMEOUT_S + pending = dict(self._urls) + async with httpx.AsyncClient(timeout=2.0) as client: + while pending and time.monotonic() < deadline: + ready_now: list[str] = [] + for name, url in pending.items(): + proc = self._processes[name] + if proc.poll() is not None: + log_tail = self._tail_log(name) + raise RuntimeError( + f"MCP server {name!r} exited before readiness " + f"(returncode={proc.returncode}). Last log lines:\n" + f"{log_tail}", + ) + try: + response = await client.get(url) + if response.status_code < 500: + ready_now.append(name) + except httpx.RequestError: + pass + for name in ready_now: + logger.info("MCP server %s ready at %s", name, pending[name]) + pending.pop(name) + if pending: + await asyncio.sleep(_READINESS_POLL_INTERVAL_S) + if pending: + stuck = sorted(pending) + tails = "\n".join( + f"=== {name} ===\n{self._tail_log(name)}" + for name in stuck + ) + raise RuntimeError( + f"MCP servers did not become ready within " + f"{_READINESS_TIMEOUT_S:.0f}s: {stuck}\n{tails}", + ) + + def _tail_log(self, name: str, n: int = 40) -> str: + path = self._log_dir / f"{name}.log" + if not path.exists(): + return "(no log file)" + try: + text = path.read_text(encoding="utf-8", errors="replace") + except OSError: + return "(log unreadable)" + return "\n".join(text.splitlines()[-n:]) + + +def _langchain_tool( + *, + server_name: str, + server_url: str, + tool_name: str, + description: str, + input_schema: dict[str, Any], +) -> BaseTool: + async def call_mcp_tool(**kwargs: Any) -> Any: + return await _call_mcp_tool( + server_url=server_url, + tool_name=tool_name, + arguments=kwargs, + ) + + call_mcp_tool.__name__ = f"{server_name}_{tool_name}" + return StructuredTool.from_function( + coroutine=call_mcp_tool, + name=tool_name, + description=description, + args_schema=input_schema, + metadata={ + "chemgraph_academy_tool_kind": "science_tool", + "mcp_server": server_name, + }, + ) + + +async def _call_mcp_tool( + *, + server_url: str, + tool_name: str, + arguments: dict[str, Any], +) -> Any: + async with streamablehttp_client(server_url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool(tool_name, arguments) + return _serialize_call_tool_result(result) + + +def _serialize_call_tool_result(result: CallToolResult) -> dict[str, Any]: + payload: dict[str, Any] = { + "is_error": bool(result.isError), + "content": [ + _json_safe(block) + for block in result.content + ], + } + if result.structuredContent is not None: + payload["structured_content"] = _json_safe(result.structuredContent) + if result.isError: + payload["status"] = "error" + else: + payload["status"] = "ok" + return payload + + +def _json_safe(value: Any) -> Any: + if hasattr(value, "model_dump"): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe(item) for item in value] + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return repr(value) diff --git a/src/chemgraph/academy/runtime/mpi.py b/src/chemgraph/academy/runtime/mpi.py new file mode 100644 index 00000000..7439f587 --- /dev/null +++ b/src/chemgraph/academy/runtime/mpi.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import os +import pathlib +import socket +import sys +from collections.abc import Mapping +from typing import Any + +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.run_files import write_json_atomic + +MPI_RANK_ENV = ( + 'PMI_RANK', + 'PMIX_RANK', + 'OMPI_COMM_WORLD_RANK', + 'PALS_RANK', + 'SLURM_PROCID', +) + +MPI_LOCAL_RANK_ENV = ( + 'MPI_LOCALRANKID', + 'PMI_LOCAL_RANK', + 'PMIX_LOCAL_RANK', + 'OMPI_COMM_WORLD_LOCAL_RANK', + 'PALS_LOCAL_RANK', + 'SLURM_LOCALID', +) + + +def rank_from_env(env: Mapping[str, str] | None = None) -> int: + env = os.environ if env is None else env + for name in MPI_RANK_ENV: + value = env.get(name) + if value is not None: + return int(value) + raise RuntimeError( + 'Could not determine MPI rank from environment. Expected one of ' + f'{", ".join(MPI_RANK_ENV)}. Run this through mpiexec.', + ) + + +def local_rank_from_env(env: Mapping[str, str] | None = None) -> int | None: + env = os.environ if env is None else env + for name in MPI_LOCAL_RANK_ENV: + value = env.get(name) + if value is not None: + return int(value) + return None + + +def append_system_trace( + run_dir: pathlib.Path, + event: str, + payload: dict[str, Any], +) -> None: + EventLog(run_dir / 'events.jsonl').emit( + event, # type: ignore[arg-type] + run_id=run_dir.name, + agent_id='system', + payload=payload, + ) + + +def placement_payload(config: Any, agent_name: str) -> dict[str, Any]: + host = socket.gethostname() + pbs_keys = ( + 'PBS_JOBID', + 'PBS_NODEFILE', + 'PBS_O_WORKDIR', + 'PBS_NCPUS', + 'PBS_NUM_NODES', + 'PBS_TASKNUM', + ) + mpi_keys = (*MPI_RANK_ENV, *MPI_LOCAL_RANK_ENV) + env = { + key: os.environ[key] + for key in (*pbs_keys, *mpi_keys) + if key in os.environ + } + nodefile = os.environ.get('PBS_NODEFILE') + nodes: list[str] = [] + if nodefile and pathlib.Path(nodefile).exists(): + nodes = [ + line.strip() + for line in pathlib.Path(nodefile).read_text().splitlines() + if line.strip() + ] + return { + 'agent_name': agent_name, + 'hostname': host, + 'short_hostname': host.split('.')[0], + 'pid': os.getpid(), + 'cwd': os.getcwd(), + 'python_executable': sys.executable, + 'rank': config.rank, + 'local_rank': config.local_rank, + 'exchange_type': config.exchange_type, + 'redis_host': config.redis_host, + 'redis_port': config.redis_port, + 'redis_namespace': config.redis_namespace, + 'env': env, + 'pbs_nodefile_nodes': nodes, + } diff --git a/src/chemgraph/academy/runtime/profiles/__init__.py b/src/chemgraph/academy/runtime/profiles/__init__.py new file mode 100644 index 00000000..2ead8a21 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/__init__.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from importlib import resources +from pathlib import Path + + +BUILTIN_SYSTEM_PROFILES = { + "aurora": "aurora.template.json", + "polaris": "polaris.template.json", +} + + +def resolve_builtin_system_profile(path_or_name: str | Path) -> Path: + value = str(path_or_name) + path = Path(value) + if path.exists(): + return path.resolve() + relative = BUILTIN_SYSTEM_PROFILES.get(value) + if relative is None: + return path + return Path(str(resources.files(__package__).joinpath(relative))) + + +def list_builtin_system_profiles() -> list[str]: + return sorted(BUILTIN_SYSTEM_PROFILES) + + +from chemgraph.academy.runtime.profiles.system import SystemProfile # noqa: E402 +from chemgraph.academy.runtime.profiles.system import load_system_profile # noqa: E402 + + +__all__ = [ + "BUILTIN_SYSTEM_PROFILES", + "SystemProfile", + "list_builtin_system_profiles", + "load_system_profile", + "resolve_builtin_system_profile", +] diff --git a/src/chemgraph/academy/runtime/profiles/aurora.template.json b/src/chemgraph/academy/runtime/profiles/aurora.template.json new file mode 100644 index 00000000..1e3e40a5 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/aurora.template.json @@ -0,0 +1,36 @@ +{ + "name": "aurora", + "remote_host": "${ALCF_USER}@aurora.alcf.anl.gov", + "remote_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}", + "repo_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", + "run_root": "/flare/${ALCF_PROJECT}/${ALCF_USER}/runs", + "relay_host_file": "/flare/${ALCF_PROJECT}/${ALCF_USER}/uan-relay-18186.host", + "relay_port": 18186, + "venv_python": "/flare/${ALCF_PROJECT}/${ALCF_USER}/venvs/academy-swarm/bin/python", + "redis_bin_dir": "/flare/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "redis_port": 6392, + "redis_bind": "0.0.0.0", + "redis_protected_mode": "no", + "mpiexec": "mpiexec", + "pythonpath_entries": [ + "/flare/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src" + ], + "path_entries": [ + "/flare/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "/flare/${ALCF_PROJECT}/${ALCF_USER}/bin" + ], + "env": { + "NUMEXPR_MAX_THREADS": "256", + "NUMEXPR_NUM_THREADS": "64", + "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_ACADEMY_PY": "0.0.0+aurora" + }, + "unset_env": [ + "http_proxy", + "HTTP_PROXY", + "https_proxy", + "HTTPS_PROXY", + "all_proxy", + "ALL_PROXY" + ], + "no_proxy": "127.0.0.1,localhost,.alcf.anl.gov,*.alcf.anl.gov" +} diff --git a/src/chemgraph/academy/runtime/profiles/polaris.template.json b/src/chemgraph/academy/runtime/profiles/polaris.template.json new file mode 100644 index 00000000..7be57c92 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/polaris.template.json @@ -0,0 +1,36 @@ +{ + "name": "polaris", + "remote_host": "${ALCF_USER}@polaris.alcf.anl.gov", + "remote_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}", + "repo_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph", + "run_root": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/runs", + "relay_host_file": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/uan-relay-18186.host", + "relay_port": 18186, + "venv_python": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/venvs/academy-swarm/bin/python", + "redis_bin_dir": "/eagle/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "redis_port": 6392, + "redis_bind": "0.0.0.0", + "redis_protected_mode": "no", + "mpiexec": "mpiexec", + "pythonpath_entries": [ + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/ChemGraph/src" + ], + "path_entries": [ + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/tools/redis/bin", + "/eagle/${ALCF_PROJECT}/${ALCF_USER}/bin" + ], + "env": { + "NUMEXPR_MAX_THREADS": "256", + "NUMEXPR_NUM_THREADS": "64", + "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_ACADEMY_PY": "0.0.0+polaris" + }, + "unset_env": [ + "http_proxy", + "HTTP_PROXY", + "https_proxy", + "HTTPS_PROXY", + "all_proxy", + "ALL_PROXY" + ], + "no_proxy": "127.0.0.1,localhost,.alcf.anl.gov,*.alcf.anl.gov" +} diff --git a/src/chemgraph/academy/runtime/profiles/system.py b/src/chemgraph/academy/runtime/profiles/system.py new file mode 100644 index 00000000..02ed6dc0 --- /dev/null +++ b/src/chemgraph/academy/runtime/profiles/system.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import json +import os +import re +from pathlib import Path + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from chemgraph.academy.runtime.profiles import resolve_builtin_system_profile + + +class SystemProfile(BaseModel): + """Site/runtime paths for launching ChemGraph Academy on an HPC system.""" + + model_config = ConfigDict(extra="forbid") + + name: str + remote_host: str + remote_root: str + repo_root: str + run_root: str + relay_host_file: str + relay_port: int + venv_python: str + redis_bin_dir: str + redis_port: int + redis_bind: str + redis_protected_mode: str + mpiexec: str + pythonpath_entries: list[str] + path_entries: list[str] = Field(default_factory=list) + env: dict[str, str] = Field(default_factory=dict) + unset_env: list[str] = Field(default_factory=list) + no_proxy: str + + +def load_system_profile(path_or_name: str | Path) -> SystemProfile: + profile_path = resolve_builtin_system_profile(path_or_name) + text = os.path.expandvars(profile_path.read_text(encoding="utf-8")) + unresolved = sorted(set(re.findall(r"\$\{([^}]+)\}", text))) + if unresolved: + raise ValueError( + f"System profile {profile_path} contains unresolved environment " + f"variables: {', '.join(unresolved)}", + ) + data = json.loads(text) + return SystemProfile.model_validate(data) diff --git a/src/chemgraph/academy/runtime/registration.py b/src/chemgraph/academy/runtime/registration.py new file mode 100644 index 00000000..ef8823da --- /dev/null +++ b/src/chemgraph/academy/runtime/registration.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import asyncio +import json +import pathlib +import time +from collections.abc import Mapping +from typing import Any + +from academy.exchange.hybrid import HybridAgentRegistration +from academy.exchange.local import LocalAgentRegistration +from academy.exchange.redis import RedisAgentRegistration +from academy.exchange.transport import AgentRegistration +from academy.identifier import AgentId +from pydantic import BaseModel + +from chemgraph.academy.observability.run_files import write_json_atomic + + +_REGISTRATION_TYPES: dict[str, type[BaseModel]] = { + 'local': LocalAgentRegistration, + 'hybrid': HybridAgentRegistration, + 'redis': RedisAgentRegistration, +} + + +def academy_registration_path(run_dir: pathlib.Path) -> pathlib.Path: + return run_dir / 'academy_registrations.json' + + +def _exchange_type_of(registration: AgentRegistration[Any]) -> str: + value = getattr(registration, 'exchange_type', None) + if not isinstance(value, str): + raise TypeError( + f'Registration {type(registration).__name__} has no string ' + '`exchange_type` field; cannot persist.', + ) + return value + + +def registration_payload( + *, + run_token: str, + registrations: Mapping[str, AgentRegistration[Any]], +) -> dict[str, Any]: + if not registrations: + raise ValueError('at least one registration is required') + exchange_types = {_exchange_type_of(r) for r in registrations.values()} + if len(exchange_types) > 1: + raise ValueError( + f'mixed exchange types in one campaign: {sorted(exchange_types)}', + ) + (exchange_type,) = exchange_types + return { + 'run_token': run_token, + 'exchange_type': exchange_type, + 'agents': { + name: registration.agent_id.model_dump(mode='json') + for name, registration in registrations.items() + }, + } + + +def write_academy_registrations( + *, + run_dir: pathlib.Path, + run_token: str, + registrations: Mapping[str, AgentRegistration[Any]], +) -> None: + write_json_atomic( + academy_registration_path(run_dir), + registration_payload(run_token=run_token, registrations=registrations), + ) + + +def load_academy_registrations( + run_dir: pathlib.Path, + *, + run_token: str, +) -> dict[str, AgentRegistration[Any]]: + path = academy_registration_path(run_dir) + data = json.loads(path.read_text(encoding='utf-8')) + if data.get('run_token') != run_token: + raise RuntimeError( + f'Academy registration file {path} belongs to a different run', + ) + exchange_type = data.get('exchange_type') + if exchange_type not in _REGISTRATION_TYPES: + raise RuntimeError( + f'Academy registration file has unsupported exchange_type ' + f'{exchange_type!r}; expected one of ' + f'{sorted(_REGISTRATION_TYPES)}', + ) + cls = _REGISTRATION_TYPES[exchange_type] + agents = data.get('agents') + if not isinstance(agents, dict): + raise RuntimeError(f'Academy registration file is malformed: {path}') + return { + name: cls(agent_id=AgentId[Any].model_validate(agent_id)) + for name, agent_id in agents.items() + } + + +async def wait_academy_registrations( + run_dir: pathlib.Path, + *, + run_token: str, + timeout_s: float, +) -> dict[str, AgentRegistration[Any]]: + path = academy_registration_path(run_dir) + deadline = time.monotonic() + timeout_s + while True: + if path.exists(): + return load_academy_registrations( + run_dir, + run_token=run_token, + ) + if time.monotonic() > deadline: + raise TimeoutError( + f'Timed out waiting for Academy registrations at {path}', + ) + await asyncio.sleep(0.25) diff --git a/src/chemgraph/academy/runtime/templates/__init__.py b/src/chemgraph/academy/runtime/templates/__init__.py new file mode 100644 index 00000000..143a959e --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/__init__.py @@ -0,0 +1 @@ +"""Runtime shell templates shipped with ChemGraph Academy.""" diff --git a/src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl b/src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl new file mode 100644 index 00000000..f168159b --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/compute_wrapper.sh.tmpl @@ -0,0 +1,31 @@ +#!/bin/bash +set -euo pipefail + +log() { + printf '[chemgraph-academy-run] %s\n' "$*" >&2 +} + +export PATH="%{path_prefix}%:${PATH}" +export PYTHONPATH="%{pythonpath}%:${PYTHONPATH:-}" + +PYTHON_BIN="${CHEMGRAPH_ACADEMY_PYTHON:-python}" +if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then + log "Python command not found: ${PYTHON_BIN}" + log "Load your site module and activate the ChemGraph/Academy environment first." + log "Profile Python, if you want to use it explicitly: %{venv_python}%" + exit 1 +fi + +ACTIVE_PYTHON="$("${PYTHON_BIN}" -c 'import sys; print(sys.executable)')" +log "using active Python: ${ACTIVE_PYTHON}" +log "not loading modules or activating a venv inside this wrapper" + +if ! "${PYTHON_BIN}" -c 'import chemgraph.academy.runtime.compute_launcher' >/dev/null 2>&1; then + log "active Python cannot import chemgraph.academy.runtime.compute_launcher" + log "Load the proper site module and venv before running this command." + log "Profile Python, if you want to use it explicitly: %{venv_python}%" + exit 1 +fi + +log "starting ChemGraph Academy compute launcher" +exec "${PYTHON_BIN}" -m chemgraph.academy.runtime.compute_launcher "$@" diff --git a/src/chemgraph/academy/runtime/templates/rsync_loop.sh b/src/chemgraph/academy/runtime/templates/rsync_loop.sh new file mode 100644 index 00000000..26663692 --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/rsync_loop.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -euo pipefail + +HOST="$1" +CONTROL_PATH="$2" +REMOTE_RUN_DIR="$3" +LOCAL_RUN_DIR="$4" +INTERVAL_S="$5" +LOG_PATH="$6" + +mkdir -p "${LOCAL_RUN_DIR}" +while true; do + rsync -az --delete \ + -e "ssh -o BatchMode=yes -o ControlMaster=auto -o ControlPath=${CONTROL_PATH} -o ControlPersist=yes" \ + "${HOST}:${REMOTE_RUN_DIR}/" \ + "${LOCAL_RUN_DIR}/" \ + >> "${LOG_PATH}" 2>&1 || true + sleep "${INTERVAL_S}" +done diff --git a/src/chemgraph/academy/runtime/templates/start_relay.sh b/src/chemgraph/academy/runtime/templates/start_relay.sh new file mode 100644 index 00000000..1bb9e5fd --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/start_relay.sh @@ -0,0 +1,59 @@ +#!/bin/bash +set -euo pipefail + +REMOTE_ROOT="$1" +RELAY_SCRIPT="$2" +RELAY_HOST_FILE="$3" +RELAY_PID_FILE="$4" +RELAY_LOG_FILE="$5" +RELAY_PORT="$6" +REVERSE_PORT="$7" +RELAY_PYTHON="$8" + +cd "${REMOTE_ROOT}" +UAN_HOST="$(hostname -f)" +printf '%s\n' "${UAN_HOST}" > "${RELAY_HOST_FILE}" + +if [ -f "${RELAY_PID_FILE}" ]; then + OLD_PID="$(cat "${RELAY_PID_FILE}" 2>/dev/null || true)" + case "${OLD_PID}" in + ''|*[!0-9]*) ;; + *) kill "${OLD_PID}" 2>/dev/null || true ;; + esac +fi + +"${RELAY_PYTHON}" "${RELAY_SCRIPT}" \ + --listen-host 0.0.0.0 \ + --listen-port "${RELAY_PORT}" \ + --target-host 127.0.0.1 \ + --target-port "${REVERSE_PORT}" \ + > "${RELAY_LOG_FILE}" 2>&1 & +RELAY_PID="$!" +printf '%s\n' "${RELAY_PID}" > "${RELAY_PID_FILE}" + +cleanup_remote() { + kill "${RELAY_PID}" 2>/dev/null || true +} +trap cleanup_remote EXIT + +deadline=$((SECONDS + 45)) +while ! curl -fsS "http://${UAN_HOST}:${RELAY_PORT}/v1/models" >/dev/null; do + if ! kill -0 "${RELAY_PID}" 2>/dev/null; then + echo "UAN relay process exited before readiness. Last relay log lines:" >&2 + tail -n 80 "${RELAY_LOG_FILE}" >&2 || true + exit 1 + fi + if [ "${SECONDS}" -gt "${deadline}" ]; then + echo "UAN relay did not become ready. Last relay log lines:" >&2 + tail -n 80 "${RELAY_LOG_FILE}" >&2 || true + exit 1 + fi + sleep 1 +done + +echo "UAN_RELAY_HOST=${UAN_HOST}" +echo "UAN relay ready at http://${UAN_HOST}:${RELAY_PORT}/argoapi/v1" + +while true; do + sleep 3600 +done diff --git a/src/chemgraph/academy/runtime/templates/uan_http_relay.py b/src/chemgraph/academy/runtime/templates/uan_http_relay.py new file mode 100644 index 00000000..8ce424fd --- /dev/null +++ b/src/chemgraph/academy/runtime/templates/uan_http_relay.py @@ -0,0 +1,96 @@ +"""Tiny TCP relay used by the dashboard launcher. + +Listens on a UAN-visible port and forwards every accepted connection to a +loopback service on the same UAN host. The dashboard launcher pairs this +with a reverse SSH tunnel (Mac argo-shim -> UAN loopback), so compute +nodes can curl http://:/argoapi/v1 and reach the developer's +local argo-shim. + +This file is materialised onto the remote system at runtime by +``chemgraph.academy.runtime.dashboard_launcher.start_relay``. It was +previously expected to live in a sibling ``academy`` source checkout +under ``examples/09-polaris-lm-swarm/``; bundling it here removes the +need for that second checkout on remote hosts. + +The implementation is intentionally stdlib-only so the script runs under +any Python interpreter without pip-installing anything on the remote. +""" + +from __future__ import annotations + +import argparse +import socket +import threading + + +def pump(src: socket.socket, dst: socket.socket) -> None: + try: + while True: + data = src.recv(65536) + if not data: + break + dst.sendall(data) + except OSError: + pass + finally: + try: + dst.shutdown(socket.SHUT_WR) + except OSError: + pass + + +def handle_client( + client: socket.socket, + target_host: str, + target_port: int, +) -> None: + with client: + try: + upstream = socket.create_connection((target_host, target_port)) + except OSError as e: + print(f'upstream connection failed: {e}', flush=True) + return + with upstream: + left = threading.Thread(target=pump, args=(client, upstream)) + right = threading.Thread(target=pump, args=(upstream, client)) + left.start() + right.start() + left.join() + right.join() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description='Relay a UAN-reachable TCP port to a loopback service.', + ) + parser.add_argument('--listen-host', default='0.0.0.0') + parser.add_argument('--listen-port', type=int, required=True) + parser.add_argument('--target-host', default='127.0.0.1') + parser.add_argument('--target-port', type=int, required=True) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind((args.listen_host, args.listen_port)) + server.listen(128) + print( + f'relay listening on {args.listen_host}:{args.listen_port} ' + f'-> {args.target_host}:{args.target_port}', + flush=True, + ) + while True: + client, addr = server.accept() + print(f'accepted connection from {addr[0]}:{addr[1]}', flush=True) + thread = threading.Thread( + target=handle_client, + args=(client, args.target_host, args.target_port), + daemon=True, + ) + thread.start() + + +if __name__ == '__main__': + raise SystemExit(main()) diff --git a/src/chemgraph/agent/events.py b/src/chemgraph/agent/events.py new file mode 100644 index 00000000..1c3c2bf7 --- /dev/null +++ b/src/chemgraph/agent/events.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import logging +from typing import Any, Callable + +from langchain_core.callbacks import BaseCallbackHandler + +logger = logging.getLogger(__name__) + +EventCallback = Callable[[str, dict], None] + + +def _serialized_name(serialized: Any) -> str | None: + from chemgraph.agent.turn import _serialized_name as turn_serialized_name + + return turn_serialized_name(serialized) + + +def _response_tool_calls(response: Any) -> list[dict[str, str | None]]: + from chemgraph.agent.turn import _response_tool_calls as turn_response_tool_calls + + return turn_response_tool_calls(response) + + +def _serialize_state(value: Any) -> Any: + from chemgraph.agent.turn import serialize_state + + return serialize_state(value) + + +class _BaseDashboardEventCallback(BaseCallbackHandler): + """Forward LangChain callback events to the dashboard event surface.""" + + _failure_log_message = "dashboard event callback failed" + + def __init__(self, on_event: EventCallback, thread_id: str) -> None: + self._on_event = on_event + self._thread_id = thread_id + + def _emit(self, event: str, payload: dict[str, Any]) -> None: + try: + self._on_event(event, {"thread_id": self._thread_id, **payload}) + except Exception: # noqa: BLE001 - callbacks must not break the run. + logger.debug(self._failure_log_message, exc_info=True) + + def on_chat_model_start(self, serialized, messages, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(messages[0]) if messages else 0, + }, + ) + + def on_llm_start(self, serialized, prompts, **kwargs) -> None: + self._emit( + "llm_call_started", + { + "model": _serialized_name(serialized), + "message_count": len(prompts or []), + }, + ) + + def on_llm_end(self, response, **kwargs) -> None: + payload: dict[str, Any] = {} + usage = getattr(response, "llm_output", None) + if isinstance(usage, dict): + payload["llm_output"] = usage + self._emit("llm_call_finished", payload) + # Only surface an llm_decision when the model actually requested tool + # calls; a plain text answer has no decision to report. + tool_calls = _response_tool_calls(response) + if tool_calls: + self._emit("llm_decision", {"tool_calls": tool_calls}) + + def on_llm_error(self, error, **kwargs) -> None: + self._emit("llm_call_failed", {"error": repr(error)}) + + def on_tool_start(self, serialized, input_str, **kwargs) -> None: + self._emit( + "tool_call_started", + { + "tool_name": _serialized_name(serialized), + "arguments": _serialize_state(input_str), + }, + ) + + def on_tool_end(self, output, **kwargs) -> None: + payload: dict[str, Any] = {"result": _serialize_state(output)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_finished", payload) + + def on_tool_error(self, error, **kwargs) -> None: + payload = {"error": repr(error)} + name = kwargs.get("name") + if name: + payload["tool_name"] = name + self._emit("tool_call_failed", payload) + + +class _TurnEventCallback(_BaseDashboardEventCallback): + """Forward run_turn callback events to the dashboard event surface.""" + + _failure_log_message = "turn event callback failed" + + +class _AstreamEventCallback(_BaseDashboardEventCallback): + """Forward graph stream callback events to the dashboard event surface.""" + + _failure_log_message = "astream event callback failed" diff --git a/src/chemgraph/agent/llm_agent.py b/src/chemgraph/agent/llm_agent.py index d1f5b373..f047cb14 100644 --- a/src/chemgraph/agent/llm_agent.py +++ b/src/chemgraph/agent/llm_agent.py @@ -1,10 +1,11 @@ import asyncio import datetime -import dataclasses import os -from typing import Callable, List, Optional +import time +from typing import Callable, Collection, List, Optional import uuid +from chemgraph.agent.events import EventCallback, _AstreamEventCallback from chemgraph.memory.store import SessionStore from chemgraph.memory.schemas import SessionMessage from chemgraph.models.openai import load_openai_model @@ -22,12 +23,12 @@ supported_gemini_models, ) + from chemgraph.schemas.ase_input import ( get_available_calculator_names, get_calculator_selection_context, get_default_calculator_name, ) - from chemgraph.prompt.single_agent_prompt import ( single_agent_prompt, get_single_agent_prompt, @@ -65,121 +66,30 @@ logger = logging.getLogger(__name__) -def _is_mock_object(value) -> bool: - """Return True for unittest.mock objects without importing test-only APIs. - - Parameters - ---------- - value : Any - Object to inspect. - - Returns - ------- - bool - ``True`` when the object comes from ``unittest.mock``. - """ - return value.__class__.__module__.startswith("unittest.mock") - - -def serialize_state(state, *, max_depth: int = 50, _seen: set[int] | None = None): +def serialize_state(state): """Convert non-serializable objects in state to a JSON-friendly format. Parameters ---------- state : Any The state object to be serialized. Can be a list, dict, or object with __dict__ - max_depth : int, optional - Maximum object nesting depth to serialize before falling back to a - placeholder. This prevents runaway recursion for complex graph objects. Returns ------- Any A JSON-serializable version of the input state """ - if _seen is None: - _seen = set() - - if max_depth < 0: - return f"" - - if isinstance(state, (str, int, float, bool)) or state is None: + if isinstance(state, (int, float, bool)) or state is None: return state - - if isinstance(state, (datetime.datetime, datetime.date)): - return state.isoformat() - - if _is_mock_object(state): + elif isinstance(state, list): + return [serialize_state(item) for item in state] + elif isinstance(state, dict): + return {key: serialize_state(value) for key, value in state.items()} + elif hasattr(state, "__dict__"): + return {key: serialize_state(value) for key, value in state.__dict__.items()} + else: return str(state) - state_id = id(state) - if state_id in _seen: - return f"" - - if isinstance(state, dict): - _seen.add(state_id) - try: - return { - str(key): serialize_state( - value, max_depth=max_depth - 1, _seen=_seen - ) - for key, value in state.items() - } - finally: - _seen.remove(state_id) - - if isinstance(state, (list, tuple, set, frozenset)): - _seen.add(state_id) - try: - return [ - serialize_state(item, max_depth=max_depth - 1, _seen=_seen) - for item in state - ] - finally: - _seen.remove(state_id) - - model_dump = getattr(state, "model_dump", None) - if callable(model_dump): - _seen.add(state_id) - try: - try: - dumped = model_dump(mode="json") - except TypeError: - dumped = model_dump() - return serialize_state(dumped, max_depth=max_depth - 1, _seen=_seen) - except Exception: - return str(state) - finally: - _seen.remove(state_id) - - if dataclasses.is_dataclass(state) and not isinstance(state, type): - _seen.add(state_id) - try: - return { - field.name: serialize_state( - getattr(state, field.name), - max_depth=max_depth - 1, - _seen=_seen, - ) - for field in dataclasses.fields(state) - } - finally: - _seen.remove(state_id) - - if hasattr(state, "__dict__"): - _seen.add(state_id) - try: - return { - str(key): serialize_state( - value, max_depth=max_depth - 1, _seen=_seen - ) - for key, value in vars(state).items() - } - finally: - _seen.remove(state_id) - - return str(state) - class ChemGraph: """A graph-based workflow for LLM-powered computational chemistry tasks. @@ -230,6 +140,11 @@ class ChemGraph: pause and request human input. When ``False`` the tool is excluded from the tool list and the corresponding instruction is removed from the default system prompt, by default False. + terminal_tool_names : Collection[str], optional + Tool names that should terminate supported workflows after + successful execution, by default empty. + on_event : callable, optional + Callback invoked with dashboard workflow events, by default None. Raises ------ @@ -267,64 +182,9 @@ def __init__( max_retries: int = 1, human_input_handler: Optional[Callable[[str], str]] = None, human_supervised: bool = False, + terminal_tool_names: Collection[str] = (), + on_event: Optional[EventCallback] = None, ): - """Initialize a ChemGraph workflow instance. - - Parameters - ---------- - model_name : str, optional - LLM model identifier. - workflow_type : str, optional - Workflow constructor key. - base_url : str, optional - Custom provider endpoint URL. - api_key : str, optional - API key passed to compatible model loaders. - argo_user : str, optional - Argo username for Argo-hosted models. - system_prompt : str, optional - System prompt for single-agent-style workflows. - formatter_prompt : str, optional - Prompt used to format single-agent final output. - structured_output : bool, optional - Whether structured final output is requested. - return_option : str, optional - Return mode, such as ``"last_message"`` or ``"state"``. - recursion_limit : int, optional - LangGraph recursion limit. - planner_prompt : str, optional - Planner prompt for multi-agent workflows. - executor_prompt : str, optional - Executor prompt for multi-agent workflows. - aggregator_prompt : str, optional - Aggregator prompt retained for compatibility. - formatter_multi_prompt : str, optional - Formatter prompt for multi-agent workflows. - generate_report : bool, optional - Whether report generation is enabled. - report_prompt : str, optional - Prompt used by the report-generation workflow. - support_structured_output : bool, optional - Whether the selected model supports structured output. - tools : list, optional - Custom tool list for applicable workflows. - data_tools : list, optional - Additional data-analysis tools for MCP workflows. - session_store : SessionStore, optional - Existing session store instance. - enable_memory : bool, optional - Whether persistent session memory is enabled. - memory_db_path : str, optional - SQLite path for the session store. - log_dir : str, optional - Directory for run logs and artifacts. - max_retries : int, optional - LLM parse-retry limit for formatter/planner nodes. - human_input_handler : Callable[[str], str], optional - Callback used to answer graph human-interrupt prompts. - human_supervised : bool, optional - Whether to expose human-supervision tools to the agent. - """ # Always generate a unique identifier for this instance self.uuid = str(uuid.uuid4())[:8] @@ -454,6 +314,8 @@ def __init__( self.max_retries = max_retries self.human_input_handler = human_input_handler self.human_supervised = human_supervised + self.terminal_tool_names = tuple(terminal_tool_names) + self.on_event = on_event # When human supervision is disabled and the caller is using the # default system prompt, strip the ask_human instructions so the @@ -521,6 +383,7 @@ def append_calculator_context(prompt: str) -> str: self.tools, max_retries=self.max_retries, human_supervised=self.human_supervised, + terminal_tool_names=self.terminal_tool_names, ) elif self.workflow_type == "multi_agent": self.workflow = self.workflow_map[workflow_type]["constructor"]( @@ -588,18 +451,6 @@ def visualize(self, method: str = "ascii"): This method creates and displays a visual representation of the workflow graph using Mermaid diagrams. The visualization is shown in Jupyter notebooks. - Parameters - ---------- - method : str, optional - Visualization backend. ``"ascii"`` returns an ASCII graph; - any other value renders a Mermaid PNG in the active notebook. - - Returns - ------- - str or None - ASCII graph text when ``method`` is ``"ascii"``; otherwise - displays an image and returns ``None``. - Notes ----- Requires IPython and nest_asyncio to be installed. @@ -770,13 +621,7 @@ def session_id(self) -> str: return self.uuid def _ensure_session(self, query: str) -> None: - """Create a session record on first run if memory is enabled. - - Parameters - ---------- - query : str - User query used to generate the session title. - """ + """Create a session record on first run if memory is enabled.""" if self.session_store is None: return if self._session_created: @@ -794,15 +639,7 @@ def _ensure_session(self, query: str) -> None: logger.info(f"Created session {self.uuid}: {self._session_title}") def _save_messages_to_store(self, last_state: dict, query: str) -> None: - """Extract messages from workflow state and persist to session store. - - Parameters - ---------- - last_state : dict - Latest LangGraph state containing a ``messages`` sequence. - query : str - Original user query associated with the saved messages. - """ + """Extract messages from workflow state and persist to session store.""" if self.session_store is None or not self._session_created: return @@ -896,16 +733,6 @@ async def _call_human_input_handler(self, question: str) -> str: Raises :class:`HumanInputRequired` when no handler is configured, allowing external callers (CLI, UI) to catch it, prompt the user, and resume the graph. - - Parameters - ---------- - question : str - Prompt emitted by the graph for a human response. - - Returns - ------- - str - Human response returned by the configured handler. """ handler = self.human_input_handler if handler is None: @@ -936,21 +763,13 @@ async def run(self, query: str, config=None, resume_from: Optional[str] = None): Session ID to load context from. The previous conversation summary is prepended to the query. """ + from chemgraph.agent.turn import ( + _executed_tool_names, + _state_messages, + _terminal_tool_name, + ) def _validate_config(cfg): - """Normalize and validate the LangGraph run configuration. - - Parameters - ---------- - cfg : dict or None - User-provided configuration, optionally with top-level - ``thread_id``. - - Returns - ------- - dict - Config with ``configurable.thread_id`` and recursion limit set. - """ if cfg is None: cfg = {} if not isinstance(cfg, dict): @@ -969,21 +788,6 @@ def _validate_config(cfg): return cfg def _save_state_and_select_return(last_state, cfg): - """Persist the final state and apply the configured return option. - - Parameters - ---------- - last_state : dict - Final streamed graph state. - cfg : dict - LangGraph run configuration used to retrieve/write state. - - Returns - ------- - Any - Final message or serialized state, depending on - ``self.return_option``. - """ log_dir = self.log_dir if not log_dir: log_dir = "cg_logs" @@ -1004,18 +808,9 @@ def _save_state_and_select_return(last_state, cfg): async def _stream_until_interrupt(stream_input, cfg): """Stream the workflow until completion or an interrupt. - Parameters - ---------- - stream_input : dict or Command - Initial graph input or resume command to stream. - cfg : dict - LangGraph run configuration. - - Returns - ------- - tuple - ``(last_state, interrupt_value)`` where ``interrupt_value`` is - ``None`` when the graph completed normally. + Returns ``(last_state, interrupt_value)`` where + ``interrupt_value`` is ``None`` when the graph completed + normally. LangGraph's ``astream(stream_mode="values")`` does **not** raise ``GraphInterrupt``. Instead the stream emits a state @@ -1092,6 +887,13 @@ async def _stream_until_interrupt(stream_input, cfg): logger.debug("run called with config=%s", config) config = _validate_config(config) + thread_id = str(config["configurable"]["thread_id"]) + started = time.time() + event = self.on_event or (lambda _event, _payload: None) + if self.on_event: + callbacks = list(config.get("callbacks") or []) + callbacks.append(_AstreamEventCallback(self.on_event, thread_id)) + config["callbacks"] = callbacks logger.debug("validated config=%s", config) # Initialize logging directory before determining inputs or running workflow @@ -1114,6 +916,16 @@ async def _stream_until_interrupt(stream_input, cfg): logger.info(f"Injected context from session {resume_from}") inputs = {"messages": query} + event( + "workflow_started", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "tool_names": [ + getattr(tool, "name", str(tool)) for tool in self.tools or [] + ], + }, + ) try: last_state, interrupt_value = await _stream_until_interrupt(inputs, config) @@ -1163,6 +975,24 @@ async def _stream_until_interrupt(stream_input, cfg): # Save messages to persistent session store self._save_messages_to_store(last_state, query) + messages = _state_messages(last_state) + executed_tools = _executed_tool_names(messages) + terminal_tool = _terminal_tool_name( + executed_tools, + self.terminal_tool_names, + ) + event( + "workflow_finished", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "status": "completed", + "executed_tool_names": list(executed_tools), + "terminal_tool": terminal_tool, + "duration_s": round(time.time() - started, 3), + }, + ) + return _save_state_and_select_return(last_state, config) except HumanInputRequired: @@ -1170,6 +1000,16 @@ async def _stream_until_interrupt(stream_input, cfg): # caller (CLI / UI) can prompt the user and resume. raise except Exception as e: + event( + "workflow_finished", + { + "workflow_type": self.workflow_type, + "thread_id": thread_id, + "status": "failed", + "error": repr(e), + "duration_s": round(time.time() - started, 3), + }, + ) logger.error(f"Error running workflow {self.workflow_type}: {e}") raise @@ -1182,12 +1022,5 @@ class HumanInputRequired(Exception): """ def __init__(self, question: str): - """Initialize the exception with the pending human question. - - Parameters - ---------- - question : str - Question that should be presented to the user. - """ self.question = question super().__init__(question) diff --git a/src/chemgraph/agent/turn.py b/src/chemgraph/agent/turn.py new file mode 100644 index 00000000..e5652134 --- /dev/null +++ b/src/chemgraph/agent/turn.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import dataclasses +import datetime +import logging +import os +import time +import uuid +from typing import Any, Collection + +from chemgraph.graphs.single_agent import construct_single_agent_graph +from chemgraph.models.loader import load_chat_model +from chemgraph.models.settings import LLMSettings +from chemgraph.prompt.single_agent_prompt import ( + formatter_prompt as default_formatter_prompt, +) +from chemgraph.prompt.single_agent_prompt import report_prompt as default_report_prompt +from chemgraph.prompt.single_agent_prompt import single_agent_prompt + +logger = logging.getLogger(__name__) + + +def _is_mock_object(value) -> bool: + """Return True for unittest.mock objects without importing test-only APIs. + + Parameters + ---------- + value : Any + Object to inspect. + + Returns + ------- + bool + ``True`` when the object comes from ``unittest.mock``. + """ + return value.__class__.__module__.startswith("unittest.mock") + + +def serialize_state(state, *, max_depth: int = 50, _seen: set[int] | None = None): + """Convert non-serializable objects in state to a JSON-friendly format. + + Parameters + ---------- + state : Any + The state object to be serialized. Can be a list, dict, or object with __dict__ + max_depth : int, optional + Maximum object nesting depth to serialize before falling back to a + placeholder. This prevents runaway recursion for complex graph objects. + + Returns + ------- + Any + A JSON-serializable version of the input state + """ + if _seen is None: + _seen = set() + + if max_depth < 0: + return f"" + + if isinstance(state, (str, int, float, bool)) or state is None: + return state + + if isinstance(state, (datetime.datetime, datetime.date)): + return state.isoformat() + + if _is_mock_object(state): + return str(state) + + state_id = id(state) + if state_id in _seen: + return f"" + + if isinstance(state, dict): + _seen.add(state_id) + try: + return { + str(key): serialize_state( + value, max_depth=max_depth - 1, _seen=_seen + ) + for key, value in state.items() + } + finally: + _seen.remove(state_id) + + if isinstance(state, (list, tuple, set, frozenset)): + _seen.add(state_id) + try: + return [ + serialize_state(item, max_depth=max_depth - 1, _seen=_seen) + for item in state + ] + finally: + _seen.remove(state_id) + + model_dump = getattr(state, "model_dump", None) + if callable(model_dump): + _seen.add(state_id) + try: + try: + dumped = model_dump(mode="json") + except TypeError: + dumped = model_dump() + return serialize_state(dumped, max_depth=max_depth - 1, _seen=_seen) + except Exception: + return str(state) + finally: + _seen.remove(state_id) + + if dataclasses.is_dataclass(state) and not isinstance(state, type): + _seen.add(state_id) + try: + return { + field.name: serialize_state( + getattr(state, field.name), + max_depth=max_depth - 1, + _seen=_seen, + ) + for field in dataclasses.fields(state) + } + finally: + _seen.remove(state_id) + + if hasattr(state, "__dict__"): + _seen.add(state_id) + try: + return { + str(key): serialize_state( + value, max_depth=max_depth - 1, _seen=_seen + ) + for key, value in vars(state).items() + } + finally: + _seen.remove(state_id) + + return str(state) + + +def _custom_openai_compatible_kwargs( + *, + model_name: str, + temperature: float, + base_url: str, + api_key: str, + max_tokens: int, + top_p: float, + frequency_penalty: float, + presence_penalty: float, + argo_user: str | None, +) -> dict: + kwargs = { + "model": model_name, + "temperature": temperature, + "base_url": base_url, + "api_key": api_key, + "max_tokens": max_tokens, + "top_p": top_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + } + user = argo_user or os.getenv("ARGO_USER") + if base_url and "argoapi" in base_url and user: + kwargs["model_kwargs"] = {"user": user} + return kwargs + + +@dataclasses.dataclass(frozen=True) +class TurnResult: + """Result of one bounded ChemGraph single-agent turn.""" + + final_text: str + state: dict[str, Any] + executed_tool_names: tuple[str, ...] + terminal_tool: str | None + thread_id: str + duration_s: float + + +def _serialized_name(serialized: Any) -> str | None: + if isinstance(serialized, dict): + return serialized.get("name") or serialized.get("id") + return None + + +def _message_tool_calls(message: Any) -> list[Any]: + if isinstance(message, dict): + calls = message.get("tool_calls") + else: + calls = getattr(message, "tool_calls", None) + return calls if isinstance(calls, list) else [] + + +def _response_tool_calls(response: Any) -> list[dict[str, str | None]]: + try: + generations = getattr(response, "generations", None) or [] + tool_calls: list[dict[str, str | None]] = [] + for generation_group in generations: + for generation in generation_group or []: + message = getattr(generation, "message", None) + for call in _message_tool_calls(message): + name = _call_name(call) + if not name: + continue + tool_calls.append( + { + "name": name, + "id": _call_id(call), + }, + ) + return tool_calls + except Exception: # noqa: BLE001 - event extraction must not break runs. + logger.debug("failed to extract llm_decision tool calls", exc_info=True) + return [] + + +def _tool_message_name(message: Any) -> str | None: + if isinstance(message, dict): + name = message.get("name") + role = message.get("role") or message.get("type") + if name and role in {"tool", "tool_message", "ToolMessage"}: + return str(name) + return str(name) if name and not _message_tool_calls(message) else None + name = getattr(message, "name", None) + message_type = getattr(message, "type", None) + if name and message_type == "tool": + return str(name) + return str(name) if name and not _message_tool_calls(message) else None + + +def _call_name(call: Any) -> str | None: + if isinstance(call, dict): + if call.get("name"): + return str(call["name"]) + function = call.get("function") + if isinstance(function, dict) and function.get("name"): + return str(function["name"]) + name = getattr(call, "name", None) + return str(name) if name else None + + +def _call_id(call: Any) -> str | None: + if isinstance(call, dict): + value = call.get("id") or call.get("tool_call_id") + else: + value = getattr(call, "id", None) or getattr(call, "tool_call_id", None) + return str(value) if value else None + + +def _state_messages(state: Any) -> list[Any]: + if isinstance(state, dict): + messages = state.get("messages", []) + else: + messages = getattr(state, "messages", []) + return list(messages or []) + + +def _executed_tool_names(messages: list[Any]) -> tuple[str, ...]: + names: list[str] = [] + for message in messages: + name = _tool_message_name(message) + if name: + names.append(name) + if names: + return tuple(names) + for message in messages: + for call in _message_tool_calls(message): + if name := _call_name(call): + names.append(name) + return tuple(names) + + +def _terminal_tool_name( + executed_tool_names: tuple[str, ...], + terminal_tool_names: Collection[str], +) -> str | None: + terminal = set(terminal_tool_names) + for name in reversed(executed_tool_names): + if name in terminal: + return name + return None + + +def _message_text(message: Any) -> str: + content = message.get("content") if isinstance(message, dict) else getattr(message, "content", "") + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text") or item.get("content") or item)) + else: + parts.append(str(item)) + return "\n".join(parts) + return "" if content is None else str(content) + + +def _final_text(messages: list[Any]) -> str: + for message in reversed(messages): + message_type = ( + message.get("role") or message.get("type") + if isinstance(message, dict) + else getattr(message, "type", None) + ) + if message_type in {"ai", "assistant"}: + return _message_text(message) + return _message_text(messages[-1]) if messages else "" + + +def _load_turn_llm( + *, + model_name: str, + base_url: str | None, + api_key: str | None, + argo_user: str | None, +) -> Any: + temperature = 0.0 + try: + return load_chat_model( + settings=LLMSettings( + model=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + temperature=temperature, + ), + ) + except ValueError: + pass + + endpoint = os.getenv("VLLM_BASE_URL", base_url or "") + key = os.getenv("OPENAI_API_KEY", api_key or "dummy_vllm_key") + if not endpoint: + raise ValueError(f"Unsupported model or missing base URL for: {model_name}") + from langchain_openai import ChatOpenAI + + return ChatOpenAI( + **_custom_openai_compatible_kwargs( + model_name=model_name, + temperature=temperature, + base_url=endpoint, + api_key=key, + max_tokens=4000, + top_p=1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + argo_user=argo_user, + ), + ) + + +from chemgraph.agent.events import EventCallback, _TurnEventCallback + + +async def run_turn( + *, + query: str, + tools: list[Any] | None = None, + model_name: str = "gpt-4o-mini", + base_url: str | None = None, + api_key: str | None = None, + argo_user: str | None = None, + system_prompt: str = single_agent_prompt, + formatter_prompt: str = default_formatter_prompt, + structured_output: bool = False, + generate_report: bool = False, + report_prompt: str = default_report_prompt, + recursion_limit: int = 50, + thread_id: str | None = None, + terminal_tool_names: Collection[str] = (), + human_supervised: bool = False, + on_event: EventCallback | None = None, +) -> TurnResult: + """Run one bounded single-agent ChemGraph LangGraph turn.""" + + started = time.time() + thread_id = thread_id or str(uuid.uuid4()) + callbacks = [_TurnEventCallback(on_event, thread_id)] if on_event else [] + event = on_event or (lambda _event, _payload: None) + event( + "workflow_started", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "tool_names": [getattr(tool, "name", str(tool)) for tool in tools or []], + }, + ) + llm = _load_turn_llm( + model_name=model_name, + base_url=base_url, + api_key=api_key, + argo_user=argo_user, + ) + workflow = construct_single_agent_graph( + llm, + system_prompt, + structured_output, + formatter_prompt, + generate_report, + report_prompt, + tools, + human_supervised=human_supervised, + terminal_tool_names=terminal_tool_names, + ) + config: dict[str, Any] = { + "configurable": {"thread_id": thread_id}, + "recursion_limit": recursion_limit, + } + if callbacks: + config["callbacks"] = callbacks + + last_state: Any = None + try: + async for state in workflow.astream( + {"messages": query}, + stream_mode="values", + config=config, + ): + last_state = state + except Exception as exc: + event( + "workflow_finished", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "status": "failed", + "error": repr(exc), + "duration_s": round(time.time() - started, 3), + }, + ) + raise + + if last_state is None: + raise RuntimeError("ChemGraph turn produced no states.") + + messages = _state_messages(last_state) + executed_tools = _executed_tool_names(messages) + terminal_tool = _terminal_tool_name(executed_tools, terminal_tool_names) + result = TurnResult( + final_text=_final_text(messages), + state=serialize_state(last_state), + executed_tool_names=executed_tools, + terminal_tool=terminal_tool, + thread_id=thread_id, + duration_s=round(time.time() - started, 3), + ) + event( + "workflow_finished", + { + "workflow_type": "single_agent", + "thread_id": thread_id, + "status": "completed", + "executed_tool_names": list(result.executed_tool_names), + "terminal_tool": terminal_tool, + "duration_s": result.duration_s, + }, + ) + return result + diff --git a/src/chemgraph/cli/commands.py b/src/chemgraph/cli/commands.py index abbd0fff..70934046 100644 --- a/src/chemgraph/cli/commands.py +++ b/src/chemgraph/cli/commands.py @@ -177,6 +177,7 @@ def initialize_agent( verbose: bool = False, human_supervised: bool = False, tools: Optional[list] = None, + on_event: Optional[Any] = None, ) -> Any: """Initialize a ChemGraph agent with progress indication. @@ -280,6 +281,7 @@ def _create_agent() -> Any: structured_output=structured_output, human_supervised=human_supervised, tools=tools, + on_event=on_event, ) try: diff --git a/src/chemgraph/cli/main.py b/src/chemgraph/cli/main.py index badf4168..788d5aa6 100644 --- a/src/chemgraph/cli/main.py +++ b/src/chemgraph/cli/main.py @@ -173,6 +173,15 @@ def _add_run_args(parser: argparse.ArgumentParser) -> None: default="ChemGraph General Tools", help="Display name for the MCP server connection (default: 'ChemGraph General Tools')", ) + parser.add_argument( + "--trace-dir", + type=str, + default=None, + help=( + "Write per-run events to this directory so the run is viewable " + "via 'chemgraph dashboard -- --run-dir '." + ), + ) def create_argument_parser() -> argparse.ArgumentParser: @@ -237,6 +246,59 @@ def create_argument_parser() -> argparse.ArgumentParser: # ---- "models" subcommand --------------------------------------------- subparsers.add_parser("models", help="List all available LLM models.") + # ---- "dashboard" subcommands ---------------------------------------- + dashboard_parser = subparsers.add_parser( + "dashboard", + help="Serve the ChemGraph dashboard for a run directory.", + ) + dashboard_parser.add_argument( + "dashboard_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.dashboard.", + ) + + # ---- "academy" subcommand ------------------------------------------- + academy_parser = subparsers.add_parser( + "academy", + help="Run and inspect Academy-backed ChemGraph agent campaigns.", + ) + academy_sub = academy_parser.add_subparsers(dest="academy_command") + + daemon_parser = academy_sub.add_parser( + "mpi-daemon", + help="Run one ChemGraph Academy agent daemon inside mpiexec.", + ) + daemon_parser.add_argument( + "daemon_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.runtime.daemon.", + ) + + compute_parser = academy_sub.add_parser( + "run-compute", + help="Run a profile-backed ChemGraph Academy campaign in this allocation.", + ) + compute_parser.add_argument( + "compute_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.runtime.compute_launcher.", + ) + + dashboard_parser = academy_sub.add_parser( + "dashboard", + help="Start the local dashboard launcher for a ChemGraph Academy run.", + ) + dashboard_parser.add_argument( + "dashboard_args", + nargs=argparse.REMAINDER, + help="Arguments forwarded to chemgraph.academy.runtime.dashboard_launcher.", + ) + + academy_sub.add_parser( + "campaigns", + help="List ChemGraph Academy campaign specs.", + ) + # ---- Legacy fallback args ------------------------------------------- # Also add run args to the top-level parser so that # `chemgraph -q "..."` keeps working without a subcommand. @@ -440,6 +502,33 @@ def _handle_run(args: argparse.Namespace) -> None: # Show banner console.print(create_banner()) + # ---- Optional run trace for the local dashboard -------------------- + trace = None + trace_dir = getattr(args, "trace_dir", None) or config.get("trace_dir") + if trace_dir: + from pathlib import Path + + from chemgraph.cli.trace import CLIRunTrace + + if args.workflow != "single_agent": + console.print( + "[yellow]--trace-dir is currently only effective for the " + "single_agent workflow; events will not be written for " + f"{args.workflow!r}.[/yellow]" + ) + else: + trace = CLIRunTrace( + Path(trace_dir), + model_name=args.model, + workflow_type=args.workflow, + query=args.query, + ) + trace.start() + console.print( + f"[dim]Tracing run to {trace.trace_dir}. " + f"View with: chemgraph dashboard -- --run-dir {trace.trace_dir}[/dim]" + ) + # Initialize agent agent = initialize_agent( args.model, @@ -453,18 +542,28 @@ def _handle_run(args: argparse.Namespace) -> None: verbose=(args.verbose > 0), human_supervised=args.human_supervised, tools=mcp_tools, + on_event=trace.on_event if trace else None, ) if not agent: + if trace is not None: + trace.finish(status="failed", error="agent_initialization_failed") sys.exit(1) # Execute query console.print(f"[bold blue]Query:[/bold blue] {args.query}") if args.resume: console.print(f"[bold blue]Resuming from:[/bold blue] {args.resume}") - result = run_query( - agent, args.query, verbose=(args.verbose > 0), resume_from=args.resume - ) + try: + result = run_query( + agent, args.query, verbose=(args.verbose > 0), resume_from=args.resume + ) + except Exception: + if trace is not None: + trace.finish(status="failed") + raise + if trace is not None: + trace.finish(status="completed") if result: format_response(result, verbose=(args.verbose > 0)) @@ -482,6 +581,62 @@ def _handle_run(args: argparse.Namespace) -> None: console.print("[dim]Thank you for using ChemGraph CLI![/dim]") +def _strip_remainder_separator(args: list[str]) -> list[str]: + """Remove an optional argparse remainder separator.""" + if args and args[0] == "--": + return args[1:] + return args + + +def _run_module_main(module_name: str, argv: list[str]) -> None: + """Run a module-level main() with forwarded command-line arguments.""" + import importlib + + module = importlib.import_module(module_name) + old_argv = sys.argv + try: + sys.argv = [f"chemgraph {module_name.rsplit('.', 1)[-1]}", *argv] + code = module.main() + finally: + sys.argv = old_argv + if isinstance(code, int) and code: + sys.exit(code) + + +def _handle_academy(args: argparse.Namespace) -> None: + """Handle Academy-backed ChemGraph campaign commands.""" + command = getattr(args, "academy_command", None) + if command == "mpi-daemon": + _run_module_main( + "chemgraph.academy.runtime.daemon", + _strip_remainder_separator(args.daemon_args), + ) + return + if command == "dashboard": + _run_module_main( + "chemgraph.academy.runtime.dashboard_launcher", + _strip_remainder_separator(args.dashboard_args), + ) + return + if command == "run-compute": + from chemgraph.academy.runtime.compute_launcher import main as compute_main + + code = compute_main(_strip_remainder_separator(args.compute_args)) + if code: + sys.exit(code) + return + if command == "campaigns": + from chemgraph.academy.campaigns import list_campaigns + + for name in list_campaigns(): + console.print(name) + return + console.print( + "Usage: chemgraph academy " + "{mpi-daemon,run-compute,dashboard,campaigns}.", + ) + + # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- @@ -517,6 +672,15 @@ def main() -> None: elif args.command == "models": list_models() + elif args.command == "dashboard": + _run_module_main( + "chemgraph.academy.dashboard", + _strip_remainder_separator(args.dashboard_args), + ) + + elif args.command == "academy": + _handle_academy(args) + elif args.command == "run": _handle_run(args) diff --git a/src/chemgraph/cli/mcp_utils.py b/src/chemgraph/cli/mcp_utils.py index ce287af9..752849e3 100644 --- a/src/chemgraph/cli/mcp_utils.py +++ b/src/chemgraph/cli/mcp_utils.py @@ -6,6 +6,7 @@ from __future__ import annotations +import os import shlex import time from typing import List, Optional @@ -15,6 +16,35 @@ from chemgraph.cli.formatting import console from chemgraph.utils.async_utils import run_async_callable +# Env vars that the MCP stdio subprocess may need. The MCP SDK's stdio +# transport inherits only a hard-coded whitelist of standard system vars +# (PATH, HOME, etc.) by default -- ChemGraph- and Globus-specific keys +# must be passed through explicitly or the spawned MCP server has no way +# to see what the user exported in their shell. +_FORWARDED_ENV_VARS = ( + # Shell essentials (so python and the user's HOME resolve correctly) + "PATH", + "HOME", + "USER", + "TMPDIR", + "LANG", + "LC_ALL", + "VIRTUAL_ENV", + "CONDA_PREFIX", + "CONDA_DEFAULT_ENV", + # ChemGraph runtime selection + "CHEMGRAPH_EXECUTION_BACKEND", + "CHEMGRAPH_LOG_DIR", + # Globus Compute + "GLOBUS_COMPUTE_ENDPOINT_ID", + # Globus Transfer + "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID", + "GLOBUS_TRANSFER_DESTINATION_BASE_PATH", + # ALCF inference endpoints + "ALCF_ACCESS_TOKEN", +) + def load_mcp_tools_from_config( url: Optional[str] = None, @@ -65,11 +95,13 @@ def load_mcp_tools_from_config( transport_label = f"streamable_http @ {url}" elif command: parts = shlex.split(command) + env = {k: os.environ[k] for k in _FORWARDED_ENV_VARS if k in os.environ} connections = { server_name: { "command": parts[0], "args": parts[1:], "transport": "stdio", + "env": env, } } transport_label = f"stdio: {command}" diff --git a/src/chemgraph/cli/trace.py b/src/chemgraph/cli/trace.py new file mode 100644 index 00000000..9faae2cc --- /dev/null +++ b/src/chemgraph/cli/trace.py @@ -0,0 +1,114 @@ +"""Trace writer for traditional ChemGraph CLI runs. + +Bridges ChemGraph run events into the dashboard's on-disk schema +(`events.jsonl` + `status.json` + `manifest.json`), so the existing +``chemgraph dashboard`` browser UI can render a traditional ChemGraph run +without going through the Academy daemon path. +""" + +from __future__ import annotations + +import time +from pathlib import Path + +from chemgraph.academy.observability.event_log import EventLog +from chemgraph.academy.observability.run_files import write_json_atomic + + +_AGENT_ID = "chemgraph" +_AGENT_ROLE = "single_agent" + + +class CLIRunTrace: + """Writer for a single traditional ChemGraph run. + + Produces the on-disk layout the dashboard expects: + + :: + + /events.jsonl + /status.json + /manifest.json + + The ``status.json.mode`` field is ``"chemgraph_workflow"`` so the + dashboard renders the per-agent workflow inspector (the "inner tab" + you'd see if you clicked a logical-agent node in an Academy run). + """ + + def __init__( + self, + trace_dir: Path, + *, + run_id: str | None = None, + model_name: str | None = None, + workflow_type: str | None = None, + query: str | None = None, + ) -> None: + self.trace_dir = Path(trace_dir) + self.run_id = run_id or self.trace_dir.name + self.model_name = model_name + self.workflow_type = workflow_type + self.query = query + self._log = EventLog(self.trace_dir / "events.jsonl") + + def start(self) -> None: + """Initialise the run directory and write the static metadata.""" + self.trace_dir.mkdir(parents=True, exist_ok=True) + write_json_atomic( + self.trace_dir / "manifest.json", + { + "mode": "chemgraph_workflow", + "run_id": self.run_id, + "model": self.model_name, + "workflow_type": self.workflow_type, + }, + ) + self._write_status() + self._log.emit( + "run_started", + run_id=self.run_id, + agent_id=_AGENT_ID, + role=_AGENT_ROLE, + payload={ + "model": self.model_name, + "workflow_type": self.workflow_type, + "query": self.query, + }, + ) + + def finish(self, *, status: str, error: str | None = None) -> None: + """Mark the run as completed and refresh ``status.json``.""" + self._log.emit( + "run_finished", + run_id=self.run_id, + agent_id=_AGENT_ID, + role=_AGENT_ROLE, + payload={"status": status, "error": error} if error else {"status": status}, + ) + self._write_status() + + def on_event(self, event: str, payload: dict) -> None: + """Callback handed to :class:`chemgraph.agent.llm_agent.ChemGraph`.""" + self._log.emit( + event, # type: ignore[arg-type] + run_id=self.run_id, + agent_id=_AGENT_ID, + role=_AGENT_ROLE, + payload=payload, + ) + + def _write_status(self) -> None: + write_json_atomic( + self.trace_dir / "status.json", + { + "mode": "chemgraph_workflow", + "updated": time.time(), + "agents": [ + { + "agent_id": _AGENT_ID, + "agent_name": _AGENT_ID, + "role": _AGENT_ROLE, + }, + ], + }, + ) diff --git a/src/chemgraph/execution/__init__.py b/src/chemgraph/execution/__init__.py new file mode 100644 index 00000000..bd6d0ccf --- /dev/null +++ b/src/chemgraph/execution/__init__.py @@ -0,0 +1,35 @@ +"""Pluggable execution backends for ChemGraph HPC workloads. + +This package provides a backend-agnostic interface for submitting +computational tasks to different workflow managers (Parsl, +EnsembleLauncher, Globus Compute, local process pool). + +Quick start +----------- +>>> from chemgraph.execution import get_backend, TaskSpec +>>> backend = get_backend() # reads config.toml / env vars +>>> future = backend.submit(TaskSpec( +... task_id="test-1", +... task_type="python", +... callable=my_function, +... kwargs={"param": 42}, +... )) +>>> result = future.result() +>>> backend.shutdown() + +See Also +-------- +:mod:`chemgraph.execution.base` -- abstract classes +:mod:`chemgraph.execution.config` -- factory function +""" + +from chemgraph.execution.base import ExecutionBackend, TaskSpec +from chemgraph.execution.config import get_backend +from chemgraph.execution.job_tracker import JobTracker + +__all__ = [ + "ExecutionBackend", + "JobTracker", + "TaskSpec", + "get_backend", +] diff --git a/src/chemgraph/execution/base.py b/src/chemgraph/execution/base.py new file mode 100644 index 00000000..2e5e4962 --- /dev/null +++ b/src/chemgraph/execution/base.py @@ -0,0 +1,164 @@ +"""Abstract base classes for execution backends. + +This module defines the ``ExecutionBackend`` protocol and the ``TaskSpec`` +data model that all workflow managers (Parsl, EnsembleLauncher, local +process pool, etc.) must implement. Downstream code (MCP servers, tools) +only depends on these abstractions -- never on a concrete backend. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from concurrent.futures import Future +from typing import Any, Callable, Dict, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field + +logger = logging.getLogger(__name__) + + +class TaskSpec(BaseModel): + """Specification for a single unit of work to submit to a backend. + + Supports two execution modes: + + * **python** -- run a Python callable (``callable(*args, **kwargs)``) + * **shell** -- run a shell command string + + Resource hints (``num_nodes``, ``processes_per_node``, ``gpus_per_task``) + are advisory; backends may ignore hints they do not support. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + task_id: str = Field( + description="Unique identifier for this task within the batch." + ) + task_type: Literal["python", "shell"] = Field( + default="python", + description="Execution mode: 'python' for a callable, 'shell' for a command.", + ) + + # ── Python task fields ────────────────────────────────────────────── + callable: Optional[Callable[..., Any]] = Field( + default=None, + description="Python callable to execute (required when task_type='python').", + ) + args: tuple = Field( + default=(), + description="Positional arguments forwarded to the callable.", + ) + kwargs: dict = Field( + default_factory=dict, + description="Keyword arguments forwarded to the callable.", + ) + + # ── Shell task fields ─────────────────────────────────────────────── + command: Optional[str] = Field( + default=None, + description="Shell command to execute (required when task_type='shell').", + ) + working_dir: Optional[str] = Field( + default=None, + description="Working directory for the shell command.", + ) + stdout: Optional[str] = Field( + default=None, + description="Path to capture stdout (shell tasks).", + ) + stderr: Optional[str] = Field( + default=None, + description="Path to capture stderr (shell tasks).", + ) + + # ── Resource hints ────────────────────────────────────────────────── + num_nodes: int = Field( + default=1, + description="Number of compute nodes requested.", + ) + processes_per_node: int = Field( + default=1, + description="Number of processes (ranks) per node.", + ) + gpus_per_task: int = Field( + default=0, + description="Number of GPUs requested per task.", + ) + env: Dict[str, str] = Field(default_factory=dict) + + +class ExecutionBackend(ABC): + """Abstract interface that every workflow-manager adapter must implement. + + Lifecycle + --------- + 1. ``initialize(system, **kwargs)`` -- start the backend + 2. ``submit(task)`` / ``submit_batch(tasks)`` -- dispatch work + 3. ``shutdown()`` -- release resources + + The class also supports the context-manager protocol (``with`` statement). + """ + + def __init__(self) -> None: + self._initialized: bool = False + + @property + def is_async_remote(self) -> bool: + """Whether this backend submits to a remote queue where jobs may + take minutes to hours. When ``True``, MCP tools should return + immediately after submission and provide separate status/result + retrieval tools instead of blocking until completion.""" + return False + + @property + def shares_filesystem(self) -> bool: + """Whether workers see the same filesystem as the submitting server. + + When ``True`` (default), a path written by the server is readable by + the worker, so file-transport tricks (inline embedding, ``/tmp`` + re-materialisation) are unnecessary. Globus Compute overrides this to + ``False`` because its workers run on a remote host without a shared + filesystem.""" + return True + + @abstractmethod + def initialize(self, system: str = "local", **kwargs: Any) -> None: + """Prepare the backend for accepting work. + + Parameters + ---------- + system : str + Target HPC system name (e.g. ``"polaris"``, ``"aurora"``, + ``"local"``). Backends may use this to load system-specific + configurations. + **kwargs + Backend-specific options (worker_init, run_dir, etc.). + """ + + @abstractmethod + def submit(self, task: TaskSpec) -> Future: + """Submit a single task and return a ``concurrent.futures.Future``. + + The future resolves to whatever the callable/command returns. + """ + + def submit_batch(self, tasks: list[TaskSpec]) -> list[Future]: + """Submit multiple tasks, returning futures in submission order. + + The default implementation simply loops over ``submit()``. + Backends may override this for optimized batch submission. + """ + return [self.submit(t) for t in tasks] + + @abstractmethod + def shutdown(self) -> None: + """Release all resources held by the backend.""" + + # ── Context-manager protocol ──────────────────────────────────────── + + def __enter__(self) -> ExecutionBackend: + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001 + self.shutdown() diff --git a/src/chemgraph/execution/config.py b/src/chemgraph/execution/config.py new file mode 100644 index 00000000..37c858d7 --- /dev/null +++ b/src/chemgraph/execution/config.py @@ -0,0 +1,259 @@ +"""Execution backend configuration and factory. + +Reads the ``[execution]`` section from ``config.toml`` (or env-var +overrides) and returns an initialised :class:`ExecutionBackend` instance. + +Environment variables +--------------------- +``CHEMGRAPH_EXECUTION_BACKEND`` + Override the backend name (``"parsl"``, ``"ensemble_launcher"``, + ``"globus_compute"``, ``"local"``). +``COMPUTE_SYSTEM`` + Override the target HPC system (``"polaris"``, ``"aurora"``, + ``"crux"``, ``"local"``). +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from chemgraph.execution.base import ExecutionBackend + +logger = logging.getLogger(__name__) + +# Supported backend names (keep in sync with the ``elif`` chain below) +SUPPORTED_BACKENDS = ("parsl", "ensemble_launcher", "globus_compute", "local") + + +def _load_execution_config(config_path: Optional[str] = None) -> dict[str, Any]: + """Read the ``[execution]`` table from ``config.toml``. + + Returns an empty dict if the section is missing or the file is not + found, so callers always get sensible defaults. + """ + if config_path is None: + # Walk upward from CWD to find config.toml (same heuristic the + # rest of ChemGraph uses). + candidate = Path.cwd() / "config.toml" + if candidate.is_file(): + config_path = str(candidate) + else: + # Try the repo root (two levels up from this file). + repo_root = Path(__file__).resolve().parents[3] + candidate = repo_root / "config.toml" + if candidate.is_file(): + config_path = str(candidate) + + if config_path is None: + return {} + + try: + import toml + + full_config = toml.load(config_path) + return full_config.get("execution", {}) + except Exception as exc: # noqa: BLE001 + logger.warning("Could not read [execution] from %s: %s", config_path, exc) + return {} + + +def get_backend( + config_path: Optional[str] = None, + backend_name: Optional[str] = None, + system: Optional[str] = None, + **kwargs: Any, +) -> ExecutionBackend: + """Create and initialise an :class:`ExecutionBackend`. + + Resolution order for ``backend_name``: + + 1. Explicit ``backend_name`` argument + 2. ``CHEMGRAPH_EXECUTION_BACKEND`` environment variable + 3. ``config.toml`` ``[execution] backend`` key + 4. ``"local"`` (safe fallback) + + Resolution order for ``system``: + + 1. Explicit ``system`` argument + 2. ``COMPUTE_SYSTEM`` environment variable + 3. ``config.toml`` ``[execution] system`` key + 4. ``"local"`` + + Parameters + ---------- + config_path : str, optional + Path to ``config.toml``. Auto-detected when omitted. + backend_name : str, optional + Force a specific backend. + system : str, optional + Target HPC system name. + **kwargs + Extra keyword arguments forwarded to + :meth:`ExecutionBackend.initialize`. + + Returns + ------- + ExecutionBackend + A ready-to-use backend instance. + """ + cfg = _load_execution_config(config_path) + + # -- resolve backend name ------------------------------------------------- + resolved_backend = ( + backend_name + or os.getenv("CHEMGRAPH_EXECUTION_BACKEND") + or cfg.get("backend", "local") + ) + resolved_backend = resolved_backend.lower().strip() + + if resolved_backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"Unknown execution backend '{resolved_backend}'. " + f"Supported: {', '.join(SUPPORTED_BACKENDS)}" + ) + + # -- resolve system ------------------------------------------------------- + resolved_system = ( + system or os.getenv("COMPUTE_SYSTEM") or cfg.get("system", "local") + ) + + # -- merge backend-specific config ---------------------------------------- + backend_cfg = cfg.get(resolved_backend, {}) + merged_kwargs = {**backend_cfg, **kwargs} + + # Globus Compute: fall back to GLOBUS_COMPUTE_ENDPOINT_ID env var + if resolved_backend == "globus_compute" and not merged_kwargs.get("endpoint_id"): + env_id = os.getenv("GLOBUS_COMPUTE_ENDPOINT_ID") + if env_id: + merged_kwargs["endpoint_id"] = env_id + + # -- instantiate ---------------------------------------------------------- + logger.info( + "Creating execution backend '%s' for system '%s'", + resolved_backend, + resolved_system, + ) + + if resolved_backend == "parsl": + from chemgraph.execution.parsl_backend import ParslBackend + + backend = ParslBackend() + + elif resolved_backend == "ensemble_launcher": + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + + if merged_kwargs.get("client_only", False): + # Client-only mode: pass through as-is (needs checkpoint_dir). + pass + else: + # Managed mode: start orchestrator locally. + assert resolved_system in SYSTEM_CONFIG_REGISTRY, ( + f"Unknown system {resolved_system}: " + f"only know {list(SYSTEM_CONFIG_REGISTRY.keys())}" + ) + # System-appropriate MPI flavour: multi-node HPC systems need + # mpich/hydra so child-spec JSON actually lands on remote /tmp; + # "test" only works for single-host runs. + launcher_cfg_kwargs = dict(backend_cfg) + if "mpi_flavour" not in launcher_cfg_kwargs: + _system_mpi_flavour = { + "aurora": "mpich", + "polaris": "mpich", + "crux": "mpich", + "local": "test", + } + launcher_cfg_kwargs["mpi_flavour"] = _system_mpi_flavour.get( + resolved_system, "mpich" + ) + merged_kwargs = { + "system_config": SYSTEM_CONFIG_REGISTRY[resolved_system], + "launcher_config": get_launcher_config(**launcher_cfg_kwargs), + } + + elif resolved_backend == "globus_compute": + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + + elif resolved_backend == "local": + from chemgraph.execution.local_backend import LocalBackend + + backend = LocalBackend() + + else: + # Should be unreachable thanks to the validation above. + raise ValueError(f"Unsupported backend: {resolved_backend}") + + backend.initialize(system=resolved_system, **merged_kwargs) + return backend + + +def get_transfer_manager( + config_path: Optional[str] = None, + **kwargs: Any, +): + """Create a :class:`GlobusTransferManager` from config, or ``None``. + + Reads the ``[execution.globus_transfer]`` section from + ``config.toml``. Returns ``None`` when the required endpoint IDs + are not configured, so callers can skip transfer-tool registration. + + Environment variable overrides + ------------------------------ + ``GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID`` + ``GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID`` + ``GLOBUS_TRANSFER_DESTINATION_BASE_PATH`` + """ + cfg = _load_execution_config(config_path) + transfer_cfg = cfg.get("globus_transfer", {}) + merged = {**transfer_cfg, **kwargs} + + for key, env_var in ( + ("source_endpoint_id", "GLOBUS_TRANSFER_SOURCE_ENDPOINT_ID"), + ("destination_endpoint_id", "GLOBUS_TRANSFER_DESTINATION_ENDPOINT_ID"), + ("destination_base_path", "GLOBUS_TRANSFER_DESTINATION_BASE_PATH"), + ): + if not merged.get(key): + env_val = os.getenv(env_var) + if env_val: + merged[key] = env_val + + required = ( + "source_endpoint_id", + "destination_endpoint_id", + "destination_base_path", + ) + if not all(merged.get(k) for k in required): + logger.debug( + "Globus Transfer not configured (missing %s). " + "Transfer tools will not be registered.", + [k for k in required if not merged.get(k)], + ) + return None + + from chemgraph.execution.globus_transfer import GlobusTransferManager + + manager = GlobusTransferManager( + source_endpoint_id=merged["source_endpoint_id"], + destination_endpoint_id=merged["destination_endpoint_id"], + destination_base_path=merged["destination_base_path"], + source_base_path=merged.get("source_base_path"), + client_id=merged.get("client_id"), + ) + logger.info( + "GlobusTransferManager created: %s -> %s", + merged["source_endpoint_id"], + merged["destination_endpoint_id"], + ) + return manager diff --git a/src/chemgraph/execution/ensemble_launcher_backend.py b/src/chemgraph/execution/ensemble_launcher_backend.py new file mode 100644 index 00000000..14ec756f --- /dev/null +++ b/src/chemgraph/execution/ensemble_launcher_backend.py @@ -0,0 +1,405 @@ +"""EnsembleLauncher execution backend. + +Wraps `EnsembleLauncher `_ +to conform to the :class:`ExecutionBackend` interface. Uses the +cluster-mode API (``EnsembleLauncher.start()`` + ``ClusterClient``) so +that tasks can be submitted dynamically. + +EnsembleLauncher must be installed separately +(``pip install chemgraphagent[ensemble_launcher]``). +""" + +from __future__ import annotations + +import contextlib +import logging +import os +import subprocess +import sys +import tempfile +import time +import uuid +import json +from concurrent.futures import Future +from typing import List, Literal, Optional, Union + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +try: + from ensemble_launcher import EnsembleLauncher + from ensemble_launcher.config import ( + LauncherConfig, + MPIConfig, + PolicyConfig, + SystemConfig, + ) + from ensemble_launcher.helper_functions import get_nodes + from ensemble_launcher.orchestrator import ClusterClient + + _ENSEMBLE_LAUNCHER_AVAILABLE = True +except ImportError: + EnsembleLauncher = None + LauncherConfig = None + MPIConfig = None + PolicyConfig = None + SystemConfig = None + get_nodes = None + ClusterClient = None + _ENSEMBLE_LAUNCHER_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +def _require_ensemble_launcher() -> None: + if not _ENSEMBLE_LAUNCHER_AVAILABLE: + raise ImportError( + "EnsembleLauncher is required for the EnsembleLauncherBackend. " + "Install it with: pip install ensemble-launcher" + ) + + +@contextlib.contextmanager +def _stdout_to_stderr(): + """Redirect this process's stdout fd to stderr for the duration. + + EnsembleLauncher (and its ``el stop`` helper) prints lifecycle notices + such as "Sent SIGTERM to launcher process …" to stdout. Under a stdio + MCP server stdout IS the JSON-RPC channel, so those lines corrupt the + protocol stream and crash the client's message parser. Redirect at the + fd level (not ``contextlib.redirect_stdout``) so in-process, + library, and inherited-stdout subprocess writes are all caught, then + restore. + """ + try: + saved_fd = os.dup(sys.stdout.fileno()) + except (OSError, ValueError, AttributeError): + # stdout is not a real fd (e.g. captured in tests/notebooks) -- + # nothing to guard. + yield + return + try: + sys.stdout.flush() + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + yield + finally: + sys.stdout.flush() + os.dup2(saved_fd, sys.stdout.fileno()) + os.close(saved_fd) + + +def get_local_system_config(): + _require_ensemble_launcher() + system_config = SystemConfig( + name="local", + ncpus=os.cpu_count(), + cpus=list(range(os.cpu_count())), + ) + return system_config + + +def get_polaris_system_config(): + _require_ensemble_launcher() + system_config = SystemConfig( + name="polaris", + ncpus=32, + cpus=list(range(32)), + ngpus=4, + gpus=list(range(4)), + ) + return system_config + + +def get_aurora_system_config(): + _require_ensemble_launcher() + system_config = SystemConfig( + name="aurora", + ncpus=102, + cpus=list(range(1, 52)) + list(range(53, 104)), + ngpus=12, + gpus=list(range(12)), + ) + return system_config + + +def get_crux_system_config(): + _require_ensemble_launcher() + system_config = SystemConfig( + name="crux", + ncpus=128, + cpus=list(range(128)), + ) + return system_config + + +def get_launcher_config( + task_executor_name: Union[str, List] = "async_processpool", + child_executor_policy: str = "fixed_leafs_children_policy", + policy_config=None, + checkpoint_dir=None, + mpi_flavour: Literal[ + "test", "mpich", "intel", "cray-pals", "openmpi", "srun", "aprun", "jsrun" + ] = "mpich", +): + """Build a LauncherConfig. + + ``mpi_flavour`` defaults to ``"mpich"`` (hydra ``mpiexec``) which is the + multi-node-safe choice for Aurora/Polaris/Crux. Use ``"test"`` only for + single-node local runs — its ``write_file_to_nodes`` does not actually + distribute child-spec JSON to remote ``/tmp``. + """ + _require_ensemble_launcher() + if policy_config is None: + policy_config = PolicyConfig(nlevels=2, leaf_nodes=len(get_nodes())) + if checkpoint_dir is None: + checkpoint_dir = f"{os.getcwd()}/.ckpt_{uuid.uuid4().hex[:6]}" + return LauncherConfig( + child_executor_name="async_mpi", + task_executor_name=task_executor_name, + return_stdout=True, + worker_logs=True, + master_logs=True, + children_scheduler_policy=child_executor_policy, + policy_config=policy_config, + cluster=True, + checkpoint_dir=checkpoint_dir, + mpi_config=MPIConfig(flavor=mpi_flavour), + ) + + +class EnsembleLauncherBackend(ExecutionBackend): + """Execution backend that submits tasks through a :class:`ClusterClient`. + + Supports two initialization modes: + + **Client-only** — connect to a running EnsembleLauncher orchestrator:: + + backend.initialize(checkpoint_dir="/path/to/running/el") + + **Managed** — start a local orchestrator, then connect:: + + backend.initialize(system_config=..., launcher_config=...) + + In both modes the backend submits work through :class:`ClusterClient`. + ``shutdown()`` tears down the client and, in managed mode, stops the + orchestrator. + """ + + def __init__(self) -> None: + _require_ensemble_launcher() + super().__init__() + self._orchestrator = None + self._client = None + + def initialize( + self, + system: str = "local", + *, + client_only: bool = False, + checkpoint_dir: Optional[str] = None, + node_id: str = "global", + system_config: Optional[SystemConfig] = None, + launcher_config: Optional[LauncherConfig] = None, + startup_delay: float = 10.0, + **kwargs, + ) -> None: + """Prepare the backend for accepting work. + + Parameters + ---------- + client_only : bool + When ``True``, connect to a running orchestrator via + *checkpoint_dir* — no orchestrator is started. + checkpoint_dir : str + Path to the orchestrator's checkpoint directory. Required + when *client_only* is ``True``. + node_id : str + Orchestrator node to connect to (default ``"global"``). + system_config, launcher_config + Required for **managed** mode (``client_only=False``). + The backend starts its own orchestrator with these. + startup_delay : float + Seconds to wait for the orchestrator to become ready + (managed mode only). + """ + if client_only: + # -- client-only mode ---------------------------------------------- + if checkpoint_dir is None: + raise ValueError( + "client_only=True requires a checkpoint_dir pointing " + "to a running orchestrator." + ) + self._client = ClusterClient(checkpoint_dir=checkpoint_dir, node_id=node_id) + self._client.start() + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized in client-only mode " + "(checkpoint_dir='%s', node_id='%s')", + checkpoint_dir, + node_id, + ) + else: + # -- managed mode: start orchestrator first ------------------------ + if system_config is None or launcher_config is None: + raise ValueError( + "Managed mode requires system_config and launcher_config " + "(or set client_only=True with a checkpoint_dir)." + ) + os.makedirs(launcher_config.checkpoint_dir, exist_ok=True) + with tempfile.TemporaryDirectory() as tmp_dir: + launcher_config_fname = os.path.join(tmp_dir, "launcher_config.json") + with open(launcher_config_fname, "w") as f: + f.write(launcher_config.model_dump_json()) + system_config_fname = os.path.join(tmp_dir, "system_config.json") + with open(system_config_fname, "w") as f: + f.write(system_config.model_dump_json()) + ensemble_fname = os.path.join(tmp_dir,"ensemble_file.json") + with open(ensemble_fname, "w") as f: + json.dump({"ensembles":{}}, f) + cmd = [ + "el", + "start", + ensemble_fname, + "--system-config-file", + f"{system_config_fname}", + "--launcher-config-file", + f"{launcher_config_fname}", + ] + logger.info(f"Executing {cmd}") + self._orchestrator = subprocess.Popen( + cmd, + stderr=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stdin=subprocess.DEVNULL, + ) + time.sleep(startup_delay) + + if self._orchestrator.poll() is not None: + logger.error( + f"Starting el failed with error code: {self._orchestrator.poll()}" + ) + raise RuntimeError() + + self._client = ClusterClient( + checkpoint_dir=launcher_config.checkpoint_dir, + node_id=node_id, + ) + self._client.start() + self._initialized = True + logger.info( + "EnsembleLauncherBackend initialized in managed mode " + "(system='%s', comm='%s', executor='%s', nodes=%s)", + system_config.name, + launcher_config.comm_name, + launcher_config.task_executor_name, + ) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._client is None: + raise RuntimeError( + "EnsembleLauncherBackend is not initialized. Call initialize() first." + ) + + from ensemble_launcher.ensemble import Task as ELTask + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + el_task = ELTask( + task_id=task.task_id, + nnodes=task.num_nodes, + ppn=task.processes_per_node, + executable=task.callable, + args=task.args or (), + kwargs=task.kwargs or {}, + env=task.env, + ) + return self._client.submit(el_task) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + el_task = ELTask( + task_id=task.task_id, + nnodes=task.num_nodes, + ppn=task.processes_per_node, + executable=task.command, + env=task.env, + ) + return self._client.submit(el_task) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + self._initialized = False + # EnsembleLauncher (in-process teardown and the `el stop` helper) + # prints lifecycle notices to stdout; guard the fd so they don't + # corrupt a stdio MCP server's JSON-RPC channel. + with _stdout_to_stderr(): + client_ok = True + if self._client is not None: + try: + self._client.teardown() + self._client = None + except Exception: + client_ok = False + logger.warning( + "Error tearing down EnsembleLauncher client.", exc_info=True + ) + + p = subprocess.Popen(["el", "stop"]) + try: + p.wait(timeout=10.0) + except Exception: + pass + + orchestrator_ok = True + if self._orchestrator is not None: + try: + self._orchestrator.wait(timeout=10.0) + finally: + if self._orchestrator.poll() is None: + self._orchestrator.kill() + + if client_ok and orchestrator_ok: + logger.info("EnsembleLauncherBackend shut down.") + else: + logger.warning( + "EnsembleLauncherBackend partially shut down. " + "Call shutdown() again to retry failed teardown." + ) + + +_SYSTEM_CONFIG_BUILDERS = { + "local": get_local_system_config, + "aurora": get_aurora_system_config, + "polaris": get_polaris_system_config, + "crux": get_crux_system_config, +} + + +class _LazyRegistry: + """Built-on-first-access mapping of system name -> SystemConfig. + + Avoids importing ``ensemble_launcher`` at module load time. + """ + + def __contains__(self, key: str) -> bool: + return key in _SYSTEM_CONFIG_BUILDERS + + def __getitem__(self, key: str): + if key not in _SYSTEM_CONFIG_BUILDERS: + raise KeyError(key) + return _SYSTEM_CONFIG_BUILDERS[key]() + + def keys(self): + return _SYSTEM_CONFIG_BUILDERS.keys() + + +SYSTEM_CONFIG_REGISTRY = _LazyRegistry() diff --git a/src/chemgraph/execution/globus_compute_backend.py b/src/chemgraph/execution/globus_compute_backend.py new file mode 100644 index 00000000..48cf8143 --- /dev/null +++ b/src/chemgraph/execution/globus_compute_backend.py @@ -0,0 +1,218 @@ +"""Globus Compute execution backend. + +Wraps the `Globus Compute SDK `_ +to conform to the :class:`ExecutionBackend` interface. Python tasks are +dispatched via :meth:`Executor.submit` and shell tasks via +:class:`ShellFunction`. + +Unlike the Parsl and EnsembleLauncher backends, Globus Compute does **not** +require an active PBS/Slurm allocation at submit time. A persistent +Globus Compute *endpoint* daemon running on the HPC login node +automatically provisions and manages batch jobs as tasks arrive. + +**Prerequisites** + +1. Install the SDK: ``pip install chemgraphagent[globus_compute]`` +2. On the HPC system, configure and start an endpoint:: + + globus-compute-endpoint configure chemgraph-polaris + globus-compute-endpoint start chemgraph-polaris + # -> prints the endpoint UUID + +3. Set ``endpoint_id`` in ``config.toml`` or pass it to + :func:`~chemgraph.execution.config.get_backend`. +""" + +from __future__ import annotations + +import logging +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class GlobusComputeBackend(ExecutionBackend): + """Execution backend that delegates work to Globus Compute. + + Configuration + ------------- + The following ``kwargs`` are accepted by :meth:`initialize`: + + ``endpoint_id`` : str **required** + UUID of the Globus Compute endpoint to submit tasks to. + ``amqp_port`` : int, optional + Port for the AMQP result-streaming connection. Defaults to the + SDK default (5671). Set to ``443`` if outbound 5671 is blocked. + """ + + def __init__(self) -> None: + super().__init__() + self._executor = None + self._endpoint_id: str | None = None + self._shares_filesystem = False + + @property + def is_async_remote(self) -> bool: + return True + + @property + def shares_filesystem(self) -> bool: + return self._shares_filesystem + + # ── lifecycle ──────────────────────────────────────────────────────── + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + try: + from globus_compute_sdk import Executor + except ImportError as exc: + raise ImportError( + "globus-compute-sdk is required for the GlobusComputeBackend. " + "Install it with: pip install chemgraphagent[globus_compute]" + ) from exc + + endpoint_id = kwargs.get("endpoint_id") + if not endpoint_id: + raise ValueError( + "GlobusComputeBackend requires an 'endpoint_id'. " + "Set it in config.toml under [execution.globus_compute] " + "or pass it directly to get_backend()." + ) + + executor_kwargs: dict[str, Any] = {"endpoint_id": endpoint_id} + + amqp_port = kwargs.get("amqp_port") + if amqp_port is not None: + executor_kwargs["amqp_port"] = int(amqp_port) + + # Opt-in: a Globus Compute endpoint that shares an HPC filesystem with + # the MCP server can skip inline file embedding and read paths directly. + self._shares_filesystem = bool(kwargs.get("shares_filesystem", False)) + + self._endpoint_id = endpoint_id + self._executor = Executor(**executor_kwargs) + self._initialized = True + logger.info( + "GlobusComputeBackend initialized (system='%s', endpoint='%s')", + system, + endpoint_id, + ) + + # ── task submission ───────────────────────────────────────────────── + + def _ensure_executor(self) -> None: + """Re-create the Executor if it was shut down (e.g. after a + task failure).""" + from globus_compute_sdk import Executor + + if self._executor is None or getattr(self._executor, "_stopped", False): + logger.info("Re-creating Globus Compute Executor") + self._executor = Executor(endpoint_id=self._endpoint_id) + + @staticmethod + def _looks_like_stopped_executor(exc: BaseException) -> bool: + """Heuristic: did a submit fail because the Executor is shut down? + + The SDK does not expose a stable exception type for this state; + we match on common substrings observed in practice. + """ + msg = str(exc).lower() + return "shut down" in msg or "stopped" in msg or "closed" in msg + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._executor is None: + raise RuntimeError( + "GlobusComputeBackend is not initialized. Call initialize() first." + ) + + self._ensure_executor() + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + # Executor.submit() returns a ComputeFuture (a + # concurrent.futures.Future subclass), fully compatible + # with asyncio.wrap_future() used by gather_futures(). + try: + return self._executor.submit(task.callable, *task.args, **task.kwargs) + except Exception as exc: + if not self._looks_like_stopped_executor(exc): + raise + logger.warning( + "Submit raised %s -- rebuilding Globus Compute Executor " + "and retrying once.", + type(exc).__name__, + ) + self._executor = None + self._ensure_executor() + return self._executor.submit(task.callable, *task.args, **task.kwargs) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + from globus_compute_sdk import ShellFunction + + shell_fn = ShellFunction(task.command) + try: + return self._executor.submit(shell_fn) + except Exception as exc: + if not self._looks_like_stopped_executor(exc): + raise + logger.warning( + "Submit raised %s -- rebuilding Globus Compute Executor " + "and retrying once.", + type(exc).__name__, + ) + self._executor = None + self._ensure_executor() + return self._executor.submit(shell_fn) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + # ── health check ──────────────────────────────────────────────────── + + def check_endpoint_status(self) -> dict: + """Check the status of the configured Globus Compute endpoint. + + Returns a dict with ``endpoint_id`` and ``status`` fields. + Useful as a pre-flight check before submitting tasks. + """ + try: + from globus_compute_sdk import Client + + client = Client() + status = client.get_endpoint_status(self._endpoint_id) + return { + "endpoint_id": self._endpoint_id, + "status": status, + } + except Exception as e: + logger.warning( + "Endpoint status check failed: %s", e, exc_info=True, + ) + return { + "endpoint_id": self._endpoint_id, + "status": "error", + "error": str(e), + } + + # ── teardown ──────────────────────────────────────────────────────── + + def shutdown(self) -> None: + if self._executor is not None: + try: + self._executor.shutdown() + logger.info("GlobusComputeBackend shut down.") + except Exception: + logger.warning("Error during Globus Compute shutdown.", exc_info=True) + self._executor = None + self._initialized = False diff --git a/src/chemgraph/execution/globus_transfer.py b/src/chemgraph/execution/globus_transfer.py new file mode 100644 index 00000000..3355a8d1 --- /dev/null +++ b/src/chemgraph/execution/globus_transfer.py @@ -0,0 +1,340 @@ +"""Globus Transfer file-staging manager. + +Transfers files between a local Globus collection and a remote HPC +collection using the `Globus Transfer API +`_. This avoids encoding large +input files (e.g. atomic structures) inside Globus Compute function +payloads. + +**Prerequisites** + +1. Install ``globus_sdk`` (already a core dependency). +2. Have *Globus Connect Personal* running on the submitting machine + **or** use a managed Globus endpoint. +3. Configure endpoint IDs and base path in ``config.toml``:: + + [execution.globus_transfer] + source_endpoint_id = "" + destination_endpoint_id = "" + destination_base_path = "/eagle/projects/MyProject/staging" +""" + +from __future__ import annotations + +import logging +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# Globus Transfer API scope +TRANSFER_SCOPE = "urn:globus:auth:scope:transfer.api.globus.org:all" + +# Default Globus native-app client ID (Globus Tutorial client). +# Projects should register their own app at https://app.globus.org. +_DEFAULT_CLIENT_ID = "61338d24-54d5-408f-a10d-66c06b59f6d2" + + +@dataclass +class TransferResult: + """Metadata returned after submitting a Globus Transfer task.""" + + task_id: str + source_endpoint_id: str + destination_endpoint_id: str + file_mapping: dict[str, str] # local_path -> remote_path + remote_directory: str + submitted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + label: str = "" + + +class GlobusTransferManager: + """Manage file transfers between local and remote Globus collections. + + Parameters + ---------- + source_endpoint_id : str + UUID of the Globus collection on the submitting machine. + destination_endpoint_id : str + UUID of the Globus collection on the HPC system. + destination_base_path : str + Root directory on the destination where staged files are placed. + Each transfer batch creates a subdirectory underneath. + source_base_path : str, optional + If provided, local paths are resolved relative to this directory. + client_id : str, optional + Globus app client ID for OAuth. Defaults to the Globus Tutorial + client. + """ + + def __init__( + self, + source_endpoint_id: str, + destination_endpoint_id: str, + destination_base_path: str, + source_base_path: Optional[str] = None, + client_id: Optional[str] = None, + ) -> None: + self.source_endpoint_id = source_endpoint_id + self.destination_endpoint_id = destination_endpoint_id + self.destination_base_path = destination_base_path.rstrip("/") + self.source_base_path = source_base_path + self._client_id = client_id or _DEFAULT_CLIENT_ID + self._transfer_client = None + + # ── authentication ────────────────────────────────────────────────── + + def _get_transfer_client(self): + """Lazily create an authenticated ``TransferClient``.""" + if self._transfer_client is not None: + return self._transfer_client + + try: + import globus_sdk + except ImportError as exc: + raise ImportError( + "globus_sdk is required for Globus Transfer. " + "Install it with: pip install globus-sdk" + ) from exc + + client = globus_sdk.NativeAppAuthClient(self._client_id) + client.oauth2_start_flow( + requested_scopes=TRANSFER_SCOPE, + refresh_tokens=True, + ) + + # Try loading cached tokens first + token_file = ( + Path.home() / ".globus" / "chemgraph_transfer_tokens.json" + ) + tokens = self._load_tokens(token_file) + + if tokens is None: + # Interactive login required + authorize_url = client.oauth2_get_authorize_url() + logger.info( + "Globus Transfer authentication required.\n" + "Go to this URL and login:\n %s", + authorize_url, + ) + print( + "\nGlobus Transfer authentication required.\n" + f"Go to this URL and login:\n {authorize_url}\n" + ) + auth_code = input("Enter the authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(auth_code) + tokens = token_response.by_resource_server["transfer.api.globus.org"] + self._save_tokens(token_file, tokens) + else: + # Refresh if expired + if tokens.get("expires_at_seconds", 0) < time.time(): + try: + token_response = client.oauth2_refresh_tokens( + globus_sdk.RefreshTokenAuthorizer( + tokens["refresh_token"], client + ) + ) + tokens = token_response.by_resource_server[ + "transfer.api.globus.org" + ] + self._save_tokens(token_file, tokens) + except Exception: + logger.warning( + "Token refresh failed, falling back to existing token." + ) + + authorizer = globus_sdk.AccessTokenAuthorizer(tokens["access_token"]) + self._transfer_client = globus_sdk.TransferClient(authorizer=authorizer) + return self._transfer_client + + @staticmethod + def _load_tokens(path: Path) -> Optional[dict]: + if not path.is_file(): + return None + import json + + try: + with open(path) as f: + return json.load(f) + except (json.JSONDecodeError, KeyError): + return None + + @staticmethod + def _save_tokens(path: Path, tokens: dict) -> None: + import json + + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(dict(tokens), f, indent=2) + path.chmod(0o600) + + # ── transfers ─────────────────────────────────────────────────────── + + def transfer_files( + self, + local_paths: list[str], + remote_subdir: Optional[str] = None, + label: Optional[str] = None, + ) -> TransferResult: + """Submit a Globus Transfer task to stage files on the remote endpoint. + + Parameters + ---------- + local_paths : list[str] + Absolute paths to local files to transfer. + remote_subdir : str, optional + Subdirectory name under ``destination_base_path``. A UUID-based + name is generated if omitted. + label : str, optional + Human-readable label for the transfer task. + + Returns + ------- + TransferResult + Metadata including the Globus task ID and local-to-remote + path mapping. + """ + import globus_sdk + + tc = self._get_transfer_client() + + if remote_subdir is None: + remote_subdir = f"batch_{uuid.uuid4().hex[:12]}" + + remote_dir = f"{self.destination_base_path}/{remote_subdir}" + transfer_label = label or f"ChemGraph file staging ({remote_subdir})" + + tdata = globus_sdk.TransferData( + tc, + self.source_endpoint_id, + self.destination_endpoint_id, + label=transfer_label, + sync_level="checksum", + ) + + # Disambiguate same-basename inputs (e.g. /a/in.cif and /b/in.cif) + # by suffixing duplicates with _1, _2, ... Without this the + # second add_item silently overwrites the first on the + # destination collection. + file_mapping: dict[str, str] = {} + used_names: dict[str, int] = {} + for local_path in local_paths: + p = Path(local_path).resolve() + base = p.name + count = used_names.get(base, 0) + if count == 0: + remote_name = base + else: + stem, dot, suffix = base.partition(".") + remote_name = ( + f"{stem}_{count}.{suffix}" if dot else f"{stem}_{count}" + ) + used_names[base] = count + 1 + remote_path = f"{remote_dir}/{remote_name}" + tdata.add_item(str(p), remote_path) + file_mapping[str(p)] = remote_path + + result = tc.submit_transfer(tdata) + task_id = result["task_id"] + + logger.info( + "Globus Transfer submitted: task_id=%s, %d files -> %s", + task_id, + len(local_paths), + remote_dir, + ) + + return TransferResult( + task_id=task_id, + source_endpoint_id=self.source_endpoint_id, + destination_endpoint_id=self.destination_endpoint_id, + file_mapping=file_mapping, + remote_directory=remote_dir, + label=transfer_label, + ) + + def check_transfer_status(self, task_id: str) -> dict[str, Any]: + """Check the status of a Globus Transfer task. + + Returns + ------- + dict + Keys: ``task_id``, ``status``, ``nice_status``, ``bytes_transferred``, + ``files``, ``files_transferred``. + """ + tc = self._get_transfer_client() + task = tc.get_task(task_id) + return { + "task_id": task_id, + "status": task["status"], + "nice_status": task.get("nice_status", ""), + "bytes_transferred": task.get("bytes_transferred", 0), + "files": task.get("files", 0), + "files_transferred": task.get("files_transferred", 0), + } + + def wait_for_transfer( + self, + task_id: str, + timeout: float = 300, + poll_interval: float = 5, + ) -> dict[str, Any]: + """Block until a transfer completes, fails, or times out. + + Parameters + ---------- + timeout : float + Maximum seconds to wait (default 300). + poll_interval : float + Seconds between status checks (default 5). + + Returns + ------- + dict + Final transfer status. + """ + deadline = time.time() + timeout + while time.time() < deadline: + status = self.check_transfer_status(task_id) + if status["status"] in ("SUCCEEDED", "FAILED"): + return status + time.sleep(poll_interval) + + status = self.check_transfer_status(task_id) + status["timed_out"] = True + return status + + def list_remote_directory(self, path: str) -> list[dict[str, Any]]: + """List files in a directory on the destination endpoint. + + Returns + ------- + list[dict] + Each dict has ``name``, ``type`` ("file" or "dir"), and ``size``. + """ + tc = self._get_transfer_client() + entries = [] + for entry in tc.operation_ls(self.destination_endpoint_id, path=path): + entries.append( + { + "name": entry["name"], + "type": entry["type"], + "size": entry.get("size", 0), + } + ) + return entries + + def get_remote_path( + self, + local_path: str, + remote_subdir: Optional[str] = None, + ) -> str: + """Compute the remote path for a local file.""" + filename = Path(local_path).name + if remote_subdir: + return f"{self.destination_base_path}/{remote_subdir}/{filename}" + return f"{self.destination_base_path}/{filename}" diff --git a/src/chemgraph/execution/job_tracker.py b/src/chemgraph/execution/job_tracker.py new file mode 100644 index 00000000..4efa41b0 --- /dev/null +++ b/src/chemgraph/execution/job_tracker.py @@ -0,0 +1,531 @@ +"""In-memory job tracker for async remote execution backends. + +Tracks ``concurrent.futures.Future`` objects returned by +:meth:`ExecutionBackend.submit` so that MCP tools can return +immediately after submission and provide separate status / result +retrieval endpoints. + +Each MCP server process creates its own ``JobTracker`` instance +(mirroring the existing ``backend = get_backend()`` pattern). + +When a *persist_file* is provided, batch metadata and Globus Compute +task UUIDs are written to a JSON file so that a future session can +reload them and query Globus Compute directly for results. +""" + +from __future__ import annotations + +import json +import logging +import threading +import time +import uuid +from concurrent.futures import Future +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class TrackedTask: + """A single task within a tracked batch.""" + + task_id: str + meta: dict + future: Optional[Future] = None + globus_task_id: Optional[str] = None + result: Optional[dict] = None + + +@dataclass +class TrackedBatch: + """A group of tasks submitted together.""" + + batch_id: str + tool_name: str + submitted_at: datetime + tasks: list[TrackedTask] = field(default_factory=list) + post_fn: Optional[Callable[[dict, Any], dict]] = None + + +class JobTracker: + """Track submitted job batches and their futures. + + Thread-safe: all public methods acquire an internal lock. + + Parameters + ---------- + persist_file : Path or str, optional + Path to a JSON file for persisting batch metadata across + sessions. When set, batches are saved after registration and + after results are cached. On init, existing batches are loaded. + """ + + def __init__(self, persist_file: Optional[Path | str] = None) -> None: + self._batches: dict[str, TrackedBatch] = {} + self._lock = threading.Lock() + self._gc_lock = threading.Lock() + self._persist_file = Path(persist_file) if persist_file else None + self._gc_client = None # lazily initialised Globus Compute Client + + if self._persist_file is not None: + self._load() + + # ── Globus Compute client (lazy) ────────────────────────────────── + + def _get_gc_client(self): + """Return a Globus Compute ``Client`` (created once, reused).""" + if self._gc_client is not None: + return self._gc_client + with self._gc_lock: + if self._gc_client is None: + try: + from globus_compute_sdk import Client + + self._gc_client = Client() + except Exception: + logger.warning( + "Could not create Globus Compute Client", + exc_info=True, + ) + return None + return self._gc_client + + # ── persistence ─────────────────────────────────────────────────── + + def _save(self) -> None: + """Write current batch metadata to *persist_file*.""" + if self._persist_file is None: + return + + data: dict[str, Any] = {} + with self._lock: + for bid, batch in self._batches.items(): + data[bid] = { + "tool_name": batch.tool_name, + "submitted_at": batch.submitted_at.isoformat(), + "tasks": [ + { + "task_id": t.task_id, + "meta": t.meta, + "globus_task_id": t.globus_task_id, + "result": t.result, + } + for t in batch.tasks + ], + } + + self._persist_file.parent.mkdir(parents=True, exist_ok=True) + tmp = self._persist_file.with_suffix(".tmp") + with open(tmp, "w") as f: + json.dump(data, f, indent=2) + tmp.replace(self._persist_file) + + def _load(self) -> None: + """Load batch metadata from *persist_file* (if it exists).""" + if self._persist_file is None or not self._persist_file.is_file(): + return + + try: + with open(self._persist_file) as f: + data = json.load(f) + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Could not load job tracker state: %s", exc) + return + + orphaned: list[tuple[str, str]] = [] # (batch_id, task_id) + with self._lock: + for bid, info in data.items(): + if bid in self._batches: + continue # don't overwrite live batches + + tasks = [] + for t in info.get("tasks", []): + tracked = TrackedTask( + task_id=t["task_id"], + meta=t.get("meta", {}), + future=None, + globus_task_id=t.get("globus_task_id"), + result=t.get("result"), + ) + # Tasks loaded from disk with no globus_task_id and + # no cached result are orphaned -- get_status cannot + # query Globus for them (see line ~320). + if tracked.globus_task_id is None and tracked.result is None: + orphaned.append((bid, tracked.task_id)) + tasks.append(tracked) + + self._batches[bid] = TrackedBatch( + batch_id=bid, + tool_name=info["tool_name"], + submitted_at=datetime.fromisoformat(info["submitted_at"]), + tasks=tasks, + ) + + logger.info( + "Loaded %d batches from %s", len(data), self._persist_file + ) + if orphaned: + logger.warning( + "%d task(s) reloaded without a Globus task_id -- their " + "results cannot be recovered. Examples: %s", + len(orphaned), + ", ".join(f"{b}/{t}" for b, t in orphaned[:5]), + ) + + # ── registration ─────────────────────────────────────────────────── + + def register_batch( + self, + tool_name: str, + pending_tasks: list[tuple[dict, Future]], + post_fn: Optional[Callable[[dict, Any], dict]] = None, + ) -> str: + """Register a batch of submitted tasks and return a batch ID. + + Parameters + ---------- + tool_name : str + Name of the MCP tool that submitted the batch. + pending_tasks : list[tuple[dict, Future]] + Each element is ``(metadata_dict, future)``. + post_fn : callable, optional + Post-processing function applied when collecting results. + Called as ``post_fn(metadata, raw_result) -> dict``. + + Returns + ------- + str + A UUID batch identifier. + """ + batch_id = uuid.uuid4().hex[:12] + tracked = [ + TrackedTask( + task_id=meta.get("task_id", meta.get("structure", f"task_{i}")), + meta=meta, + future=fut, + ) + for i, (meta, fut) in enumerate(pending_tasks) + ] + batch = TrackedBatch( + batch_id=batch_id, + tool_name=tool_name, + submitted_at=datetime.now(timezone.utc), + tasks=tracked, + post_fn=post_fn, + ) + with self._lock: + self._batches[batch_id] = batch + + logger.info( + "Registered batch '%s' (%s) with %d tasks", + batch_id, + tool_name, + len(tracked), + ) + + # Wait briefly for the Executor background thread to set task_ids + # on the ComputeFutures. Typically takes ~1-2 s; we cap at 3 s + # so the MCP tool response isn't delayed excessively. + self._wait_for_globus_task_ids(tracked, timeout=3.0) + self._save() + return batch_id + + def _wait_for_globus_task_ids( + self, tasks: list[TrackedTask], timeout: float = 3.0 + ) -> None: + """Wait up to *timeout* seconds for Globus ``task_id`` to appear + on each ComputeFuture, then store them for persistence.""" + deadline = time.monotonic() + timeout + pending = [t for t in tasks if t.future is not None and t.globus_task_id is None] + + while pending and time.monotonic() < deadline: + still_pending = [] + for t in pending: + gc_id = getattr(t.future, "task_id", None) + if gc_id is not None: + t.globus_task_id = str(gc_id) + else: + still_pending.append(t) + pending = still_pending + if pending: + time.sleep(0.25) + + if pending: + # Promoted from debug -> warning: tasks without a task_id + # at this point will be lost across a server restart, so the + # user should see this immediately rather than only in the + # post-mortem orphan warning at reload time. + logger.warning( + "%d task(s) did not receive a Globus task_id within %.1fs; " + "they will be unrecoverable if the server restarts before " + "the next get_status call", + len(pending), + timeout, + ) + + def _try_capture_globus_task_ids(self, tasks: list[TrackedTask]) -> bool: + """Non-blocking: extract ``task_id`` from any ComputeFuture that + has one available. Returns True if any new IDs were captured.""" + captured = False + for t in tasks: + if t.globus_task_id is None and t.future is not None: + gc_id = getattr(t.future, "task_id", None) + if gc_id is not None: + t.globus_task_id = str(gc_id) + captured = True + return captured + + # ── status ───────────────────────────────────────────────────────── + + def get_status(self, batch_id: str) -> dict: + """Return the current status of a batch. + + For tasks loaded from disk (no in-memory ``Future``), queries + Globus Compute directly if a ``globus_task_id`` is available. + + Returns + ------- + dict + Keys: ``batch_id``, ``tool_name``, ``submitted_at``, + ``status``, ``total_tasks``, ``completed_tasks``, + ``failed_tasks``, ``pending_tasks``, ``progress_pct``. + """ + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + total = len(batch.tasks) + done = 0 + failed = 0 + # Lazily capture Globus Compute task UUIDs (set asynchronously + # by the Executor background thread after submission). + dirty = self._try_capture_globus_task_ids(batch.tasks) + + for t in batch.tasks: + task_done = False + + # --- live future path --- + if t.future is not None and t.future.done(): + task_done = True + if t.result is None: + try: + raw = t.future.result(timeout=0) + if batch.post_fn is not None: + t.result = batch.post_fn(t.meta, raw) + elif isinstance(raw, dict): + merged = {**t.meta, **raw} + merged.setdefault("status", "success") + t.result = merged + else: + t.result = { + **t.meta, + "result": raw, + "status": "success", + } + except Exception as e: + t.result = { + **t.meta, + "status": "failure", + "error_type": type(e).__name__, + "message": str(e), + } + dirty = True + + # --- loaded-from-disk path (no future, use Globus client) --- + elif t.future is None and t.result is None and t.globus_task_id: + gc = self._get_gc_client() + if gc is not None: + try: + task_info = gc.get_task(t.globus_task_id) + if not task_info.get("pending", True): + task_done = True + if "result" in task_info: + raw = task_info["result"] + if isinstance(raw, dict): + merged = {**t.meta, **raw} + merged.setdefault("status", "success") + t.result = merged + else: + t.result = { + **t.meta, + "result": raw, + "status": "success", + } + elif "exception" in task_info: + t.result = { + **t.meta, + "status": "failure", + "error_type": "RemoteException", + "message": str(task_info["exception"]), + } + dirty = True + except Exception as e: + logger.warning( + "Failed to query Globus task %s: %s", + t.globus_task_id, + e, + exc_info=True, + ) + + # --- already have a cached result --- + elif t.result is not None: + task_done = True + + if task_done: + done += 1 + if t.result is not None and t.result.get("status") == "failure": + failed += 1 + + if dirty: + self._save() + + pending = total - done + if pending == total: + status = "pending" + elif pending > 0: + status = "running" + elif failed == total: + status = "failed" + elif failed > 0: + status = "partial" + else: + status = "completed" + + return { + "batch_id": batch_id, + "tool_name": batch.tool_name, + "submitted_at": batch.submitted_at.isoformat(), + "status": status, + "total_tasks": total, + "completed_tasks": done - failed, + "failed_tasks": failed, + "pending_tasks": pending, + "progress_pct": round(done / total * 100, 1) if total else 0.0, + } + + # ── results ──────────────────────────────────────────────────────── + + def get_results( + self, batch_id: str, include_partial: bool = False + ) -> dict: + """Collect results from a batch. + + Parameters + ---------- + batch_id : str + The batch identifier. + include_partial : bool + If ``True``, return results for completed tasks even if some + are still pending. If ``False`` (default) and the batch is + not fully resolved, return a status message instead. + + Returns + ------- + dict + Contains ``status``, ``results`` list, and summary counts. + """ + status_info = self.get_status(batch_id) + if "error" in status_info: + return status_info + + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + if not include_partial and status_info["pending_tasks"] > 0: + return { + **status_info, + "message": ( + f"{status_info['pending_tasks']} of " + f"{status_info['total_tasks']} tasks still pending. " + f"Call check_job_status('{batch_id}') to monitor, " + f"or use include_partial=True to get partial results." + ), + } + + results = [] + for t in batch.tasks: + if t.result is not None: + results.append(t.result) + + return { + **status_info, + "results": results, + } + + # ── listing ──────────────────────────────────────────────────────── + + def list_batches(self) -> list[dict]: + """Return a summary of all tracked batches.""" + with self._lock: + batch_ids = list(self._batches.keys()) + + summaries = [] + for bid in batch_ids: + summaries.append(self.get_status(bid)) + return summaries + + # ── cancellation ─────────────────────────────────────────────────── + + def cancel_batch(self, batch_id: str) -> dict: + """Attempt to cancel pending tasks in a batch. + + Returns a dict with the number of successfully cancelled tasks. + Note: ``Future.cancel()`` only succeeds if the task has not yet + started executing. + """ + with self._lock: + batch = self._batches.get(batch_id) + if batch is None: + return {"error": f"Unknown batch_id: '{batch_id}'"} + + cancelled = 0 + already_done = 0 + for t in batch.tasks: + if t.future is None: + already_done += 1 + elif t.future.done(): + already_done += 1 + elif t.future.cancel(): + cancelled += 1 + + return { + "batch_id": batch_id, + "cancelled": cancelled, + "already_done": already_done, + "could_not_cancel": len(batch.tasks) - cancelled - already_done, + } + + # ── cleanup ──────────────────────────────────────────────────────── + + def cleanup(self, max_age_hours: float = 24) -> int: + """Remove completed batches older than *max_age_hours*. + + Returns the number of batches removed. + """ + now = datetime.now(timezone.utc) + to_remove: list[str] = [] + + with self._lock: + for bid, batch in self._batches.items(): + age_hours = (now - batch.submitted_at).total_seconds() / 3600 + all_done = all( + (t.future is not None and t.future.done()) + or t.result is not None + for t in batch.tasks + ) + if age_hours > max_age_hours and all_done: + to_remove.append(bid) + for bid in to_remove: + del self._batches[bid] + + if to_remove: + logger.info("Cleaned up %d old batches", len(to_remove)) + self._save() + return len(to_remove) diff --git a/src/chemgraph/execution/local_backend.py b/src/chemgraph/execution/local_backend.py new file mode 100644 index 00000000..e9250914 --- /dev/null +++ b/src/chemgraph/execution/local_backend.py @@ -0,0 +1,155 @@ +"""Local execution backend using ``concurrent.futures.ProcessPoolExecutor``. + +Ideal for development, testing, and single-node runs where no HPC +workflow manager is needed. Requires zero external dependencies beyond +the Python standard library. +""" + +from __future__ import annotations + +import logging +import os +import subprocess +import sys +from concurrent.futures import Future, ProcessPoolExecutor +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + +# Default number of worker processes (can be overridden via config). +_DEFAULT_MAX_WORKERS = 4 + + +def _silence_worker_stdout() -> None: + """ProcessPoolExecutor *initializer*: redirect this worker's stdout fd to stderr. + + Used when ``LocalBackend`` runs inside a stdio MCP server, where the + parent process's stdout is the JSON-RPC channel. Worker children inherit + that fd by default, so any unguarded print (e.g. ``mace/tools/cg.py``'s + "cuequivariance ... will be disabled" notice) corrupts the protocol + stream. dup2 redirects this child's stdout fd to its stderr fd so prints + are logged but never reach the client. + """ + try: + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) + except (OSError, ValueError, AttributeError): + # Best-effort: skip silently if the fds aren't real (e.g. in some + # test or notebook contexts where stderr is captured). + pass + + +def _run_shell_task( + command: str, + working_dir: str | None, + stdout_path: str | None, + stderr_path: str | None, +) -> int: + """Execute a shell command in a child process. + + Returns the process exit code. stdout/stderr are captured to + files when paths are provided. + """ + import contextlib + + with ( + open(stdout_path, "w") if stdout_path else contextlib.nullcontext() as stdout_fh, + open(stderr_path, "w") if stderr_path else contextlib.nullcontext() as stderr_fh, + ): + result = subprocess.run( + command, + shell=True, + cwd=working_dir, + stdout=stdout_fh, + stderr=stderr_fh, + check=True, + ) + return result.returncode + + +def _run_python_task( + fn: Any, # Callable -- typed as Any for pickling + args: tuple, + kwargs: dict, +) -> Any: + """Execute a Python callable in a child process.""" + return fn(*args, **kwargs) + + +class LocalBackend(ExecutionBackend): + """Execution backend backed by :class:`ProcessPoolExecutor`. + + Configuration + ------------- + ``max_workers`` : int + Maximum number of concurrent worker processes (default: 4). + """ + + def __init__(self) -> None: + super().__init__() + self._pool: ProcessPoolExecutor | None = None + + def initialize(self, system: str = "local", **kwargs: Any) -> None: + max_workers = kwargs.get("max_workers", _DEFAULT_MAX_WORKERS) + + # Opt-in: silence worker stdout (redirect fd to stderr) so prints + # from worker callables don't pollute a parent's stdout. Required + # when LocalBackend runs under stdio MCP, where the parent's stdout + # IS the JSON-RPC channel. Off by default so notebook/CLI users + # still see prints. Explicit kwarg wins; otherwise env var. + silence = kwargs.get("silence_worker_stdout") + if silence is None: + silence = os.environ.get("CHEMGRAPH_LOCAL_SILENCE_STDOUT") == "1" + + pool_kwargs: dict[str, Any] = {"max_workers": max_workers} + if silence: + pool_kwargs["initializer"] = _silence_worker_stdout + + self._pool = ProcessPoolExecutor(**pool_kwargs) + self._initialized = True + logger.info( + "LocalBackend initialized with %d workers (silence_worker_stdout=%s)", + max_workers, + bool(silence), + ) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized or self._pool is None: + raise RuntimeError( + "LocalBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + return self._pool.submit( + _run_python_task, task.callable, task.args, task.kwargs + ) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + return self._pool.submit( + _run_shell_task, + task.command, + task.working_dir, + task.stdout, + task.stderr, + ) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + if self._pool is not None: + logger.info("Shutting down LocalBackend process pool.") + self._pool.shutdown(wait=True) + self._pool = None + self._initialized = False diff --git a/src/chemgraph/execution/parsl_backend.py b/src/chemgraph/execution/parsl_backend.py new file mode 100644 index 00000000..c1b1c286 --- /dev/null +++ b/src/chemgraph/execution/parsl_backend.py @@ -0,0 +1,134 @@ +"""Parsl execution backend. + +Wraps `Parsl `_ to conform to the +:class:`ExecutionBackend` interface. Python tasks are dispatched via +``@python_app`` and shell tasks via ``@bash_app``. + +Parsl must be installed separately (``pip install chemgraphagent[parsl]``). +""" + +from __future__ import annotations + +import logging +from concurrent.futures import Future +from typing import Any + +from chemgraph.execution.base import ExecutionBackend, TaskSpec + +logger = logging.getLogger(__name__) + + +class ParslBackend(ExecutionBackend): + """Execution backend that delegates work to Parsl. + + Configuration + ------------- + The ``system`` argument passed to :meth:`initialize` is forwarded to + :func:`chemgraph.hpc_configs.loader.load_parsl_config` which returns + the appropriate ``parsl.config.Config``. + + Extra ``kwargs`` are forwarded to the config loader (e.g. + ``worker_init``). + """ + + def __init__(self) -> None: + super().__init__() + self._python_app = None + self._bash_app = None + + def initialize(self, system: str = "polaris", **kwargs: Any) -> None: + try: + import parsl + from parsl import bash_app, python_app + except ImportError as exc: + raise ImportError( + "Parsl is required for the ParslBackend. " + "Install it with: pip install chemgraphagent[parsl]" + ) from exc + + from chemgraph.hpc_configs.loader import load_parsl_config + + run_dir = kwargs.pop("run_dir", None) + worker_init = kwargs.pop("worker_init", None) + + # Build kwargs for the config loader + loader_kwargs: dict[str, Any] = {} + if run_dir is not None: + loader_kwargs["run_dir"] = run_dir + if worker_init is not None: + loader_kwargs["worker_init"] = worker_init + + config = load_parsl_config(system, **loader_kwargs) + parsl.load(config) + + # Create generic app wrappers ------------------------------------------ + # These are created once and reused for all submitted tasks. + + @python_app + def _generic_python_app(fn, args, kwargs): + """Execute an arbitrary callable on a Parsl worker.""" + return fn(*args, **kwargs) + + @bash_app + def _generic_bash_app(command, stdout=None, stderr=None): + """Execute a shell command string on a Parsl worker.""" + return command + + self._python_app = _generic_python_app + self._bash_app = _generic_bash_app + + self._initialized = True + logger.info("ParslBackend initialized for system '%s'", system) + + def submit(self, task: TaskSpec) -> Future: + if not self._initialized: + raise RuntimeError( + "ParslBackend is not initialized. Call initialize() first." + ) + + if task.task_type == "python": + if task.callable is None: + raise ValueError( + f"Task '{task.task_id}': task_type='python' requires a callable." + ) + from chemgraph.execution.utils import to_picklable + + return self._python_app( + task.callable, to_picklable(task.args), to_picklable(task.kwargs) + ) + + elif task.task_type == "shell": + if task.command is None: + raise ValueError( + f"Task '{task.task_id}': task_type='shell' requires a command." + ) + bash_kwargs: dict[str, Any] = {"command": task.command} + if task.stdout: + bash_kwargs["stdout"] = task.stdout + if task.stderr: + bash_kwargs["stderr"] = task.stderr + return self._bash_app(**bash_kwargs) + + else: + raise ValueError( + f"Task '{task.task_id}': unsupported task_type '{task.task_type}'." + ) + + def shutdown(self) -> None: + if self._initialized: + try: + import parsl + + # cleanup() stops executors and releases resources; + # clear() only removes the DFK from the global registry. + # Without cleanup(), Parsl logs + # "Python is exiting with a DFK still running" at interpreter exit. + try: + parsl.dfk().cleanup() + except Exception: + logger.warning("Error during Parsl DFK cleanup.", exc_info=True) + parsl.clear() + logger.info("ParslBackend shut down.") + except Exception: + logger.warning("Error during Parsl shutdown.", exc_info=True) + self._initialized = False diff --git a/src/chemgraph/execution/utils.py b/src/chemgraph/execution/utils.py new file mode 100644 index 00000000..c7a0ed0b --- /dev/null +++ b/src/chemgraph/execution/utils.py @@ -0,0 +1,268 @@ +"""Shared utilities for ensemble execution in MCP servers. + +Consolidates patterns that were previously duplicated across +``graspa_mcp_parsl.py``, ``xanes_mcp_parsl.py``, and +``mace_mcp_parsl.py``: + +* Structure file resolution from directory or file list +* Async future gathering with error handling +* JSONL result writing +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from concurrent.futures import Future +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional + +from pydantic import BaseModel + +if TYPE_CHECKING: + from chemgraph.execution.base import ExecutionBackend + from chemgraph.execution.job_tracker import JobTracker + +logger = logging.getLogger(__name__) + + +def to_picklable(value: Any) -> Any: + """Recursively convert Pydantic instances to plain dicts. + + FastMCP's ``func_metadata`` builds tool-argument models with + ``pydantic.create_model`` and a ``__module__`` that does not actually + contain the class, so cloudpickle cannot serialize instances of those + classes to a Parsl/Globus-Compute worker. Converting every Pydantic + instance to a dict at the framework boundary side-steps the problem + without patching the third-party library. + + Containers (``dict``, ``list``, ``tuple``) are walked recursively and + rebuilt with the same shape; everything else passes through unchanged. + """ + if isinstance(value, BaseModel): + return value.model_dump() + if isinstance(value, dict): + return {k: to_picklable(v) for k, v in value.items()} + if isinstance(value, list): + return [to_picklable(v) for v in value] + if isinstance(value, tuple): + return tuple(to_picklable(v) for v in value) + return value + + +def resolve_structure_files( + input_source: str | list[str], + extensions: set[str] | None = None, +) -> tuple[list[Path], Path]: + """Resolve a directory path or file list into a list of structure files. + + Parameters + ---------- + input_source : str or list[str] + Either a directory path (all matching files will be collected) + or an explicit list of file paths. + extensions : set[str], optional + File extensions to include when scanning a directory (e.g. + ``{".cif", ".xyz"}``). If *None*, all files are included. + + Returns + ------- + structure_files : list[Path] + Sorted list of resolved file paths. + output_dir : Path + The parent directory (useful for placing output files). + + Raises + ------ + ValueError + If no files are found or if listed files do not exist. + """ + structure_files: list[Path] = [] + output_dir: Path = Path.cwd() + + if isinstance(input_source, list): + structure_files = [Path(p) for p in input_source] + missing = [p for p in structure_files if not p.exists()] + if missing: + raise ValueError(f"The following input files are missing: {missing}") + if structure_files: + output_dir = structure_files[0].parent + else: + input_dir = Path(input_source) + if not input_dir.is_dir(): + raise ValueError(f"'{input_dir}' is not a valid directory.") + + if extensions: + structure_files = sorted( + p for p in input_dir.iterdir() if p.suffix in extensions + ) + else: + structure_files = sorted(p for p in input_dir.iterdir() if p.is_file()) + + output_dir = input_dir + + if not structure_files: + raise ValueError("No structure files found to simulate.") + + return structure_files, output_dir + + +async def gather_futures( + pending: list[tuple[dict, Future]], + post_fn: Optional[Callable[[dict, Any], dict]] = None, + timeout: Optional[float] = None, +) -> list[dict]: + """Await a list of ``(metadata, future)`` pairs concurrently. + + Each future is converted to an asyncio-awaitable via + :func:`asyncio.wrap_future` and gathered concurrently. + + Parameters + ---------- + pending : list[tuple[dict, Future]] + Each element is ``(task_metadata_dict, concurrent_futures_Future)``. + post_fn : callable, optional + If provided, called as ``post_fn(metadata, result)`` after a + successful future resolution. Must return a ``dict`` to include + in the results list. When *None*, the raw result is merged with + metadata. + timeout : float, optional + Maximum seconds to wait for all futures to resolve. If the + timeout expires, an :class:`asyncio.TimeoutError` is raised. + When *None* (default), wait indefinitely. + + Returns + ------- + list[dict] + One result dict per task (successful or failed). + + Raises + ------ + asyncio.TimeoutError + If *timeout* is set and exceeded before all futures complete. + """ + + async def _wait(meta: dict, fut: Future) -> dict: + try: + result = await asyncio.wrap_future(fut) + if post_fn is not None: + return post_fn(meta, result) + # Default: merge metadata with result (if result is a dict) + if isinstance(result, dict): + merged = {**meta, **result} + merged.setdefault("status", "success") + return merged + return {**meta, "result": result, "status": "success"} + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": str(e), + } + + coro = asyncio.gather(*(_wait(meta, fut) for meta, fut in pending)) + if timeout is not None: + return list(await asyncio.wait_for(coro, timeout=timeout)) + return list(await coro) + + +async def submit_or_gather( + backend: ExecutionBackend, + pending: list[tuple[dict, Future]], + tracker: JobTracker, + tool_name: str, + post_fn: Optional[Callable[[dict, Any], dict]] = None, +) -> dict: + """Gather results or register for async tracking, depending on the backend. + + When ``backend.is_async_remote`` is ``True``, the pending futures are + registered with the *tracker* and a submission confirmation is + returned immediately (non-blocking). Otherwise, results are gathered + synchronously via :func:`gather_futures`. + + Parameters + ---------- + backend : ExecutionBackend + The active execution backend. + pending : list[tuple[dict, Future]] + Each element is ``(metadata_dict, future)``. + tracker : JobTracker + The job tracker instance to register batches with. + tool_name : str + Name of the MCP tool submitting the batch. + post_fn : callable, optional + Post-processing function for results. + + Returns + ------- + dict + Either ``{"status": "submitted", "batch_id": ..., ...}`` for + async backends, or ``{"status": "completed", "results": ...}`` + for synchronous backends. + """ + if backend.is_async_remote: + batch_id = tracker.register_batch(tool_name, pending, post_fn=post_fn) + return { + "status": "submitted", + "batch_id": batch_id, + "n_tasks": len(pending), + "message": ( + f"Submitted {len(pending)} task(s) to remote HPC endpoint. " + f"Use check_job_status(batch_id='{batch_id}') to monitor " + f"progress, and get_job_results(batch_id='{batch_id}') to " + f"retrieve results once complete." + ), + } + + results = await gather_futures(pending, post_fn=post_fn) + return {"status": "completed", "results": results} + + +def write_results_jsonl( + results: list[dict], + output_path: Path, + append: bool = True, +) -> tuple[int, int]: + """Write results to a JSONL file and return (success_count, total_count). + + Parameters + ---------- + results : list[dict] + Each dict should contain a ``"status"`` key. + output_path : Path + Path to the JSONL file. + append : bool + If *True* (default), append to an existing file. + + Returns + ------- + success_count : int + total_count : int + """ + mode = "a" if append else "w" + success_count = 0 + + with open(output_path, mode, encoding="utf-8") as f: + for res in results: + if res.get("status") == "success": + success_count += 1 + f.write(json.dumps(res) + "\n") + + return success_count, len(results) + + +def make_per_structure_output( + struct_path: Path, + base_output: Path, +) -> Path: + """Generate a per-structure output filename. + + Given ``struct_path = "/data/MOF-5.cif"`` and + ``base_output = "/results/output.json"``, returns + ``"/results/MOF-5_output.json"``. + """ + base_suffix = base_output.suffix or ".json" + base_stem = base_output.stem + return base_output.with_name(f"{struct_path.stem}_{base_stem}{base_suffix}") diff --git a/src/chemgraph/graphs/single_agent.py b/src/chemgraph/graphs/single_agent.py index 7be83d71..25404d02 100644 --- a/src/chemgraph/graphs/single_agent.py +++ b/src/chemgraph/graphs/single_agent.py @@ -1,4 +1,5 @@ import json +from collections.abc import Collection from langgraph.graph import StateGraph, START, END from langchain_openai import ChatOpenAI @@ -114,6 +115,36 @@ def _tool_message_content(message): return getattr(message, "content", "") +def _message_tool_calls(message) -> list: + """Extract tool calls from a message-like object.""" + if isinstance(message, dict): + calls = message.get("tool_calls") + else: + calls = getattr(message, "tool_calls", None) + return calls if isinstance(calls, list) else [] + + +def _state_messages(state: State): + """Extract messages from a LangGraph state or message list.""" + if isinstance(state, list): + return state + if messages := state.get("messages", []): + return messages + raise ValueError(f"No messages found in input state to tool_edge: {state}") + + +def _tool_result_names_after_latest_ai_tool_call(messages) -> set[str]: + """Return tool-result names appended after the latest AI tool-call message.""" + names: set[str] = set() + for message in reversed(messages): + if _message_tool_calls(message): + return names + name = _tool_message_name(message) + if name: + names.add(str(name)) + return names + + def _is_successful_report_message(message) -> bool: """Return True when a message indicates successful report generation. @@ -152,19 +183,29 @@ def route_tools(state: State): str Either 'tools' or 'done' based on the state conditions """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: + messages = _state_messages(state) + ai_message = messages[-1] + if _message_tool_calls(ai_message): if not isinstance(state, list) and _is_repeated_tool_cycle(messages): return "done" return "tools" return "done" +def route_after_tools( + state: State, + terminal_tool_names: Collection[str] = (), +): + """Stop the graph after terminal tools; otherwise continue to the LLM.""" + if not terminal_tool_names: + return "continue" + executed_names = _tool_result_names_after_latest_ai_tool_call( + _state_messages(state), + ) + terminal_names = {str(name) for name in terminal_tool_names} + return "done" if executed_names & terminal_names else "continue" + + def route_report_tools(state: State): """Route report tool execution and stop if a report was already generated. @@ -186,14 +227,15 @@ def route_report_tools(state: State): else: raise ValueError(f"No messages found in input state to tool_edge: {state}") - if not (hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0): + tool_calls = _message_tool_calls(ai_message) + if not tool_calls: return "done" # Only allow known report tool calls to reach ToolNode. valid_report_tools = {"generate_html"} requested_tools = { call.get("name") - for call in getattr(ai_message, "tool_calls", []) + for call in tool_calls if isinstance(call, dict) } if not requested_tools or not requested_tools.issubset(valid_report_tools): @@ -411,6 +453,7 @@ def construct_single_agent_graph( tools: list = None, max_retries: int = 1, human_supervised: bool = False, + terminal_tool_names: Collection[str] = (), ): """Construct a geometry optimization graph. @@ -436,6 +479,9 @@ def construct_single_agent_graph( human_supervised : bool, optional Whether to include the ``ask_human`` tool so the agent can pause and request human input, by default False + terminal_tool_names : Collection[str], optional + Tool names that should terminate the graph after successful tool + execution instead of routing back to the LLM, by default empty. Returns ------- @@ -491,7 +537,11 @@ def construct_single_agent_graph( route_tools, {"tools": "tools", "done": "ReportAgent"}, ) - graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_conditional_edges( + "tools", + lambda state: route_after_tools(state, terminal_tool_names), + {"continue": "ChemGraphAgent", "done": END}, + ) graph_builder.add_conditional_edges( "ReportAgent", route_report_tools, @@ -508,7 +558,11 @@ def construct_single_agent_graph( route_tools, {"tools": "tools", "done": END}, ) - graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_conditional_edges( + "tools", + lambda state: route_after_tools(state, terminal_tool_names), + {"continue": "ChemGraphAgent", "done": END}, + ) graph = graph_builder.compile(checkpointer=checkpointer) logger.info("Graph construction completed") @@ -539,7 +593,11 @@ def construct_single_agent_graph( route_tools, {"tools": "tools", "done": "ResponseAgent"}, ) - graph_builder.add_edge("tools", "ChemGraphAgent") + graph_builder.add_conditional_edges( + "tools", + lambda state: route_after_tools(state, terminal_tool_names), + {"continue": "ChemGraphAgent", "done": END}, + ) graph_builder.add_edge(START, "ChemGraphAgent") graph_builder.add_edge("ResponseAgent", END) diff --git a/src/chemgraph/graphs/single_agent_architector.py b/src/chemgraph/graphs/single_agent_architector.py deleted file mode 100644 index 9e61747d..00000000 --- a/src/chemgraph/graphs/single_agent_architector.py +++ /dev/null @@ -1,143 +0,0 @@ -from langgraph.graph import StateGraph, START, END -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode -from chemgraph.tools.cheminformatics_tools import ( - molecule_name_to_smiles, - smiles_to_coordinate_file, -) - -from chemgraph.tools.architector_tools import ( - visualize_molecule, - image_to_connection_points, - build_metal_complex -) -from chemgraph.utils.logging_config import setup_logger -from chemgraph.state.state import State - -logger = setup_logger(__name__) - -single_agent_prompt = "" - -def route_tools(state: State): - """Route to the 'tools' node if the last message has tool calls; otherwise, route to 'done'. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - - Returns - ------- - str - Either 'tools' or 'done' based on the state conditions - """ - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "done" - - -def ChemGraphAgent(state: State, llm: ChatOpenAI, system_prompt: str, tools=None): - """LLM node that processes messages and decides next actions. - - Parameters - ---------- - state : State - The current state containing messages and remaining steps - llm : ChatOpenAI - The language model to use for processing - system_prompt : str - The system prompt to guide the LLM's behavior - tools : list, optional - List of tools available to the agent, by default None - - Returns - ------- - dict - Updated state containing the LLM's response - """ - - # Load default tools if no tool is specified. - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - visualize_molecule, - image_to_connection_points, - build_metal_complex - ] - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"{state['messages']}"}, - ] - llm_with_tools = llm.bind_tools(tools=tools) - return {"messages": [llm_with_tools.invoke(messages)]} - -def construct_single_agent_architector_graph( - llm: ChatOpenAI, - system_prompt: str = "", - tools: list = None, -): - """Construct a geometry optimization graph. - - Parameters - ---------- - llm : ChatOpenAI - The language model to use for the graph - system_prompt : str, optional - The system prompt to guide the LLM's behavior, by default single_agent_prompt - structured_output : bool, optional - Whether to use structured output, by default False - formatter_prompt : str, optional - The prompt to guide the LLM's formatting behavior, by default formatter_prompt - generate_report: bool, optional - Whether to generate a report, by default False - report_prompt: str, optional - The prompt to guide the LLM's report generation behavior, by default report_prompt - tool: list, optional - The list of tools for the main agent, by default None - Returns - ------- - StateGraph - The constructed single agent graph - """ - try: - logger.info("Constructing single agent graph") - checkpointer = MemorySaver() - if tools is None: - tools = [ - molecule_name_to_smiles, - smiles_to_coordinate_file, - visualize_molecule, - image_to_connection_points, - build_metal_complex - ] - tool_node = ToolNode(tools=tools) - graph_builder = StateGraph(State) - - graph_builder.add_node( - "ChemGraphAgent", - lambda state: ChemGraphAgent(state, llm, system_prompt=system_prompt, tools=tools), - ) - graph_builder.add_node("tools", tool_node) - graph_builder.add_edge(START, "ChemGraphAgent") - graph_builder.add_conditional_edges( - "ChemGraphAgent", - route_tools, - {"tools": "tools", "done": END}, - ) - graph_builder.add_edge("tools", "ChemGraphAgent") - graph_builder.add_edge("ChemGraphAgent", END) - - graph = graph_builder.compile(checkpointer=checkpointer) - logger.info("Graph construction completed") - return graph - except Exception as e: - logger.error(f"Error constructing graph: {str(e)}") - raise diff --git a/src/chemgraph/hpc_configs/__init__.py b/src/chemgraph/hpc_configs/__init__.py new file mode 100644 index 00000000..32d8bc92 --- /dev/null +++ b/src/chemgraph/hpc_configs/__init__.py @@ -0,0 +1 @@ +"""HPC configuration factories for workflow managers.""" diff --git a/src/chemgraph/hpc_configs/aurora_parsl.py b/src/chemgraph/hpc_configs/aurora_parsl.py index 61793aaf..22824e38 100644 --- a/src/chemgraph/hpc_configs/aurora_parsl.py +++ b/src/chemgraph/hpc_configs/aurora_parsl.py @@ -5,9 +5,13 @@ from parsl.launchers import MpiExecLauncher from parsl.addresses import address_by_interface +from chemgraph.hpc_configs.loader import resolve_worker_init + def get_aurora_config( run_dir=None, + worker_init: str | None = None, + max_workers_per_node: int | None = None, ): """Create a Parsl configuration for Aurora PBS jobs. @@ -15,6 +19,11 @@ def get_aurora_config( ---------- run_dir : str, optional Directory used as Parsl's run directory and worker working directory. + worker_init : str, optional + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over the Aurora fallback + (``module load frameworks``). Returns ------- @@ -24,8 +33,13 @@ def get_aurora_config( if run_dir is None: run_dir = os.getcwd() - # Hard-wired worker_init for aurora - worker_init = f"export TMPDIR=/tmp; cd {run_dir}; module load frameworks" + if worker_init is None: + worker_init = resolve_worker_init(run_dir, fallback="module load frameworks") + + if max_workers_per_node is None: + max_workers_per_node = int( + os.getenv("CHEMGRAPH_PARSL_MAX_WORKERS_PER_NODE", "9") + ) # Get the number of nodes: node_file = os.getenv("PBS_NODEFILE") @@ -34,9 +48,9 @@ def get_aurora_config( node_list = f.readlines() num_nodes = len(node_list) else: - # Fallback for testing/local runs without PBS - raise ValueError("Warning: PBS_NODEFILE not found. Defaulting to 1 node.") - num_nodes = 1 + raise ValueError( + "PBS_NODEFILE not found. Cannot determine node count for Aurora." + ) config = Config( executors=[ @@ -45,7 +59,7 @@ def get_aurora_config( heartbeat_period=30, heartbeat_threshold=240, available_accelerators=12, - max_workers_per_node=9, + max_workers_per_node=max_workers_per_node, address=address_by_interface('bond0'), provider=LocalProvider( nodes_per_block=num_nodes, diff --git a/src/chemgraph/hpc_configs/crux_parsl.py b/src/chemgraph/hpc_configs/crux_parsl.py new file mode 100644 index 00000000..e753ed3e --- /dev/null +++ b/src/chemgraph/hpc_configs/crux_parsl.py @@ -0,0 +1,76 @@ +import os +from parsl.config import Config +from parsl.providers import LocalProvider +from parsl.executors import HighThroughputExecutor +from parsl.launchers import MpiExecLauncher + +from chemgraph.hpc_configs.loader import resolve_worker_init + + +def get_crux_config( + run_dir=None, + max_workers_per_node: int = 16, + worker_init: str | None = None, +): + """Create a Parsl configuration for ALCF Crux PBS jobs. + + Crux is a CPU-only AMD EPYC system (no accelerators). + + Parameters + ---------- + run_dir : str, optional + Directory used as Parsl's run directory and worker working directory. + max_workers_per_node : int, optional + Number of concurrent workers per node. Defaults to 16 + (≈8 cores per worker on a 128-core node). + worker_init : str, optional + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over the Crux fallback. + + Returns + ------- + parsl.config.Config + Configured Parsl ``Config`` for Crux. + """ + if run_dir is None: + run_dir = os.getcwd() + + if worker_init is None: + worker_init = resolve_worker_init( + run_dir, fallback="module load conda; conda activate base" + ) + + node_file = os.getenv("PBS_NODEFILE") + if node_file and os.path.exists(node_file): + with open(node_file, "r", encoding="utf-8") as f: + node_list = f.readlines() + num_nodes = len(node_list) + else: + raise ValueError( + "PBS_NODEFILE not found. Cannot determine node count for Crux." + ) + + config = Config( + executors=[ + HighThroughputExecutor( + label="htex", + heartbeat_period=30, + heartbeat_threshold=240, + max_workers_per_node=max_workers_per_node, + provider=LocalProvider( + nodes_per_block=num_nodes, + launcher=MpiExecLauncher( + bind_cmd="--cpu-bind", overrides="--ppn 1" + ), + init_blocks=1, + worker_init=worker_init, + max_blocks=1, + min_blocks=0, + ), + ) + ], + run_dir=run_dir, + ) + + return config diff --git a/src/chemgraph/hpc_configs/loader.py b/src/chemgraph/hpc_configs/loader.py new file mode 100644 index 00000000..29a71a10 --- /dev/null +++ b/src/chemgraph/hpc_configs/loader.py @@ -0,0 +1,107 @@ +"""Unified loader for HPC-specific Parsl configurations. + +This consolidates the ``load_parsl_config()`` function that was +previously duplicated across ``graspa_mcp_parsl.py`` and +``xanes_mcp_parsl.py``. +""" + +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + + +def resolve_worker_init(run_dir: str, fallback: str) -> str: + """Build a Parsl ``worker_init`` shell snippet with layered precedence. + + Precedence (highest first): + + 1. Environment variable ``CHEMGRAPH_WORKER_INIT`` -- if set and non-empty, + used verbatim. Lets a user point Parsl workers at any env without + editing code. + 2. Auto-detect the submitting process's Python env and emit an activate + line for it (``VIRTUAL_ENV`` then ``CONDA_PREFIX``). The agent / MCP + subprocess runs from this env, so workers should too. + 3. The system-specific *fallback* string passed by the caller (e.g. + ``"module load conda; conda activate base"`` on Crux). + + The returned string is always prefixed with ``export TMPDIR=/tmp; + cd {run_dir};`` so Parsl workers land in the same directory the + submitter chose. + """ + override = os.environ.get("CHEMGRAPH_WORKER_INIT", "").strip() + if override: + activate = override + else: + venv = os.environ.get("VIRTUAL_ENV", "").strip() + conda_prefix = os.environ.get("CONDA_PREFIX", "").strip() + conda_env = os.environ.get("CONDA_DEFAULT_ENV", "").strip() + if venv: + activate = f"source {venv}/bin/activate" + elif conda_prefix and conda_env: + activate = ( + f"source {conda_prefix}/etc/profile.d/conda.sh && " + f"conda activate {conda_env}" + ) + else: + activate = fallback + return f"export TMPDIR=/tmp; cd {run_dir}; {activate}" + + +def load_parsl_config(system_name: str, run_dir: str | None = None, **kwargs): + """Dynamically import and return a Parsl ``Config`` for the given HPC system. + + Parameters + ---------- + system_name : str + Target system name. Supported: ``"local"``, ``"polaris"``, + ``"aurora"``, ``"crux"``. + run_dir : str, optional + Parsl run directory. Defaults to the current working directory. + **kwargs + Extra keyword arguments forwarded to the system-specific + config factory (e.g. ``worker_init``, ``max_workers``). + + Returns + ------- + parsl.config.Config + A ready-to-use Parsl configuration object. + + Raises + ------ + ValueError + If *system_name* is not recognised. + """ + system_name = system_name.lower().strip() + if run_dir is None: + run_dir = os.getcwd() + + logger.info("Loading Parsl config for system: %s", system_name) + + if system_name == "local": + from chemgraph.hpc_configs.local_parsl import get_local_config + + return get_local_config(run_dir=run_dir, **kwargs) + + elif system_name == "polaris": + from chemgraph.hpc_configs.polaris_parsl import get_polaris_config + + return get_polaris_config(run_dir=run_dir, **kwargs) + + elif system_name == "aurora": + from chemgraph.hpc_configs.aurora_parsl import get_aurora_config + + return get_aurora_config(run_dir=run_dir, **kwargs) + + elif system_name == "crux": + from chemgraph.hpc_configs.crux_parsl import get_crux_config + + return get_crux_config(run_dir=run_dir, **kwargs) + + else: + raise ValueError( + f"Unknown HPC system: '{system_name}'. " + f"Supported systems: local, polaris, aurora, crux" + ) diff --git a/src/chemgraph/hpc_configs/local_parsl.py b/src/chemgraph/hpc_configs/local_parsl.py new file mode 100644 index 00000000..ac4f61ff --- /dev/null +++ b/src/chemgraph/hpc_configs/local_parsl.py @@ -0,0 +1,67 @@ +"""Local Parsl configuration for development and single-node runs. + +Uses ``HighThroughputExecutor`` with a ``LocalProvider`` (no MPI +launcher, no PBS/Slurm dependency). Suitable for laptops, CI runners, +and single-node workstations where the Parsl backend is desired but no +HPC scheduler is available. +""" + +from __future__ import annotations + +import logging +import os + +from parsl.config import Config +from parsl.executors import HighThroughputExecutor +from parsl.providers import LocalProvider + +from chemgraph.hpc_configs.loader import resolve_worker_init + +logger = logging.getLogger(__name__) + +_DEFAULT_MAX_WORKERS = 4 + + +def get_local_config( + run_dir: str | None = None, + max_workers: int = _DEFAULT_MAX_WORKERS, + worker_init: str | None = None, +) -> Config: + """Generate a Parsl configuration for local execution. + + Parameters + ---------- + run_dir : str, optional + Parsl run directory. Defaults to the current working directory. + max_workers : int, optional + Maximum number of concurrent workers. Default: 4. + worker_init : str, optional + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over a noop fallback. + """ + if run_dir is None: + run_dir = os.getcwd() + + if worker_init is None: + worker_init = resolve_worker_init(run_dir, fallback="true") + + logger.info("Creating local Parsl config with %d workers", max_workers) + + config = Config( + executors=[ + HighThroughputExecutor( + label="local_htex", + max_workers_per_node=max_workers, + provider=LocalProvider( + init_blocks=1, + min_blocks=0, + max_blocks=1, + worker_init=worker_init, + ), + ), + ], + run_dir=run_dir, + ) + + return config diff --git a/src/chemgraph/hpc_configs/polaris_parsl.py b/src/chemgraph/hpc_configs/polaris_parsl.py index ef60f207..bdaa9075 100644 --- a/src/chemgraph/hpc_configs/polaris_parsl.py +++ b/src/chemgraph/hpc_configs/polaris_parsl.py @@ -4,10 +4,12 @@ from parsl.executors import HighThroughputExecutor from parsl.launchers import MpiExecLauncher +from chemgraph.hpc_configs.loader import resolve_worker_init + def get_polaris_config( run_dir=None, - worker_init: str = "export TMPDIR=/tmp", + worker_init: str | None = None, ): """Generate the Parsl configuration for the Polaris supercomputer. @@ -16,7 +18,10 @@ def get_polaris_config( run_dir : str, optional Directory used as Parsl's run directory. worker_init : str, optional - Shell initialization snippet run by each Parsl worker. + Explicit shell snippet for worker init. When ``None`` (default), + :func:`resolve_worker_init` picks ``CHEMGRAPH_WORKER_INIT`` / + ``VIRTUAL_ENV`` / ``CONDA_PREFIX`` over a bare ``export TMPDIR=/tmp`` + fallback. Returns ------- @@ -26,6 +31,9 @@ def get_polaris_config( if run_dir is None: run_dir = os.getcwd() + if worker_init is None: + worker_init = resolve_worker_init(run_dir, fallback="true") + # Get the number of nodes from the PBS environment node_file = os.getenv("PBS_NODEFILE") if node_file and os.path.exists(node_file): diff --git a/src/chemgraph/mcp/cg_fastmcp.py b/src/chemgraph/mcp/cg_fastmcp.py new file mode 100644 index 00000000..a653443f --- /dev/null +++ b/src/chemgraph/mcp/cg_fastmcp.py @@ -0,0 +1,636 @@ +"""Backend-aware FastMCP subclass for ChemGraph. + +:class:`CGFastMCP` extends :class:`FastMCP` with an execution backend. +Tools registered via :meth:`tool` are automatically submitted to the +backend as :class:`~chemgraph.execution.base.TaskSpec` instances — +the tool author writes a plain function and the framework handles +submission, future resolution, and async job tracking. + +Tools that do **not** need the backend (e.g. JSON loaders, plotting +utilities) should be registered with :meth:`add_tool` (inherited from +FastMCP) which bypasses the backend wrapper entirely. +""" + +import asyncio +import functools +import inspect +import logging +from typing import Any, Callable, Dict, Optional + +from mcp.server.fastmcp import FastMCP +from mcp.types import ToolAnnotations + +logger = logging.getLogger(__name__) + + +def _register_fastmcp_dynamic_models() -> None: + """Make pydantic models built by ``fastmcp.func_metadata`` pickle-by-qualname. + + FastMCP builds per-tool ``Arguments`` / ``Output`` classes via + ``pydantic.create_model(__module__="mcp.server.fastmcp.utilities.func_metadata")`` + but never inserts them into that module's namespace. Dill's by-qualname + lookup then fails and either raises ``PicklingError`` or falls back to + pickle-by-value, which walks ``__globals__`` and can hit other surprises. + Wrapping ``func_metadata`` so the resulting models are inserted into the + module's ``__dict__`` makes the lookup succeed regardless of how the + pickle graph reaches the class. + """ + import sys + + from mcp.server.fastmcp.utilities import func_metadata as _fm + + if getattr(_fm, "_chemgraph_models_registered", False): + return + + _orig = _fm.func_metadata + _mod_ns = sys.modules[_fm.__name__].__dict__ + + def _register(model): + if model is None: + return + name = getattr(model, "__name__", None) + if name and name not in _mod_ns: + _mod_ns[name] = model + try: + model.__module__ = _fm.__name__ + except (AttributeError, TypeError): + pass + + def _patched(*args, **kwargs): + meta = _orig(*args, **kwargs) + _register(getattr(meta, "arg_model", None)) + _register(getattr(meta, "output_model", None)) + return meta + + _fm.func_metadata = _patched + # Several fastmcp modules captured the original via + # ``from mcp.server.fastmcp.utilities.func_metadata import func_metadata`` + # before this patch ran, so they hold their own bound name. Rebind the + # name in each known call site so every tool registration goes through + # the wrapper. + for _modname in ( + "mcp.server.fastmcp.tools.base", + "mcp.server.fastmcp.prompts.base", + "mcp.server.fastmcp.resources.templates", + ): + _m = sys.modules.get(_modname) + if _m is not None and getattr(_m, "func_metadata", None) is _orig: + _m.func_metadata = _patched + + _fm._chemgraph_models_registered = True + + +_register_fastmcp_dynamic_models() + + +class CGFastMCP(FastMCP): + """FastMCP with an integrated execution backend. + + Parameters + ---------- + **kwargs + Forwarded to :class:`FastMCP` (``name``, ``instructions``, etc.). + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._backend = None + self._tracker = None + self._backend_kwargs: Optional[dict[str, Any]] = None + self._tracker_kwargs: dict[str, Any] = {} + self._pre_submit_hook: Optional[Callable] = None + self._task_counter: int = 0 + + # ── Backend lifecycle ─────────────────────────────────────────────── + + def init_backend( + self, + *, + tracker_kwargs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Register backend configuration for lazy initialisation. + + The backend is not created until the first tool invocation, + so the MCP server can start accepting connections immediately. + + Parameters + ---------- + tracker_kwargs : dict, optional + Forwarded to :class:`~chemgraph.execution.job_tracker.JobTracker` + on first use. Use this to pass ``persist_file`` for cross-session + job state recovery. + **kwargs + Forwarded to :func:`~chemgraph.execution.config.get_backend`. + """ + self._backend_kwargs = kwargs + self._tracker_kwargs = tracker_kwargs or {} + self._register_job_tools() + logger.info("CGFastMCP backend configured (lazy init).") + + def _ensure_backend(self) -> None: + """Create the backend on first use.""" + if self._backend is not None: + return + if self._backend_kwargs is None: + raise RuntimeError( + "Backend not configured. Call init_backend() first." + ) + from chemgraph.execution import JobTracker, get_backend + + self._backend = get_backend(**self._backend_kwargs) + self._tracker = JobTracker(**self._tracker_kwargs) + logger.info( + "CGFastMCP backend initialised: %s", type(self._backend).__name__ + ) + + def shutdown_backend(self) -> None: + """Shut down the execution backend and release resources.""" + if self._backend is not None: + try: + self._backend.shutdown() + except Exception: + logger.warning("Error during backend shutdown.", exc_info=True) + self._backend = None + self._tracker = None + self._backend_kwargs = None + self._tracker_kwargs = {} + logger.info("CGFastMCP backend shut down.") + + # ── Pre-submit transport hook ────────────────────────────────────── + + def set_pre_submit_hook(self, hook: Optional[Callable]) -> None: + """Register a hook that transforms each TaskSpec before submission. + + The hook receives the :class:`~chemgraph.execution.base.TaskSpec` + and must return one (possibly the same instance). Used for + transport concerns that should apply to every backend-submitted + tool on this server -- e.g. embedding a local structure file + into ``kwargs`` so a remote worker can materialise it, or + rewriting a local path to a pre-staged remote path. + + Pass ``None`` to clear the hook. + """ + self._pre_submit_hook = hook + + def _apply_pre_submit_hook(self, task): + """Run the registered pre-submit hook (no-op when unset). + + Hook exceptions are wrapped in a ``ValueError`` naming the hook + and the offending task_id, so they surface to the agent as a + structured error instead of an opaque traceback. + """ + if self._pre_submit_hook is None: + return task + try: + return self._pre_submit_hook(task) + except Exception as exc: + hook_name = getattr( + self._pre_submit_hook, "__name__", repr(self._pre_submit_hook) + ) + task_id = getattr(task, "task_id", "") + logger.warning( + "Pre-submit hook %s failed for task %s", + hook_name, + task_id, + exc_info=True, + ) + raise ValueError( + f"Pre-submit hook '{hook_name}' failed for task '{task_id}': {exc}" + ) from exc + + # ── Job tracking tools ───────────────────────────────────────────── + + def _register_job_tools(self) -> None: + """Register job-management tools (status, results, cancel).""" + + @self.add_tool + def check_job_status(batch_id: str) -> dict: + """Check the status of a submitted job batch.""" + self._ensure_backend() + return self._tracker.get_status(batch_id) + + @self.add_tool + def get_job_results( + batch_id: str, include_partial: bool = False + ) -> dict: + """Retrieve results from a completed job batch.""" + self._ensure_backend() + return self._tracker.get_results( + batch_id, include_partial=include_partial + ) + + @self.add_tool + def list_jobs() -> list[dict]: + """List all tracked job batches.""" + self._ensure_backend() + batches = self._tracker.list_batches() + if not batches: + return [{"message": "No job batches tracked."}] + return batches + + @self.add_tool + def cancel_job(batch_id: str) -> dict: + """Cancel pending tasks in a job batch.""" + self._ensure_backend() + return self._tracker.cancel_batch(batch_id) + + @self.add_tool + def check_endpoint_status() -> dict: + """Check whether the remote compute endpoint is reachable.""" + self._ensure_backend() + if hasattr(self._backend, "check_endpoint_status"): + return self._backend.check_endpoint_status() + return {"status": "not_applicable", + "message": "This backend does not support endpoint status checks."} + + # ── Internal helpers ────────────────────────────────────────────── + + @staticmethod + def _fix_module_for_pickle(fn: Callable) -> None: + """Ensure *fn* is picklable when the MCP server runs as ``__main__``. + + Under ``python -m pkg.mod`` runpy sets ``__name__ == "__main__"`` + and populates both ``sys.modules["__main__"]`` and + ``sys.modules["pkg.mod"]`` -- but it does **not** attach + ``mod`` as an attribute of the parent package ``pkg``. Dill's + by-qualname pickling resolves ``pkg.mod.fn`` via + ``__import__("pkg", fromlist=["mod"])`` followed by + ``getattr(pkg, "mod")``, which fails for that reason and silently + falls back to pickle-by-value -- dragging the entire module's + globals (including the FastMCP dynamic ``arg_model`` classes) + into the byte stream. + + Three things must be true for dill to pickle ``fn`` by reference: + + 1. ``fn.__module__`` points at the real dotted name (not ``__main__``). + 2. ``sys.modules[fn.__module__]`` exists and contains ``fn`` as + an attribute. + 3. The parent package has the leaf module attached as an attribute + (so ``getattr(pkg, leaf)`` resolves to the same module object). + """ + if fn.__module__ == "__main__": + import sys + + spec = getattr(sys.modules.get("__main__"), "__spec__", None) + if spec and spec.name: + fn.__module__ = spec.name + target = sys.modules.get(spec.name) + if target is None: + target = sys.modules["__main__"] + sys.modules[spec.name] = target + elif getattr(target, fn.__name__, None) is not fn: + setattr(target, fn.__name__, fn) + # Attach the leaf module to its parent package so dill's + # ``__import__(parent, fromlist=[leaf])`` lookup succeeds. + if "." in spec.name: + parent_name, _, leaf = spec.name.rpartition(".") + parent = sys.modules.get(parent_name) + if parent is not None and getattr(parent, leaf, None) is not target: + setattr(parent, leaf, target) + + # ── Tool registration ─────────────────────────────────────────────── + + def tool( + self, + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + structured_output: Optional[bool] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a tool that runs on the execution backend. + + Same calling convention as :meth:`FastMCP.tool` — **parens are + required** (``@mcp.tool()``, not ``@mcp.tool``). + + The additional parameters (``num_nodes``, ``processes_per_node``, + ``gpus_per_task``, ``env``, ``working_dir``) are forwarded to the + :class:`~chemgraph.execution.base.TaskSpec` that wraps the + decorated function when it is invoked. + + Parameters + ---------- + name, title, description, annotations, structured_output + Passed through to :meth:`FastMCP.add_tool`. + num_nodes : int + Number of compute nodes (default ``1``). + processes_per_node : int + Processes per node (default ``1``). + gpus_per_task : int + GPUs per task (default ``0``). + env : dict, optional + Extra environment variables for the worker. + working_dir : str, optional + Working directory for the task. + """ + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if title is not None: + fastmcp_kwargs["title"] = title + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + if structured_output is not None: + fastmcp_kwargs["structured_output"] = structured_output + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + def decorator(fn: Callable) -> Callable: + wrapper = self._make_backend_wrapper(fn, task_spec_kwargs) + self.add_tool(wrapper, **fastmcp_kwargs) + return fn + + return decorator + + # ── Ensemble tool registration ───────────────────────────────────── + + def ensemble_tool( + self, + name: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a fan-out tool that submits ``list[params]`` to the backend. + + Decorates ``fn(params: Schema) -> result``. The MCP tool schema + becomes ``list[Schema]`` — the LLM provides a list of jobs and + the framework submits each as a + :class:`~chemgraph.execution.base.TaskSpec`, then gathers results + via :func:`~chemgraph.execution.utils.submit_or_gather`. + + Parameters + ---------- + name, description, annotations + Passed through to :meth:`FastMCP.add_tool`. + num_nodes, processes_per_node, gpus_per_task, env, working_dir + Forwarded to :class:`~chemgraph.execution.base.TaskSpec`. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + + def decorator(fn: Callable) -> Callable: + self._fix_module_for_pickle(fn) + sig = inspect.signature(fn) + params = list(sig.parameters.values()) + if len(params) != 1: + raise TypeError( + f"@ensemble_tool expects a function with exactly one " + f"parameter (the per-item schema), got {len(params)} " + f"on {fn.__qualname__}." + ) + param = params[0] + param_type = param.annotation + + async def wrapper(params): + from chemgraph.execution.utils import to_picklable + + self._ensure_backend() + self._task_counter += 1 + batch_counter = self._task_counter + pending = [] + for i, p in enumerate(params): + task = TaskSpec( + task_id=f"{fn.__name__}_{batch_counter}_{i}", + task_type="python", + callable=fn, + kwargs={param.name: to_picklable(p)}, + **task_spec_kwargs, + ) + task = self._apply_pre_submit_hook(task) + fut = self._backend.submit(task) + pending.append(({"index": i}, fut)) + + return await submit_or_gather( + self._backend, + pending, + self._tracker, + name or fn.__name__, + ) + + wrapper.__name__ = name or fn.__name__ + wrapper.__doc__ = fn.__doc__ + wrapper.__module__ = fn.__module__ + wrapper.__qualname__ = fn.__qualname__ + + new_param = inspect.Parameter( + "params", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=list[param_type], + ) + wrapper.__signature__ = inspect.Signature( + parameters=[new_param] + ) + + self.add_tool(wrapper, **fastmcp_kwargs) + return fn + + return decorator + + # ── Schema-driven fanout tool ────────────────────────────────────── + + def schema_fanout_tool( + self, + *, + worker: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + annotations: Optional[ToolAnnotations] = None, + # ── TaskSpec resource hints ────────────────────────────────── + num_nodes: int = 1, + processes_per_node: int = 1, + gpus_per_task: int = 0, + env: Optional[Dict[str, str]] = None, + working_dir: Optional[str] = None, + ) -> Callable: + """Register a fan-out tool driven by a single *ensemble* schema. + + The decorated function is an **expander**: it receives the + ensemble schema and returns a list of per-item arguments. The + framework calls ``worker(item)`` on the backend for each item, + gathers the results, and returns a batch summary -- same shape + as :meth:`ensemble_tool`. + + Unlike :meth:`ensemble_tool` (whose tool signature is + ``list[Schema]``), this preserves the ensemble schema as the + agent-facing API, so the LLM makes a single tool call against + e.g. ``input_structure_directory`` and server-side expansion + produces the per-file jobs. + + Parameters + ---------- + worker : Callable + The per-item function executed on the backend. Must take + a single positional argument (the item produced by the + expander). + name, description, annotations + Passed through to :meth:`FastMCP.add_tool`. + num_nodes, processes_per_node, gpus_per_task, env, working_dir + Forwarded to each :class:`~chemgraph.execution.base.TaskSpec`. + """ + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather + + task_spec_kwargs: dict[str, Any] = { + "num_nodes": num_nodes, + "processes_per_node": processes_per_node, + "gpus_per_task": gpus_per_task, + "env": env or {}, + } + if working_dir is not None: + task_spec_kwargs["working_dir"] = working_dir + + fastmcp_kwargs: dict[str, Any] = {} + if name is not None: + fastmcp_kwargs["name"] = name + if description is not None: + fastmcp_kwargs["description"] = description + if annotations is not None: + fastmcp_kwargs["annotations"] = annotations + + # Worker is what actually runs on the backend, so it must be + # picklable from the MCP server's __main__ module. + self._fix_module_for_pickle(worker) + + worker_sig = inspect.signature(worker) + worker_params = list(worker_sig.parameters.values()) + if len(worker_params) != 1: + raise TypeError( + f"schema_fanout_tool worker must take exactly one " + f"parameter, got {len(worker_params)} on " + f"{worker.__qualname__}." + ) + worker_param_name = worker_params[0].name + + def decorator(expander: Callable) -> Callable: + sig = inspect.signature(expander) + params = list(sig.parameters.values()) + if len(params) != 1: + raise TypeError( + f"@schema_fanout_tool expander must take exactly one " + f"parameter (the ensemble schema), got {len(params)} " + f"on {expander.__qualname__}." + ) + param = params[0] + tool_name = name or expander.__name__ + + async def wrapper(**kwargs): + from chemgraph.execution.utils import to_picklable + + self._ensure_backend() + self._task_counter += 1 + batch_counter = self._task_counter + ensemble_params = kwargs[param.name] + items = expander(ensemble_params) + pending = [] + for i, item in enumerate(items): + task = TaskSpec( + task_id=f"{tool_name}_{batch_counter}_{i}", + task_type="python", + callable=worker, + kwargs={worker_param_name: to_picklable(item)}, + **task_spec_kwargs, + ) + task = self._apply_pre_submit_hook(task) + fut = self._backend.submit(task) + pending.append(({"index": i}, fut)) + + return await submit_or_gather( + self._backend, + pending, + self._tracker, + tool_name, + ) + + wrapper.__name__ = tool_name + wrapper.__doc__ = expander.__doc__ + wrapper.__module__ = expander.__module__ + wrapper.__qualname__ = expander.__qualname__ + # Preserve the expander's input signature so FastMCP advertises + # the ensemble schema to the LLM, not the worker's per-item one. + # The wrapper returns a submit_or_gather batch summary, though, + # so it must not inherit the expander's list-of-jobs annotation. + wrapper.__signature__ = sig.replace( + return_annotation=dict[str, Any] + ) + + self.add_tool(wrapper, **fastmcp_kwargs) + return expander + + return decorator + + # ── Internal ──────────────────────────────────────────────────────── + + def _make_backend_wrapper( + self, fn: Callable, task_spec_kwargs: dict[str, Any] + ) -> Callable: + """Build an async wrapper that submits *fn* to the backend.""" + from chemgraph.execution.base import TaskSpec + from chemgraph.execution.utils import submit_or_gather, to_picklable + + self._fix_module_for_pickle(fn) + + @functools.wraps(fn) + async def wrapper(**kwargs: Any) -> Any: + self._ensure_backend() + self._task_counter += 1 + task_id = f"{fn.__name__}_{self._task_counter}" + task = TaskSpec( + task_id=task_id, + task_type="python", + callable=fn, + kwargs=to_picklable(kwargs), + **task_spec_kwargs, + ) + task = self._apply_pre_submit_hook(task) + fut = self._backend.submit(task) + + if self._backend.is_async_remote: + return await submit_or_gather( + self._backend, + [({"task_id": task_id}, fut)], + self._tracker, + fn.__name__, + ) + + return await asyncio.wrap_future(fut) + + return wrapper diff --git a/src/chemgraph/mcp/graspa_mcp_hpc.py b/src/chemgraph/mcp/graspa_mcp_hpc.py new file mode 100644 index 00000000..af33f792 --- /dev/null +++ b/src/chemgraph/mcp/graspa_mcp_hpc.py @@ -0,0 +1,241 @@ +"""Backend-agnostic gRASPA MCP server. + +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +The ensemble expander emits one job per ``(structure, condition)`` pair +and supports both local input directories and pre-staged remote +directories (mirrors the MACE server's local/remote modes). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. +""" + +import logging +import os +from pathlib import Path + +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import ( + make_per_structure_output, + resolve_structure_files, +) +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools +from chemgraph.schemas.graspa_schema import graspa_input_schema_ensemble + +logger = logging.getLogger(__name__) + +_JOBS_FILE = Path("~/.chemgraph/graspa_jobs.json").expanduser() + +mcp = CGFastMCP( + name="ChemGraph Graspa Tools", + instructions=""" + You expose tools for running gRASPA simulations and reading + their results. The available tools are: + 1. run_graspa_ensemble: run gRASPA calculations over every + structure in a directory at one or more (T, P) conditions. + Local mode uses input_structures; remote mode uses + remote_structure_directory (pre-stage files first with + transfer_files). + 2. check_job_status / get_job_results / list_jobs / cancel_job: + HPC job batch management. Job state persists across sessions. + 3. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on + the remote HPC filesystem before running ensembles in remote + mode. + + Guidelines: + - Use each tool only when its input schema matches the user + request. + - Do not guess numerical values; report tool errors exactly as + they occur. + - Keep responses compact -- full results are written to the + output files defined in the schemas. + - When returning paths, use absolute paths. + - Energies are in eV and wall times are in seconds. + - When a tool returns status='submitted' with a batch_id, use + check_job_status to poll for progress before calling + get_job_results. Job state is persisted across sessions. + """, +) + + +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _graspa_worker(job: dict) -> dict: + """Execute a single gRASPA simulation on a backend worker.""" + from chemgraph.schemas.graspa_schema import graspa_input_schema + from chemgraph.tools.graspa_tools import run_graspa_core + + job = dict(job) + structure = job.pop("_structure_name", None) + temperature = job.get("temperature") + pressure = job.get("pressure") + + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "raspa.log"), + ) + + params = graspa_input_schema(**job) + result = run_graspa_core(params) + + if isinstance(result, dict): + merged = { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + **result, + } + merged.setdefault("status", "success") + return merged + return { + "structure": structure, + "temperature": temperature, + "pressure": pressure, + "result": result, + "status": "success", + } + + +# Note: ``_graspa_worker`` is registered via ``@mcp.schema_fanout_tool`` below, +# which fixes its module for pickling automatically; no explicit fix is needed. + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) + + +# Submitted as a bare ``callable=`` TaskSpec (not via a decorator), so it must +# be fixed explicitly for pickle-by-reference when run as ``__main__``. Mirrors +# the equivalent fix in mace_mcp_hpc.py. +CGFastMCP._fix_module_for_pickle(_ls_remote_files) + + +def _expand_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: + """Server-side expansion of an ensemble request into per-job dicts. + + Local mode: enumerates ``input_structures`` on this host. + Remote mode: submits a one-shot probe task to the backend to list + files under ``remote_structure_directory``, then builds per-file + jobs that the worker reads directly from the remote filesystem. + """ + base_output = Path(params.output_result_file) + + if params.remote_structure_directory: + remote_dir = params.remote_structure_directory + mcp._ensure_backend() + probe = TaskSpec( + task_id="ls_remote_dir", + task_type="python", + callable=_ls_remote_files, + kwargs={"path": remote_dir}, + ) + fut = mcp._backend.submit(probe) + try: + file_names = fut.result(timeout=30) + except Exception as exc: + raise RuntimeError( + f"Could not list remote directory {remote_dir}: {exc}" + ) from exc + + # Filter to CIF files (gRASPA expects CIFs). + file_names = [f for f in file_names if f.lower().endswith(".cif")] + if not file_names: + raise ValueError( + f"No CIF files found under remote directory {remote_dir}." + ) + + jobs = [] + for fname in file_names: + mof_name = Path(fname).stem + for condition in params.conditions: + per_output = make_per_structure_output(Path(fname), base_output) + jobs.append( + { + "_structure_name": mof_name, + "remote_structure_file": f"{remote_dir}/{fname}", + "output_result_file": str(per_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } + ) + return jobs + + if not params.input_structures: + raise ValueError( + "Either input_structures or remote_structure_directory " + "must be provided." + ) + + structure_files, _ = resolve_structure_files( + params.input_structures, extensions={".cif"} + ) + jobs = [] + for struct_path in structure_files: + mof_name = struct_path.stem + for condition in params.conditions: + per_output = make_per_structure_output(struct_path, base_output) + jobs.append( + { + "_structure_name": mof_name, + "input_structure_file": str(struct_path), + "output_result_file": str(per_output), + "temperature": condition.temperature, + "pressure": condition.pressure, + "adsorbate": params.adsorbate, + "n_cycles": params.n_cycles, + } + ) + return jobs + + +@mcp.schema_fanout_tool( + name="run_graspa_ensemble", + description=( + "Run gRASPA calculations over every structure in a directory at " + "one or more (temperature, pressure) conditions. Local mode " + "uses input_structures; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_graspa_worker, +) +def run_graspa_ensemble(params: graspa_input_schema_ensemble) -> list[dict]: + return _expand_graspa_ensemble(params) + + +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on gRASPA MCP server.") + + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9001) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/graspa_mcp_parsl.py b/src/chemgraph/mcp/graspa_mcp_parsl.py index 3b55690a..378dc5ad 100644 --- a/src/chemgraph/mcp/graspa_mcp_parsl.py +++ b/src/chemgraph/mcp/graspa_mcp_parsl.py @@ -1,8 +1,19 @@ import asyncio import json import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.graspa_mcp_parsl is deprecated; use " + "chemgraph.mcp.graspa_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP import parsl diff --git a/src/chemgraph/mcp/hpc_misc_mcp.py b/src/chemgraph/mcp/hpc_misc_mcp.py new file mode 100644 index 00000000..106e5c52 --- /dev/null +++ b/src/chemgraph/mcp/hpc_misc_mcp.py @@ -0,0 +1,167 @@ +"""FastMCP tools for generic HPC run-artifact inspection.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from mcp.server.fastmcp import FastMCP + + +mcp = FastMCP( + name="ChemGraph HPC Misc Tools", + instructions=""" + You expose small, generic tools for inspecting files produced by HPC + calculations. These tools do not run chemistry; they help agents inspect + run artifacts without relying on simulation-specific readers. + """, +) + + +@mcp.tool( + name="inspect_json", + description=( + "Inspect a JSON file, a directory of JSON files, or a missing expected " + "JSON path. Returns compact summaries and nearby JSON files when the " + "requested path is absent." + ), +) +def inspect_json( + path: str, + glob_pattern: str = "*.json", + max_files: int = 20, + max_preview_chars: int = 1200, + recursive: bool = False, +) -> dict[str, Any]: + """Inspect JSON artifacts without assuming one fixed output-file layout.""" + target = Path(path).expanduser() + if target.is_file(): + return { + "status": "ok", + "kind": "file", + "path": str(target), + "json": _load_json_summary( + target, + max_preview_chars=max_preview_chars, + ), + } + + if target.is_dir(): + files = _json_files( + target, + glob_pattern=glob_pattern, + max_files=max_files, + recursive=recursive, + ) + return { + "status": "ok", + "kind": "directory", + "path": str(target), + "glob_pattern": glob_pattern, + "recursive": recursive, + "file_count_returned": len(files), + "files": [ + { + "path": str(file), + "json": _load_json_summary( + file, + max_preview_chars=max_preview_chars, + ), + } + for file in files + ], + } + + parent = target.parent + nearby = ( + _json_files( + parent, + glob_pattern=glob_pattern, + max_files=max_files, + recursive=False, + ) + if parent.is_dir() + else [] + ) + return { + "status": "not_found", + "kind": "missing", + "path": str(target), + "parent_exists": parent.is_dir(), + "nearby_json_files": [str(file) for file in nearby], + } + + +def _json_files( + directory: Path, + *, + glob_pattern: str, + max_files: int, + recursive: bool, +) -> list[Path]: + if max_files < 1: + return [] + iterator = directory.rglob(glob_pattern) if recursive else directory.glob(glob_pattern) + return sorted(path for path in iterator if path.is_file())[:max_files] + + +def _load_json_summary(path: Path, *, max_preview_chars: int) -> dict[str, Any]: + try: + value = json.loads(path.read_text(encoding="utf-8")) + except Exception as exc: # noqa: BLE001 - report file/read/parse failure. + return { + "status": "error", + "error": repr(exc), + } + return { + "status": "ok", + "summary": _summarize_json(value), + "preview": _json_preview(value, max_chars=max_preview_chars), + } + + +def _summarize_json(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + summary: dict[str, Any] = { + "type": "object", + "keys": sorted(str(key) for key in value.keys())[:40], + } + for key in ("status", "energy", "energy_ev", "driver", "model"): + if key in value: + summary[key] = value[key] + for key in ("results", "failures", "errors"): + nested = value.get(key) + if isinstance(nested, list): + summary[f"{key}_count"] = len(nested) + return summary + if isinstance(value, list): + return { + "type": "array", + "length": len(value), + "first_item": _summarize_json(value[0]) if value else None, + } + return { + "type": type(value).__name__, + "value": value, + } + + +def _json_preview(value: Any, *, max_chars: int) -> Any: + try: + text = json.dumps(value, sort_keys=True) + except TypeError: + text = repr(value) + if len(text) <= max_chars: + return value + return { + "truncated": True, + "chars": len(text), + "text": text[:max_chars], + } + + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + run_mcp_server(mcp, default_port=9020) diff --git a/src/chemgraph/mcp/job_tools.py b/src/chemgraph/mcp/job_tools.py new file mode 100644 index 00000000..6974aef1 --- /dev/null +++ b/src/chemgraph/mcp/job_tools.py @@ -0,0 +1,107 @@ +"""Shared MCP tools for job status tracking and result retrieval. + +Call :func:`register_job_tools` to add ``check_job_status``, +``get_job_results``, ``list_jobs``, ``cancel_job``, and (optionally) +``check_endpoint_status`` to any :class:`~mcp.server.fastmcp.FastMCP` +server instance. + +These tools are only useful when the execution backend is async-remote +(e.g. Globus Compute), but are registered unconditionally so the LLM +agent always has a consistent tool surface. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from mcp.server.fastmcp import FastMCP + + from chemgraph.execution.base import ExecutionBackend + from chemgraph.execution.job_tracker import JobTracker + + +def register_job_tools( + mcp: FastMCP, + tracker: JobTracker, + backend: ExecutionBackend, +) -> None: + """Register job-management MCP tools on *mcp*. + + Parameters + ---------- + mcp : FastMCP + The MCP server to register tools on. + tracker : JobTracker + The job tracker for this server process. + backend : ExecutionBackend + The active execution backend (used for endpoint health checks). + """ + + @mcp.tool( + name="check_job_status", + description=( + "Check the status of a previously submitted HPC job batch. " + "Returns progress information including how many tasks are " + "complete, failed, or still pending. Use this to poll " + "long-running remote compute jobs." + ), + ) + def check_job_status(batch_id: str) -> dict: + """Check the status of a submitted job batch.""" + return tracker.get_status(batch_id) + + @mcp.tool( + name="get_job_results", + description=( + "Retrieve results from a completed (or partially completed) " + "HPC job batch. By default, returns results only when all " + "tasks are done. Set include_partial=True to get results " + "for tasks that have finished so far." + ), + ) + def get_job_results( + batch_id: str, + include_partial: bool = False, + ) -> dict: + """Retrieve results from a job batch.""" + return tracker.get_results(batch_id, include_partial=include_partial) + + @mcp.tool( + name="list_jobs", + description=( + "List all tracked job batches with their current status. " + "Shows batch IDs, tool names, submission times, and progress." + ), + ) + def list_jobs() -> list[dict]: + """List all tracked job batches.""" + batches = tracker.list_batches() + if not batches: + return [{"message": "No job batches tracked."}] + return batches + + @mcp.tool( + name="cancel_job", + description=( + "Cancel pending tasks in a job batch. Only tasks that have " + "not yet started executing can be cancelled." + ), + ) + def cancel_job(batch_id: str) -> dict: + """Cancel pending tasks in a job batch.""" + return tracker.cancel_batch(batch_id) + + if backend.is_async_remote and hasattr(backend, "check_endpoint_status"): + + @mcp.tool( + name="check_endpoint_status", + description=( + "Check whether the remote HPC compute endpoint is " + "reachable and accepting tasks. Use this as a pre-flight " + "check before submitting jobs." + ), + ) + def check_endpoint_status() -> dict: + """Check the remote compute endpoint status.""" + return backend.check_endpoint_status() diff --git a/src/chemgraph/mcp/mace_mcp_hpc.py b/src/chemgraph/mcp/mace_mcp_hpc.py new file mode 100644 index 00000000..c3bbe5fd --- /dev/null +++ b/src/chemgraph/mcp/mace_mcp_hpc.py @@ -0,0 +1,366 @@ +"""Backend-agnostic MACE MCP server. + +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +Transport (local-file embedding, pre-staged remote-path passthrough) +lives in a single pre-submit hook so the tool bodies stay simple. The +hook rewrites :class:`~chemgraph.execution.base.TaskSpec` instances +before submission to attach an inline structure when the input file +exists on the submitting host, leaving the path untouched when it +does not (assumed to be remote). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. +""" + +import logging +import os +import sys +from pathlib import Path + +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import ( + make_per_structure_output, + resolve_structure_files, +) +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools +from chemgraph.schemas.mace_parsl_schema import ( + mace_input_schema, + mace_input_schema_ensemble, +) +from chemgraph.tools.parsl_tools import extract_output_json, run_mace_core + +logger = logging.getLogger(__name__) + +_JOBS_FILE = Path("~/.chemgraph/mace_jobs.json").expanduser() +_MACE_MP_ALIASES = {"mace_mp", "mace-mp", "MACE-MP", "mace_MP"} + +mcp = CGFastMCP( + name="ChemGraph MACE Tools", + instructions=""" + You expose tools for running MACE simulations and reading their results. + The available tools are: + 1. run_mace_single: run a single MACE calculation. + 2. run_mace_ensemble: run MACE calculations over every structure in a + directory (local or pre-staged remote). + 3. extract_output_json: load simulation results from a JSON file. + 4. check_job_status / get_job_results / list_jobs / cancel_job: HPC + job batch management. Job state persists across sessions. + 5. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on the + remote HPC filesystem before running ensembles in remote mode. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they + occur. + - Keep responses compact -- full results are written to the output + files defined in the schemas. + - When returning paths, use absolute paths. + - Energies are in eV and wall times are in seconds. + - When a tool returns status='submitted' with a batch_id, call + get_job_results(batch_id) to retrieve results. If still pending, + report the batch_id so the user can check later -- job state is + persisted across sessions. + - For the `model` field, pass a MACE foundation model name (e.g. + 'medium-mpa-0'). 'mace_mp' is the calculator type, not a model + name -- do not pass it. + """, +) + + +# ── Worker (runs on the backend) ─────────────────────────────────────── + + +def _mace_worker(job: dict) -> dict: + """Execute a single MACE simulation on a backend worker. + + Accepts a *job dict* (not the schema) so the pre-submit hook can + attach transport keys ``inline_structure`` / ``remote_structure_file`` + before submission. + """ + import tempfile + + job = dict(job) + + # Pre-staged remote file: use the path directly on the worker FS. + remote_file = job.pop("remote_structure_file", None) + if remote_file is not None: + job["input_structure_file"] = remote_file + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + os.path.dirname(remote_file), + job.get("output_result_file", "output.json"), + ) + + # Inline structure: materialise on the worker's filesystem. + inline = job.pop("inline_structure", None) + if inline is not None: + from ase import Atoms + from ase.io import write as ase_write + + atoms = Atoms( + numbers=inline["numbers"], + positions=inline["positions"], + cell=inline.get("cell"), + pbc=inline.get("pbc"), + ) + tmpdir = tempfile.mkdtemp(prefix="chemgraph_mace_") + xyz_path = os.path.join(tmpdir, "structure.xyz") + ase_write(xyz_path, atoms) + job["input_structure_file"] = xyz_path + if not os.path.isabs(job.get("output_result_file", "")): + job["output_result_file"] = os.path.join( + tmpdir, job.get("output_result_file", "output.json") + ) + + output_file = job.get("output_result_file") + if output_file: + os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) + + params = mace_input_schema(**job) + result = run_mace_core(params) + return result + + +# Force pickle-by-reference for callables that the transport hook installs +# as `task.callable`. Without this, dill sees `__module__ == "__main__"` +# (this file is run as ``python -m chemgraph.mcp.mace_mcp_hpc``) and falls +# back to pickle-by-value, which walks the module's globals and tries to +# serialize the dynamic ``run_mace_singleArguments`` class held by +# ``mcp._tool_manager._tools[...].fn_metadata.arg_model`` -- that class +# was created by ``pydantic.create_model`` with a ``__module__`` it was +# never registered into, so dill raises a PicklingError. +CGFastMCP._fix_module_for_pickle(_mace_worker) + + +# ── Pre-submit transport hook ────────────────────────────────────────── + + +def _embed_inline_if_local(job: dict) -> None: + """Mutate *job* in-place: attach inline_structure when the input + file is readable on the submitting host (and no other transport + key has already been set).""" + if job.get("remote_structure_file") or job.get("inline_structure"): + return + input_file = job.get("input_structure_file") + if not input_file or not os.path.isfile(input_file): + return # remote path -- worker will read it directly + + from ase.io import read as ase_read + + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = ase_read(input_file) + job["inline_structure"] = atoms_to_atomsdata(atoms).model_dump() + + +def _normalize_model(job: dict) -> None: + """Map calculator-type aliases to a valid foundation model name.""" + if job.get("model") in _MACE_MP_ALIASES: + job["model"] = "medium-mpa-0" + + +def _backend_shares_fs() -> bool: + """Whether the active backend shares the server's filesystem. + + When it does, inline embedding (and the worker's ``/tmp`` round-trip) + is unnecessary -- the worker reads ``input_structure_file`` directly. + Defaults to ``True`` (skip embedding) when no backend exists yet.""" + backend = getattr(mcp, "_backend", None) + return getattr(backend, "shares_filesystem", True) + + +def _mace_transport_hook(task: TaskSpec) -> TaskSpec: + """Route single-tool calls to the dict-based worker and embed + local structures only when the backend has no shared filesystem.""" + logger.debug( + "mace transport hook: task_id=%s callable=%s", + task.task_id, + getattr(task.callable, "__qualname__", task.callable), + ) + if task.callable is run_mace_single: + params = task.kwargs.get("params") + if params is None: + return task + job = ( + params.model_dump() if hasattr(params, "model_dump") else dict(params) + ) + _normalize_model(job) + if not _backend_shares_fs(): + _embed_inline_if_local(job) + task.callable = _mace_worker + task.kwargs = {"job": job} + elif task.callable is _mace_worker: + job = dict(task.kwargs.get("job", {})) + _normalize_model(job) + if not _backend_shares_fs(): + _embed_inline_if_local(job) + task.kwargs = {"job": job} + return task + + +mcp.set_pre_submit_hook(_mace_transport_hook) + + +# ── Single-structure tool ────────────────────────────────────────────── + + +def run_mace_single(params: mace_input_schema) -> dict: + """Run a single MACE calculation on the configured backend. + + The pre-submit hook rewrites this call to invoke ``_mace_worker`` + on the backend with a job dict that may carry an embedded inline + structure (when the input file exists locally) or a remote path + (when it does not). + """ + # Direct-call fallback path (no hook registered) -- normalises and + # delegates to the same worker. + job = params.model_dump() + _normalize_model(job) + return _mace_worker(job) + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _ls_remote_files(path: str) -> list[str]: + """Backend-side helper: list non-directory entries in *path*.""" + return sorted( + f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) + ) + + +CGFastMCP._fix_module_for_pickle(_ls_remote_files) + + +def _expand_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: + """Server-side expansion of an ensemble request into per-file jobs. + + Local mode: enumerates ``input_structure_directory`` on this host. + Remote mode: submits a one-shot probe task to the backend to list + files under ``remote_structure_directory``, then builds per-file + jobs that the worker reads directly from the remote filesystem. + """ + shared = { + "output_result_file": params.output_result_file, + "driver": params.driver, + "model": params.model, + "device": params.device, + "temperature": params.temperature, + "pressure": params.pressure, + "fmax": params.fmax, + "steps": params.steps, + "optimizer": params.optimizer, + } + base_output = Path(params.output_result_file) + + if params.remote_structure_directory: + remote_dir = params.remote_structure_directory + mcp._ensure_backend() + probe = TaskSpec( + task_id="ls_remote_dir", + task_type="python", + callable=_ls_remote_files, + kwargs={"path": remote_dir}, + ) + fut = mcp._backend.submit(probe) + try: + file_names = fut.result(timeout=30) + except Exception as exc: + raise RuntimeError( + f"Could not list remote directory {remote_dir}: {exc}" + ) from exc + + jobs = [] + for fname in file_names: + per_output = make_per_structure_output(Path(fname), base_output) + job = {**shared} + job["remote_structure_file"] = f"{remote_dir}/{fname}" + job["output_result_file"] = str(per_output) + jobs.append(job) + return jobs + + if not params.input_structure_directory: + raise ValueError( + "Either input_structure_directory or remote_structure_directory " + "must be provided." + ) + + structure_files, _ = resolve_structure_files(params.input_structure_directory) + return [ + { + **shared, + "input_structure_file": str(f), + "output_result_file": str(make_per_structure_output(f, base_output)), + } + for f in structure_files + ] + + +def run_mace_ensemble(params: mace_input_schema_ensemble) -> list[dict]: + return _expand_mace_ensemble(params) + + +# ── Orchestration tools (no backend involvement) ─────────────────────── + + +mcp.add_tool( + extract_output_json, + name="extract_output_json", + description="Load simulation results from an output JSON file.", +) + + +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on MACE MCP server.") + + +if __name__ == "__main__": + import argparse as _ap + + from chemgraph.mcp.server_utils import run_mcp_server + + _parser = _ap.ArgumentParser(add_help=False) + _parser.add_argument("--ppn", type=int, default=1, + help="Processes per node for backend tasks") + _parser.add_argument("--ngpus-per-process", type=int, default=0, + help="GPUs per process for backend tasks") + _args, _remaining = _parser.parse_known_args() + sys.argv = [sys.argv[0]] + _remaining + + mcp.tool( + name="run_mace_single", + description="Run a single MACE calculation", + processes_per_node=_args.ppn, + gpus_per_task=_args.ngpus_per_process, + )(run_mace_single) + + mcp.schema_fanout_tool( + name="run_mace_ensemble", + description=( + "Run MACE calculations over every structure in a directory. " + "Local mode uses input_structure_directory; remote mode uses " + "remote_structure_directory (pre-stage files first with " + "transfer_files)." + ), + worker=_mace_worker, + processes_per_node=_args.ppn, + gpus_per_task=_args.ngpus_per_process, + )(run_mace_ensemble) + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9004) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/mace_mcp_parsl.py b/src/chemgraph/mcp/mace_mcp_parsl.py index 4b3f03fc..42ae67c7 100644 --- a/src/chemgraph/mcp/mace_mcp_parsl.py +++ b/src/chemgraph/mcp/mace_mcp_parsl.py @@ -1,6 +1,17 @@ import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.mace_mcp_parsl is deprecated; use " + "chemgraph.mcp.mace_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP from parsl.config import Config from parsl.executors import HighThroughputExecutor diff --git a/src/chemgraph/mcp/server_utils.py b/src/chemgraph/mcp/server_utils.py index 91fce11e..71cc5c6d 100644 --- a/src/chemgraph/mcp/server_utils.py +++ b/src/chemgraph/mcp/server_utils.py @@ -84,6 +84,11 @@ def run_mcp_server( uvicorn.run(app, host=args.host, port=args.port) else: logging.info("Starting %s via stdio transport...", mcp.name) + # Under stdio, the server's stdout IS the JSON-RPC channel. Any + # unguarded print from a worker (e.g. mace's "cuequivariance ... + # will be disabled" notice) would corrupt it. setdefault so the + # user can override with CHEMGRAPH_LOCAL_SILENCE_STDOUT=0. + os.environ.setdefault("CHEMGRAPH_LOCAL_SILENCE_STDOUT", "1") # FastMCP.run(transport='stdio') handles the stdio loop mcp.run(transport="stdio") diff --git a/src/chemgraph/mcp/transfer_tools.py b/src/chemgraph/mcp/transfer_tools.py new file mode 100644 index 00000000..79ae2323 --- /dev/null +++ b/src/chemgraph/mcp/transfer_tools.py @@ -0,0 +1,186 @@ +"""Shared MCP tools for Globus Transfer file staging. + +Call :func:`register_transfer_tools` to add ``transfer_files``, +``check_transfer_status``, and ``list_remote_files`` to any +:class:`~mcp.server.fastmcp.FastMCP` (or +:class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`) server instance. + +These tools allow an LLM agent to stage input files on a remote HPC +filesystem *before* submitting compute jobs, avoiding the overhead of +encoding large files inside Globus Compute function payloads. + +Note +---- +Transfer tools are orchestration tools (they call the Globus Transfer +API directly from the MCP server process), not compute tools, so they +are registered via :meth:`FastMCP.add_tool` rather than CGFastMCP's +backend-submitting ``@tool()`` decorator. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from mcp.server.fastmcp import FastMCP + + from chemgraph.execution.globus_transfer import GlobusTransferManager + +logger = logging.getLogger(__name__) + + +def register_transfer_tools( + mcp: FastMCP, + transfer_manager: GlobusTransferManager, +) -> None: + """Register file-transfer MCP tools on *mcp*. + + Parameters + ---------- + mcp : FastMCP + The MCP server to register tools on. May be a plain ``FastMCP`` + or a :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`; ``add_tool`` + is inherited so the same registration works either way. + transfer_manager : GlobusTransferManager + The configured transfer manager instance. + """ + + def transfer_files( + source_paths: Union[str, list[str]], + extensions: Optional[list[str]] = None, + remote_subdir: Optional[str] = None, + wait: bool = True, + label: Optional[str] = None, + ) -> dict: + """Transfer files to the remote HPC endpoint via Globus Transfer. + + Parameters + ---------- + source_paths : str or list[str] + A directory path (all matching files transferred) or a list + of individual file paths. + extensions : list[str], optional + When *source_paths* is a directory, only transfer files with + these extensions (e.g. ``[".cif", ".xyz"]``). Ignored when + *source_paths* is a list. + remote_subdir : str, optional + Subdirectory name on the remote endpoint. Auto-generated if + omitted. + wait : bool + If True (default), block until the transfer completes. + label : str, optional + Human-readable label for the transfer task. + """ + if isinstance(source_paths, str): + src = Path(source_paths) + if src.is_dir(): + if extensions: + ext_set = { + e if e.startswith(".") else f".{e}" for e in extensions + } + files = sorted( + str(f) + for f in src.iterdir() + if f.is_file() and f.suffix.lower() in ext_set + ) + else: + files = sorted( + str(f) for f in src.iterdir() if f.is_file() + ) + if not files: + return { + "status": "error", + "message": f"No files found in {source_paths}" + + ( + f" with extensions {extensions}" + if extensions + else "" + ), + } + elif src.is_file(): + files = [str(src.resolve())] + else: + return { + "status": "error", + "message": f"Path not found: {source_paths}", + } + else: + files = [str(Path(p).resolve()) for p in source_paths] + + transfer_result = transfer_manager.transfer_files( + local_paths=files, + remote_subdir=remote_subdir, + label=label, + ) + + response = { + "task_id": transfer_result.task_id, + "remote_directory": transfer_result.remote_directory, + "file_count": len(files), + "file_mapping": transfer_result.file_mapping, + } + + if wait: + status = transfer_manager.wait_for_transfer(transfer_result.task_id) + response["status"] = ( + "completed" + if status["status"] == "SUCCEEDED" + else status["status"] + ) + response.update( + { + k: status[k] + for k in ("bytes_transferred", "files_transferred") + if k in status + } + ) + else: + response["status"] = "submitted" + + return response + + def check_transfer_status(task_id: str) -> dict: + """Check the status of a Globus Transfer task. + + Use to poll a non-blocking transfer submitted with ``wait=False``. + """ + return transfer_manager.check_transfer_status(task_id) + + def list_remote_files(remote_path: str) -> list[dict]: + """List files in a directory on the remote HPC endpoint. + + Useful to verify that files were staged correctly before + running ensemble calculations. + """ + return transfer_manager.list_remote_directory(remote_path) + + mcp.add_tool( + transfer_files, + name="transfer_files", + description=( + "Transfer local files to the remote HPC filesystem via " + "Globus Transfer. Use this to pre-stage structure files " + "before running ensemble calculations with " + "remote_structure_directory. Returns the remote directory " + "path and a mapping of local-to-remote file paths." + ), + ) + mcp.add_tool( + check_transfer_status, + name="check_transfer_status", + description=( + "Check the status of a Globus Transfer task. Use this to " + "poll a non-blocking transfer submitted with wait=False." + ), + ) + mcp.add_tool( + list_remote_files, + name="list_remote_files", + description=( + "List files in a directory on the remote HPC endpoint. " + "Useful to verify that files were staged correctly before " + "running ensemble calculations." + ), + ) diff --git a/src/chemgraph/mcp/xanes_mcp_hpc.py b/src/chemgraph/mcp/xanes_mcp_hpc.py new file mode 100644 index 00000000..0c3008b5 --- /dev/null +++ b/src/chemgraph/mcp/xanes_mcp_hpc.py @@ -0,0 +1,295 @@ +"""Backend-agnostic XANES/FDMNES MCP server. + +Uses :class:`~chemgraph.mcp.cg_fastmcp.CGFastMCP`. Tool functions are +plain computation -- the framework handles backend submission, future +resolution, and async job tracking. + +The ensemble expander runs server-side and prepares per-structure +FDMNES input files in ``runs_dir``; the worker (which runs on the +backend) executes FDMNES via subprocess and extracts convergence data. +This assumes the server and worker share a filesystem (true for any +Globus Compute endpoint on the same HPC where the MCP server runs; +Globus Transfer staging is a separate concern). + +Nothing requiring the backend is initialised at import time so worker +subprocesses (EnsembleLauncher, Globus Compute) can re-import this +module safely. +""" + +import logging +import subprocess +from pathlib import Path + +from chemgraph.execution.config import get_transfer_manager +from chemgraph.execution.utils import resolve_structure_files +from chemgraph.mcp.cg_fastmcp import CGFastMCP +from chemgraph.mcp.transfer_tools import register_transfer_tools +from chemgraph.schemas.xanes_schema import ( + mp_query_schema, + xanes_input_schema, + xanes_input_schema_ensemble, +) + +logger = logging.getLogger(__name__) + +_JOBS_FILE = Path("~/.chemgraph/xanes_jobs.json").expanduser() + +mcp = CGFastMCP( + name="ChemGraph XANES Tools", + instructions=""" + You expose tools for running XANES/FDMNES simulations. + The available tools are: + 1. run_xanes_single: run a single FDMNES calculation for one structure. + 2. run_xanes_ensemble: run FDMNES calculations over multiple structures + using the configured execution backend. + 3. fetch_mp_structures: fetch optimized structures from Materials Project. + 4. plot_xanes: generate normalized XANES plots for completed calculations. + 5. check_job_status / get_job_results / list_jobs / cancel_job: HPC + job batch management. Job state persists across sessions. + 6. transfer_files / check_transfer_status / list_remote_files + (when Globus Transfer is configured): stage input files on the + remote HPC filesystem before running ensembles. + + Guidelines: + - Use each tool only when its input schema matches the user request. + - Do not guess numerical values; report tool errors exactly as they occur. + - Keep responses compact -- full results are in the output directories. + - When returning paths, use absolute paths. + - Energies are in eV. + - When a tool returns status='submitted' with a batch_id, call + get_job_results(batch_id) to retrieve results. If the job is + still pending, report the batch_id to the user so they can + check later. Job state is persisted across sessions -- the + user can call list_jobs or get_job_results in a future session + to retrieve results. + """, +) + + +# ── Single-structure tool ────────────────────────────────────────────── + + +def _xanes_single_worker(params: xanes_input_schema) -> dict: + """Run a single FDMNES calculation on a backend worker.""" + from chemgraph.tools.xanes_tools import run_xanes_core + + result = run_xanes_core(params) + if isinstance(result, dict): + result.setdefault("status", "success") + return result + return {"status": "success", "result": result} + + +@mcp.tool( + name="run_xanes_single", + description="Run a single XANES/FDMNES calculation for one input structure.", +) +def run_xanes_single(params: xanes_input_schema): + """Run a single FDMNES calculation using the core engine. + + The CGFastMCP wrapper submits this call to the configured backend; + the body is the direct-call fallback when no backend is active. + """ + return _xanes_single_worker(params) + + +# ── Ensemble fanout ──────────────────────────────────────────────────── + + +def _xanes_ensemble_worker(item: dict) -> dict: + """Execute one prepared FDMNES run on the backend. + + The expander has already written ``input_fdmnes.txt`` (or the + equivalent) into ``item['run_dir']``; this worker runs the binary + via subprocess and then extracts convergence data. + """ + from chemgraph.tools.xanes_tools import extract_conv + + run_dir = item["run_dir"] + fdmnes_exe = item["fdmnes_exe"] + meta = { + "structure": item.get("structure"), + "run_dir": run_dir, + "z_absorber": item.get("z_absorber"), + } + + stdout_path = Path(run_dir) / "fdmnes_stdout.txt" + stderr_path = Path(run_dir) / "fdmnes_stderr.txt" + try: + with open(stdout_path, "w") as out, open(stderr_path, "w") as err: + proc = subprocess.run( + [fdmnes_exe], + cwd=run_dir, + stdout=out, + stderr=err, + check=False, + ) + if proc.returncode != 0: + return { + **meta, + "status": "failure", + "error_type": "FDMNESExitCode", + "message": f"FDMNES exited with code {proc.returncode}", + "returncode": proc.returncode, + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"FDMNES launch failed: {e}", + } + + try: + conv_data = extract_conv(run_dir) + return { + **meta, + "status": "success", + "n_conv_files": len(conv_data), + } + except Exception as e: + return { + **meta, + "status": "failure", + "error_type": type(e).__name__, + "message": f"Post-processing failed: {e}", + } + + +# Note: ``_xanes_ensemble_worker`` is registered via ``@mcp.schema_fanout_tool`` +# below, which fixes its module for pickling automatically. + + +def _expand_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: + """Server-side expansion: prepare per-structure run dirs and return + one item per structure for the worker to execute.""" + from ase.io import read as ase_read + + from chemgraph.tools.xanes_tools import write_fdmnes_input + + structure_files, output_dir = resolve_structure_files( + params.input_structures, + extensions={".cif", ".xyz", ".poscar"}, + ) + + runs_dir = output_dir / "fdmnes_batch_runs" + runs_dir.mkdir(parents=True, exist_ok=True) + + items: list[dict] = [] + for i, struct_path in enumerate(structure_files): + run_dir = runs_dir / f"run_{i}" + run_dir.mkdir(parents=True, exist_ok=True) + + atoms = ase_read(str(struct_path)) + z_abs = ( + params.z_absorber + if params.z_absorber is not None + else int(max(atoms.get_atomic_numbers())) + ) + + write_fdmnes_input( + ase_atoms=atoms, + z_absorber=z_abs, + input_file_dir=run_dir, + radius=params.radius, + magnetism=params.magnetism, + ) + + items.append( + { + "structure": struct_path.name, + "run_dir": str(run_dir), + "z_absorber": z_abs, + "fdmnes_exe": params.fdmnes_exe, + } + ) + + return items + + +@mcp.schema_fanout_tool( + name="run_xanes_ensemble", + description=( + "Run FDMNES/XANES calculations over every structure in an input " + "directory (or list of files). Each structure is prepared " + "server-side and submitted to the configured execution backend." + ), + worker=_xanes_ensemble_worker, +) +def run_xanes_ensemble(params: xanes_input_schema_ensemble) -> list[dict]: + return _expand_xanes_ensemble(params) + + +# ── Orchestration tools (no backend involvement) ─────────────────────── + + +def fetch_mp_structures(params: mp_query_schema): + """Fetch structures from Materials Project and save as CIF files and pickle database.""" + from chemgraph.tools.xanes_tools import ( + _get_data_dir, + fetch_materials_project_data, + ) + + data_dir = _get_data_dir() + result = fetch_materials_project_data(params, data_dir) + return { + "status": "success", + "n_structures": result["n_structures"], + "chemsys": params.chemsys, + "output_dir": str(data_dir), + "structure_files": result["structure_files"], + "pickle_file": result["pickle_file"], + } + + +def plot_xanes(runs_dir: str): + """Generate XANES plots for all completed runs in a directory.""" + from chemgraph.tools.xanes_tools import ( + _get_data_dir, + plot_xanes_results, + ) + + runs_path = Path(runs_dir) + if not runs_path.is_dir(): + raise ValueError(f"'{runs_dir}' is not a valid directory.") + + data_dir = _get_data_dir() + result = plot_xanes_results(data_dir, runs_path) + return { + "status": "success", + "n_plots": result["n_plots"], + "n_failed": result["n_failed"], + "plot_files": result["plot_files"], + "failed": result["failed"], + } + + +mcp.add_tool( + fetch_mp_structures, + name="fetch_mp_structures", + description="Fetch optimized structures from Materials Project.", +) +mcp.add_tool( + plot_xanes, + name="plot_xanes", + description="Generate normalized XANES plots for completed FDMNES calculations.", +) + + +# ── Globus Transfer (registered only when configured) ────────────────── + +_transfer_manager = get_transfer_manager() +if _transfer_manager is not None: + register_transfer_tools(mcp, _transfer_manager) + logger.info("Registered Globus Transfer tools on XANES MCP server.") + + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + mcp.init_backend(tracker_kwargs={"persist_file": _JOBS_FILE}) + + try: + run_mcp_server(mcp, default_port=9007) + finally: + mcp.shutdown_backend() diff --git a/src/chemgraph/mcp/xanes_mcp_parsl.py b/src/chemgraph/mcp/xanes_mcp_parsl.py index 0ec794c1..b5f8729a 100644 --- a/src/chemgraph/mcp/xanes_mcp_parsl.py +++ b/src/chemgraph/mcp/xanes_mcp_parsl.py @@ -1,8 +1,19 @@ import asyncio import json import os +import warnings from pathlib import Path +warnings.warn( + "chemgraph.mcp.xanes_mcp_parsl is deprecated; use " + "chemgraph.mcp.xanes_mcp_hpc, which dispatches via the " + "chemgraph.execution backend abstraction (Parsl, EnsembleLauncher, " + "Globus Compute, or local). This module will be removed in a future " + "release.", + DeprecationWarning, + stacklevel=2, +) + from mcp.server.fastmcp import FastMCP import parsl diff --git a/src/chemgraph/models/loader.py b/src/chemgraph/models/loader.py index 07583777..64f0f105 100644 --- a/src/chemgraph/models/loader.py +++ b/src/chemgraph/models/loader.py @@ -14,29 +14,31 @@ from chemgraph.models.groq import load_groq_model from chemgraph.models.local_model import load_ollama_model from chemgraph.models.openai import load_openai_model +from chemgraph.models.settings import LLMSettings from chemgraph.models.supported_models import ( supported_alcf_models, supported_anthropic_models, supported_argo_models, supported_gemini_models, - supported_ollama_models, supported_openai_models, ) def load_chat_model( - model_name: str, + model_name: str | None = None, temperature: float = 0.0, base_url: Optional[str] = None, api_key: Optional[str] = None, argo_user: Optional[str] = None, + *, + settings: LLMSettings | None = None, ): """Load a LangChain chat model by provider auto-detection. Parameters ---------- - model_name : str + model_name : str, optional Model name from any supported provider list. temperature : float Sampling temperature (default 0.0 for deterministic output). @@ -46,6 +48,9 @@ def load_chat_model( API key override (falls back to environment variables). argo_user : str, optional Argo user identifier. + settings : LLMSettings, optional + Canonical endpoint settings. When provided, this overrides + model_name/base_url/api_key/argo_user. Returns ------- @@ -57,12 +62,25 @@ def load_chat_model( ValueError If the model name is not found in any supported provider list. """ + if settings is not None: + model_name = settings.model + base_url = settings.base_url + api_key = settings.api_key + argo_user = settings.argo_user + if settings.temperature is not None: + temperature = settings.temperature + + if model_name is None: + raise ValueError("load_chat_model requires model_name or settings") + if model_name in supported_openai_models or model_name in supported_argo_models: kwargs = { "model_name": model_name, "temperature": temperature, "base_url": base_url, } + if api_key is not None: + kwargs["api_key"] = api_key if argo_user is not None: kwargs["argo_user"] = argo_user return load_openai_model(**kwargs) @@ -87,5 +105,6 @@ def load_chat_model( else: raise ValueError( f"Model '{model_name}' not found in any supported model list. " - f"Use a model from: OpenAI, Anthropic, Gemini, groq:, argo:, ALCF, or Ollama." + "Use a model from: OpenAI, Anthropic, Gemini, groq:, " + "argo:, ALCF, or Ollama." ) diff --git a/src/chemgraph/models/openai.py b/src/chemgraph/models/openai.py index e48fdb27..bb33d38b 100644 --- a/src/chemgraph/models/openai.py +++ b/src/chemgraph/models/openai.py @@ -2,7 +2,10 @@ import os from getpass import getpass +from urllib.parse import urlparse + from langchain_openai import ChatOpenAI + from chemgraph.models.supported_models import ( ARGO_DEFAULT_BASE_URL, supported_openai_models, @@ -60,13 +63,20 @@ } +ARGO_LOCAL_OPENAI_MODEL_MAP = { + # argo-shim advertises GPT-5.4 with this casing. Lowercase gpt-5.4 is + # rejected by the upstream Argo API behind the shim. + "argo:gpt-5.4": "GPT-5.4", +} + + def _normalize_argo_model(model_name: str, base_url: str) -> str: """Normalize an ``argo:``-prefixed model name for the target endpoint. - * Argo API (base_url contains ``argoapi``): map to internal wire - names via ``ARGO_MODEL_MAP`` (e.g. ``argo:gpt-4o`` -> ``gpt4o``). - * Other endpoints (ArgoProxy, custom): strip the ``argo:`` prefix - and send the remainder as-is (e.g. ``argo:gpt-4o`` -> ``gpt-4o``). + * Hosted Argo API endpoints use internal wire names via + ``ARGO_MODEL_MAP``. + * Argo shim, ArgoProxy, and custom OpenAI-compatible endpoints strip the + ``argo:`` prefix and keep the OpenAI-style name. Parameters ---------- @@ -83,18 +93,28 @@ def _normalize_argo_model(model_name: str, base_url: str) -> str: if not model_name.startswith("argo:"): return model_name - if base_url and "argoapi" in base_url: - # Argo API endpoint -- use the wire-name map - normalized = ARGO_MODEL_MAP.get(model_name) - if normalized: - logger.info("Normalized Argo model '%s' -> '%s'", model_name, normalized) - return normalized - # Fallback: strip prefix and remove punctuation - fallback = model_name.removeprefix("argo:").replace("-", "").replace(".", "") + model_format = os.getenv("CHEMGRAPH_ARGO_MODEL_FORMAT", "").lower() + if model_format == "shim": + return _normalize_argo_local_openai_model(model_name) + if model_format in {"openai", "openai-compatible"}: + stripped = model_name.removeprefix("argo:") + logger.info("Stripped argo: prefix '%s' -> '%s'", model_name, stripped) + return stripped + if model_format in {"wire", "argo"}: + return _normalize_argo_wire_model(model_name) + + if _is_local_http_endpoint(base_url): + stripped = _normalize_argo_local_openai_model(model_name) logger.info( - "Normalized Argo model '%s' -> '%s' (fallback)", model_name, fallback + "Using OpenAI-style Argo model for local endpoint '%s': '%s' -> '%s'", + base_url, + model_name, + stripped, ) - return fallback + return stripped + + if base_url and "argoapi" in base_url: + return _normalize_argo_wire_model(model_name) else: # Non-Argo-API endpoint -- strip prefix only stripped = model_name.removeprefix("argo:") @@ -102,6 +122,41 @@ def _normalize_argo_model(model_name: str, base_url: str) -> str: return stripped +def _normalize_argo_local_openai_model(model_name: str) -> str: + """Return the model name expected by local OpenAI-compatible Argo shims.""" + return ARGO_LOCAL_OPENAI_MODEL_MAP.get( + model_name, + model_name.removeprefix("argo:"), + ) + + +def _normalize_argo_wire_model(model_name: str) -> str: + """Return the hosted-Argo wire model for an ``argo:`` model name.""" + normalized = ARGO_MODEL_MAP.get(model_name) + if normalized: + logger.info("Normalized Argo model '%s' -> '%s'", model_name, normalized) + return normalized + + fallback = model_name.removeprefix("argo:").replace("-", "").replace(".", "") + logger.info( + "Normalized Argo model '%s' -> '%s' (fallback)", model_name, fallback + ) + return fallback + + +def _is_local_http_endpoint(base_url: str | None) -> bool: + """Return True for local HTTP endpoints such as ``argo-shim``.""" + if not base_url: + return False + parsed = urlparse(base_url) + return parsed.scheme == "http" and parsed.hostname in { + "localhost", + "127.0.0.1", + "::1", + "0.0.0.0", + } + + def load_openai_model( model_name: str, temperature: float, @@ -173,9 +228,13 @@ def load_openai_model( api_key = getpass("OpenAI API key: ") os.environ["OPENAI_API_KEY"] = api_key - if model_name not in supported_openai_models and model_name not in supported_argo_models: + if ( + model_name not in supported_openai_models + and model_name not in supported_argo_models + ): raise ValueError( - f"Unsupported model '{model_name}'. Supported models are: {supported_openai_models}." + f"Unsupported model '{model_name}'. " + f"Supported models are: {supported_openai_models}." ) is_argo_endpoint = bool(base_url and "argoapi" in base_url) @@ -214,7 +273,7 @@ def load_openai_model( api_key=api_key, max_tokens=6000, ) - # No guarantee that api_key is valid, authentication happens only during invocation + # Authentication happens only during invocation. logger.info(f"Requested model: {model_name}") logger.info("OpenAI model loaded successfully") return llm diff --git a/src/chemgraph/models/settings.py b/src/chemgraph/models/settings.py new file mode 100644 index 00000000..e24951bf --- /dev/null +++ b/src/chemgraph/models/settings.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import dataclasses +import json +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib # type: ignore[no-redef] + + +@dataclasses.dataclass(frozen=True, init=False) +class LLMSettings: + """Fully resolved description of one LLM endpoint.""" + + model: str + base_url: str | None = None + api_key: str | None = None + argo_user: str | None = None + provider: str | None = None + timeout_s: float | None = None + temperature: float | None = None + max_tokens: int | None = None + max_retries: int | None = None + retry_delay_s: float | None = None + + def __init__( + self, + model: str, + base_url: str | None = None, + api_key: str | None = None, + argo_user: str | None = None, + provider: str | None = None, + timeout_s: float | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + max_retries: int | None = None, + retry_delay_s: float | None = None, + user: str | None = None, + ) -> None: + object.__setattr__(self, "model", model) + object.__setattr__(self, "base_url", base_url) + object.__setattr__(self, "api_key", api_key) + object.__setattr__(self, "argo_user", argo_user or user) + object.__setattr__(self, "provider", provider) + object.__setattr__(self, "timeout_s", timeout_s) + object.__setattr__(self, "temperature", temperature) + object.__setattr__(self, "max_tokens", max_tokens) + object.__setattr__(self, "max_retries", max_retries) + object.__setattr__(self, "retry_delay_s", retry_delay_s) + + @property + def user(self) -> str | None: + """Backward-compatible academy name for Argo user metadata.""" + return self.argo_user + + +def load_lm_settings(source: str | Path | Mapping[str, Any]) -> LLMSettings: + """Build LLMSettings from a JSON file, TOML file, or already-parsed dict.""" + if isinstance(source, Mapping): + return _from_mapping(source) + + path = Path(source) + text = path.read_text(encoding="utf-8") + if path.suffix.lower() == ".toml": + raw = tomllib.loads(text) + return _from_mapping(_extract_endpoint_from_cli_toml(raw)) + return _from_mapping(json.loads(text)) + + +def _from_mapping(data: Mapping[str, Any]) -> LLMSettings: + if not isinstance(data, Mapping): + raise ValueError("LM config must be a mapping/object") + + model = data.get("model") or data.get("model_name") + if not isinstance(model, str) or not model: + raise ValueError("LM config requires a non-empty 'model' field") + + provider = data.get("provider") + if provider is not None and provider != "openai_compatible_tools": + raise ValueError( + "LM config 'provider' must be 'openai_compatible_tools' or absent", + ) + + api_key = data.get("api_key") + if provider == "openai_compatible_tools" and not api_key: + raise ValueError( + "openai_compatible_tools provider requires api_key " + "(use 'dummy' for Argo shim routes that ignore auth)", + ) + + return LLMSettings( + model=str(model), + base_url=_str_or_none(data.get("base_url")), + api_key=_str_or_none(api_key), + argo_user=_str_or_none(data.get("user") or data.get("argo_user")), + provider=_str_or_none(provider), + timeout_s=_float_or_none(data.get("timeout_s")), + temperature=_float_or_none(data.get("temperature")), + max_tokens=_int_or_none(data.get("max_tokens")), + max_retries=_int_or_none(data.get("max_retries")), + retry_delay_s=_float_or_none(data.get("retry_delay_s")), + ) + + +def _extract_endpoint_from_cli_toml(raw: Mapping[str, Any]) -> dict[str, Any]: + """Pull LLM endpoint fields out of the CLI's nested TOML structure.""" + general = raw.get("general") or {} + api = raw.get("api") or {} + model = general.get("model") + argo_user = general.get("argo_user") or (api.get("argo") or {}).get("user") + + base_url = None + if isinstance(model, str): + if model.startswith("argo:"): + base_url = (api.get("argo") or {}).get("base_url") + else: + for section_name in ("openai", "anthropic", "gemini", "alcf", "ollama"): + section = api.get(section_name) or {} + if section.get("base_url"): + base_url = section["base_url"] + break + + return { + "model": model, + "base_url": base_url, + "argo_user": argo_user, + "api_key": (api.get(_provider_section_for(model)) or {}).get("api_key"), + } + + +def _provider_section_for(model: Any) -> str: + if isinstance(model, str): + if model.startswith("argo:"): + return "argo" + if model.startswith("groq:"): + return "groq" + return "openai" + + +def _str_or_none(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value or None + return str(value) or None + + +def _float_or_none(value: Any) -> float | None: + return None if value is None else float(value) + + +def _int_or_none(value: Any) -> int | None: + return None if value is None else int(value) diff --git a/src/chemgraph/schemas/calculators/mace_calc.py b/src/chemgraph/schemas/calculators/mace_calc.py index 2ad50216..712df5c2 100644 --- a/src/chemgraph/schemas/calculators/mace_calc.py +++ b/src/chemgraph/schemas/calculators/mace_calc.py @@ -1,13 +1,19 @@ """MACE foundation models parameters for ChemGraph Reference: https://github.com/ACEsuit/mace/blob/main/mace/calculators/foundations_models.py""" +import functools +import logging import os +import tempfile import threading +from contextlib import contextmanager from pathlib import Path from typing import Optional, Union from pydantic import BaseModel, Field import torch +_logger = logging.getLogger(__name__) + # Process-wide lock for MACE operations. # MACE model deserialization (torch.load) triggers torch.fx.symbolic_trace # inside Contraction.__init__, which temporarily patches @@ -18,6 +24,78 @@ _mace_lock = threading.Lock() +@functools.lru_cache(maxsize=1) +def _mace_lockfile_path() -> Optional[str]: + """Return the path of the per-node MACE init lock file, or ``None`` if + no writable directory is available. Memoised so we only resolve once.""" + candidates = [ + os.environ.get("CHEMGRAPH_MACE_LOCK_DIR"), + os.environ.get("TMPDIR"), + tempfile.gettempdir(), + str(Path.home() / ".cache" / "chemgraph"), + ] + uid = os.getuid() if hasattr(os, "getuid") else "unknown" + for d in candidates: + if not d: + continue + try: + Path(d).mkdir(parents=True, exist_ok=True) + path = str(Path(d) / f"chemgraph_mace_init.{uid}.lock") + # Touch to confirm we can write. + with open(path, "a"): + pass + return path + except OSError: + continue + return None + + +@contextmanager +def mace_loading_lock(): + """Serialize MACE model loads across both threads and processes on one node. + + EnsembleLauncher's ``AsyncProcessPool`` spawns multiple Python workers in + parallel; a per-process :data:`_mace_lock` is not enough because torch's + ``symbolic_trace`` patches ``torch.nn.Module.__call__`` at the class level + during MACE deserialization, and concurrent loads in sibling processes + racing on the same node can deadlock or trip the same NameError that #110 + describes. We add an ``fcntl.flock``-based file lock on top so that + siblings on the same node take turns. + + Degrades to thread-only locking when ``fcntl`` is unavailable (e.g. + Windows) or no writable lock directory exists. + """ + try: + import fcntl + except ImportError: + fcntl = None # type: ignore[assignment] + + path = _mace_lockfile_path() if fcntl is not None else None + fh = None + try: + with _mace_lock: + if path is not None: + fh = open(path, "w") + try: + fcntl.flock(fh.fileno(), fcntl.LOCK_EX) + except OSError as exc: + _logger.warning( + "fcntl.flock on %s failed (%s); proceeding without " + "inter-process MACE serialization.", + path, + exc, + ) + fh.close() + fh = None + yield + finally: + if fh is not None: + try: + fcntl.flock(fh.fileno(), fcntl.LOCK_UN) + finally: + fh.close() + + class MaceCalc(BaseModel): """MACE (Message-passing Atomic and Continuous Environment) calculator configuration. diff --git a/src/chemgraph/schemas/graspa_schema.py b/src/chemgraph/schemas/graspa_schema.py index 9cd08231..996ec12b 100644 --- a/src/chemgraph/schemas/graspa_schema.py +++ b/src/chemgraph/schemas/graspa_schema.py @@ -46,7 +46,16 @@ class graspa_input_schema(BaseModel): class graspa_input_schema_ensemble(BaseModel): input_structures: Union[str, list[str]] = Field( - description="Path to a directory of CIF files OR a specific list of file paths." + default="", + description="Path to a directory of CIF files OR a specific list of file paths. Required unless remote_structure_directory is provided.", + ) + remote_structure_directory: str | None = Field( + default=None, + description=( + "Path to pre-staged CIF files on the remote HPC filesystem. " + "When provided, workers read structures directly from this path. " + "Use the transfer_files tool to stage files first." + ), ) output_result_file: str = Field( default="raspa.log", diff --git a/src/chemgraph/schemas/mace_parsl_schema.py b/src/chemgraph/schemas/mace_parsl_schema.py index e04ddba6..17d5c54f 100644 --- a/src/chemgraph/schemas/mace_parsl_schema.py +++ b/src/chemgraph/schemas/mace_parsl_schema.py @@ -17,14 +17,20 @@ class mace_input_schema(BaseModel): default="output.json", description="Path to a JSON file where simulation results will be saved.", ) - driver: str = Field( + driver: str | None = Field( default=None, description="Specifies the type of simulation to run. Options: 'energy' for single-point energy calculations, 'opt' for geometry optimization, 'vib' for vibrational frequency analysis, and 'thermo' for thermochemical properties (including enthalpy, entropy, and Gibbs free energy).", ) model: str = Field( default="medium-mpa-0", - description="Path to the model. Default is medium-mpa-0." - "Options are 'small', 'medium', 'large', 'small-0b', 'medium-0b', 'small-0b2', 'medium-0b2','large-0b2', 'medium-0b3', 'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', 'mace-matpes-r2scan-0'", + description="MACE foundation model name or absolute local model file path " + "(NOT the calculator type). " + "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " + "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " + "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " + "'mace-matpes-r2scan-0', or an absolute path to a local .model file. " + "Default is 'medium-mpa-0'. " + "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( default="cpu", @@ -54,20 +60,37 @@ class mace_input_schema(BaseModel): class mace_input_schema_ensemble(BaseModel): input_structure_directory: str = Field( - description="Path to a folder of input structures containing the atomic structure for the simulations." + default="", + description="Path to a local folder of input structures. Required unless remote_structure_directory is provided.", + ) + remote_structure_directory: str | None = Field( + default=None, + description=( + "Path to pre-staged structure files on the remote HPC filesystem. " + "When provided, workers read structures directly from this path " + "instead of using inline structure embedding. Use the " + "transfer_files tool to stage files first, then pass the " + "remote directory here." + ), ) output_result_file: str = Field( default="output.json", description="Path to a JSON file where simulation results will be saved.", ) - driver: str = Field( + driver: str | None = Field( default=None, description="Specifies the type of simulation to run. Options: 'energy' for single-point energy calculations, 'opt' for geometry optimization, 'vib' for vibrational frequency analysis, and 'thermo' for thermochemical properties (including enthalpy, entropy, and Gibbs free energy).", ) model: str = Field( default="medium-mpa-0", - description="Path to the model. Default is medium-mpa-0." - "Options are 'small', 'medium', 'large', 'small-0b', 'medium-0b', 'small-0b2', 'medium-0b2','large-0b2', 'medium-0b3', 'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', 'mace-matpes-r2scan-0'", + description="MACE foundation model name or absolute local model file path " + "(NOT the calculator type). " + "Options: 'small', 'medium', 'large', 'small-0b', 'medium-0b', " + "'small-0b2', 'medium-0b2', 'large-0b2', 'medium-0b3', " + "'medium-mpa-0', 'medium-omat-0', 'mace-matpes-pbe-0', " + "'mace-matpes-r2scan-0', or an absolute path to a local .model file. " + "Default is 'medium-mpa-0'. " + "Do NOT pass 'mace_mp' — that is the calculator type, not a model name.", ) device: str = Field( default="cpu", @@ -102,8 +125,12 @@ class mace_output_schema(BaseModel): output_result_file: str = Field( description="Path to a JSON file where simulation results is saved.", ) - model: str = Field( - default=None, description="Path to the model. Default is medium-mpa-0." + model: str | None = Field( + default=None, + description=( + "MACE foundation model name or absolute local model file path. " + "Default is medium-mpa-0." + ), ) device: str = Field( default="cpu", @@ -143,7 +170,7 @@ class mace_output_schema(BaseModel): default="", description="Error captured during the simulation", ) - wall_time: float = Field( + wall_time: float | None = Field( default=None, description="Total wall time (in seconds) taken to complete the simulation.", ) diff --git a/src/chemgraph/tools/ase_core.py b/src/chemgraph/tools/ase_core.py index 4e3dc915..b1c483e3 100644 --- a/src/chemgraph/tools/ase_core.py +++ b/src/chemgraph/tools/ase_core.py @@ -10,6 +10,7 @@ import glob import json +import logging import os import shutil import tempfile @@ -22,6 +23,28 @@ from chemgraph.schemas.atomsdata import AtomsData from chemgraph.schemas.ase_input import ASEInputSchema, ASEOutputSchema +logger = logging.getLogger(__name__) + + +def _ensure_ase_core_file_log() -> None: + """Attach a single ``FileHandler`` to the ase_core logger. + + ``run_ase_core`` runs both in the MCP-server process (where + ``server_utils`` already configures root logging) and in worker + processes (Parsl / EnsembleLauncher / Globus Compute) that never go + through that setup, so we add our own file handler here. Idempotent: + a second call is a no-op, which avoids accumulating one open file + handle per invocation. Honors ``CHEMGRAPH_LOG_DIR`` when set. + """ + if any(isinstance(h, logging.FileHandler) for h in logger.handlers): + return + log_dir = os.environ.get("CHEMGRAPH_LOG_DIR", os.path.join(os.getcwd(), "cg_logs")) + os.makedirs(log_dir, exist_ok=True) + fh = logging.FileHandler(os.path.join(log_dir, "ase_core.log")) + fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + logger.addHandler(fh) + logger.setLevel(logging.DEBUG) + # --------------------------------------------------------------------------- # Path helpers @@ -202,7 +225,18 @@ def load_calculator(calculator: dict) -> tuple[object, dict, object]: if hasattr(calc, "get_atoms_properties"): extra_info = calc.get_atoms_properties() - return calc.get_calculator(), extra_info, calc + if "mace" in calc_type: + # MACE's torch.load + symbolic_trace is unsafe under concurrent loads, + # whether the concurrency is threads in one process or sibling processes + # spawned by the EnsembleLauncher process pool. See mace_calc._mace_lock. + from chemgraph.schemas.calculators.mace_calc import mace_loading_lock + + with mace_loading_lock(): + ase_calculator = calc.get_calculator() + else: + ase_calculator = calc.get_calculator() + + return ase_calculator, extra_info, calc # --------------------------------------------------------------------------- @@ -314,10 +348,16 @@ def run_ase_core(params: ASEInputSchema) -> dict: from ase.io import read from ase.optimize import BFGS, LBFGS, GPMin, FIRE, MDMin + # ---- file logger (cg_logs/) ---- + _ensure_ase_core_file_log() + + logger.info("run_ase_core called with params: %s", params.model_dump_json()) + # ---- unpack params ---- try: calculator = params.calculator.model_dump() except Exception as e: + logger.error("Calculator validation failed: %s", e) return { "status": "failure", "error_type": "ValidationError", @@ -336,7 +376,11 @@ def run_ase_core(params: ASEInputSchema) -> dict: pressure = params.pressure # ---- input validation ---- + logger.info("driver=%s, input=%s, output=%s, optimizer=%s, fmax=%s, steps=%s", + driver, input_structure_file, output_results_file, optimizer, fmax, steps) + if not os.path.isfile(input_structure_file): + logger.error("Input file not found: %s", input_structure_file) return { "status": "failure", "error_type": "FileNotFoundError", @@ -344,15 +388,27 @@ def run_ase_core(params: ASEInputSchema) -> dict: } if not output_results_file.endswith(".json"): + logger.error("Invalid output file extension: %s", output_results_file) return { "status": "failure", "error_type": "ValueError", "message": f"Output results file must end with '.json', got: {params.output_results_file}", } + # Make sure the destination directory exists before the simulation runs; + # otherwise the trailing ``open(output_results_file, "w")`` fails with + # FileNotFoundError after the calculation has already burned its + # compute time. Callers (LLM agents, scripts) routinely point at a + # not-yet-created subdirectory of a shared run dir, so create it now. + output_parent = os.path.dirname(os.path.abspath(output_results_file)) + if output_parent: + os.makedirs(output_parent, exist_ok=True) + + logger.info("Loading calculator: %s", calculator) calc, system_info, calc_model = load_calculator(calculator) if calc is None: + logger.error("Unsupported calculator: %s", calculator) return { "status": "failure", "error_type": "ValueError", @@ -361,16 +417,19 @@ def run_ase_core(params: ASEInputSchema) -> dict: "MACE (mace_mp, mace_off, mace_anicc), EMT, TBLite (GFN2-xTB, GFN1-xTB), NWChem and Orca" ), } + logger.info("Calculator loaded successfully: %s", type(calc).__name__) try: atoms = read(input_structure_file) except Exception as e: + logger.error("Failed to read input structure: %s", e) return { "status": "failure", "error_type": type(e).__name__, "message": f"Cannot read {input_structure_file} using ASE. Exception from ASE: {e}", } + logger.info("Read %d atoms from %s", len(atoms), input_structure_file) atoms.info.update(system_info) atoms.calc = calc @@ -378,7 +437,9 @@ def run_ase_core(params: ASEInputSchema) -> dict: # Driver: energy / dipole (single-point, no optimization) # ------------------------------------------------------------------ if driver in ("energy", "dipole"): + logger.info("Running single-point %s calculation", driver) energy = atoms.get_potential_energy() + logger.info("Single-point energy: %s eV", energy) final_structure = atoms_to_atomsdata(atoms) dipole: List[Optional[float]] = [None, None, None] @@ -403,6 +464,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: ) with open(output_results_file, "w", encoding="utf-8") as wf: wf.write(simulation_output.model_dump_json(indent=4)) + logger.info("Results saved to %s (wall_time=%.2fs)", output_results_file, wall_time) if driver == "energy": return { @@ -434,13 +496,16 @@ def run_ase_core(params: ASEInputSchema) -> dict: if optimizer_class is None: raise ValueError(f"Unsupported optimizer: {optimizer}") + logger.info("Running optimization with %s (fmax=%s, steps=%s)", optimizer, fmax, steps) if len(atoms) > 1: dyn = optimizer_class(atoms) converged = dyn.run(fmax=fmax, steps=steps) else: converged = True + logger.info("Optimization converged=%s", converged) single_point_energy = float(atoms.get_potential_energy()) + logger.info("Post-optimization energy: %s eV", single_point_energy) final_structure = AtomsData( numbers=atoms.numbers, positions=atoms.positions, @@ -455,6 +520,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # Vibrational / thermo / IR analysis # -------------------------------------------------------------- if driver in {"vib", "thermo", "ir"}: + logger.info("Starting vibrational analysis (driver=%s)", driver) from ase.vibrations import Vibrations from ase import units @@ -470,6 +536,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: vib = Vibrations(atoms, name=vib_name) vib.clean() vib.run() + logger.info("Vibrational analysis complete") vib_data = { "energies": [], @@ -516,6 +583,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # ---- IR ---- if driver == "ir": + logger.info("Running IR calculation") from ase.vibrations import Infrared import matplotlib @@ -547,6 +615,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: fig.savefig(ir_plot_path, format="png", dpi=300) plt.close(fig) + logger.info("IR spectrum plot saved to %s", ir_plot_path) ir_data["IR Plot"] = f"Saved to {os.path.abspath(ir_plot_path)}" ir_data["Normal mode data"] = ( f"Normal modes saved as individual .traj files with prefix {mol_stem}_" @@ -554,6 +623,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # ---- Thermochemistry ---- if driver == "thermo": + logger.info("Computing thermochemistry (T=%s K, P=%s Pa)", temperature, pressure) if len(atoms) == 1: thermo_data = { "enthalpy": single_point_energy, @@ -604,6 +674,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: # ---- serialise full output ---- end_time = time.time() wall_time = end_time - start_time + logger.info("Simulation finished (driver=%s, wall_time=%.2fs, converged=%s)", driver, wall_time, converged) simulation_output = ASEOutputSchema( input_structure_file=input_structure_file, @@ -660,6 +731,7 @@ def run_ase_core(params: ASEInputSchema) -> dict: } except Exception as e: + logger.exception("run_ase_core failed with %s: %s", type(e).__name__, e) return { "status": "failure", "error_type": type(e).__name__, diff --git a/src/chemgraph/tools/ase_tools.py b/src/chemgraph/tools/ase_tools.py index ff4650a3..369c2753 100644 --- a/src/chemgraph/tools/ase_tools.py +++ b/src/chemgraph/tools/ase_tools.py @@ -13,7 +13,6 @@ from chemgraph.schemas.atomsdata import AtomsData from chemgraph.schemas.ase_input import ASEInputSchema -from chemgraph.schemas.calculators.mace_calc import _mace_lock from chemgraph.tools.ase_core import ( _resolve_path, atoms_to_atomsdata, @@ -166,8 +165,6 @@ def run_ase(params: ASEInputSchema) -> dict: ValueError If the calculator is not supported or if the calculation fails """ - calc_type = params.calculator.calculator_type.lower() - if "mace" in calc_type: - with _mace_lock: - return run_ase_core(params) + # MACE thread/process serialization now lives in run_ase_core -> + # load_calculator, so this wrapper just delegates. return run_ase_core(params) diff --git a/src/chemgraph/tools/cheminformatics_core.py b/src/chemgraph/tools/cheminformatics_core.py index 0ffe13b3..321fc659 100644 --- a/src/chemgraph/tools/cheminformatics_core.py +++ b/src/chemgraph/tools/cheminformatics_core.py @@ -142,6 +142,9 @@ def smiles_to_coordinate_file_core( atoms = Atoms(numbers=numbers, positions=positions) final_output_file = _resolve_path(output_file) + parent = os.path.dirname(os.path.abspath(final_output_file)) + if parent: + os.makedirs(parent, exist_ok=True) ase_write(final_output_file, atoms) return { diff --git a/src/chemgraph/tools/parsl_tools.py b/src/chemgraph/tools/parsl_tools.py index 908ac29c..6d8bbda7 100644 --- a/src/chemgraph/tools/parsl_tools.py +++ b/src/chemgraph/tools/parsl_tools.py @@ -6,23 +6,25 @@ from __future__ import annotations -from chemgraph.tools.ase_core import run_ase_core +import logging + from chemgraph.schemas.ase_input import ASEInputSchema from chemgraph.schemas.mace_parsl_schema import ( mace_input_schema, - mace_input_schema_ensemble, mace_output_schema, ) +from chemgraph.tools.ase_core import run_ase_core # Re-export schemas so existing ``from chemgraph.tools.parsl_tools import …`` # statements continue to work. __all__ = [ "mace_input_schema", - "mace_input_schema_ensemble", "mace_output_schema", "run_mace_core", + "extract_output_json", ] +logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Core execution — delegates to the unified implementation @@ -77,3 +79,15 @@ def run_mace_core(params: mace_input_schema) -> dict: """ ase_params = _mace_input_to_ase_input(params) return run_ase_core(ase_params) + + +def extract_output_json(json_file: str) -> dict: + """Load simulation results from a JSON file produced by run_ase.""" + import json + + try: + with open(json_file, "r") as f: + ret = json.load(f) + except Exception as e: + ret = {} + return ret diff --git a/tests/conftest.py b/tests/conftest.py index 083d138e..76b425d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,10 @@ # Configure pytest-asyncio #pytest_plugins = ("pytest_asyncio",) +# Test modules that require the optional ``academy`` extra guard themselves with +# ``pytest.importorskip("academy")`` at module top, so they skip cleanly (rather +# than erroring collection) when the extra is not installed. + @pytest.fixture(autouse=True) def setup_test_env(): @@ -27,12 +31,22 @@ def pytest_addoption(parser): parser.addoption( "--run-llm", action="store_true", default=False, help="run tests that call LLM APIs" ) + parser.addoption( + "--run-globus-compute", action="store_true", default=False, + help="run tests that require a live Globus Compute endpoint" + ) def pytest_collection_modifyitems(config, items): - if config.getoption("--run-llm"): - # --run-llm given in cli: do not skip llm tests - return - skip_llm = pytest.mark.skip(reason="need --run-llm option to run") + skip_llm = None + if not config.getoption("--run-llm"): + skip_llm = pytest.mark.skip(reason="need --run-llm option to run") + + skip_globus = None + if not config.getoption("--run-globus-compute"): + skip_globus = pytest.mark.skip(reason="need --run-globus-compute option to run") + for item in items: - if "llm" in item.keywords: - item.add_marker(skip_llm) \ No newline at end of file + if skip_llm and "llm" in item.keywords: + item.add_marker(skip_llm) + if skip_globus and "globus_compute" in item.keywords: + item.add_marker(skip_globus) \ No newline at end of file diff --git a/tests/test_academy_campaign.py b/tests/test_academy_campaign.py new file mode 100644 index 00000000..1fa2956e --- /dev/null +++ b/tests/test_academy_campaign.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import json + +import pytest + +# Skip the whole module when the optional 'academy' extra is absent. +# Even though this file only touches the pure-stdlib parts of +# chemgraph.academy, the import guard is applied uniformly across the +# academy test suite so pytest collection stays clean on a CPU-only +# checkout without per-test bookkeeping. +pytest.importorskip("academy") + +from chemgraph.academy.core.campaign import campaign_bootstrap_text +from chemgraph.academy.core.campaign import load_campaign +from chemgraph.academy.core.campaign import MCPServerSpec +from chemgraph.academy.core.campaign import validate_campaign + + +def test_builtin_mace_campaign_uses_star_coordinator_without_routing_policy() -> None: + campaign = load_campaign("mace-ensemble-screening-20") + + validate_campaign(campaign, len(campaign.agents)) + + assert campaign.initial_agent == "coordinator-agent" + assert [agent.name for agent in campaign.agents] == [ + "coordinator-agent", + "structure-agent-a", + "structure-agent-b", + "mace-agent", + "assessment-agent", + ] + peers = {agent.name: set(agent.allowed_peers) for agent in campaign.agents} + assert peers["coordinator-agent"] == { + "structure-agent-a", + "structure-agent-b", + "mace-agent", + "assessment-agent", + } + assert peers["structure-agent-a"] == {"coordinator-agent"} + assert peers["structure-agent-b"] == {"coordinator-agent"} + assert peers["mace-agent"] == {"coordinator-agent"} + assert peers["assessment-agent"] == {"coordinator-agent"} + + bootstrap = json.loads(campaign_bootstrap_text(campaign)) + assert "parameters" not in bootstrap + assert "routing_policy" not in bootstrap + + +def test_removed_structured_orchestration_fields_are_rejected(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "stale", + "user_task": "test", + "prompt_profile": "prompt.json", + "parameters": {"old": "field"}, + "routing_policy": {"type": "old"}, + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": [], + }, + ], + "mcp_servers": [], + }, + ), + encoding="utf-8", + ) + + with pytest.raises(RuntimeError, match="removed structured orchestration"): + load_campaign(campaign_path) + + +def test_campaign_loader_accepts_jsonc_comments(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + """ + { + // User-facing campaign files may include comments. + "run_id": "commented", + "user_task": "test", + "prompt_profile": "prompt.json", + "resources": { + /* Resource options are documented in the built-in examples. */ + "input": { + "kind": "json", + "path": "input.json", + "scope": "campaign_file", + "expose_content": false + } + }, + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + "resources": ["input"] + } + ], + "mcp_servers": [ + { + "name": "general", + "command": "python -m chemgraph.mcp.mcp_tools" + } + ] + } + """, + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + + assert campaign.run_id == "commented" + assert campaign.resources["input"].kind == "json" + assert campaign.mcp_servers[0].name == "general" + assert campaign.agents[0].mcp_servers == ("general",) + + +def test_mcp_server_spec_validation() -> None: + spec = MCPServerSpec.model_validate( + {"name": "general", "command": "python -m server"}, + ) + assert spec.env == {} + + with pytest.raises(ValueError, match="field required|Field required"): + MCPServerSpec.model_validate({"name": "general"}) + + with pytest.raises(ValueError): + MCPServerSpec.model_validate( + {"name": "general", "command": "python -m server", "extra": "bad"}, + ) + + +def test_resource_kind_and_scope_are_option_sets(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "bad-resource", + "user_task": "test", + "prompt_profile": "prompt.json", + "resources": { + "input": { + "kind": "blob", + "path": "input.json", + "scope": "somewhere", + }, + }, + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": [], + }, + ], + "mcp_servers": [], + }, + ), + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="resource kind must be one of"): + load_campaign(campaign_path) + + +def test_validate_campaign_rejects_unknown_mcp_server(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "bad-server", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["missing"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises(RuntimeError, match="unknown MCP servers"): + validate_campaign(campaign, 1) + + +def test_validate_campaign_rejects_duplicate_mcp_server_names(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "duplicate-server", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m one"}, + {"name": "general", "command": "python -m two"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises(RuntimeError, match="MCP server names must be unique"): + validate_campaign(campaign, 1) + + +def test_agent_allowed_tools_parses(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "allowed-tools-ok", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m chemgraph.mcp.mcp_tools"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + "allowed_tools": ["run_ase", "extract_output_json"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + validate_campaign(campaign, 1) + + assert campaign.agents[0].allowed_tools == ( + "run_ase", + "extract_output_json", + ) + + +def test_agent_allowed_tools_defaults_to_empty(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "allowed-tools-default", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m chemgraph.mcp.mcp_tools"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + validate_campaign(campaign, 1) + + assert campaign.agents[0].allowed_tools == () + + +def test_validate_campaign_rejects_duplicate_allowed_tools(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "duplicate-allowed-tools", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [ + {"name": "general", "command": "python -m chemgraph.mcp.mcp_tools"}, + ], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": ["general"], + "allowed_tools": ["run_ase", "run_ase"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises(RuntimeError, match="duplicate allowed_tools"): + validate_campaign(campaign, 1) + + +def test_validate_campaign_rejects_allowed_tools_without_servers(tmp_path) -> None: + campaign_path = tmp_path / "campaign.jsonc" + campaign_path.write_text( + json.dumps( + { + "run_id": "allowed-tools-no-servers", + "user_task": "test", + "prompt_profile": "prompt.json", + "mcp_servers": [], + "agents": [ + { + "name": "agent-a", + "role": "Role", + "mission": "Do the task.", + "allowed_peers": [], + "mcp_servers": [], + "allowed_tools": ["run_ase"], + }, + ], + }, + ), + encoding="utf-8", + ) + + campaign = load_campaign(campaign_path) + with pytest.raises( + RuntimeError, + match="allowed_tools but no mcp_servers", + ): + validate_campaign(campaign, 1) diff --git a/tests/test_academy_compute_launcher.py b/tests/test_academy_compute_launcher.py new file mode 100644 index 00000000..d73098f3 --- /dev/null +++ b/tests/test_academy_compute_launcher.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +# Skip when the optional 'academy' extra is absent; the runtime +# subpackage imports academy.* at module level. +pytest.importorskip("academy") + +from chemgraph.academy.runtime import compute_launcher +from chemgraph.academy.runtime.compute_launcher import AllocationPlan + + +def _plan(tmp_path: Path) -> AllocationPlan: + lm_config = tmp_path / "lm.json" + campaign = tmp_path / "campaign.jsonc" + lm_config.write_text("{}\n", encoding="utf-8") + campaign.write_text("{}\n", encoding="utf-8") + return AllocationPlan( + run_dir=tmp_path, + run_token="token-1", + agent_count=3, + agents_per_node=1, + campaign_config=campaign, + lm_config=lm_config, + max_decisions=7, + poll_timeout_s=2.0, + idle_timeout_s=600.0, + startup_timeout_s=120.0, + completion_timeout_s=60.0, + status_interval_s=5.0, + redis_host="redis-host", + redis_port=6392, + redis_bind="0.0.0.0", + redis_protected_mode="no", + redis_namespace="ns", + start_redis=False, + mpiexec="mpiexec", + chemgraph_repo_root=tmp_path / "ChemGraph", + ) + + +def test_run_allocation_builds_single_mpiexec_command(tmp_path, monkeypatch) -> None: + calls: list[list[str]] = [] + monkeypatch.setattr(compute_launcher, "wait_redis", lambda *args, **kwargs: None) + monkeypatch.setattr( + compute_launcher.subprocess, + "call", + lambda cmd: calls.append(cmd) or 0, + ) + + assert compute_launcher.run_allocation(_plan(tmp_path)) == 0 + + assert len(calls) == 1 + cmd = calls[0] + assert cmd[:4] == ["mpiexec", "-n", "3", "--ppn"] + assert "chemgraph.cli.main" in cmd + assert "mpi-daemon" in cmd + assert "--campaign-config" in cmd + assert "--lm-config" in cmd + assert "--exchange-type" in cmd + assert "--chemgraph-repo-root" in cmd + assert (tmp_path / "launch_command.txt").exists() diff --git a/tests/test_academy_dashboard.py b/tests/test_academy_dashboard.py new file mode 100644 index 00000000..0abec32c --- /dev/null +++ b/tests/test_academy_dashboard.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import json + +import pytest + +# Skip when the optional 'academy' extra is absent. The dashboard +# module itself is pure stdlib, but the import guard is applied +# uniformly across the academy test suite. +pytest.importorskip("academy") + +import chemgraph.academy.dashboard as dashboard +from chemgraph.academy.observability.event_log import EventLog + + +def test_dashboard_reads_canonical_events_jsonl(tmp_path) -> None: + run_dir = tmp_path / "daemon-run" + run_dir.mkdir() + (run_dir / "status.json").write_text( + json.dumps({"mode": "mpi_daemon", "timestamp": 10.0, "agents": []}) + + "\n", + encoding="utf-8", + ) + log = EventLog(run_dir / "events.jsonl") + log.emit( + "agent_started", + agent_id="agent-00", + role="scheduler observer", + payload={ + "role": "scheduler observer", + "placement": {"hostname": "x1", "short_hostname": "x1"}, + "hostname": "x1", + "short_hostname": "x1", + }, + ) + log.emit( + "agent_decision", + agent_id="agent-00", + role="scheduler observer", + payload={ + "round": 1, + "tool_names": ["send_message"], + "actions": [{"action": "send_message"}], + }, + ) + + events = dashboard.events_payload(run_dir)["events"] + + assert events[0]["event"] == "agent_started" + assert events[0]["payload"]["placement"]["hostname"] == "x1" + assert events[1]["event"] == "agent_decision" + assert events[1]["payload"]["actions"] == [{"action": "send_message"}] + + +def test_status_payload_builds_summary_from_events(tmp_path) -> None: + run_dir = tmp_path / "daemon-run" + run_dir.mkdir() + (run_dir / "status.json").write_text( + json.dumps({"mode": "mpi_daemon", "agents": []}) + "\n", + encoding="utf-8", + ) + log = EventLog(run_dir / "events.jsonl") + for agent_id, hostname in (("agent-00", "x0"), ("agent-01", "x1")): + log.emit( + "agent_started", + agent_id=agent_id, + role="observer", + payload={ + "role": "observer", + "placement": {"hostname": hostname, "short_hostname": hostname}, + "hostname": hostname, + "short_hostname": hostname, + }, + ) + log.emit( + "message_sent", + agent_id="agent-00", + role="observer", + payload={ + "message_id": "msg-1", + "timestamp": 2.0, + "sender": "agent-00", + "recipient": "agent-01", + "kind": "message", + "content": "share evidence", + "tldr": "evidence", + "artifact_refs": [], + "tool_result_ids": [], + }, + ) + log.emit( + "belief_updated", + agent_id="agent-01", + role="observer", + payload={ + "hypothesis": "used peer evidence", + "confidence": 0.8, + "supporting_message_ids": ["msg-1"], + "supporting_tool_result_ids": [], + }, + ) + + class Handler: + pass + + handler = Handler() + handler.run_dir = run_dir + payload = dashboard.status_payload(handler) + + assert set(payload) == { + "placement", + "run_dir", + "schema", + "status", + "summary", + "updated", + } + assert payload["summary"]["message_count"] == 1 + assert payload["summary"]["final_reports"] == [ + { + "agent_id": "agent-01", + "confidence": 0.8, + "summary": "used peer evidence", + "supporting_message_ids": ["msg-1"], + "supporting_tool_result_ids": [], + }, + ] + + +def test_dashboard_ignores_legacy_trace_jsonl(tmp_path) -> None: + run_dir = tmp_path / "old-run" + run_dir.mkdir() + (run_dir / "trace.jsonl").write_text( + json.dumps( + { + "timestamp": 1.0, + "agent": "agent-00", + "event": "daemon_started", + "payload": {"hostname": "x0"}, + }, + ) + + "\n", + encoding="utf-8", + ) + + assert dashboard.events_payload(run_dir)["events"] == [] diff --git a/tests/test_academy_dashboard_launcher.py b/tests/test_academy_dashboard_launcher.py new file mode 100644 index 00000000..51dc1a96 --- /dev/null +++ b/tests/test_academy_dashboard_launcher.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import argparse +import json +import subprocess +from pathlib import Path + +import pytest + +# Skip when the optional 'academy' extra is absent. +pytest.importorskip("academy") + +from chemgraph.academy.runtime import dashboard_launcher +from chemgraph.academy.runtime.profiles.system import SystemProfile + + +def _profile(tmp_path: Path) -> SystemProfile: + return SystemProfile( + name="test-system", + remote_host="user@example", + remote_root="/remote/root", + repo_root="/remote/root/ChemGraph", + run_root="/remote/root/runs", + relay_host_file="/remote/root/relay.host", + relay_port=18186, + venv_python="/remote/root/venv/bin/python", + redis_bin_dir="/remote/root/tools/redis/bin", + redis_port=6392, + redis_bind="0.0.0.0", + redis_protected_mode="no", + mpiexec="mpiexec", + pythonpath_entries=[str(tmp_path), "/remote/root/ChemGraph/src"], + no_proxy="127.0.0.1,localhost", + ) + + +def _args(tmp_path: Path, **overrides) -> argparse.Namespace: + values = { + "run_id": "run-001", + "system": "test-system", + "campaign": "mace-ensemble-screening-20", + "lm_connect": "direct", + "lm_base_url": "http://lm.example/v1", + "remote_host": None, + "ssh_control_path": str(tmp_path / "ssh-control"), + "keep_ssh_master": False, + "local_argo_host": "127.0.0.1", + "local_argo_port": 18085, + "reverse_port": 18185, + "relay_port": None, + "relay_python": None, + "rsync_interval_s": 2.0, + "local_mirror_root": str(tmp_path / "mirror"), + "local_run_dir": None, + "dashboard_host": "127.0.0.1", + "dashboard_port": 8765, + "local": False, + "no_dashboard": True, + "overwrite_run": True, + } + values.update(overrides) + return argparse.Namespace(**values) + + +def test_compute_wrapper_template_renders_profile_values(tmp_path) -> None: + text = dashboard_launcher.wrapper(_profile(tmp_path)) + + assert "%{" not in text + assert '/remote/root/tools/redis/bin:/remote/root/bin:${PATH}' in text + assert f'{tmp_path}:/remote/root/ChemGraph/src:${{PYTHONPATH:-}}' in text + assert "/remote/root/venv/bin/python" in text + + +def test_dashboard_launcher_overwrite_writes_remote_state(tmp_path, monkeypatch) -> None: + local_run = tmp_path / "mirror" / "run-001" + local_run.mkdir(parents=True) + (local_run / "status.json").write_text("{}\n", encoding="utf-8") + calls: list[dict] = [] + + def fake_ssh(host, command, **kwargs): + calls.append({"host": host, "command": command, **kwargs}) + return subprocess.CompletedProcess(["ssh"], 0, stdout="") + + monkeypatch.setattr(dashboard_launcher, "parse_args", lambda: _args(tmp_path)) + monkeypatch.setattr(dashboard_launcher, "load_system_profile", lambda _: _profile(tmp_path)) + monkeypatch.setattr(dashboard_launcher, "campaign_launch_defaults", lambda _: object()) + monkeypatch.setattr(dashboard_launcher, "ssh", fake_ssh) + monkeypatch.setattr(dashboard_launcher, "start_rsync", lambda *args, **kwargs: None) + + assert dashboard_launcher.main() == 0 + assert not local_run.exists() + + delete_command = calls[1]["command"] + assert 'mv -- "$run_dir" "$trash_dir"' in delete_command + assert 'rm -rf -- "$trash_dir"' in delete_command + assert 'mkdir -p "$run_dir"' in delete_command + + wrapper_call = calls[2] + assert wrapper_call["command"].endswith("chmod +x /remote/root/bin/chemgraph-academy-run") + assert "chemgraph.academy.runtime.compute_launcher" in wrapper_call["input_text"] + + metadata = json.loads(calls[3]["input_text"]) + assert metadata["run_id"] == "run-001" + assert metadata["lm_base_url"] == "http://lm.example/v1" + assert metadata["remote_run_dir"] == "/remote/root/runs/run-001" + + +def test_dashboard_launcher_rejects_unsafe_overwrite_run_id(tmp_path, monkeypatch) -> None: + monkeypatch.setattr( + dashboard_launcher, + "parse_args", + lambda: _args(tmp_path, run_id="../bad"), + ) + monkeypatch.setattr(dashboard_launcher, "load_system_profile", lambda _: _profile(tmp_path)) + monkeypatch.setattr(dashboard_launcher, "campaign_launch_defaults", lambda _: object()) + monkeypatch.setattr( + dashboard_launcher, + "ssh", + lambda *args, **kwargs: subprocess.CompletedProcess(["ssh"], 0, stdout=""), + ) + + with pytest.raises(RuntimeError, match="unsafe run id"): + dashboard_launcher.main() diff --git a/tests/test_academy_exchange_registration.py b/tests/test_academy_exchange_registration.py new file mode 100644 index 00000000..0c1f95cb --- /dev/null +++ b/tests/test_academy_exchange_registration.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +# Skip when the optional 'academy' extra is absent; this module +# imports academy.exchange.* directly at top level. +pytest.importorskip("academy") + +from academy.exchange.hybrid import HybridAgentRegistration +from academy.exchange.local import LocalAgentRegistration +from academy.exchange.redis import RedisAgentRegistration +from academy.identifier import AgentId + +from chemgraph.academy.core.campaign import ChemGraphDaemonConfig +from chemgraph.academy.runtime.exchange import build_exchange_factory +from chemgraph.academy.runtime.registration import load_academy_registrations +from chemgraph.academy.runtime.registration import registration_payload +from chemgraph.academy.runtime.registration import write_academy_registrations + + +def _config(tmp_path: Path, exchange_type: str) -> ChemGraphDaemonConfig: + return ChemGraphDaemonConfig( + run_dir=tmp_path, + run_token='token-1', + agent_count=1, + campaign_config=tmp_path / 'campaign.jsonc', + lm_config=tmp_path / 'lm.json', + max_decisions=1, + poll_timeout_s=1.0, + idle_timeout_s=1.0, + startup_timeout_s=1.0, + completion_timeout_s=1.0, + status_interval_s=1.0, + redis_host='localhost', + redis_port=6392, + redis_namespace='ns', + rank=0, + local_rank=0, + chemgraph_repo_root=tmp_path, + exchange_type=exchange_type, + ) + + +@pytest.mark.parametrize( + ('exchange_type', 'expected_class'), + [ + ('redis', 'RedisExchangeFactory'), + ('local', 'LocalExchangeFactory'), + ('hybrid', 'HybridExchangeFactory'), + ], +) +def test_build_exchange_factory_dispatches_by_config( + tmp_path, + exchange_type, + expected_class, +) -> None: + factory = build_exchange_factory(_config(tmp_path, exchange_type)) + + assert type(factory).__name__ == expected_class + + +def test_build_exchange_factory_rejects_unknown_exchange(tmp_path) -> None: + with pytest.raises(ValueError, match='Unsupported exchange type'): + build_exchange_factory(_config(tmp_path, 'bad')) + + +@pytest.mark.parametrize( + 'registration_cls', + [ + RedisAgentRegistration, + LocalAgentRegistration, + HybridAgentRegistration, + ], +) +def test_academy_registration_round_trips_by_exchange_type( + tmp_path, + registration_cls, +) -> None: + registration = registration_cls(agent_id=AgentId.new('agent-a')) + write_academy_registrations( + run_dir=tmp_path, + run_token='token-1', + registrations={'agent-a': registration}, + ) + + loaded = load_academy_registrations(tmp_path, run_token='token-1') + + assert isinstance(loaded['agent-a'], registration_cls) + assert loaded['agent-a'].agent_id == registration.agent_id + + +def test_registration_payload_rejects_mixed_exchange_types() -> None: + with pytest.raises(ValueError, match='mixed exchange types'): + registration_payload( + run_token='token-1', + registrations={ + 'redis-agent': RedisAgentRegistration( + agent_id=AgentId.new('redis-agent'), + ), + 'local-agent': LocalAgentRegistration( + agent_id=AgentId.new('local-agent'), + ), + }, + ) + + +def test_registration_payload_rejects_empty_registrations() -> None: + with pytest.raises(ValueError, match='at least one registration'): + registration_payload(run_token='token-1', registrations={}) diff --git a/tests/test_academy_mcp_supervisor.py b/tests/test_academy_mcp_supervisor.py new file mode 100644 index 00000000..d42920a4 --- /dev/null +++ b/tests/test_academy_mcp_supervisor.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +# Skip when the optional 'academy' extra is absent; mcp_supervisor +# imports httpx (also in the extra) at module level. +pytest.importorskip("academy") + +from chemgraph.academy.core.campaign import MCPServerSpec +from chemgraph.academy.runtime.mcp_supervisor import MCPServerSupervisor + + +def _pythonpath(tmp_path: Path) -> str: + current = os.environ.get("PYTHONPATH", "") + parts = [str(tmp_path)] + if current: + parts.append(current) + return os.pathsep.join(parts) + + +def _write_tiny_server(tmp_path: Path) -> None: + (tmp_path / "tiny_mcp.py").write_text( + """ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("tiny") + +@mcp.tool(name="echo", description="Echo one string.") +def echo(text: str) -> dict: + return {"text": text} + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + run_mcp_server(mcp, default_port=0) +""", + encoding="utf-8", + ) + + +def _write_multi_tool_server(tmp_path: Path) -> None: + """A server that advertises three tools so allowed_tools can subset it.""" + (tmp_path / "multi_mcp.py").write_text( + """ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("multi") + +@mcp.tool(name="alpha", description="Tool alpha.") +def alpha(text: str) -> dict: + return {"who": "alpha", "text": text} + +@mcp.tool(name="beta", description="Tool beta.") +def beta(text: str) -> dict: + return {"who": "beta", "text": text} + +@mcp.tool(name="gamma", description="Tool gamma.") +def gamma(text: str) -> dict: + return {"who": "gamma", "text": text} + +if __name__ == "__main__": + from chemgraph.mcp.server_utils import run_mcp_server + + run_mcp_server(mcp, default_port=0) +""", + encoding="utf-8", + ) + + +@pytest.mark.asyncio +async def test_mcp_supervisor_starts_server_and_gets_tools(tmp_path) -> None: + _write_tiny_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="tiny", + command=f"{sys.executable} -m tiny_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + urls = await supervisor.start_all() + tools = await supervisor.get_tools(("tiny",)) + echo = next(tool for tool in tools if tool.name == "echo") + result = await echo.ainvoke({"text": "hello"}) + finally: + await supervisor.shutdown() + + assert sorted(urls) == ["tiny"] + assert "echo" in {tool.name for tool in tools} + assert result["status"] == "ok" + assert "hello" in repr(result) + + +@pytest.mark.asyncio +async def test_mcp_supervisor_shutdown_terminates_process(tmp_path) -> None: + _write_tiny_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="tiny", + command=f"{sys.executable} -m tiny_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + await supervisor.start_all() + proc = supervisor._processes["tiny"] + + await supervisor.shutdown() + + assert proc.poll() is not None + + +@pytest.mark.asyncio +async def test_mcp_supervisor_reports_server_exit_log_tail(tmp_path) -> None: + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="bad", + command=f"{sys.executable} -c \"print('boom'); raise SystemExit(1)\"", + ), + ], + run_dir=tmp_path / "run", + ) + + with pytest.raises(RuntimeError, match="boom"): + await supervisor.start_all() + + await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_mcp_supervisor_rejects_unknown_server_request(tmp_path) -> None: + _write_tiny_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="tiny", + command=f"{sys.executable} -m tiny_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + with pytest.raises(RuntimeError, match="available"): + await supervisor.get_tools(("missing",)) + finally: + await supervisor.shutdown() + + +@pytest.mark.asyncio +async def test_get_tools_returns_all_when_no_allowed_tools(tmp_path) -> None: + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + tools = await supervisor.get_tools(("multi",)) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha", "beta", "gamma"} + + +@pytest.mark.asyncio +async def test_get_tools_filters_by_allowed_tools(tmp_path) -> None: + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + tools = await supervisor.get_tools( + ("multi",), + allowed_tools=frozenset({"alpha", "gamma"}), + ) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha", "gamma"} + + +@pytest.mark.asyncio +async def test_get_tools_warns_on_whitelist_misses(tmp_path, caplog) -> None: + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + with caplog.at_level("WARNING"): + tools = await supervisor.get_tools( + ("multi",), + allowed_tools=frozenset({"alpha", "does_not_exist"}), + ) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha"} + assert any( + "does_not_exist" in record.message for record in caplog.records + ) + + +@pytest.mark.asyncio +async def test_get_tools_empty_allowed_tools_returns_all(tmp_path) -> None: + """An empty whitelist is treated as None (no filter).""" + _write_multi_tool_server(tmp_path) + supervisor = MCPServerSupervisor( + [ + MCPServerSpec( + name="multi", + command=f"{sys.executable} -m multi_mcp", + env={"PYTHONPATH": _pythonpath(tmp_path)}, + ), + ], + run_dir=tmp_path / "run", + ) + try: + await supervisor.start_all() + tools = await supervisor.get_tools( + ("multi",), + allowed_tools=frozenset(), + ) + finally: + await supervisor.shutdown() + + assert {tool.name for tool in tools} == {"alpha", "beta", "gamma"} diff --git a/tests/test_academy_payloads.py b/tests/test_academy_payloads.py new file mode 100644 index 00000000..b8c1d04c --- /dev/null +++ b/tests/test_academy_payloads.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import pytest + +# Skip when the optional 'academy' extra is absent. The event_log +# module itself is pure stdlib, but the import guard is applied +# uniformly across the academy test suite. +pytest.importorskip("academy") + +from chemgraph.academy.observability.event_log import EventLog, read_events + + +def test_event_log_preserves_payload_shape(tmp_path) -> None: + log = EventLog(tmp_path / "events.jsonl") + + log.emit( + "message_sent", + run_id="run-1", + agent_id="agent-a", + role="Worker", + payload={ + "message_id": "msg-1", + "recipient": "agent-b", + "tldr": "short", + }, + ) + + event = read_events(tmp_path / "events.jsonl")[0] + assert event.event == "message_sent" + assert event.payload == { + "message_id": "msg-1", + "recipient": "agent-b", + "tldr": "short", + } diff --git a/tests/test_academy_reasoning_phase2.py b/tests/test_academy_reasoning_phase2.py new file mode 100644 index 00000000..1798679c --- /dev/null +++ b/tests/test_academy_reasoning_phase2.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import json +from pathlib import Path +from typing import Any + +import pytest + +# Skip when the optional 'academy' extra is absent. +pytest.importorskip("academy") + +from chemgraph.academy.core import agent as agent_module +from chemgraph.academy.core import turn as turn_module +from chemgraph.academy.core.agent import ChemGraphLogicalAgent +from chemgraph.academy.core.campaign import ChemGraphAgentSpec, ChemGraphCampaign +from chemgraph.academy.core.campaign import ResourceSpec, resolve_campaign_resources +from chemgraph.academy.core.prompt import PromptProfile, PromptStateLimits +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.turn import ReasoningTurnResult, build_peer_status +from chemgraph.agent.turn import TurnResult +from chemgraph.models.settings import LLMSettings + + +def _agent_spec() -> ChemGraphAgentSpec: + return ChemGraphAgentSpec( + name="agent-a", + role="Worker", + mission="Use explicit tools only.", + allowed_peers=(), + mcp_servers=(), + ) + + +def _agent_spec_with_peer() -> ChemGraphAgentSpec: + return dataclasses.replace(_agent_spec(), allowed_peers=("agent-b",)) + + +def _campaign(spec: ChemGraphAgentSpec) -> ChemGraphCampaign: + return ChemGraphCampaign( + run_id="campaign-1", + user_task="Rank staged candidates.", + initial_agent=spec.name, + prompt_profile=Path("prompt_profiles/default.json"), + agents=(spec,), + ) + + +def _prompt_profile() -> PromptProfile: + return PromptProfile( + prompt_version="test", + prompt_style="json_state", + system_prompt="system prompt", + protocol_prompt="call finish_turn when idle", + langchain_recursion_limit=8, + state_limits=PromptStateLimits( + received_messages_last_n=1, + tool_results_last_n=1, + actions_last_n=2, + ), + ) + + +def _lm_settings() -> LLMSettings: + return LLMSettings( + base_url="http://127.0.0.1:18085/argoapi/v1", + model="GPT-5.4", + provider="openai_compatible_tools", + api_key="dummy", + user="test-user", + timeout_s=60, + temperature=0, + max_tokens=1024, + max_retries=1, + retry_delay_s=0, + ) + + +class _SlowPeerHandle: + def __init__(self) -> None: + self.delivered = asyncio.Event() + self.calls: list[tuple[str, dict]] = [] + + async def action(self, name: str, message: dict) -> None: + await asyncio.sleep(0.1) + self.calls.append((name, message)) + self.delivered.set() + + +@pytest.mark.asyncio +async def test_reasoning_adapter_finish_turn_traces(tmp_path) -> None: + traces: list[tuple[str, dict]] = [] + tools = await build_chemgraph_reasoning_tools( + spec=_agent_spec(), + run_dir=tmp_path, + peer_names=(), + peer_handles={}, + outbox=[], + tool_results=[], + get_round_index=lambda: 1, + set_final_result=lambda result: None, + trace=lambda event, payload: traces.append((event, payload)), + ) + + result = await next(t for t in tools if t.name == "finish_turn").ainvoke( + {"reason": "nothing useful now"}, + ) + + assert result == {"status": "finished", "reason": "nothing useful now"} + assert traces == [ + ( + "turn_finished_without_external_action", + {"reason": "nothing useful now"}, + ) + ] + + +@pytest.mark.asyncio +async def test_send_message_does_not_block_on_busy_peer(tmp_path) -> None: + peer = _SlowPeerHandle() + traces: list[tuple[str, dict]] = [] + outbox: list[dict] = [] + tools = await build_chemgraph_reasoning_tools( + spec=_agent_spec_with_peer(), + run_dir=tmp_path, + peer_names=("agent-b",), + peer_handles={"agent-b": peer}, + outbox=outbox, + tool_results=[], + get_round_index=lambda: 1, + set_final_result=lambda result: None, + trace=lambda event, payload: traces.append((event, payload)), + ) + + result = await asyncio.wait_for( + next(t for t in tools if t.name == "send_message").ainvoke( + { + "recipient": "agent-b", + "tldr": "short summary", + "content": "full message", + "artifact_refs": [], + "tool_result_ids": [], + "reply_requested": False, + "reason": "peer needs this evidence", + "confidence": 0.8, + }, + ), + timeout=0.05, + ) + + assert result["delivery"] == "queued" + assert len(outbox) == 1 + assert [name for name, _ in traces] == ["message_sent"] + await asyncio.wait_for(peer.delivered.wait(), timeout=1) + assert peer.calls[0][0] == "receive_message" + + +@pytest.mark.asyncio +async def test_run_academy_turn_maps_action_and_science_tools(monkeypatch, tmp_path) -> None: + async def fake_run_turn(**kwargs: Any) -> TurnResult: + payload = json.loads(kwargs["query"]) + assert payload["received_messages"] == [{"message_id": "new"}] + assert payload["local_chemgraph_tool_results"] == [{"tool_result_id": "new"}] + kwargs["on_event"]("workflow_started", {"thread_id": kwargs["thread_id"]}) + return TurnResult( + final_text="done", + state={"messages": []}, + executed_tool_names=("science_tool", "finish_turn"), + terminal_tool="finish_turn", + thread_id=kwargs["thread_id"], + duration_s=0.1, + ) + + monkeypatch.setattr(turn_module, "run_turn", fake_run_turn) + traces: list[tuple[str, dict]] = [] + result = await turn_module.run_academy_turn( + campaign=_campaign(_agent_spec()), + spec=_agent_spec(), + llm_settings=_lm_settings(), + prompt_profile=_prompt_profile(), + run_dir=tmp_path, + max_decisions=5, + tools=[], + received_message_history=[{"message_id": "old"}, {"message_id": "new"}], + outbox=[], + tool_results=[{"tool_result_id": "old"}, {"tool_result_id": "new"}], + get_final_result=lambda: {"summary": "current"}, + get_round_index=lambda: 2, + trace=lambda event, payload: traces.append((event, payload)), + ) + + assert result.action_tools_called == ("finish_turn",) + assert result.science_tools_called == ("science_tool",) + assert result.requested_finish is True + assert result.requested_self_wake is True + assert [event for event, _ in traces] == [ + "chemgraph_reasoning_turn_started", + "workflow_started", + "chemgraph_reasoning_turn_finished", + ] + + +@pytest.mark.asyncio +async def test_logical_agent_reasoning_round_calls_turn_runner(monkeypatch, tmp_path) -> None: + spec = _agent_spec() + agent = ChemGraphLogicalAgent( + spec, + campaign=_campaign(spec), + llm_settings=_lm_settings(), + prompt_profile=_prompt_profile(), + run_dir=tmp_path, + max_decisions=5, + ) + agent.round_index = 1 + + async def fake_tools(**kwargs: Any) -> list: + assert kwargs["spec"] is spec + return [] + + async def fake_turn(**kwargs: Any) -> ReasoningTurnResult: + assert kwargs["spec"] is spec + return ReasoningTurnResult( + final_text="done", + executed_tool_names=("science_tool", "finish_turn"), + action_tools_called=("finish_turn",), + science_tools_called=("science_tool",), + requested_finish=True, + requested_self_wake=True, + thread_id="agent-a-round-1", + ) + + monkeypatch.setattr(agent_module, "build_chemgraph_reasoning_tools", fake_tools) + monkeypatch.setattr(agent_module, "run_academy_turn", fake_turn) + + assert await agent._reasoning_round() is True + events = [ + json.loads(line)["event"] + for line in tmp_path.joinpath("events.jsonl").read_text().splitlines() + ] + assert events == [ + "round_started", + "agent_decision", + "round_finished", + "self_wake_scheduled", + ] + + +def test_build_peer_status_uses_agent_status_file(tmp_path) -> None: + state_dir = tmp_path / "agent_status" + state_dir.mkdir() + (state_dir / "agent-b.json").write_text( + json.dumps( + { + "round": 3, + "finished": False, + "last_error": None, + "status_updated_at": 100.0, + }, + ) + + "\n", + encoding="utf-8", + ) + + status = build_peer_status(run_dir=tmp_path, peer_names=("agent-b",)) + + assert status["agent-b"]["state"] == "idle" + assert status["agent-b"]["round"] == 3 + assert status["agent-b"]["last_error"] is None + + +def test_campaign_resources_resolve_to_shared_run_artifacts(tmp_path) -> None: + spec = dataclasses.replace( + _agent_spec(), + resources=("candidate_dataset", "structure_output_directory"), + ) + campaign = ChemGraphCampaign( + run_id="campaign-1", + user_task="Rank staged candidates.", + initial_agent=spec.name, + prompt_profile=Path("prompt_profiles/default.json"), + agents=(spec,), + resources={ + "candidate_dataset": ResourceSpec( + kind="json", + path="/source/data/candidates.json", + scope="absolute", + expose_content=True, + ), + "structure_output_directory": ResourceSpec( + kind="directory", + path="academy_mace_structures", + scope="shared_run", + ), + "mace_output_result_file": ResourceSpec( + kind="file", + path="academy_mace_outputs/mace_results.json", + scope="shared_run", + ), + }, + ) + + resolved = resolve_campaign_resources(campaign, tmp_path / "run-1") + + assert resolved.resources["candidate_dataset"].path == "/source/data/candidates.json" + assert resolved.resources["structure_output_directory"].path == str( + tmp_path / "run-1" / "shared" / "academy_mace_structures", + ) + assert resolved.resources["mace_output_result_file"].path == str( + tmp_path / "run-1" / "shared" / "academy_mace_outputs" / "mace_results.json", + ) + + # The directory resource itself is materialised on disk so tools that + # expect to write into it do not hit FileNotFoundError on first use. + assert ( + tmp_path / "run-1" / "shared" / "academy_mace_structures" + ).is_dir() + # File resources get their parent directory materialised (the file + # itself is the agent's responsibility to write). + assert ( + tmp_path / "run-1" / "shared" / "academy_mace_outputs" + ).is_dir() + assert not ( + tmp_path / "run-1" / "shared" / "academy_mace_outputs" / "mace_results.json" + ).exists() + + +def test_resolve_campaign_resources_skips_non_shared_run_paths(tmp_path) -> None: + """Only shared_run resources get on-disk materialisation.""" + spec = dataclasses.replace(_agent_spec(), resources=("local_dataset",)) + campaign = ChemGraphCampaign( + run_id="campaign-2", + user_task="Static dataset.", + initial_agent=spec.name, + prompt_profile=Path("prompt_profiles/default.json"), + agents=(spec,), + resources={ + "local_dataset": ResourceSpec( + kind="json", + path="/should/not/exist/data.json", + scope="absolute", + ), + }, + ) + + resolved = resolve_campaign_resources(campaign, tmp_path / "run-1") + + # The absolute path is preserved verbatim and no directory is created. + assert resolved.resources["local_dataset"].path == "/should/not/exist/data.json" + assert not Path("/should/not/exist").exists() diff --git a/tests/test_agent_session.py b/tests/test_agent_session.py index f646c33d..5db9887e 100644 --- a/tests/test_agent_session.py +++ b/tests/test_agent_session.py @@ -14,9 +14,11 @@ import os import pytest +from types import SimpleNamespace from unittest.mock import Mock, patch -from chemgraph.agent.llm_agent import ChemGraph, serialize_state +from chemgraph.agent.llm_agent import ChemGraph +from chemgraph.agent.turn import TurnResult, serialize_state from chemgraph.memory.store import SessionStore @@ -44,16 +46,47 @@ def tmp_db(tmp_path): return str(tmp_path / "test_sessions.db") +class _GraphStreamCompatibleWorkflow: + def __init__(self): + self.side_effect = self.default_graph_stream + self.last_state = {"messages": []} + + async def default_graph_stream(self, **kwargs): + ai_msg = Mock() + ai_msg.type = "ai" + ai_msg.content = "Test response" + return TurnResult( + final_text="Test response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) + + async def astream(self, inputs, *, stream_mode, config): + result = await self.side_effect( + query=inputs.get("messages"), + thread_id=str(config["configurable"]["thread_id"]), + ) + self.last_state = result.state + yield self.last_state + + def get_state(self, config): + return SimpleNamespace(values=self.last_state) + + @pytest.fixture def mock_agent_patches(): - """Patch LLM loading and graph construction for fast agent creation.""" + """Patch LLM loading and graph streaming for fast agent creation.""" with ( patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load, - patch("chemgraph.agent.llm_agent.construct_single_agent_graph") as mock_graph, + patch("chemgraph.agent.llm_agent.construct_single_agent_graph") as mock_constructor, ): mock_load.return_value = Mock() - mock_graph.return_value = Mock() - yield mock_load, mock_graph + workflow = _GraphStreamCompatibleWorkflow() + mock_constructor.return_value = workflow + yield mock_load, workflow def _make_agent(clean_env, mock_agent_patches, tmp_db, **kwargs): @@ -62,6 +95,7 @@ def _make_agent(clean_env, mock_agent_patches, tmp_db, **kwargs): "model_name": "gpt-4o-mini", "enable_memory": True, "memory_db_path": tmp_db, + "log_dir": os.path.join(os.path.dirname(tmp_db), "logs"), } defaults.update(kwargs) agent = ChemGraph(**defaults) @@ -128,7 +162,7 @@ def test_uuid_set_when_log_dir_preset(self, mock_agent_patches, tmp_db): """uuid must be set even when CHEMGRAPH_LOG_DIR is already in env.""" os.environ["CHEMGRAPH_LOG_DIR"] = "/tmp/preset_log_dir" try: - agent = _make_agent(None, mock_agent_patches, tmp_db) + agent = _make_agent(None, mock_agent_patches, tmp_db, log_dir=None) assert agent.uuid is not None assert len(agent.uuid) == 8 assert agent.log_dir == "/tmp/preset_log_dir" @@ -350,8 +384,7 @@ def test_filename_includes_uuid( ): agent = _make_agent(clean_env, mock_agent_patches, tmp_db) - # Mock get_state to return something serializable - agent.workflow.get_state = Mock(return_value=Mock(values={"messages": []})) + agent._last_run_state = {"messages": []} log_dir = str(tmp_path / "test_logs") os.makedirs(log_dir, exist_ok=True) @@ -382,7 +415,7 @@ def test_no_overwrite_same_second( if "CHEMGRAPH_LOG_DIR" in os.environ: del os.environ["CHEMGRAPH_LOG_DIR"] a = _make_agent(clean_env, mock_agent_patches, tmp_db) - a.workflow.get_state = Mock(return_value=Mock(values={"messages": []})) + a._last_run_state = {"messages": []} a.log_dir = log_dir agents.append(a) @@ -402,23 +435,8 @@ def test_no_overwrite_same_second( class TestResumeFrom: def _make_streamable_agent(self, clean_env, mock_agent_patches, tmp_db): - """Create an agent with a mock async workflow.""" - agent = _make_agent(clean_env, mock_agent_patches, tmp_db) - - # Set up a mock astream that yields one state - ai_msg = Mock() - ai_msg.type = "ai" - ai_msg.content = "Test response" - ai_msg.pretty_print = Mock() - - final_state = {"messages": [ai_msg]} - - async def mock_astream(inputs, stream_mode, config): - yield final_state - - agent.workflow.astream = mock_astream - agent.workflow.get_state = Mock(return_value=Mock(values=final_state)) - return agent + """Create an agent whose run path is mocked through graph stream.""" + return _make_agent(clean_env, mock_agent_patches, tmp_db) @pytest.mark.asyncio async def test_resume_prepends_context(self, clean_env, mock_agent_patches, tmp_db): @@ -435,23 +453,24 @@ async def test_resume_prepends_context(self, clean_env, mock_agent_patches, tmp_ # Create second agent sharing the same DB agent2 = self._make_streamable_agent(clean_env, mock_agent_patches, tmp_db) - # Track what inputs are passed to astream + # Track what query is passed to graph stream. captured_inputs = [] - async def tracking_astream(inputs, stream_mode, config): - captured_inputs.append(inputs) + async def tracking_graph_stream(**kwargs): + captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" ai_msg.content = "Follow-up response" - ai_msg.pretty_print = Mock() - yield {"messages": [ai_msg]} - - agent2.workflow.astream = tracking_astream - agent2.workflow.get_state = Mock( - return_value=Mock( - values={"messages": [Mock(type="ai", content="Follow-up")]} + return TurnResult( + final_text="Follow-up response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, ) - ) + + mock_agent_patches[1].side_effect = tracking_graph_stream await agent2.run("Continue the analysis", resume_from=session_id) @@ -469,18 +488,21 @@ async def test_resume_from_nonexistent_session( captured_inputs = [] - async def tracking_astream(inputs, stream_mode, config): - captured_inputs.append(inputs) + async def tracking_graph_stream(**kwargs): + captured_inputs.append({"messages": kwargs["query"]}) ai_msg = Mock() ai_msg.type = "ai" ai_msg.content = "Response" - ai_msg.pretty_print = Mock() - yield {"messages": [ai_msg]} + return TurnResult( + final_text="Response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) - agent.workflow.astream = tracking_astream - agent.workflow.get_state = Mock( - return_value=Mock(values={"messages": [Mock(type="ai", content="resp")]}) - ) + mock_agent_patches[1].side_effect = tracking_graph_stream await agent.run("Hello", resume_from="nonexistent_id") @@ -495,21 +517,23 @@ async def test_resume_from_ignored_when_memory_disabled( ): agent = _make_agent(clean_env, mock_agent_patches, tmp_db, enable_memory=False) - ai_msg = Mock() - ai_msg.type = "ai" - ai_msg.content = "Response" - ai_msg.pretty_print = Mock() - captured_inputs = [] - async def tracking_astream(inputs, stream_mode, config): - captured_inputs.append(inputs) - yield {"messages": [ai_msg]} + async def tracking_graph_stream(**kwargs): + captured_inputs.append({"messages": kwargs["query"]}) + ai_msg = Mock() + ai_msg.type = "ai" + ai_msg.content = "Response" + return TurnResult( + final_text="Response", + state={"messages": [ai_msg]}, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) - agent.workflow.astream = tracking_astream - agent.workflow.get_state = Mock( - return_value=Mock(values={"messages": [ai_msg]}) - ) + mock_agent_patches[1].side_effect = tracking_graph_stream await agent.run("Hello", resume_from="some_id") @@ -528,7 +552,6 @@ async def test_full_lifecycle(self, clean_env, mock_agent_patches, tmp_db): """init -> run -> messages saved -> load_previous_context -> resume""" agent = _make_agent(clean_env, mock_agent_patches, tmp_db) - # Set up mock workflow human_msg = Mock() human_msg.type = "human" human_msg.content = "Calculate energy of H2" @@ -540,11 +563,17 @@ async def test_full_lifecycle(self, clean_env, mock_agent_patches, tmp_db): final_state = {"messages": [human_msg, ai_msg]} - async def mock_astream(inputs, stream_mode, config): - yield final_state + async def mock_graph_stream(**kwargs): + return TurnResult( + final_text=ai_msg.content, + state=final_state, + executed_tool_names=(), + terminal_tool=None, + thread_id=kwargs["thread_id"], + duration_s=0.0, + ) - agent.workflow.astream = mock_astream - agent.workflow.get_state = Mock(return_value=Mock(values=final_state)) + mock_agent_patches[1].side_effect = mock_graph_stream # Step 1: Run await agent.run("Calculate energy of H2") @@ -568,8 +597,7 @@ async def mock_astream(inputs, stream_mode, config): del os.environ["CHEMGRAPH_LOG_DIR"] agent2 = _make_agent(clean_env, mock_agent_patches, tmp_db) - agent2.workflow.astream = mock_astream - agent2.workflow.get_state = Mock(return_value=Mock(values=final_state)) + mock_agent_patches[1].side_effect = mock_graph_stream await agent2.run("Now optimize H2", resume_from=agent.uuid) diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 00000000..05f3d355 --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,1231 @@ +"""Tests for the chemgraph.execution abstraction layer. + +Tests cover: +- TaskSpec validation +- LocalBackend: python and shell tasks +- GlobusComputeBackend: python and shell tasks (mocked SDK) +- Backend factory (get_backend) +- Shared utilities: resolve_structure_files, gather_futures, write_results_jsonl +""" + +import asyncio +import json +import os +import sys +import tempfile +from concurrent.futures import Future +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from chemgraph.execution.base import TaskSpec +from chemgraph.execution.local_backend import LocalBackend +from chemgraph.execution.utils import ( + gather_futures, + make_per_structure_output, + resolve_structure_files, + write_results_jsonl, +) + +# ── TaskSpec tests ────────────────────────────────────────────────────── + + +class TestTaskSpec: + def test_python_task_minimal(self): + spec = TaskSpec(task_id="t1", task_type="python", callable=abs, args=(42,)) + assert spec.task_id == "t1" + assert spec.task_type == "python" + assert spec.callable is abs + assert spec.args == (42,) + + def test_shell_task_minimal(self): + spec = TaskSpec(task_id="t2", task_type="shell", command="echo hello") + assert spec.task_type == "shell" + assert spec.command == "echo hello" + + def test_defaults(self): + spec = TaskSpec(task_id="t3") + assert spec.task_type == "python" + assert spec.callable is None + assert spec.args == () + assert spec.kwargs == {} + assert spec.num_nodes == 1 + assert spec.processes_per_node == 1 + assert spec.gpus_per_task == 0 + + +# ── LocalBackend tests ────────────────────────────────────────────────── + + +def _square(x): + return x * x + + +def _add(a, b): + return a + b + + +def _failing_fn(): + raise ValueError("intentional test error") + + +class TestLocalBackend: + def test_python_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=2) + try: + task = TaskSpec( + task_id="sq", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + assert fut.result(timeout=10) == 49 + finally: + backend.shutdown() + + def test_python_task_kwargs(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=2) + try: + task = TaskSpec( + task_id="add", + task_type="python", + callable=_add, + kwargs={"a": 3, "b": 5}, + ) + assert backend.submit(task).result(timeout=10) == 8 + finally: + backend.shutdown() + + def test_shell_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False + ) as f: + stdout_path = f.name + + task = TaskSpec( + task_id="echo", + task_type="shell", + command="echo hello_world", + stdout=stdout_path, + ) + fut = backend.submit(task) + fut.result(timeout=10) + + with open(stdout_path) as f: + assert "hello_world" in f.read() + finally: + backend.shutdown() + os.unlink(stdout_path) + + def test_submit_batch(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=4) + try: + tasks = [ + TaskSpec( + task_id=f"sq_{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(5) + ] + futures = backend.submit_batch(tasks) + assert len(futures) == 5 + results = [f.result(timeout=10) for f in futures] + assert results == [0, 1, 4, 9, 16] + finally: + backend.shutdown() + + def test_failing_task(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec( + task_id="fail", + task_type="python", + callable=_failing_fn, + ) + fut = backend.submit(task) + with pytest.raises(ValueError, match="intentional test error"): + fut.result(timeout=10) + finally: + backend.shutdown() + + def test_context_manager(self): + with LocalBackend() as backend: + backend.initialize(system="local", max_workers=1) + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + assert backend.submit(task).result(timeout=10) == 9 + + def test_not_initialized_raises(self): + backend = LocalBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_python_task_missing_callable(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + finally: + backend.shutdown() + + def test_shell_task_missing_command(self): + backend = LocalBackend() + backend.initialize(system="local", max_workers=1) + try: + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + finally: + backend.shutdown() + + +# ── EnsembleLauncherBackend tests ────────────────────────────────────────── + + +class TestELBackend: + @classmethod + def setup_class(cls): + # EnsembleLauncher is an optional, HPC-only dependency (not on PyPI + # for Python 3.12). Skip the whole class where it isn't installed. + pytest.importorskip("ensemble_launcher") + project_root = str(Path(__file__).resolve().parent.parent) + existing = os.environ.get("PYTHONPATH", "") + os.environ["PYTHONPATH"] = ( + f"{project_root}:{existing}" if existing else project_root + ) + + def test_python_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="sq", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + assert fut.result(timeout=10) == 49 + finally: + backend.shutdown() + + def test_python_task_kwargs(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="add", + task_type="python", + callable=_add, + kwargs={"a": 3, "b": 5}, + ) + assert backend.submit(task).result(timeout=10) == 8 + finally: + backend.shutdown() + + def test_shell_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="echo", + task_type="shell", + command="echo hello_world", + ) + fut = backend.submit(task) + result = fut.result(timeout=10) + assert result is not None + finally: + backend.shutdown() + + def test_submit_batch(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + tasks = [ + TaskSpec( + task_id=f"sq_{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(5) + ] + futures = backend.submit_batch(tasks) + assert len(futures) == 5 + results = [f.result(timeout=10) for f in futures] + assert results == [0, 1, 4, 9, 16] + finally: + backend.shutdown() + + def test_failing_task(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec( + task_id="fail", + task_type="python", + callable=_failing_fn, + ) + fut = backend.submit(task) + with pytest.raises(Exception, match="intentional test error"): + fut.result(timeout=10) + finally: + backend.shutdown() + + def test_context_manager(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + with EnsembleLauncherBackend() as backend: + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + assert backend.submit(task).result(timeout=10) == 9 + + def test_not_initialized_raises(self): + from chemgraph.execution.ensemble_launcher_backend import ( + EnsembleLauncherBackend, + ) + + backend = EnsembleLauncherBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_python_task_missing_callable(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + finally: + backend.shutdown() + + def test_shell_task_missing_command(self): + from chemgraph.execution.ensemble_launcher_backend import ( + SYSTEM_CONFIG_REGISTRY, + EnsembleLauncherBackend, + get_launcher_config, + ) + + backend = EnsembleLauncherBackend() + backend.initialize( + system="local", + system_config=SYSTEM_CONFIG_REGISTRY["local"], + launcher_config=get_launcher_config(), + ) + try: + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + finally: + backend.shutdown() + + +# ── GlobusComputeBackend tests ────────────────────────────────────────── + + +def _make_mock_gc_modules(): + """Create mock globus_compute_sdk module and its classes.""" + mock_sdk = MagicMock() + + # Mock Executor: instances track submit calls and return Futures + mock_executor_instance = MagicMock() + mock_future = Future() + mock_future.set_result(42) + mock_executor_instance.submit.return_value = mock_future + mock_sdk.Executor.return_value = mock_executor_instance + + # Mock ShellFunction + mock_shell_fn_instance = MagicMock() + mock_sdk.ShellFunction.return_value = mock_shell_fn_instance + + return mock_sdk, mock_executor_instance + + +class TestGlobusComputeBackend: + def _patch_and_import(self, mock_sdk): + """Patch globus_compute_sdk into sys.modules and import the backend.""" + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + # Force re-import to pick up the mock + import importlib + + import chemgraph.execution.globus_compute_backend as gc_mod + + importlib.reload(gc_mod) + return gc_mod.GlobusComputeBackend + + def test_initialize_success(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(system="polaris", endpoint_id="test-uuid-1234") + + assert backend._initialized is True + mock_sdk.Executor.assert_called_once_with(endpoint_id="test-uuid-1234") + + def test_initialize_with_amqp_port(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize( + system="polaris", + endpoint_id="test-uuid", + amqp_port=443, + ) + + mock_sdk.Executor.assert_called_once_with( + endpoint_id="test-uuid", amqp_port=443 + ) + + def test_initialize_missing_endpoint_id(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ValueError, match="endpoint_id"): + backend.initialize(system="polaris") + + def test_initialize_empty_endpoint_id(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ValueError, match="endpoint_id"): + backend.initialize(system="polaris", endpoint_id="") + + def test_initialize_import_error(self): + """Verify helpful error when globus-compute-sdk is not installed.""" + with patch.dict(sys.modules, {"globus_compute_sdk": None}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + with pytest.raises(ImportError, match="globus-compute-sdk"): + backend.initialize(endpoint_id="test-uuid") + + def test_submit_python_task(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="py1", + task_type="python", + callable=_square, + args=(7,), + ) + fut = backend.submit(task) + + assert isinstance(fut, Future) + mock_executor.submit.assert_called_once_with(_square, 7) + + def test_submit_python_task_with_kwargs(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="py2", + task_type="python", + callable=_add, + args=(3,), + kwargs={"b": 5}, + ) + backend.submit(task) + + mock_executor.submit.assert_called_once_with(_add, 3, b=5) + + def test_submit_shell_task(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="sh1", + task_type="shell", + command="echo hello", + ) + backend.submit(task) + + # ShellFunction should be constructed with the command + mock_sdk.ShellFunction.assert_called_once_with("echo hello") + # And then submitted via the executor + shell_fn_instance = mock_sdk.ShellFunction.return_value + mock_executor.submit.assert_called_once_with(shell_fn_instance) + + def test_submit_not_initialized(self): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + task = TaskSpec(task_id="x", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_submit_python_missing_callable(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="no_fn", task_type="python") + with pytest.raises(ValueError, match="requires a callable"): + backend.submit(task) + + def test_submit_shell_missing_command(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="no_cmd", task_type="shell") + with pytest.raises(ValueError, match="requires a command"): + backend.submit(task) + + def test_shutdown(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + assert backend._initialized is True + + backend.shutdown() + + assert backend._initialized is False + assert backend._executor is None + mock_executor.shutdown.assert_called_once() + + def test_shutdown_idempotent(self): + """Calling shutdown() when not initialized should not raise.""" + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.shutdown() # should be a no-op + assert backend._initialized is False + + def test_context_manager(self): + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + with GlobusComputeBackend() as backend: + backend.initialize(endpoint_id="test-uuid") + task = TaskSpec( + task_id="ctx", + task_type="python", + callable=_square, + args=(3,), + ) + backend.submit(task) + + # After exiting context, shutdown should have been called + mock_executor.shutdown.assert_called_once() + + +class TestGetBackendGlobusCompute: + def test_factory_creates_globus_compute_backend(self): + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend( + backend_name="globus_compute", + endpoint_id="factory-test-uuid", + ) + try: + assert isinstance(backend, GlobusComputeBackend) + assert backend._initialized is True + finally: + backend.shutdown() + + def test_factory_via_env_var(self): + mock_sdk, _ = _make_mock_gc_modules() + with ( + patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}), + patch.dict( + os.environ, + {"CHEMGRAPH_EXECUTION_BACKEND": "globus_compute"}, + ), + ): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend(endpoint_id="env-test-uuid") + try: + assert isinstance(backend, GlobusComputeBackend) + finally: + backend.shutdown() + + +# ── Factory tests ─────────────────────────────────────────────────────── + + +class TestGetBackend: + def test_local_backend_via_env(self): + with patch.dict(os.environ, {"CHEMGRAPH_EXECUTION_BACKEND": "local"}): + from chemgraph.execution.config import get_backend + + backend = get_backend() + try: + assert isinstance(backend, LocalBackend) + finally: + backend.shutdown() + + def test_explicit_backend_name(self): + from chemgraph.execution.config import get_backend + + backend = get_backend(backend_name="local", max_workers=2) + try: + assert isinstance(backend, LocalBackend) + finally: + backend.shutdown() + + def test_unsupported_backend_raises(self): + from chemgraph.execution.config import get_backend + + with pytest.raises(ValueError, match="Unknown execution backend"): + get_backend(backend_name="nonexistent") + + +# ── Utility tests ─────────────────────────────────────────────────────── + + +class TestResolveStructureFiles: + def test_from_directory(self, tmp_path): + for name in ["a.cif", "b.cif", "c.txt"]: + (tmp_path / name).write_text("dummy") + + files, out_dir = resolve_structure_files(str(tmp_path), extensions={".cif"}) + assert len(files) == 2 + assert out_dir == tmp_path + assert all(f.suffix == ".cif" for f in files) + + def test_from_file_list(self, tmp_path): + paths = [] + for name in ["x.xyz", "y.xyz"]: + p = tmp_path / name + p.write_text("dummy") + paths.append(str(p)) + + files, out_dir = resolve_structure_files(paths) + assert len(files) == 2 + assert out_dir == tmp_path + + def test_missing_file_raises(self, tmp_path): + with pytest.raises(ValueError, match="missing"): + resolve_structure_files([str(tmp_path / "noexist.cif")]) + + def test_empty_dir_raises(self, tmp_path): + with pytest.raises(ValueError, match="No structure files"): + resolve_structure_files(str(tmp_path), extensions={".cif"}) + + def test_invalid_dir_raises(self): + with pytest.raises(ValueError, match="not a valid directory"): + resolve_structure_files("/nonexistent/path") + + +class TestMakePerStructureOutput: + def test_basic(self): + result = make_per_structure_output( + Path("/data/MOF-5.cif"), + Path("/results/output.json"), + ) + assert result == Path("/results/MOF-5_output.json") + + def test_no_suffix(self): + result = make_per_structure_output( + Path("/data/struct.xyz"), + Path("/results/result"), + ) + assert result == Path("/results/struct_result.json") + + +class TestGatherFutures: + @pytest.mark.asyncio + async def test_successful_futures(self): + loop = asyncio.get_event_loop() + + def _make_resolved(val): + f = Future() + f.set_result(val) + return f + + pending = [ + ({"name": "a"}, _make_resolved({"status": "success", "energy": -1.0})), + ({"name": "b"}, _make_resolved({"status": "success", "energy": -2.0})), + ] + results = await gather_futures(pending) + assert len(results) == 2 + assert results[0]["name"] == "a" + assert results[0]["energy"] == -1.0 + + @pytest.mark.asyncio + async def test_failed_future(self): + f = Future() + f.set_exception(RuntimeError("boom")) + + pending = [({"name": "fail"}, f)] + results = await gather_futures(pending) + assert results[0]["status"] == "failure" + assert results[0]["error_type"] == "RuntimeError" + assert "boom" in results[0]["message"] + + @pytest.mark.asyncio + async def test_with_post_fn(self): + f = Future() + f.set_result(42) + + def post(meta, result): + return {**meta, "doubled": result * 2, "status": "success"} + + results = await gather_futures([({"id": "x"}, f)], post_fn=post) + assert results[0]["doubled"] == 84 + + +class TestWriteResultsJsonl: + def test_write_and_count(self, tmp_path): + results = [ + {"status": "success", "value": 1}, + {"status": "failure", "error": "bad"}, + {"status": "success", "value": 2}, + ] + path = tmp_path / "results.jsonl" + success, total = write_results_jsonl(results, path) + assert success == 2 + assert total == 3 + + lines = path.read_text().strip().split("\n") + assert len(lines) == 3 + assert json.loads(lines[0])["value"] == 1 + + def test_append_mode(self, tmp_path): + path = tmp_path / "results.jsonl" + write_results_jsonl([{"status": "success"}], path) + write_results_jsonl([{"status": "success"}], path, append=True) + + lines = path.read_text().strip().split("\n") + assert len(lines) == 2 + + +# ── Layer 2: GlobusComputeBackend unit-test gap coverage ──────────────── + + +class TestGlobusComputeBackendGaps: + """Additional mocked tests covering gaps in the original test suite.""" + + def test_submit_unsupported_task_type(self): + """The else branch in submit() should raise for unknown task_type.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="bad_type", + task_type="python", + callable=_square, + args=(1,), + ) + # Bypass Pydantic validation to force an invalid task_type + object.__setattr__(task, "task_type", "mpi") + + with pytest.raises(ValueError, match="unsupported task_type"): + backend.submit(task) + + def test_submit_batch_delegates(self): + """submit_batch (inherited from base) should call submit() N times.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + tasks = [ + TaskSpec( + task_id=f"t{i}", + task_type="python", + callable=_square, + args=(i,), + ) + for i in range(3) + ] + futures = backend.submit_batch(tasks) + + assert len(futures) == 3 + assert mock_executor.submit.call_count == 3 + + def test_amqp_port_string_coercion(self): + """amqp_port from config.toml arrives as a string; must be coerced to int.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid", amqp_port="443") + + mock_sdk.Executor.assert_called_once_with( + endpoint_id="test-uuid", amqp_port=443 + ) + + def test_shutdown_executor_raises(self): + """If executor.shutdown() raises, the error is swallowed and state resets.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + mock_executor.shutdown.side_effect = RuntimeError("connection lost") + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + # Should NOT raise + backend.shutdown() + + assert backend._initialized is False + assert backend._executor is None + + +class TestGetBackendGlobusComputeGaps: + """Additional factory tests for config merging and TOML-driven creation.""" + + def test_factory_kwargs_override_config(self, tmp_path): + """Explicit kwargs should override values from config.toml.""" + config_file = tmp_path / "config.toml" + config_file.write_text( + "[execution]\n" + 'backend = "globus_compute"\n\n' + "[execution.globus_compute]\n" + 'endpoint_id = "config-uuid"\n' + "amqp_port = 5671\n" + ) + + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend( + config_path=str(config_file), + endpoint_id="kwarg-uuid", + ) + try: + assert isinstance(backend, GlobusComputeBackend) + # kwarg-uuid should win over config-uuid; amqp_port from config + mock_sdk.Executor.assert_called_once_with( + endpoint_id="kwarg-uuid", + amqp_port=5671, + ) + finally: + backend.shutdown() + + def test_factory_config_toml_driven(self, tmp_path): + """get_backend() with only a config.toml path should work end-to-end.""" + config_file = tmp_path / "config.toml" + config_file.write_text( + "[execution]\n" + 'backend = "globus_compute"\n\n' + "[execution.globus_compute]\n" + 'endpoint_id = "toml-uuid"\n' + ) + + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.config import get_backend + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = get_backend(config_path=str(config_file)) + try: + assert isinstance(backend, GlobusComputeBackend) + assert backend._initialized is True + mock_sdk.Executor.assert_called_once_with(endpoint_id="toml-uuid") + finally: + backend.shutdown() + + +# ── Layer 3: Globus Compute integration tests (real endpoint) ─────────── + + +@pytest.fixture +def globus_backend(): + """Provide an initialized GlobusComputeBackend connected to a real endpoint. + + Skips the test if GLOBUS_COMPUTE_ENDPOINT_ID is not set or the SDK is + not installed. + """ + endpoint_id = os.environ.get("GLOBUS_COMPUTE_ENDPOINT_ID") + if not endpoint_id: + pytest.skip("GLOBUS_COMPUTE_ENDPOINT_ID env var not set") + + try: + from chemgraph.execution.config import get_backend + except ImportError: + pytest.skip("chemgraph.execution not available") + + try: + backend = get_backend(backend_name="globus_compute", endpoint_id=endpoint_id) + except ImportError: + pytest.skip("globus-compute-sdk not installed") + + yield backend + backend.shutdown() + + +def _gc_double(x): + """Trivial function for Globus Compute integration tests.""" + return x * 2 + + +def _gc_square(x): + """Square function for Globus Compute integration tests.""" + return x * x + + +def _gc_identity(x): + """Identity function for Globus Compute integration tests.""" + return x + + +@pytest.mark.globus_compute +class TestGlobusComputeIntegration: + """Integration tests that submit work to a real Globus Compute endpoint. + + These are skipped by default. Run with:: + + GLOBUS_COMPUTE_ENDPOINT_ID= pytest --run-globus-compute -k Integration + """ + + def test_python_task_roundtrip(self, globus_backend): + """Submit a trivial Python callable and verify the result.""" + task = TaskSpec( + task_id="roundtrip", + task_type="python", + callable=_gc_double, + args=(21,), + ) + fut = globus_backend.submit(task) + result = fut.result(timeout=120) + assert result == 42 + + def test_shell_task_roundtrip(self, globus_backend): + """Submit a shell command and verify the output.""" + task = TaskSpec( + task_id="shell_rt", + task_type="shell", + command="echo hello_globus", + ) + fut = globus_backend.submit(task) + result = fut.result(timeout=120) + # ShellFunction returns a ShellResult; stdout should contain the string + assert "hello_globus" in str(result) + + def test_batch_submission(self, globus_backend): + """Submit a batch of tasks and verify all results.""" + tasks = [ + TaskSpec( + task_id=f"batch_{i}", + task_type="python", + callable=_gc_square, + args=(i,), + ) + for i in range(5) + ] + futures = globus_backend.submit_batch(tasks) + results = [f.result(timeout=120) for f in futures] + assert results == [0, 1, 4, 9, 16] + + @pytest.mark.asyncio + async def test_gather_futures_with_real_endpoint(self, globus_backend): + """Verify gather_futures works with real ComputeFuture objects.""" + tasks = [ + TaskSpec( + task_id=f"gf_{i}", + task_type="python", + callable=_gc_identity, + args=(i,), + ) + for i in range(3) + ] + futs = globus_backend.submit_batch(tasks) + pending = [({"index": i}, f) for i, f in enumerate(futs)] + + results = await gather_futures(pending) + assert len(results) == 3 + assert all("index" in r for r in results) + + +# ── Layer 4: Edge-case and error-handling tests ───────────────────────── + + +class TestGlobusComputeEdgeCases: + """Mocked tests for error paths and edge conditions.""" + + def test_submit_after_shutdown(self): + """Submitting after shutdown() should raise RuntimeError.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + backend.shutdown() + + task = TaskSpec(task_id="late", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="not initialized"): + backend.submit(task) + + def test_double_initialize(self): + """Calling initialize() twice should succeed and create a new executor.""" + mock_sdk, _ = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="uuid-1") + backend.initialize(endpoint_id="uuid-2") + + assert backend._initialized is True + assert mock_sdk.Executor.call_count == 2 + backend.shutdown() + + def test_context_manager_with_exception(self): + """shutdown() must be called even when the body raises.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + with pytest.raises(ValueError, match="intentional"): + with GlobusComputeBackend() as backend: + backend.initialize(endpoint_id="test-uuid") + raise ValueError("intentional") + + mock_executor.shutdown.assert_called_once() + + def test_executor_submit_raises_propagates(self): + """Errors from executor.submit() should propagate to the caller.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + mock_executor.submit.side_effect = RuntimeError("endpoint unavailable") + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="err", callable=_square, args=(1,)) + with pytest.raises(RuntimeError, match="endpoint unavailable"): + backend.submit(task) + + def test_submit_with_resource_hints(self): + """Resource hints are advisory and should not break submission.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec( + task_id="hints", + task_type="python", + callable=_square, + args=(5,), + num_nodes=4, + processes_per_node=32, + gpus_per_task=4, + ) + fut = backend.submit(task) + assert isinstance(fut, Future) + # Resource hints should NOT be passed to executor.submit + mock_executor.submit.assert_called_once_with(_square, 5) + + def test_failed_future_result(self): + """A future that resolves to an exception should be retrievable.""" + mock_sdk, mock_executor = _make_mock_gc_modules() + failed_future = Future() + failed_future.set_exception(RuntimeError("task exploded")) + mock_executor.submit.return_value = failed_future + + with patch.dict(sys.modules, {"globus_compute_sdk": mock_sdk}): + from chemgraph.execution.globus_compute_backend import ( + GlobusComputeBackend, + ) + + backend = GlobusComputeBackend() + backend.initialize(endpoint_id="test-uuid") + + task = TaskSpec(task_id="fail", callable=_square, args=(1,)) + fut = backend.submit(task) + + with pytest.raises(RuntimeError, match="task exploded"): + fut.result(timeout=5) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 86426c23..f3a8fca6 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -1,56 +1,184 @@ +from types import SimpleNamespace + import pytest +from langchain_core.messages import AIMessage + +from chemgraph.agent import llm_agent from chemgraph.agent.llm_agent import ChemGraph -WORKFLOWS = [ - "single_agent", "multi_agent", "python_relp", "graspa", - "mock_agent", "single_agent_mcp", "graspa_mcp", -] -@pytest.mark.parametrize("workflow_type", WORKFLOWS) -def test_constructor_is_called(monkeypatch, workflow_type): - called_data = {} +class _DummyTool: + def __init__(self, name): + self.name = name + + +class _FakeWorkflow: + def __init__(self): + self.astream_calls = [] + self.last_state = {"messages": [AIMessage(content="done")]} + + async def astream(self, inputs, *, stream_mode, config): + self.astream_calls.append( + {"inputs": inputs, "stream_mode": stream_mode, "config": config}, + ) + for callback in config.get("callbacks", []): + callback.on_chat_model_start({"name": "FakeChatModel"}, [["hello"]]) + callback.on_llm_end(SimpleNamespace(generations=[])) + yield self.last_state + + def get_state(self, config): + return SimpleNamespace(values=self.last_state) + + +@pytest.mark.parametrize( + ("workflow_type", "constructor_attr", "kwargs"), + [ + ("single_agent", "construct_single_agent_graph", {}), + ("multi_agent", "construct_multi_agent_graph", {}), + ("python_relp", "construct_relp_graph", {}), + ("graspa", "construct_graspa_graph", {}), + ("mock_agent", "construct_mock_agent_graph", {}), + ( + "single_agent_mcp", + "construct_single_agent_mcp_graph", + {"tools": [_DummyTool("mcp_tool")]}, + ), + ( + "graspa_mcp", + "construct_graspa_mcp_graph", + {"tools": [_DummyTool("executor")], "data_tools": [_DummyTool("analysis")]}, + ), + ("rag_agent", "construct_rag_agent_graph", {}), + ("single_agent_xanes", "construct_single_agent_xanes_graph", {}), + ], +) +def test_graph_constructor_is_called( + monkeypatch, + tmp_path, + workflow_type, + constructor_attr, + kwargs, +): + called = {} + workflow = _FakeWorkflow() + + def fake_constructor(*args, **constructor_kwargs): + called["args"] = args + called["kwargs"] = constructor_kwargs + return workflow - def fake_constructor(*args, **kwargs): - called_data["args"] = args - called_data["kwargs"] = kwargs - return f"WORKFLOW-SENTINEL-{workflow_type}" - - mapping = { - "single_agent": "construct_single_agent_graph", - "multi_agent": "construct_multi_agent_graph", - "python_relp": "construct_relp_graph", - "graspa": "construct_graspa_graph", - "mock_agent": "construct_mock_agent_graph", - "single_agent_mcp": "construct_single_agent_mcp_graph", - "graspa_mcp": "construct_graspa_mcp_graph", - } - - constructor_attr = mapping[workflow_type] - - # Patch the graph constructor monkeypatch.setattr(f"chemgraph.agent.llm_agent.{constructor_attr}", fake_constructor) monkeypatch.setattr( "chemgraph.agent.llm_agent.load_openai_model", - lambda **kwargs: "FAKE_LLM", + lambda **_kwargs: "FAKE_LLM", + ) + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type=workflow_type, + enable_memory=False, + log_dir=str(tmp_path / "logs"), + **kwargs, + ) + + assert cg.workflow is workflow + args = called.get("args", ()) + constructor_kwargs = called.get("kwargs", {}) + assert (args and args[0] == "FAKE_LLM") or constructor_kwargs.get("llm") == "FAKE_LLM" + + +@pytest.mark.asyncio +async def test_graph_backed_run_uses_astream_and_emits_events(monkeypatch, tmp_path): + workflow = _FakeWorkflow() + events = [] + + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + lambda *_args, **_kwargs: workflow, + ) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: "FAKE_LLM", + ) + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + return_option="last_message", + on_event=lambda event, payload: events.append((event, payload)), ) + response = await cg.run("hello", config={"thread_id": "test-thread"}) - # Set up inputs - test_tools = ["DUMMY_TOOL"] - kwargs = {"tools": test_tools, "data_tools": test_tools} if "_mcp" in workflow_type else {} + assert response.content == "done" + assert workflow.astream_calls[0]["inputs"] == {"messages": "hello"} + assert workflow.astream_calls[0]["stream_mode"] == "values" + assert workflow.astream_calls[0]["config"]["configurable"]["thread_id"] == "test-thread" + assert [event for event, _payload in events] == [ + "workflow_started", + "llm_call_started", + "llm_call_finished", + "workflow_finished", + ] - # Initialize - cg = ChemGraph(model_name="gpt-4o-mini", workflow_type=workflow_type, **kwargs) - # Assertions - assert cg.workflow == f"WORKFLOW-SENTINEL-{workflow_type}" - - # Check if LLM was passed as the first positional arg or a keyword arg - args = called_data.get("args", []) - kwargs_called = called_data.get("kwargs", {}) - - llm_passed = (len(args) > 0 and args[0] == "FAKE_LLM") or (kwargs_called.get("llm") == "FAKE_LLM") - assert llm_passed, f"LLM not passed to {workflow_type} constructor" +def test_single_agent_initialization_injects_calculator_availability(monkeypatch, tmp_path): + called = {} + + def fake_constructor(*args, **kwargs): + called["args"] = (args, kwargs) + return _FakeWorkflow() + + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + fake_constructor, + ) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: "FAKE_LLM", + ) + + cg = ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + ) + + args_tuple, _ = called["args"] + system_prompt = args_tuple[1] + assert "Calculator availability detected during ChemGraph initialization" in system_prompt + assert cg.default_calculator in system_prompt + assert cg.default_calculator in cg.available_calculators + + +def test_rag_and_xanes_default_prompts_are_preserved(monkeypatch, tmp_path): + captured = {} + + def fake_constructor(*args, **kwargs): + captured[kwargs.get("system_prompt")] = True + return _FakeWorkflow() + + monkeypatch.setattr("chemgraph.agent.llm_agent.construct_rag_agent_graph", fake_constructor) + monkeypatch.setattr("chemgraph.agent.llm_agent.construct_single_agent_xanes_graph", fake_constructor) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: "FAKE_LLM", + ) + + ChemGraph( + model_name="gpt-4o-mini", + workflow_type="rag_agent", + enable_memory=False, + log_dir=str(tmp_path / "rag-logs"), + ) + ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent_xanes", + enable_memory=False, + log_dir=str(tmp_path / "xanes-logs"), + ) - # Specific check for MCP tool passing - if workflow_type == "graspa_mcp": - assert kwargs_called.get("executor_tools") == test_tools \ No newline at end of file + assert llm_agent.rag_agent_prompt in captured + assert llm_agent.default_xanes_single_agent_prompt in captured diff --git a/tests/test_job_tracker.py b/tests/test_job_tracker.py new file mode 100644 index 00000000..cee3d081 --- /dev/null +++ b/tests/test_job_tracker.py @@ -0,0 +1,394 @@ +"""Tests for the JobTracker and submit_or_gather utilities.""" + +import asyncio +from concurrent.futures import Future +from unittest.mock import MagicMock + +import pytest + +from chemgraph.execution.job_tracker import JobTracker +from chemgraph.execution.utils import gather_futures, submit_or_gather + + +# ── Helpers ──────────────────────────────────────────────────────────── + + +def _make_done_future(result): + """Create a Future that is already resolved with *result*.""" + fut = Future() + fut.set_result(result) + return fut + + +def _make_failed_future(exc): + """Create a Future that is already resolved with an exception.""" + fut = Future() + fut.set_exception(exc) + return fut + + +def _make_pending_future(): + """Create a Future that is not yet resolved.""" + return Future() + + +# ── JobTracker.register_batch ────────────────────────────────────────── + + +class TestRegisterBatch: + def test_returns_batch_id(self): + tracker = JobTracker() + fut = _make_pending_future() + batch_id = tracker.register_batch( + "test_tool", [({"key": "val"}, fut)] + ) + assert isinstance(batch_id, str) + assert len(batch_id) == 12 + + def test_stores_tasks(self): + tracker = JobTracker() + futs = [_make_pending_future() for _ in range(3)] + pending = [({"idx": i}, f) for i, f in enumerate(futs)] + batch_id = tracker.register_batch("test_tool", pending) + + status = tracker.get_status(batch_id) + assert status["total_tasks"] == 3 + + def test_multiple_batches_unique_ids(self): + tracker = JobTracker() + ids = set() + for _ in range(10): + bid = tracker.register_batch( + "tool", [({"x": 1}, _make_pending_future())] + ) + ids.add(bid) + assert len(ids) == 10 + + +# ── JobTracker.get_status ────────────────────────────────────────────── + + +class TestGetStatus: + def test_all_pending(self): + tracker = JobTracker() + pending = [({"i": i}, _make_pending_future()) for i in range(3)] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "pending" + assert status["total_tasks"] == 3 + assert status["completed_tasks"] == 0 + assert status["pending_tasks"] == 3 + assert status["progress_pct"] == 0.0 + + def test_all_completed(self): + tracker = JobTracker() + pending = [ + ({"i": i}, _make_done_future({"val": i})) for i in range(3) + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + assert status["completed_tasks"] == 3 + assert status["failed_tasks"] == 0 + assert status["pending_tasks"] == 0 + assert status["progress_pct"] == 100.0 + + def test_partial_done(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 0})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "running" + assert status["completed_tasks"] == 1 + assert status["pending_tasks"] == 1 + assert status["progress_pct"] == 50.0 + + def test_all_failed(self): + tracker = JobTracker() + pending = [ + ({"i": i}, _make_failed_future(ValueError(f"err_{i}"))) + for i in range(2) + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "failed" + assert status["failed_tasks"] == 2 + + def test_mixed_success_and_failure(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 0})), + ({"i": 1}, _make_failed_future(RuntimeError("boom"))), + ] + batch_id = tracker.register_batch("tool", pending) + + status = tracker.get_status(batch_id) + assert status["status"] == "partial" + assert status["completed_tasks"] == 1 + assert status["failed_tasks"] == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + status = tracker.get_status("nonexistent") + assert "error" in status + + def test_with_post_fn(self): + def post_fn(meta, result): + return {"custom": True, "status": "success", **meta} + + tracker = JobTracker() + pending = [({"i": 0}, _make_done_future({"raw": 1}))] + batch_id = tracker.register_batch("tool", pending, post_fn=post_fn) + + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + + +# ── JobTracker.get_results ───────────────────────────────────────────── + + +class TestGetResults: + def test_returns_results_when_complete(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_done_future({"val": 20})), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id) + assert "results" in result + assert len(result["results"]) == 2 + + def test_blocks_when_pending_and_partial_false(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id, include_partial=False) + assert "results" not in result + assert "message" in result + assert "still pending" in result["message"] + + def test_returns_partial_when_requested(self): + tracker = JobTracker() + pending = [ + ({"i": 0}, _make_done_future({"val": 10})), + ({"i": 1}, _make_pending_future()), + ] + batch_id = tracker.register_batch("tool", pending) + + result = tracker.get_results(batch_id, include_partial=True) + assert "results" in result + assert len(result["results"]) == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + result = tracker.get_results("nonexistent") + assert "error" in result + + +# ── JobTracker.list_batches ──────────────────────────────────────────── + + +class TestListBatches: + def test_empty(self): + tracker = JobTracker() + assert tracker.list_batches() == [] + + def test_multiple_batches(self): + tracker = JobTracker() + tracker.register_batch("tool_a", [({"x": 1}, _make_pending_future())]) + tracker.register_batch("tool_b", [({"x": 2}, _make_done_future(42))]) + + batches = tracker.list_batches() + assert len(batches) == 2 + tool_names = {b["tool_name"] for b in batches} + assert tool_names == {"tool_a", "tool_b"} + + +# ── JobTracker.cancel_batch ──────────────────────────────────────────── + + +class TestCancelBatch: + def test_cancel_pending(self): + tracker = JobTracker() + fut = _make_pending_future() + batch_id = tracker.register_batch("tool", [({"i": 0}, fut)]) + + result = tracker.cancel_batch(batch_id) + # Future.cancel() may or may not succeed depending on state, + # but the call should not raise + assert "batch_id" in result + + def test_cancel_already_done(self): + tracker = JobTracker() + fut = _make_done_future({"val": 1}) + batch_id = tracker.register_batch("tool", [({"i": 0}, fut)]) + + result = tracker.cancel_batch(batch_id) + assert result["already_done"] == 1 + + def test_unknown_batch_id(self): + tracker = JobTracker() + result = tracker.cancel_batch("nonexistent") + assert "error" in result + + +# ── JobTracker.cleanup ───────────────────────────────────────────────── + + +class TestCleanup: + def test_removes_old_completed(self): + tracker = JobTracker() + batch_id = tracker.register_batch( + "tool", [({"i": 0}, _make_done_future(1))] + ) + + # Force the submitted_at to be old + batch = tracker._batches[batch_id] + from datetime import timedelta + + batch.submitted_at -= timedelta(hours=25) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 1 + assert tracker.list_batches() == [] + + def test_keeps_recent(self): + tracker = JobTracker() + tracker.register_batch("tool", [({"i": 0}, _make_done_future(1))]) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 0 + assert len(tracker.list_batches()) == 1 + + def test_keeps_pending(self): + tracker = JobTracker() + batch_id = tracker.register_batch( + "tool", [({"i": 0}, _make_pending_future())] + ) + + batch = tracker._batches[batch_id] + from datetime import timedelta + + batch.submitted_at -= timedelta(hours=25) + + removed = tracker.cleanup(max_age_hours=24) + assert removed == 0 + + +# ── gather_futures with timeout ──────────────────────────────────────── + + +class TestGatherFuturesTimeout: + def test_completes_within_timeout(self): + pending = [ + ({"i": 0}, _make_done_future({"val": 1})), + ({"i": 1}, _make_done_future({"val": 2})), + ] + results = asyncio.get_event_loop().run_until_complete( + gather_futures(pending, timeout=5.0) + ) + assert len(results) == 2 + + def test_timeout_raises(self): + pending = [({"i": 0}, _make_pending_future())] + with pytest.raises(asyncio.TimeoutError): + asyncio.get_event_loop().run_until_complete( + gather_futures(pending, timeout=0.1) + ) + + def test_no_timeout_default(self): + pending = [({"i": 0}, _make_done_future(42))] + results = asyncio.get_event_loop().run_until_complete( + gather_futures(pending) + ) + assert len(results) == 1 + + +# ── submit_or_gather ─────────────────────────────────────────────────── + + +class TestSubmitOrGather: + def test_sync_backend_returns_completed(self): + backend = MagicMock() + backend.is_async_remote = False + + tracker = JobTracker() + pending = [({"i": 0}, _make_done_future({"val": 10}))] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + assert result["status"] == "completed" + assert "results" in result + assert len(result["results"]) == 1 + + def test_async_backend_returns_submitted(self): + backend = MagicMock() + backend.is_async_remote = True + + tracker = JobTracker() + pending = [({"i": 0}, _make_pending_future())] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + assert result["status"] == "submitted" + assert "batch_id" in result + assert result["n_tasks"] == 1 + assert "check_job_status" in result["message"] + + def test_async_backend_batch_trackable(self): + backend = MagicMock() + backend.is_async_remote = True + + tracker = JobTracker() + fut = _make_done_future({"val": 99}) + pending = [({"i": 0}, fut)] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather(backend, pending, tracker, "test_tool") + ) + batch_id = result["batch_id"] + + # Verify the batch is tracked and status works + status = tracker.get_status(batch_id) + assert status["status"] == "completed" + + # Verify results can be retrieved + results = tracker.get_results(batch_id) + assert "results" in results + assert len(results["results"]) == 1 + + def test_async_backend_with_post_fn(self): + backend = MagicMock() + backend.is_async_remote = True + + def post_fn(meta, result): + return {"processed": True, "status": "success"} + + tracker = JobTracker() + fut = _make_done_future({"raw": 1}) + pending = [({"i": 0}, fut)] + + result = asyncio.get_event_loop().run_until_complete( + submit_or_gather( + backend, pending, tracker, "test_tool", post_fn=post_fn, + ) + ) + batch_id = result["batch_id"] + + results = tracker.get_results(batch_id) + assert results["results"][0]["processed"] is True diff --git a/tests/test_llm_agent.py b/tests/test_llm_agent.py index 8d46339d..a27eca6c 100644 --- a/tests/test_llm_agent.py +++ b/tests/test_llm_agent.py @@ -1,8 +1,12 @@ -import pytest import asyncio -from chemgraph.agent.llm_agent import ChemGraph +import json +from types import SimpleNamespace from unittest.mock import Mock, patch + +import pytest from langchain_core.messages import AIMessage +from chemgraph.agent.llm_agent import ChemGraph +from chemgraph.agent.turn import _TurnEventCallback @pytest.fixture @@ -10,23 +14,170 @@ def mock_llm(): return Mock() -def test_chemgraph_initialization(): +def test_chemgraph_initialization(tmp_path): with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load: mock_load.return_value = Mock() - agent = ChemGraph(model_name="gpt-4o-mini") + agent = ChemGraph( + model_name="gpt-4o-mini", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + ) assert hasattr(agent, "workflow") -def test_agent_query(mock_llm): - with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_load: +def test_agent_query(mock_llm, tmp_path): + with patch("chemgraph.agent.llm_agent.load_openai_model") as mock_init_load, patch( + "chemgraph.models.loader.load_openai_model" + ) as mock_turn_load: # Set up the mock chain mock_chain = Mock() mock_chain.invoke.return_value = AIMessage(content="Test response") mock_llm.bind_tools.return_value = mock_chain - mock_load.return_value = mock_llm + mock_init_load.return_value = mock_llm + mock_turn_load.return_value = mock_llm - agent = ChemGraph(model_name="gpt-4o-mini") + agent = ChemGraph( + model_name="gpt-4o-mini", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + ) response = asyncio.run(agent.run("What is the SMILES string for water?")) assert isinstance(response, AIMessage) assert response.content == "Test response" mock_llm.bind_tools.assert_called_once() mock_chain.invoke.assert_called_once() + + +def test_turn_event_callback_emits_llm_decision_for_tool_calls(): + events = [] + callback = _TurnEventCallback( + lambda event, payload: events.append((event, payload)), + "thread-1", + ) + response = SimpleNamespace( + llm_output={"token_usage": {"total_tokens": 12}}, + generations=[ + [ + SimpleNamespace( + message=SimpleNamespace( + tool_calls=[ + {"name": "molecule_name_to_smiles", "id": "call-1"}, + { + "function": {"name": "smiles_to_coordinate_file"}, + "tool_call_id": "call-2", + }, + ], + ), + ), + ], + ], + ) + + callback.on_llm_end(response) + + assert events == [ + ( + "llm_call_finished", + { + "thread_id": "thread-1", + "llm_output": {"token_usage": {"total_tokens": 12}}, + }, + ), + ( + "llm_decision", + { + "thread_id": "thread-1", + "tool_calls": [ + {"name": "molecule_name_to_smiles", "id": "call-1"}, + {"name": "smiles_to_coordinate_file", "id": "call-2"}, + ], + }, + ), + ] + + +def test_turn_event_callback_skips_llm_decision_without_tool_calls(): + events = [] + callback = _TurnEventCallback( + lambda event, payload: events.append((event, payload)), + "thread-1", + ) + + callback.on_llm_end( + SimpleNamespace(generations=[[SimpleNamespace(message=AIMessage(content="done"))]]), + ) + + assert [event for event, _payload in events] == ["llm_call_finished"] + + +def test_turn_event_callback_ignores_llm_decision_extraction_errors(): + class BrokenGenerationGroup: + def __iter__(self): + raise RuntimeError("broken response") + + events = [] + callback = _TurnEventCallback( + lambda event, payload: events.append((event, payload)), + "thread-1", + ) + + callback.on_llm_end(SimpleNamespace(generations=[BrokenGenerationGroup()])) + + assert [event for event, _payload in events] == ["llm_call_finished"] + + +@pytest.mark.asyncio +async def test_cli_trace_events_are_emitted_from_astream_path(monkeypatch, tmp_path): + from chemgraph.cli.trace import CLIRunTrace + + class FakeWorkflow: + def __init__(self): + self.state = {"messages": [AIMessage(content="done")]} + + async def astream(self, inputs, *, stream_mode, config): + for callback in config.get("callbacks", []): + callback.on_chat_model_start({"name": "FakeChatModel"}, [["hello"]]) + callback.on_llm_end(SimpleNamespace(generations=[])) + yield self.state + + def get_state(self, config): + return SimpleNamespace(values=self.state) + + monkeypatch.setattr( + "chemgraph.agent.llm_agent.construct_single_agent_graph", + lambda *_args, **_kwargs: FakeWorkflow(), + ) + monkeypatch.setattr( + "chemgraph.agent.llm_agent.load_openai_model", + lambda **_kwargs: Mock(), + ) + + trace = CLIRunTrace( + tmp_path / "trace", + run_id="trace-test", + model_name="gpt-4o-mini", + workflow_type="single_agent", + query="x", + ) + trace.start() + agent = ChemGraph( + model_name="gpt-4o-mini", + workflow_type="single_agent", + enable_memory=False, + log_dir=str(tmp_path / "logs"), + on_event=trace.on_event, + ) + await agent.run("x") + trace.finish(status="completed") + + events = [ + json.loads(line)["event"] + for line in (tmp_path / "trace" / "events.jsonl").read_text().splitlines() + ] + assert events == [ + "run_started", + "workflow_started", + "llm_call_started", + "llm_call_finished", + "workflow_finished", + "run_finished", + ] diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 66cab765..86d32558 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,13 +1,16 @@ """Test suite for MCP servers.""" +import inspect import json from pathlib import Path +from typing import Any import pytest try: from mcp.types import TextContent from fastmcp import Client + from chemgraph.mcp.cg_fastmcp import CGFastMCP from chemgraph.mcp.mcp_tools import mcp from chemgraph.mcp.data_analysis_mcp import mcp as data_mcp except ModuleNotFoundError: @@ -16,6 +19,112 @@ TEST_DIR = Path(__file__).parent +def _fanout_worker(item: dict) -> dict: + return item + + +def test_schema_fanout_tool_advertises_batch_result_signature(monkeypatch): + """Fanout tools expose an ensemble input but return batch summaries.""" + local_mcp = CGFastMCP(name="test") + captured = {} + + def capture_tool(fn, **kwargs): + captured["fn"] = fn + captured["kwargs"] = kwargs + + monkeypatch.setattr(local_mcp, "add_tool", capture_tool) + + @local_mcp.schema_fanout_tool(name="fanout", worker=_fanout_worker) + def fanout(params: dict) -> list[dict]: + return [params] + + sig = inspect.signature(captured["fn"]) + + assert list(sig.parameters) == ["params"] + assert sig.parameters["params"].annotation is dict + assert sig.return_annotation == dict[str, Any] + + +def test_mace_worker_creates_inline_output_parent(monkeypatch): + from ase import Atoms + + from chemgraph.mcp import mace_mcp_hpc + from chemgraph.tools.ase_core import atoms_to_atomsdata + + atoms = Atoms(numbers=[1, 1], positions=[[0, 0, 0], [0, 0, 0.74]]) + output_file = "nested/family/output.json" + + def fake_run_mace_core(params): + output_path = Path(params.output_result_file) + assert output_path.parent.is_dir() + output_path.write_text('{"ok": true}', encoding="utf-8") + return {"status": "success"} + + monkeypatch.setattr(mace_mcp_hpc, "run_mace_core", fake_run_mace_core) + + result = mace_mcp_hpc._mace_worker( + { + "inline_structure": atoms_to_atomsdata(atoms).model_dump(), + "output_result_file": output_file, + "driver": "energy", + "model": "small", + "device": "cpu", + } + ) + + # The worker returns run_mace_core's result verbatim; full_output read-back + # was intentionally dropped. The inline output parent dir is asserted inside + # fake_run_mace_core above. + assert result["status"] == "success" + + +def test_run_ase_core_creates_output_parent_directory(monkeypatch, tmp_path): + """run_ase_core should mkdir the output file's parent before writing. + + Academy agents and CLI users routinely point output_results_file at a + not-yet-existing nested subdirectory of a shared run dir. Without this, + the final ``open(output_results_file, "w")`` fails with + FileNotFoundError after the calculation has already burned its compute + time. + """ + from ase import Atoms + from ase.io import write as ase_write + + from chemgraph.schemas.ase_input import ASEInputSchema + from chemgraph.tools import ase_core + + # Real XYZ that ase.io.read can parse. + input_path = tmp_path / "h2.xyz" + ase_write(input_path, Atoms(numbers=[1, 1], positions=[[0, 0, 0], [0, 0, 0.74]])) + + # Output path under a nested subdirectory that does NOT exist yet. + output_path = tmp_path / "deeply" / "nested" / "output.json" + assert not output_path.parent.exists() + + class _FakeCalc: + # ASE's Atoms.get_potential_energy invokes self._calc.get_potential_energy(atoms). + def get_potential_energy(self, _atoms=None, force_consistent=False): + return -1.234 + + def fake_load_calculator(_calculator): + return _FakeCalc(), {}, None + + monkeypatch.setattr(ase_core, "load_calculator", fake_load_calculator) + + params = ASEInputSchema( + input_structure_file=str(input_path), + output_results_file=str(output_path), + driver="energy", + calculator={"calculator_type": "emt"}, + ) + + result = ase_core.run_ase_core(params) + + assert result["status"] == "success", result + assert output_path.exists() + assert output_path.parent.is_dir() + + @pytest.mark.asyncio async def test_split_cif_dataset(tmp_path): """Test splitting a dataset of CIF files.""" diff --git a/tests/test_openai_model_normalization.py b/tests/test_openai_model_normalization.py new file mode 100644 index 00000000..0ff55624 --- /dev/null +++ b/tests/test_openai_model_normalization.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from chemgraph.models.openai import _normalize_argo_model + + +def test_local_argo_shim_keeps_openai_style_model_name(monkeypatch): + monkeypatch.delenv("CHEMGRAPH_ARGO_MODEL_FORMAT", raising=False) + + assert ( + _normalize_argo_model( + "argo:gpt-4o-mini", + "http://127.0.0.1:18085/argoapi/v1", + ) + == "gpt-4o-mini" + ) + + +def test_local_argo_shim_uses_advertised_gpt54_model_name(monkeypatch): + monkeypatch.delenv("CHEMGRAPH_ARGO_MODEL_FORMAT", raising=False) + + assert ( + _normalize_argo_model( + "argo:gpt-5.4", + "http://127.0.0.1:18085/argoapi/v1", + ) + == "GPT-5.4" + ) + + +def test_hosted_argo_endpoint_uses_wire_model_name(monkeypatch): + monkeypatch.delenv("CHEMGRAPH_ARGO_MODEL_FORMAT", raising=False) + + assert ( + _normalize_argo_model( + "argo:gpt-4o-mini", + "https://apps.inside.anl.gov/argoapi/v1", + ) + == "gpt4omini" + ) + + +def test_argo_model_format_env_override(monkeypatch): + monkeypatch.setenv("CHEMGRAPH_ARGO_MODEL_FORMAT", "openai") + assert ( + _normalize_argo_model( + "argo:gpt-4o-mini", + "https://apps.inside.anl.gov/argoapi/v1", + ) + == "gpt-4o-mini" + ) + + +def test_argo_model_format_shim_override_uses_local_alias(monkeypatch): + monkeypatch.setenv("CHEMGRAPH_ARGO_MODEL_FORMAT", "shim") + assert ( + _normalize_argo_model( + "argo:gpt-5.4", + "https://apps.inside.anl.gov/argoapi/v1", + ) + == "GPT-5.4" + ) diff --git a/tests/test_tool_adapter_validation.py b/tests/test_tool_adapter_validation.py new file mode 100644 index 00000000..8abf5a69 --- /dev/null +++ b/tests/test_tool_adapter_validation.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import json +from typing import Any + +import pytest + +# Skip when the optional 'academy' extra is absent; core.tools imports +# academy.agent at module level. +pytest.importorskip("academy") + +from chemgraph.academy.core.tools import build_chemgraph_reasoning_tools +from chemgraph.academy.core.campaign import ChemGraphAgentSpec + + +class _FakePeerHandle: + def __init__(self) -> None: + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def action(self, name: str, payload: dict[str, Any]) -> None: + self.calls.append((name, payload)) + + +def _agent_spec() -> ChemGraphAgentSpec: + return ChemGraphAgentSpec( + name="agent-a", + role="Worker", + mission="Use explicit tools only.", + allowed_peers=("agent-b",), + mcp_servers=(), + ) + + +async def _build_tools(tmp_path): + traces: list[tuple[str, dict[str, Any]]] = [] + outbox: list[dict[str, Any]] = [] + peer_handle = _FakePeerHandle() + tools = await build_chemgraph_reasoning_tools( + spec=_agent_spec(), + run_dir=tmp_path, + peer_names=("agent-b",), + peer_handles={"agent-b": peer_handle}, + outbox=outbox, + tool_results=[], + get_round_index=lambda: 1, + set_final_result=lambda result: None, + trace=lambda event, payload: traces.append((event, payload)), + ) + return { + "tools": {tool.name: tool for tool in tools}, + "traces": traces, + "outbox": outbox, + "peer_handle": peer_handle, + } + + +@pytest.mark.asyncio +async def test_send_message_invalid_args_return_structured_tool_error(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "invalid confidence", + "content": "content", + "artifact_refs": [], + "tool_result_ids": [], + "reason": "exercise validation", + "confidence": 1.5, + } + ) + + assert result["status"] == "error" + assert result["error_type"] == "invalid_tool_arguments" + assert result["errors"][0]["field"] == "confidence" + assert env["outbox"] == [] + assert env["peer_handle"].calls == [] + assert env["traces"] == [ + ( + "tool_call_failed", + { + "tool_name": "send_message", + "status": "failed", + "error": "invalid_tool_arguments", + "error_type": "invalid_tool_arguments", + "errors": result["errors"], + }, + ) + ] + + +@pytest.mark.asyncio +async def test_send_message_disallowed_recipient_does_not_deliver(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "not-a-peer", + "tldr": "wrong peer", + "content": "content", + "artifact_refs": [], + "tool_result_ids": [], + "reason": "exercise validation", + "confidence": 0.8, + } + ) + + assert result == { + "status": "error", + "tool_name": "send_message", + "error": "disallowed_recipient", + "error_type": "disallowed_recipient", + "recipient": "not-a-peer", + "allowed_peers": ["agent-b"], + } + assert env["outbox"] == [] + assert env["peer_handle"].calls == [] + assert env["traces"][0][0] == "tool_call_failed" + assert env["traces"][0][1]["error_type"] == "disallowed_recipient" + + +@pytest.mark.asyncio +async def test_send_message_request_requires_tldr(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "", + "content": "What happened?", + "artifact_refs": [], + "tool_result_ids": [], + "reply_requested": True, + "reason": "need a peer check", + "confidence": 0.5, + } + ) + + assert result["status"] == "error" + assert result["error_type"] == "invalid_tool_arguments" + assert result["errors"][0]["field"] == "tldr" + assert env["outbox"] == [] + assert env["peer_handle"].calls == [] + + +@pytest.mark.asyncio +async def test_send_message_reply_requested_marks_question(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "need status", + "content": "Please send current status.", + "artifact_refs": [], + "tool_result_ids": [], + "reply_requested": True, + "reason": "the report needs the peer status", + "confidence": 0.7, + } + ) + + assert result["status"] == "sent" + assert env["outbox"][0]["reply_requested"] is True + assert env["outbox"][0]["kind"] == "question" + + +@pytest.mark.asyncio +async def test_valid_send_message_still_delivers(tmp_path) -> None: + env = await _build_tools(tmp_path) + + result = await env["tools"]["send_message"].ainvoke( + { + "recipient": "agent-b", + "tldr": "candidate ready", + "content": "Candidate C1 has a usable artifact.", + "artifact_refs": ["artifacts/c1.xyz"], + "tool_result_ids": ["tool-1"], + "reply_requested": False, + "reason": "peer needs the result", + "confidence": 0.9, + } + ) + + assert result["status"] == "sent" + assert result["recipient"] == "agent-b" + assert len(env["outbox"]) == 1 + assert env["outbox"][0]["reply_requested"] is False + assert env["peer_handle"].calls[0][0] == "receive_message" + assert env["peer_handle"].calls[0][1]["message_id"] == result["message_id"] + assert [event for event, _ in env["traces"]] == [ + "message_sent", + "message_delivered", + ] + assert { + json.loads(line)["message_id"] + for line in tmp_path.joinpath("messages.jsonl").read_text().splitlines() + } == {result["message_id"]}