Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ indent-style = "space" # Use spaces for indentation
skip-magic-trailing-comma = false # Ensure Black-style formatting

[tool.pytest.ini_options]
testpaths = ["tests"]
markers = [
"llm: marks tests as requiring LLM API access (run with --run-llm)",
"asyncio: marks async tests",
Expand Down
8 changes: 7 additions & 1 deletion src/chemgraph/mcp/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,19 @@ async def run_ase(params: ASEInputSchema) -> dict:
atomsdata=final_structure
)

multiplicity = (
getattr(calc_model, "get_multiplicity", lambda: None)()
or 1
)
spin_S = (multiplicity - 1) / 2.0

thermo = IdealGasThermo(
vib_energies=energies,
potentialenergy=single_point_energy,
atoms=atoms,
geometry=geometry,
symmetrynumber=symmetrynumber,
spin=0, # Only support spin=0
spin=spin_S,
)
thermo_data = {
"enthalpy": float(
Expand Down
42 changes: 35 additions & 7 deletions src/chemgraph/schemas/calculators/fairchem_calc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from typing import Optional, Dict, Any
from typing import Any, Optional, Dict
import torch
import logging

Expand All @@ -22,8 +22,11 @@ class FAIRChemCalc(BaseModel):
Must match available tasks in the model.
seed : int, optional
Seed for model reproducibility. Default is 42.
spin : int, optional
Spin multiplicity. Default is 1.
multiplicity : int, optional
Spin multiplicity (2S+1) of the system. Default is 1 (singlet).
UMA/OMOL reads this from ``atoms.info["spin"]``; the schema field is named
``multiplicity`` for consistency with other calculators (TBLite, ORCA).
The deprecated alias ``spin=`` is still accepted as input.
charge : int, optional
System charge. Default is 0.
model_name: str
Expand All @@ -41,7 +44,14 @@ class FAIRChemCalc(BaseModel):
description="Prediction task. Options are 'omol', 'omat', 'oc20', 'odac', or 'omc",
)
seed: int = Field(default=42, description="Random seed for inference reproducibility.")
spin: Optional[int] = Field(default=1, description="Total spin multiplicity of the system.")
multiplicity: Optional[int] = Field(
default=1,
description=(
"Spin multiplicity (2S+1) of the system. Default 1 (singlet). "
"Passed to UMA via atoms.info['spin']."
),
ge=1,
)
charge: Optional[int] = Field(default=0, description="Total system charge.")
model_name: str = Field(
default="uma-s-1p1", description="Model names. Options are 'uma-s-1p1' and 'uma-m-1'"
Expand All @@ -54,6 +64,16 @@ class FAIRChemCalc(BaseModel):
default="default", description="Settings for inference. Can be 'default' or 'turbo'"
)

@model_validator(mode="before")
@classmethod
def _accept_spin_alias(cls, data: Any) -> Any:
if isinstance(data, dict) and "spin" in data and "multiplicity" not in data:
logging.warning(
"FAIRChemCalc: field 'spin' is deprecated; use 'multiplicity' instead."
)
data["multiplicity"] = data.pop("spin")
return data

def get_calculator(self) -> Any:
"""Return a configured FAIRChemCalculator.

Expand Down Expand Up @@ -83,8 +103,16 @@ def get_calculator(self) -> Any:
)

def get_atoms_properties(self) -> Dict[str, Optional[int]]:
"""Return atom-level info keys to inject into atoms.info."""
"""Return atom-level info keys to inject into atoms.info.

UMA/OMOL reads spin multiplicity from ``atoms.info["spin"]``; we keep
that key name here even though our schema field is ``multiplicity``.
"""
return {
"spin": self.spin,
"spin": self.multiplicity,
"charge": self.charge,
}

def get_multiplicity(self) -> Optional[int]:
"""Return spin multiplicity (2S+1) for thermochemistry."""
return self.multiplicity
37 changes: 36 additions & 1 deletion src/chemgraph/schemas/calculators/nwchem_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ class NWChemCalc(BaseModel):
command : str, optional
Command to execute NWChem (e.g., 'nwchem PREFIX.nwi > PREFIX.nwo'),
by default None
charge : int, optional
Total charge of the system, by default None
multiplicity : int, optional
Spin multiplicity (2S+1) of the system, by default None.
For molecular theories ('dft', 'scf', 'mp2', 'ccsd', 'tce', 'tddft') this is
injected into the theory block as ``mult``. For 'scf', NWChem expects
``nopen`` (number of unpaired electrons); set ``scf={'nopen': N}`` manually
if you need finer control.
"""

calculator_type: str = Field(
Expand Down Expand Up @@ -61,6 +69,18 @@ class NWChemCalc(BaseModel):
default=None,
description="Command to execute NWChem (e.g., 'nwchem PREFIX.nwi > PREFIX.nwo').",
)
charge: Optional[int] = Field(
default=None, description="Total charge of the system."
)
multiplicity: Optional[int] = Field(
default=None,
description=(
"Spin multiplicity (2S+1). Injected into the theory block as 'mult' "
"for dft/mp2/ccsd/tce/tddft; for 'scf' theory NWChem expects 'nopen' "
"(unpaired electrons) which is not auto-set here."
),
ge=1,
)

def get_calculator(self):
"""Get an ASE-compatible NWChem calculator instance.
Expand All @@ -80,11 +100,26 @@ def get_calculator(self):
"Invalid calculator_type. The only valid option is 'nwchem'."
)

return NWChem(
kwargs = dict(
theory=self.theory,
xc=self.xc,
basis=self.basis,
kpts=self.kpts,
directory=self.directory,
command=self.command,
)

# NWChem accepts charge/multiplicity inside the theory-specific block.
block: Dict[str, Union[int, str]] = {}
if self.charge is not None:
block["charge"] = self.charge
if self.multiplicity is not None and self.theory != "scf":
block["mult"] = self.multiplicity
if block and self.theory in {"dft", "mp2", "ccsd", "tce", "tddft"}:
kwargs[self.theory] = block

return NWChem(**kwargs)

def get_multiplicity(self) -> Optional[int]:
"""Return spin multiplicity (2S+1) for thermochemistry."""
return self.multiplicity
4 changes: 4 additions & 0 deletions src/chemgraph/schemas/calculators/orca_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,7 @@ def get_calculator(self):
directory=self.directory,
profile=profile,
)

def get_multiplicity(self) -> Optional[int]:
"""Return spin multiplicity (2S+1) for thermochemistry."""
return self.multiplicity
16 changes: 16 additions & 0 deletions src/chemgraph/schemas/calculators/psi4_calc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import BaseModel, Field


Expand Down Expand Up @@ -28,6 +30,10 @@ class Psi4Calc(BaseModel):
'cd' (Cholesky Decomposition), by default 'pk'
maxiter : int, optional
Maximum number of SCF iterations, by default 50
charge : int, optional
Total charge of the system, by default 0
multiplicity : int, optional
Spin multiplicity (2S+1) of the system, by default 1 (singlet)
"""

calculator_type: str = Field(
Expand Down Expand Up @@ -58,6 +64,10 @@ class Psi4Calc(BaseModel):
maxiter: int = Field(
default=50, description="Maximum number of SCF iterations. Default is 50."
)
charge: int = Field(default=0, description="Total charge of the system.")
multiplicity: int = Field(
default=1, description="Spin multiplicity (2S+1) of the system.", ge=1
)

def get_calculator(self) -> dict:
"""Get a dictionary of PSI4 calculation parameters.
Expand All @@ -77,5 +87,11 @@ def get_calculator(self) -> dict:
"reference": self.reference,
"scf_type": self.scf_type,
"maxiter": self.maxiter,
"charge": self.charge,
"multiplicity": self.multiplicity,
}
return params

def get_multiplicity(self) -> Optional[int]:
"""Return spin multiplicity (2S+1) for thermochemistry."""
return self.multiplicity
4 changes: 4 additions & 0 deletions src/chemgraph/schemas/calculators/tblite_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,7 @@ def get_calculator(self):
cache_api=self.cache_api,
verbosity=self.verbosity,
)

def get_multiplicity(self) -> Optional[int]:
"""Return spin multiplicity (2S+1) for thermochemistry."""
return self.multiplicity
10 changes: 9 additions & 1 deletion src/chemgraph/tools/ase_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,13 +583,21 @@ def _run_ase_impl(params: ASEInputSchema):
{"atomsdata": final_structure}
)

# IdealGasThermo expects total spin S; calculators expose
# multiplicity (2S+1) via get_multiplicity() when supported.
multiplicity = (
getattr(calc_model, "get_multiplicity", lambda: None)()
or 1
)
spin_S = (multiplicity - 1) / 2.0

thermo = IdealGasThermo(
vib_energies=energies,
potentialenergy=single_point_energy,
atoms=atoms,
geometry=geometry,
symmetrynumber=symmetrynumber,
spin=0, # Only support spin=0
spin=spin_S,
)
thermo_data = {
"enthalpy": float(
Expand Down
2 changes: 1 addition & 1 deletion src/chemgraph/utils/tool_call_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module for quick LLM evaluations"""

from deepdiff import DeepDiff
from chemgraph.models.ase_input import ASEInputSchema
from chemgraph.schemas.ase_input import ASEInputSchema


def remove_ignored_fields(obj, ignored_keys=("cell", "pbc")):
Expand Down