Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"ase",
"rdkit",
"pymatgen",
"mace-torch",
"mace-torch", # disabled for UMA-only CPU environment
"globus_sdk",
"streamlit==1.48.1",
"stmol==0.0.9",
Expand All @@ -48,7 +48,6 @@ dependencies = [
"mcp",
"fastmcp",
"pytest-asyncio",
"grandalf==0.8"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -78,7 +77,11 @@ rag = [
"sentence-transformers>=2.2.2",
"pymupdf>=1.24.0",
]

phonon = [
"grandalf==0.8",
"phonopy>=2.33.0",
"seekpath>=2.2.1"
]
[project.urls]
"Homepage" = "https://github.com/argonne-lcf/ChemGraph"
"Repository" = "https://github.com/argonne-lcf/ChemGraph"
Expand Down
56 changes: 41 additions & 15 deletions src/chemgraph/cli/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,34 @@ def format_response(result: Any, verbose: bool = False) -> None:
# Find the final AI response
final_answer = ""
for message in reversed(messages):
is_ai = False
raw_content = ""
if hasattr(message, "content") and hasattr(message, "type"):
if message.type == "ai" and message.content.strip():
content = message.content.strip()
if not _is_atomic_json(content):
final_answer = content
break
raw_content = message.content
is_ai = (message.type == "ai")
elif isinstance(message, dict):
if message.get("type") == "ai" and message.get("content", "").strip():
content = message["content"].strip()
if not _is_atomic_json(content):
final_answer = content
break
raw_content = message.get("content", "")
is_ai = (message.get("type") == "ai")

if isinstance(raw_content, list):
text_parts = []
for item in raw_content:
if isinstance(item, str):
text_parts.append(item)
elif isinstance(item, dict) and "text" in item:
text_parts.append(item["text"])
elif isinstance(item, dict) and "image_url" in item:
# Depending on structure, image_url might be a string or dict
url = item["image_url"] if isinstance(item["image_url"], str) else item["image_url"].get("url", "")
text_parts.append(f"![image]({url})")
content = "\n".join(text_parts).strip()
else:
content = str(raw_content).strip()

if is_ai and content:
if not _is_atomic_json(content):
final_answer = content
break

if final_answer:
console.print(
Expand All @@ -239,16 +255,26 @@ def format_response(result: Any, verbose: bool = False) -> None:

# Check for structure data (valid JSON with atomic keys)
for message in messages:
content = ""
raw_content = ""
if hasattr(message, "content"):
content = message.content
raw_content = message.content
elif isinstance(message, dict):
content = message.get("content", "")
raw_content = message.get("content", "")

content_str = ""
if isinstance(raw_content, list):
for item in raw_content:
text = item if isinstance(item, str) else item.get("text", "") if isinstance(item, dict) else ""
if text and _is_atomic_json(text):
content_str = text
break
else:
content_str = str(raw_content)

if content and _is_atomic_json(content):
if content_str and _is_atomic_json(content_str):
console.print(
Panel(
Syntax(content, "json", theme="monokai"),
Syntax(content_str, "json", theme="monokai"),
title="Molecular Structure Data",
style="cyan",
)
Expand Down
2 changes: 2 additions & 0 deletions src/chemgraph/graphs/multi_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,11 +812,13 @@ def construct_multi_agent_graph(
smiles_to_coordinate_file,
)
from chemgraph.tools.generic_tools import calculator
from chemgraph.tools.phonopy_tools import run_phonopy

executor_tools = [
molecule_name_to_smiles,
smiles_to_coordinate_file,
run_ase,
run_phonopy,
extract_output_json,
calculator,
]
Expand Down
3 changes: 3 additions & 0 deletions src/chemgraph/graphs/single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
molecule_name_to_smiles,
smiles_to_coordinate_file,
)
from chemgraph.tools.phonopy_tools import run_phonopy
from chemgraph.tools.report_tools import generate_html
from chemgraph.tools.generic_tools import calculator, ask_human
from chemgraph.prompt.single_agent_prompt import (
Expand Down Expand Up @@ -261,6 +262,7 @@ def ChemGraphAgent(
tools = [
smiles_to_coordinate_file,
run_ase,
run_phonopy,
molecule_name_to_smiles,
extract_output_json,
calculator,
Expand Down Expand Up @@ -450,6 +452,7 @@ def construct_single_agent_graph(
smiles_to_coordinate_file,
molecule_name_to_smiles,
run_ase,
run_phonopy,
extract_output_json,
calculator,
]
Expand Down
27 changes: 27 additions & 0 deletions src/chemgraph/mcp/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
smiles_to_coordinate_file_core,
)
from chemgraph.schemas.ase_input import ASEInputSchema
from chemgraph.tools.phonopy_core import run_phonopy_core
from chemgraph.schemas.phonopy_schema import PhonopyInputSchema


mcp = FastMCP(
Expand Down Expand Up @@ -136,6 +138,31 @@ async def run_ase(params: ASEInputSchema) -> dict:
return run_ase_core(params)


@mcp.tool(
name="run_phonopy",
description="Run Phonopy calculations using specified input parameters.",
)
async def run_phonopy(params: PhonopyInputSchema) -> dict:
"""Run Phonopy calculations using specified input parameters.

Parameters
----------
params : PhonopyInputSchema
Input parameters for the Phonopy calculation

Returns
-------
dict
Output containing calculation status
"""
import io
from contextlib import redirect_stdout

f = io.StringIO()
with redirect_stdout(f):
return run_phonopy_core(params)


if __name__ == "__main__":
from chemgraph.mcp.server_utils import run_mcp_server

Expand Down
8 changes: 1 addition & 7 deletions src/chemgraph/schemas/calculators/tblite_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from typing import List, Optional
import logging

try:
from tblite.ase import TBLite
except ImportError:
logging.warning(
"tblite is not installed. If you want to use tblite, please install it using 'pip install tblite'."
)


class TBLiteCalc(BaseModel):
"""TBLite tight-binding calculator configuration.
Expand Down Expand Up @@ -96,6 +89,7 @@ def get_calculator(self):
An ASE-compatible TBLite calculator instance with the specified
configuration parameters
"""
from tblite.ase import TBLite
return TBLite(
method=self.method,
charge=self.charge,
Expand Down
136 changes: 136 additions & 0 deletions src/chemgraph/schemas/phonopy_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Optional, List, Any
from chemgraph.schemas.atomsdata import AtomsData
from chemgraph.schemas.ase_input import CalculatorUnion, default_calculator, _calculator_key, available_calculator_classes

class PhonopyInputSchema(BaseModel):
"""
Schema for defining input parameters used in Phonopy phonon calculations.
"""
input_structure_file: str = Field(
description="Path to the input coordinate file (e.g., CIF, XYZ, POSCAR) containing the relaxed atomic structure."
)
output_results_file: str = Field(
default="phonopy_results.json",
description="Path to a JSON file where phonon simulation results will be saved.",
)
supercell_matrix: Optional[List[int]] = Field(
default=None,
description="A list of 3 integers representing the diagonal of the supercell matrix (e.g., [2, 2, 2]). If not provided, it will be determined automatically to ensure dimensions are at least 10 Å.",
)
is_2d: bool = Field(
default=False,
description="Set to true if the material is 2D. This forces the supercell matrix in the Z-direction to be 1.",
)
calculator: CalculatorUnion = Field(
default_factory=default_calculator,
description="The ASE calculator to be used for force evaluations.",
)
mesh: List[int] = Field(
default=[10, 10, 10],
description="Sampling mesh for reciprocal space (e.g., [10, 10, 10]).",
)
calculate_dos: bool = Field(
default=True,
description="Whether to calculate and plot the Total Density of States (DOS).",
)
calculate_thermal_properties: bool = Field(
default=True,
description="Whether to calculate and plot Thermal Properties (Free energy, Entropy, Heat Capacity).",
)
calculate_band_structure: bool = Field(
default=True,
description="Whether to calculate and plot the phonon band structure (dispersion curve).",
)
band_paths: Optional[List[List[List[float]]]] = Field(
default=None,
description="A list of q-point paths in reciprocal space (e.g., [[[0,0,0], [0.5,0,0], [0.333,0.333,0], [0,0,0]]]). If None but calculate_band_structure is True, it will automatically determine the standard high-symmetry paths using seekpath.",
)
band_labels: Optional[List[str]] = Field(
default=None,
description="List of labels for the q-points in band_paths (e.g., ['Gamma', 'M', 'K', 'Gamma']). Total number of labels should match the total number of unique q-points in the path.",
)
band_npoints: int = Field(
default=51,
description="Number of q-points to sample along the paths between high-symmetry points for the band structure.",
)
supercell_target_length: float = Field(
default=10.0,
description="Target minimum length (in Å) for the auto-generated supercell matrix dimensions.",
)
save_vasp_files: bool = Field(
default=True,
description="Whether to save the VASP-format FORCE_CONSTANTS and POSCAR-* files.",
)
symprec: float = Field(
default=1e-5,
description="Symmetry tolerance used in Phonopy.",
)
dft_phonon_file: Optional[str] = Field(
default=None,
description="Path to a DFT phonon data file (e.g., band.dat or results.dat) to compare with the calculated phonon dispersion.",
)

@model_validator(mode="before")
@classmethod
def _validate_calculator_type(cls, data: Any):
if not isinstance(data, dict):
return data

calc = data.get("calculator")
if calc is None:
calc = default_calculator()
data["calculator"] = calc

available_calcs = {
_calculator_key(c.__name__.removesuffix("Calc")): c
for c in available_calculator_classes
}

if isinstance(calc, dict):
calc_name = calc.get("calculator_type")
if not calc_name:
raise ValueError("Calculator dictionary must have a 'calculator_type' key.")

calc_key = _calculator_key(calc_name)
if calc_key not in available_calcs:
raise ValueError(f"Calculator {calc_name} is not an allowed or available calculator.")

init_args = calc.copy()
init_args.pop("calculator_type", None)
data["calculator"] = available_calcs[calc_key](**init_args)
return data

elif hasattr(calc, "__class__"):
calc_type_name = calc.__class__.__name__
calc_key = _calculator_key(calc_type_name.removesuffix("Calc"))
if calc_key not in available_calcs:
raise ValueError(f"Calculator {calc_type_name} is not an allowed or available calculator.")
return data


class PhonopyOutputSchema(BaseModel):
"""
Schema for defining outputs from Phonopy phonon simulations.
"""
input_structure_file: str = Field(description="Path to the input coordinate file.")
simulation_input: PhonopyInputSchema = Field(description="Input parameters used.")
success: bool = Field(default=False, description="Indicates if the simulation finished correctly.")
error: str = Field(default="", description="Error captured during the simulation.")
supercell_matrix_used: List[int] = Field(default=[], description="The actual supercell matrix used.")
thermal_properties_plot: Optional[str] = Field(default=None, description="Path to the saved thermal properties plot.")
dos_plot: Optional[str] = Field(default=None, description="Path to the saved DOS plot.")
band_structure_plot: Optional[str] = Field(default=None, description="Path to the saved band structure plot.")
band_yaml: Optional[str] = Field(default=None, description="Path to the saved band.yaml file containing frequencies and k-points.")
phonopy_yaml: Optional[str] = Field(default=None, description="Path to the generated phonopy.yaml file containing force constants.")
force_constants_file: Optional[str] = Field(default=None, description="Path to the saved FORCE_CONSTANTS file.")
poscar_files: Optional[List[str]] = Field(default=None, description="List of paths to the saved POSCAR-* supercell files.")
wall_time: float = Field(default=None, description="Total wall time (in seconds) taken to complete the simulation.")

@field_validator("error", mode="before")
@classmethod
def _coerce_error_to_string(cls, v: Any) -> str:
if v is None:
return ""
return v if isinstance(v, str) else str(v)
Loading
Loading