diff --git a/pyproject.toml b/pyproject.toml index d61f391..447c9a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "ase", "rdkit", "pymatgen", - "mace-torch", + "mace-torch", # disabled for UMA-only CPU environment "globus_sdk", "streamlit==1.48.1", "stmol==0.0.9", @@ -48,7 +48,6 @@ dependencies = [ "mcp", "fastmcp", "pytest-asyncio", - "grandalf==0.8" ] [project.optional-dependencies] @@ -78,7 +77,11 @@ rag = [ "sentence-transformers>=2.2.2", "pymupdf>=1.24.0", ] - +phonon = [ + "grandalf==0.8", + "phonopy>=2.33.0", + "seekpath>=2.2.1" +] [project.urls] "Homepage" = "https://github.com/argonne-lcf/ChemGraph" "Repository" = "https://github.com/argonne-lcf/ChemGraph" diff --git a/src/chemgraph/cli/formatting.py b/src/chemgraph/cli/formatting.py index be3b8cc..956da2e 100644 --- a/src/chemgraph/cli/formatting.py +++ b/src/chemgraph/cli/formatting.py @@ -214,18 +214,34 @@ def format_response(result: Any, verbose: bool = False) -> None: # Find the final AI response final_answer = "" for message in reversed(messages): + is_ai = False + raw_content = "" if hasattr(message, "content") and hasattr(message, "type"): - if message.type == "ai" and message.content.strip(): - content = message.content.strip() - if not _is_atomic_json(content): - final_answer = content - break + raw_content = message.content + is_ai = (message.type == "ai") elif isinstance(message, dict): - if message.get("type") == "ai" and message.get("content", "").strip(): - content = message["content"].strip() - if not _is_atomic_json(content): - final_answer = content - break + raw_content = message.get("content", "") + is_ai = (message.get("type") == "ai") + + if isinstance(raw_content, list): + text_parts = [] + for item in raw_content: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict) and "text" in item: + text_parts.append(item["text"]) + elif isinstance(item, dict) and "image_url" in item: + # Depending on structure, image_url might be a string or dict + url = item["image_url"] if isinstance(item["image_url"], str) else item["image_url"].get("url", "") + text_parts.append(f"![image]({url})") + content = "\n".join(text_parts).strip() + else: + content = str(raw_content).strip() + + if is_ai and content: + if not _is_atomic_json(content): + final_answer = content + break if final_answer: console.print( @@ -239,16 +255,26 @@ def format_response(result: Any, verbose: bool = False) -> None: # Check for structure data (valid JSON with atomic keys) for message in messages: - content = "" + raw_content = "" if hasattr(message, "content"): - content = message.content + raw_content = message.content elif isinstance(message, dict): - content = message.get("content", "") + raw_content = message.get("content", "") + + content_str = "" + if isinstance(raw_content, list): + for item in raw_content: + text = item if isinstance(item, str) else item.get("text", "") if isinstance(item, dict) else "" + if text and _is_atomic_json(text): + content_str = text + break + else: + content_str = str(raw_content) - if content and _is_atomic_json(content): + if content_str and _is_atomic_json(content_str): console.print( Panel( - Syntax(content, "json", theme="monokai"), + Syntax(content_str, "json", theme="monokai"), title="Molecular Structure Data", style="cyan", ) diff --git a/src/chemgraph/graphs/multi_agent.py b/src/chemgraph/graphs/multi_agent.py index 507e54b..fc9fd45 100644 --- a/src/chemgraph/graphs/multi_agent.py +++ b/src/chemgraph/graphs/multi_agent.py @@ -812,11 +812,13 @@ def construct_multi_agent_graph( smiles_to_coordinate_file, ) from chemgraph.tools.generic_tools import calculator + from chemgraph.tools.phonopy_tools import run_phonopy executor_tools = [ molecule_name_to_smiles, smiles_to_coordinate_file, run_ase, + run_phonopy, extract_output_json, calculator, ] diff --git a/src/chemgraph/graphs/single_agent.py b/src/chemgraph/graphs/single_agent.py index 7be83d7..bbfd549 100644 --- a/src/chemgraph/graphs/single_agent.py +++ b/src/chemgraph/graphs/single_agent.py @@ -12,6 +12,7 @@ molecule_name_to_smiles, smiles_to_coordinate_file, ) +from chemgraph.tools.phonopy_tools import run_phonopy from chemgraph.tools.report_tools import generate_html from chemgraph.tools.generic_tools import calculator, ask_human from chemgraph.prompt.single_agent_prompt import ( @@ -261,6 +262,7 @@ def ChemGraphAgent( tools = [ smiles_to_coordinate_file, run_ase, + run_phonopy, molecule_name_to_smiles, extract_output_json, calculator, @@ -450,6 +452,7 @@ def construct_single_agent_graph( smiles_to_coordinate_file, molecule_name_to_smiles, run_ase, + run_phonopy, extract_output_json, calculator, ] diff --git a/src/chemgraph/mcp/mcp_tools.py b/src/chemgraph/mcp/mcp_tools.py index 4265025..c001585 100644 --- a/src/chemgraph/mcp/mcp_tools.py +++ b/src/chemgraph/mcp/mcp_tools.py @@ -16,6 +16,8 @@ smiles_to_coordinate_file_core, ) from chemgraph.schemas.ase_input import ASEInputSchema +from chemgraph.tools.phonopy_core import run_phonopy_core +from chemgraph.schemas.phonopy_schema import PhonopyInputSchema mcp = FastMCP( @@ -136,6 +138,31 @@ async def run_ase(params: ASEInputSchema) -> dict: return run_ase_core(params) +@mcp.tool( + name="run_phonopy", + description="Run Phonopy calculations using specified input parameters.", +) +async def run_phonopy(params: PhonopyInputSchema) -> dict: + """Run Phonopy calculations using specified input parameters. + + Parameters + ---------- + params : PhonopyInputSchema + Input parameters for the Phonopy calculation + + Returns + ------- + dict + Output containing calculation status + """ + import io + from contextlib import redirect_stdout + + f = io.StringIO() + with redirect_stdout(f): + return run_phonopy_core(params) + + if __name__ == "__main__": from chemgraph.mcp.server_utils import run_mcp_server diff --git a/src/chemgraph/schemas/calculators/tblite_calc.py b/src/chemgraph/schemas/calculators/tblite_calc.py index 547c867..7e66649 100644 --- a/src/chemgraph/schemas/calculators/tblite_calc.py +++ b/src/chemgraph/schemas/calculators/tblite_calc.py @@ -4,13 +4,6 @@ from typing import List, Optional import logging -try: - from tblite.ase import TBLite -except ImportError: - logging.warning( - "tblite is not installed. If you want to use tblite, please install it using 'pip install tblite'." - ) - class TBLiteCalc(BaseModel): """TBLite tight-binding calculator configuration. @@ -96,6 +89,7 @@ def get_calculator(self): An ASE-compatible TBLite calculator instance with the specified configuration parameters """ + from tblite.ase import TBLite return TBLite( method=self.method, charge=self.charge, diff --git a/src/chemgraph/schemas/phonopy_schema.py b/src/chemgraph/schemas/phonopy_schema.py new file mode 100644 index 0000000..873dbaa --- /dev/null +++ b/src/chemgraph/schemas/phonopy_schema.py @@ -0,0 +1,136 @@ +import json +from pydantic import BaseModel, Field, field_validator, model_validator +from typing import Optional, List, Any +from chemgraph.schemas.atomsdata import AtomsData +from chemgraph.schemas.ase_input import CalculatorUnion, default_calculator, _calculator_key, available_calculator_classes + +class PhonopyInputSchema(BaseModel): + """ + Schema for defining input parameters used in Phonopy phonon calculations. + """ + input_structure_file: str = Field( + description="Path to the input coordinate file (e.g., CIF, XYZ, POSCAR) containing the relaxed atomic structure." + ) + output_results_file: str = Field( + default="phonopy_results.json", + description="Path to a JSON file where phonon simulation results will be saved.", + ) + supercell_matrix: Optional[List[int]] = Field( + default=None, + description="A list of 3 integers representing the diagonal of the supercell matrix (e.g., [2, 2, 2]). If not provided, it will be determined automatically to ensure dimensions are at least 10 Å.", + ) + is_2d: bool = Field( + default=False, + description="Set to true if the material is 2D. This forces the supercell matrix in the Z-direction to be 1.", + ) + calculator: CalculatorUnion = Field( + default_factory=default_calculator, + description="The ASE calculator to be used for force evaluations.", + ) + mesh: List[int] = Field( + default=[10, 10, 10], + description="Sampling mesh for reciprocal space (e.g., [10, 10, 10]).", + ) + calculate_dos: bool = Field( + default=True, + description="Whether to calculate and plot the Total Density of States (DOS).", + ) + calculate_thermal_properties: bool = Field( + default=True, + description="Whether to calculate and plot Thermal Properties (Free energy, Entropy, Heat Capacity).", + ) + calculate_band_structure: bool = Field( + default=True, + description="Whether to calculate and plot the phonon band structure (dispersion curve).", + ) + band_paths: Optional[List[List[List[float]]]] = Field( + default=None, + description="A list of q-point paths in reciprocal space (e.g., [[[0,0,0], [0.5,0,0], [0.333,0.333,0], [0,0,0]]]). If None but calculate_band_structure is True, it will automatically determine the standard high-symmetry paths using seekpath.", + ) + band_labels: Optional[List[str]] = Field( + default=None, + description="List of labels for the q-points in band_paths (e.g., ['Gamma', 'M', 'K', 'Gamma']). Total number of labels should match the total number of unique q-points in the path.", + ) + band_npoints: int = Field( + default=51, + description="Number of q-points to sample along the paths between high-symmetry points for the band structure.", + ) + supercell_target_length: float = Field( + default=10.0, + description="Target minimum length (in Å) for the auto-generated supercell matrix dimensions.", + ) + save_vasp_files: bool = Field( + default=True, + description="Whether to save the VASP-format FORCE_CONSTANTS and POSCAR-* files.", + ) + symprec: float = Field( + default=1e-5, + description="Symmetry tolerance used in Phonopy.", + ) + dft_phonon_file: Optional[str] = Field( + default=None, + description="Path to a DFT phonon data file (e.g., band.dat or results.dat) to compare with the calculated phonon dispersion.", + ) + + @model_validator(mode="before") + @classmethod + def _validate_calculator_type(cls, data: Any): + if not isinstance(data, dict): + return data + + calc = data.get("calculator") + if calc is None: + calc = default_calculator() + data["calculator"] = calc + + available_calcs = { + _calculator_key(c.__name__.removesuffix("Calc")): c + for c in available_calculator_classes + } + + if isinstance(calc, dict): + calc_name = calc.get("calculator_type") + if not calc_name: + raise ValueError("Calculator dictionary must have a 'calculator_type' key.") + + calc_key = _calculator_key(calc_name) + if calc_key not in available_calcs: + raise ValueError(f"Calculator {calc_name} is not an allowed or available calculator.") + + init_args = calc.copy() + init_args.pop("calculator_type", None) + data["calculator"] = available_calcs[calc_key](**init_args) + return data + + elif hasattr(calc, "__class__"): + calc_type_name = calc.__class__.__name__ + calc_key = _calculator_key(calc_type_name.removesuffix("Calc")) + if calc_key not in available_calcs: + raise ValueError(f"Calculator {calc_type_name} is not an allowed or available calculator.") + return data + + +class PhonopyOutputSchema(BaseModel): + """ + Schema for defining outputs from Phonopy phonon simulations. + """ + input_structure_file: str = Field(description="Path to the input coordinate file.") + simulation_input: PhonopyInputSchema = Field(description="Input parameters used.") + success: bool = Field(default=False, description="Indicates if the simulation finished correctly.") + error: str = Field(default="", description="Error captured during the simulation.") + supercell_matrix_used: List[int] = Field(default=[], description="The actual supercell matrix used.") + thermal_properties_plot: Optional[str] = Field(default=None, description="Path to the saved thermal properties plot.") + dos_plot: Optional[str] = Field(default=None, description="Path to the saved DOS plot.") + band_structure_plot: Optional[str] = Field(default=None, description="Path to the saved band structure plot.") + band_yaml: Optional[str] = Field(default=None, description="Path to the saved band.yaml file containing frequencies and k-points.") + phonopy_yaml: Optional[str] = Field(default=None, description="Path to the generated phonopy.yaml file containing force constants.") + force_constants_file: Optional[str] = Field(default=None, description="Path to the saved FORCE_CONSTANTS file.") + poscar_files: Optional[List[str]] = Field(default=None, description="List of paths to the saved POSCAR-* supercell files.") + wall_time: float = Field(default=None, description="Total wall time (in seconds) taken to complete the simulation.") + + @field_validator("error", mode="before") + @classmethod + def _coerce_error_to_string(cls, v: Any) -> str: + if v is None: + return "" + return v if isinstance(v, str) else str(v) diff --git a/src/chemgraph/tools/phonopy_core.py b/src/chemgraph/tools/phonopy_core.py new file mode 100644 index 0000000..8185245 --- /dev/null +++ b/src/chemgraph/tools/phonopy_core.py @@ -0,0 +1,390 @@ +"""Core simulation functions for phonopy calculations. + +This module provides the core logic for running phonon calculations +using phonopy and ASE calculators. +""" + +from __future__ import annotations + +import os +import time +import warnings +from pathlib import Path +from typing import List, Optional + +import numpy as np + +warnings.filterwarnings("ignore", module="phonopy") + +from chemgraph.schemas.phonopy_schema import PhonopyInputSchema, PhonopyOutputSchema +from chemgraph.tools.ase_core import _resolve_path, load_calculator + +def _auto_supercell_matrix(atoms, target_length: float = 10.0, is_2d: bool = False) -> List[int]: + """Calculate the optimal supercell matrix to ensure cell dimensions >= target_length.""" + lengths = atoms.cell.lengths() + matrix = [1, 1, 1] + for i in range(3): + if is_2d and i == 2: + matrix[i] = 1 + else: + if lengths[i] > 0: + matrix[i] = int(np.ceil(target_length / lengths[i])) + else: + matrix[i] = 1 + return matrix + + +def run_phonopy_core(params: PhonopyInputSchema) -> dict: + """Run a phonopy calculation using an ASE calculator for forces. + + Parameters + ---------- + params : PhonopyInputSchema + Fully validated phonopy input. + + Returns + ------- + dict + Result payload including file paths and status. + """ + try: + import phonopy + from phonopy import Phonopy + from phonopy.structure.atoms import PhonopyAtoms + except ImportError: + return { + "status": "failure", + "error_type": "ImportError", + "message": "Phonopy is not installed. Please install it using 'pip install phonopy'." + } + + from ase.io import read + from ase import Atoms + + start_time = time.time() + input_structure_file = params.input_structure_file + + if not os.path.isfile(input_structure_file): + return { + "status": "failure", + "error_type": "FileNotFoundError", + "message": f"Input structure file {input_structure_file} does not exist.", + } + + try: + atoms = read(input_structure_file) + except Exception as e: + return { + "status": "failure", + "error_type": type(e).__name__, + "message": f"Cannot read {input_structure_file} using ASE. Exception from ASE: {e}", + } + + # Determine Supercell Matrix + if params.supercell_matrix: + supercell_matrix_list = params.supercell_matrix + supercell_matrix = np.diag(supercell_matrix_list) + else: + supercell_matrix_list = _auto_supercell_matrix(atoms, target_length=params.supercell_target_length, is_2d=params.is_2d) + supercell_matrix = np.diag(supercell_matrix_list) + + # Initialize Phonopy + unitcell = PhonopyAtoms(symbols=atoms.get_chemical_symbols(), + cell=atoms.cell[:], + scaled_positions=atoms.get_scaled_positions()) + + phonon = Phonopy(unitcell, + supercell_matrix=supercell_matrix, + primitive_matrix='P', + symprec=params.symprec) + + phonon.generate_displacements(distance=0.01) + supercells = phonon.supercells_with_displacements + + # Load ASE Calculator + try: + calculator_dict = params.calculator.model_dump() + calc, system_info, calc_model = load_calculator(calculator_dict) + except Exception as e: + return { + "status": "failure", + "error_type": "ValueError", + "message": f"Error loading calculator: {e}", + } + + if calc is None: + return { + "status": "failure", + "error_type": "ValueError", + "message": "Unsupported calculator.", + } + + # Calculate Forces for Displaced Supercells + force_sets = [] + for scell in supercells: + if scell is None: + continue + ase_scell = Atoms(symbols=scell.symbols, + scaled_positions=scell.scaled_positions, + cell=scell.cell, + pbc=True) + ase_scell.info.update(system_info) + ase_scell.calc = calc + forces = ase_scell.get_forces() + force_sets.append(forces) + + # Set forces and produce force constants + phonon.forces = force_sets + phonon.produce_force_constants() + + mol_stem = Path(input_structure_file).stem if input_structure_file else "structure" + + # Save phonopy yaml + phonopy_yaml_path = _resolve_path(f"phonopy_{mol_stem}.yaml") + phonon.save(phonopy_yaml_path) + + fc_file_path = None + poscar_file_paths = [] + if params.save_vasp_files: + from phonopy.file_IO import write_FORCE_CONSTANTS + from phonopy.interface.vasp import write_vasp + + fc_file_path = _resolve_path(f"FORCE_CONSTANTS_{mol_stem}") + write_FORCE_CONSTANTS(phonon.force_constants, filename=fc_file_path) + + sposcar_path = _resolve_path(f"SPOSCAR_{mol_stem}") + write_vasp(sposcar_path, phonon.supercell) + poscar_file_paths.append(sposcar_path) + + for i, cell in enumerate(phonon.supercells_with_displacements): + if cell is not None: + p_path = _resolve_path(f"POSCAR-{i+1:03d}_{mol_stem}") + write_vasp(p_path, cell) + poscar_file_paths.append(p_path) + + dos_plot_path = None + tp_plot_path = None + + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + # Calculate DOS + if params.calculate_dos: + phonon.run_mesh(params.mesh, is_gamma_center=True) + phonon.run_total_dos() + dos_dict = phonon.get_total_dos_dict() + + fig, ax = plt.subplots() + ax.plot(dos_dict['frequency_points'], dos_dict['total_dos']) + ax.set_xlabel("Frequency (THz)") + ax.set_ylabel("Total DOS") + ax.set_title("Phonon Density of States") + ax.grid(True) + dos_plot_path = _resolve_path(f"dos_{mol_stem}.png") + fig.savefig(dos_plot_path, dpi=300) + plt.close(fig) + + # Calculate Thermal Properties + if params.calculate_thermal_properties: + phonon.run_thermal_properties() + tp_dict = phonon.get_thermal_properties_dict() + + fig, ax1 = plt.subplots() + ax1.plot(tp_dict['temperatures'], tp_dict['free_energy'], 'r-', label='Free Energy (kJ/mol)') + ax1.set_xlabel("Temperature (K)") + ax1.set_ylabel("Free Energy (kJ/mol)", color='r') + ax1.tick_params(axis='y', labelcolor='r') + + ax2 = ax1.twinx() + ax2.plot(tp_dict['temperatures'], tp_dict['entropy'], 'b--', label='Entropy (J/K/mol)') + ax2.plot(tp_dict['temperatures'], tp_dict['heat_capacity'], 'g-.', label='Heat Capacity (J/K/mol)') + ax2.set_ylabel("Entropy / Heat Capacity", color='b') + ax2.tick_params(axis='y', labelcolor='b') + + fig.tight_layout() + tp_plot_path = _resolve_path(f"thermal_properties_{mol_stem}.png") + fig.savefig(tp_plot_path, dpi=300) + plt.close(fig) + + # Calculate Band Structure + bs_plot_path = None + band_yaml_path = None + if params.calculate_band_structure: + band_yaml_path = _resolve_path(f"band_{mol_stem}.yaml") + if params.band_paths: + from phonopy.phonon.band_structure import get_band_qpoints + bands = get_band_qpoints(params.band_paths, npoints=params.band_npoints) + phonon.run_band_structure( + bands, + labels=params.band_labels, + with_eigenvectors=False + ) + phonon.write_yaml_band_structure(filename=band_yaml_path) + else: + # Fallback to automatic path generation using seekpath + phonon.auto_band_structure(npoints=params.band_npoints, with_eigenvectors=False, plot=False, write_yaml=True, filename=band_yaml_path) + + bs_plot_path = _resolve_path(f"band_structure_{mol_stem}.png") + if params.dft_phonon_file and os.path.isfile(params.dft_phonon_file): + # Custom comparison plot + bs_dict = phonon.get_band_structure_dict() + phonopy_distances = bs_dict['distances'] # list of 1D arrays + phonopy_frequencies = bs_dict['frequencies'] # list of 2D arrays: (npoints, nbands) + + fig, ax = plt.subplots(figsize=(9, 6.5), dpi=300) + + # Style the spines + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['left'].set_color('#2c3e50') + ax.spines['bottom'].set_color('#2c3e50') + ax.spines['left'].set_linewidth(1.2) + ax.spines['bottom'].set_linewidth(1.2) + + color_phonopy = '#1abc9c' # Deep teal/emerald for FAIRChem + color_dft = '#e74c3c' # Coral/red for DFT + + # Read DFT data + try: + with open(params.dft_phonon_file) as f: + dft_lines = f.readlines() + + dft_segments = [] + current_segment = [] + for line in dft_lines: + line_str = line.strip() + if not line_str: + if current_segment: + dft_segments.append(current_segment) + current_segment = [] + continue + if line_str.startswith('#'): + continue + parts = line_str.split() + if len(parts) >= 2: + try: + current_segment.append([float(x) for x in parts]) + except ValueError: + pass + if current_segment: + dft_segments.append(current_segment) + dft_segments = [seg for seg in dft_segments if len(seg) > 0] + + is_column_based = False + if dft_segments and len(dft_segments[0][0]) > 2: + is_column_based = True + + if is_column_based: + # Column-based format: x y1 y2 y3 ... + data_pts = [] + for seg in dft_segments: + data_pts.extend(seg) + data_pts = np.array(data_pts) + dft_x = data_pts[:, 0] + nbands_dft = data_pts.shape[1] - 1 + for b_idx in range(nbands_dft): + label = 'DFT (reference)' if b_idx == 0 else "" + ax.plot(dft_x, data_pts[:, b_idx + 1], color=color_dft, linestyle='--', linewidth=1.6, label=label, alpha=0.8) + else: + # Segment-based format: x y + npaths = len(phonopy_distances) + nbands_dft = len(dft_segments) // npaths + for b_idx in range(nbands_dft): + seg_x = [] + seg_y = [] + for p_idx in range(npaths): + if b_idx * npaths + p_idx < len(dft_segments): + seg = dft_segments[b_idx * npaths + p_idx] + seg_x.extend([pt[0] for pt in seg]) + seg_y.extend([pt[1] for pt in seg]) + label = 'DFT (reference)' if b_idx == 0 else "" + ax.plot(seg_x, seg_y, color=color_dft, linestyle='--', linewidth=1.6, label=label, alpha=0.8) + except Exception as dft_err: + print(f"Error parsing DFT phonon file: {dft_err}") + + # Plot Phonopy (FAIRChem/OMat) + nbands_phonopy = phonopy_frequencies[0].shape[1] + flat_distances = [] + for p_idx, dist_array in enumerate(phonopy_distances): + for b_idx in range(nbands_phonopy): + label = 'FAIRChem (OMat)' if (p_idx == 0 and b_idx == 0) else "" + ax.plot(dist_array, phonopy_frequencies[p_idx][:, b_idx], color=color_phonopy, linestyle='-', linewidth=2.2, label=label, alpha=0.9) + flat_distances.extend(dist_array) + + # Draw vertical lines for high-symmetry points + high_sym_positions = [phonopy_distances[0][0]] + for dist_array in phonopy_distances: + high_sym_positions.append(dist_array[-1]) + + for pos in high_sym_positions: + ax.axvline(x=pos, color='#bdc3c7', linestyle=':', linewidth=1.2) + + # Set x-ticks and limits + ax.set_xticks(high_sym_positions) + labels = None + if params.band_labels: + labels = [] + for label in params.band_labels: + if label.lower() in ('gamma', 'g'): + labels.append(r'$\Gamma$') + else: + labels.append(label) + elif hasattr(phonon, 'band_structure') and getattr(phonon.band_structure, 'labels', None) is not None: + labels = phonon.band_structure.labels + + if labels and len(labels) == len(high_sym_positions): + ax.set_xticklabels(labels, fontsize=12, fontweight='bold', color='#2c3e50') + ax.set_xlim(0, max(flat_distances)) + + # Labels and titles + ax.set_ylabel('Frequency (THz)', fontsize=13, fontweight='bold', color='#2c3e50') + ax.set_title(f'Phonon Dispersion: {mol_stem}', fontsize=16, fontweight='bold', pad=20, color='#2c3e50') + ax.text(0.5, 1.02, f'FAIRChem (OMat) vs DFT (from {os.path.basename(params.dft_phonon_file)})', transform=ax.transAxes, + ha='center', fontsize=11, color='#7f8c8d') + + # Grid and Legend + ax.grid(True, which='both', linestyle=':', alpha=0.5, color='#ecf0f1') + ax.legend(frameon=True, facecolor='white', edgecolor='#ecf0f1', fontsize=11, loc='upper right', shadow=True) + + plt.tight_layout() + fig.savefig(bs_plot_path, dpi=300) + plt.close(fig) + else: + bs_plot = phonon.plot_band_structure() + bs_plot.savefig(bs_plot_path, dpi=300) + bs_plot.close() + + end_time = time.time() + wall_time = end_time - start_time + + # Save output schema + output_results_file = _resolve_path(params.output_results_file) + + simulation_output = PhonopyOutputSchema( + input_structure_file=input_structure_file, + simulation_input=params, + success=True, + supercell_matrix_used=supercell_matrix_list, + thermal_properties_plot=tp_plot_path, + dos_plot=dos_plot_path, + band_structure_plot=bs_plot_path, + phonopy_yaml=phonopy_yaml_path, + wall_time=wall_time, + ) + + with open(output_results_file, "w", encoding="utf-8") as wf: + wf.write(simulation_output.model_dump_json(indent=4)) + + return { + "status": "success", + "message": f"Phonopy simulation completed. Results saved to {os.path.abspath(output_results_file)}", + "phonopy_yaml": os.path.abspath(phonopy_yaml_path), + "dos_plot": os.path.abspath(dos_plot_path) if dos_plot_path else None, + "thermal_properties_plot": os.path.abspath(tp_plot_path) if tp_plot_path else None, + "band_structure_plot": os.path.abspath(bs_plot_path) if bs_plot_path else None, + "band_yaml": os.path.abspath(band_yaml_path) if band_yaml_path else None, + "force_constants_file": os.path.abspath(fc_file_path) if fc_file_path else None, + "poscar_files": [os.path.abspath(p) for p in poscar_file_paths] if poscar_file_paths else None, + "supercell_matrix_used": supercell_matrix_list, + } diff --git a/src/chemgraph/tools/phonopy_tools.py b/src/chemgraph/tools/phonopy_tools.py new file mode 100644 index 0000000..202973d --- /dev/null +++ b/src/chemgraph/tools/phonopy_tools.py @@ -0,0 +1,36 @@ +"""LangChain ``@tool`` wrappers over :mod:`chemgraph.tools.phonopy_core`.""" + +from __future__ import annotations + +from langchain_core.tools import tool + +from chemgraph.schemas.phonopy_schema import PhonopyInputSchema +from chemgraph.tools.phonopy_core import run_phonopy_core +from chemgraph.schemas.calculators.mace_calc import _mace_lock + + +@tool +def run_phonopy(params: PhonopyInputSchema) -> dict: + """Run Phonopy calculations using specified input parameters. + + Parameters + ---------- + params : PhonopyInputSchema + Input parameters for the Phonopy calculation, including the + structure file, supercell configuration, and calculator. + + Returns + ------- + dict + Output containing calculation results, paths to plots, and status. + + Raises + ------ + ValueError + If the calculator is not supported or if the calculation fails. + """ + calc_type = params.calculator.calculator_type.lower() + if "mace" in calc_type: + with _mace_lock: + return run_phonopy_core(params) + return run_phonopy_core(params)