-
Notifications
You must be signed in to change notification settings - Fork 617
Feat(pt): Support Density fitting #5465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
6c86368
ac44cf0
29ed198
6f6493d
b29b467
d8413fd
f970284
f642468
a781429
3448481
937b53a
d48beba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -406,11 +406,7 @@ def eval( | |
| coords, atom_types, len(atom_types.shape) > 1 | ||
| ) | ||
| request_defs = self._get_request_defs(atomic) | ||
| if "spin" not in kwargs or kwargs["spin"] is None: | ||
| out = self._eval_func(self._eval_model, numb_test, natoms)( | ||
| coords, cells, atom_types, fparam, aparam, request_defs, charge_spin | ||
| ) | ||
| else: | ||
| if "spin" in kwargs and kwargs["spin"] is not None: | ||
| out = self._eval_func(self._eval_model_spin, numb_test, natoms)( | ||
| coords, | ||
| cells, | ||
|
|
@@ -421,6 +417,21 @@ def eval( | |
| request_defs, | ||
| charge_spin, | ||
| ) | ||
| elif "grid" in kwargs and kwargs["grid"] is not None: | ||
| out = self._eval_func(self._eval_model_density, numb_test, natoms)( | ||
| coords, | ||
| cells, | ||
| atom_types, | ||
| np.array(kwargs["grid"]), | ||
| fparam, | ||
| aparam, | ||
| request_defs, | ||
| ) | ||
| return {"density": out} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Return ndarray instead of tuple for
Suggested fix- return {"density": out}
+ return {"density": out[0]}Also applies to: 769-775 🤖 Prompt for AI Agents |
||
| else: | ||
| out = self._eval_func(self._eval_model, numb_test, natoms)( | ||
| coords, cells, atom_types, fparam, aparam, request_defs | ||
| ) | ||
|
Comment on lines
+432
to
+434
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pass
Suggested fix- out = self._eval_func(self._eval_model, numb_test, natoms)(
- coords, cells, atom_types, fparam, aparam, request_defs
- )
+ out = self._eval_func(self._eval_model, numb_test, natoms)(
+ coords, cells, atom_types, fparam, aparam, request_defs, charge_spin
+ )Also applies to: 527-536 🤖 Prompt for AI Agents |
||
| return dict( | ||
| zip( | ||
| [x.name for x in request_defs], | ||
|
|
@@ -688,6 +699,80 @@ def _eval_model_spin( | |
| ) # this is kinda hacky | ||
| return tuple(results) | ||
|
|
||
| def _eval_model_density( | ||
| self, | ||
| coords: np.ndarray, | ||
| cells: np.ndarray | None, | ||
| atom_types: np.ndarray, | ||
| grid: np.ndarray, | ||
| fparam: np.ndarray | None, | ||
| aparam: np.ndarray | None, | ||
| request_defs: list[OutputVariableDef], | ||
| ): | ||
| model = self.dp.to(DEVICE) | ||
|
|
||
| nframes = coords.shape[0] | ||
| if len(atom_types.shape) == 1: | ||
| natoms = len(atom_types) | ||
| atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) | ||
| else: | ||
| natoms = len(atom_types[0]) | ||
|
|
||
| coord_input = torch.tensor( | ||
| coords.reshape([nframes, natoms, 3]), | ||
| dtype=GLOBAL_PT_FLOAT_PRECISION, | ||
| device=DEVICE, | ||
| ) | ||
| type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) | ||
| grid_input = torch.tensor( | ||
| grid.reshape([nframes, -1, 3]), | ||
| dtype=GLOBAL_PT_FLOAT_PRECISION, | ||
| device=DEVICE, | ||
| ) | ||
| ngrid = grid_input.shape[1] | ||
| if cells is not None: | ||
| box_input = torch.tensor( | ||
| cells.reshape([nframes, 3, 3]), | ||
| dtype=GLOBAL_PT_FLOAT_PRECISION, | ||
| device=DEVICE, | ||
| ) | ||
| else: | ||
| box_input = None | ||
| if fparam is not None: | ||
| fparam_input = to_torch_tensor( | ||
| fparam.reshape(nframes, self.get_dim_fparam()) | ||
| ) | ||
| else: | ||
| fparam_input = None | ||
| if aparam is not None: | ||
| aparam_input = to_torch_tensor( | ||
| aparam.reshape(nframes, natoms, self.get_dim_aparam()) | ||
| ) | ||
| else: | ||
| aparam_input = None | ||
|
|
||
| do_atomic_virial = any( | ||
| x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs | ||
| ) | ||
| batch_output = model( | ||
| coord_input, | ||
| type_input, | ||
| grid=grid_input, | ||
| box=box_input, | ||
| do_atomic_virial=do_atomic_virial, | ||
| fparam=fparam_input, | ||
| aparam=aparam_input, | ||
| ) | ||
| if isinstance(batch_output, tuple): | ||
| batch_output = batch_output[0] | ||
|
|
||
| results = [] | ||
| pt_name = "density" | ||
| density_shape = [nframes, ngrid] | ||
| out = batch_output[pt_name].reshape(density_shape).detach().cpu().numpy() | ||
| results.append(out) | ||
| return tuple(results) | ||
|
|
||
| def _get_output_shape( | ||
| self, odef: OutputVariableDef, nframes: int, natoms: int | ||
| ) -> list[int]: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,125 @@ | ||||||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||||||
| import torch | ||||||
|
|
||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||
| from deepmd.pt.loss.loss import ( | ||||||
| TaskLoss, | ||||||
| ) | ||||||
| from deepmd.pt.utils import ( | ||||||
| env, | ||||||
| ) | ||||||
| from deepmd.pt.utils.env import ( | ||||||
| GLOBAL_PT_FLOAT_PRECISION, | ||||||
| ) | ||||||
| from deepmd.utils.data import ( | ||||||
| DataRequirementItem, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class GridDensityLoss(TaskLoss): | ||||||
| def __init__( | ||||||
| self, | ||||||
| starter_learning_rate=1.0, | ||||||
| start_pref_d=0.0, | ||||||
| limit_pref_d=0.0, | ||||||
| inference=False, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| r"""Construct a layer to compute loss on grid density. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| starter_learning_rate : float | ||||||
| The learning rate at the start of the training. | ||||||
| start_pref_d : float | ||||||
| The prefactor of charge density loss at the start of the training. | ||||||
| limit_pref_d : float | ||||||
| The prefactor of charge density loss at the end of the training. | ||||||
| inference : bool | ||||||
| If true, it will output all losses found in output, ignoring the pre-factors. | ||||||
| **kwargs | ||||||
| Other keyword arguments. | ||||||
| """ | ||||||
| super().__init__() | ||||||
| self.starter_learning_rate = starter_learning_rate | ||||||
| self.has_d = (start_pref_d != 0.0 and limit_pref_d != 0.0) or inference | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix density-loss activation logic on Line 44. Using Suggested fix- self.has_d = (start_pref_d != 0.0 and limit_pref_d != 0.0) or inference
+ self.has_d = (start_pref_d != 0.0 or limit_pref_d != 0.0) or inference📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
|
|
||||||
| self.start_pref_d = start_pref_d | ||||||
| self.limit_pref_d = limit_pref_d | ||||||
| self.inference = inference | ||||||
|
|
||||||
| def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): | ||||||
| """Return loss on energy and force. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| input_dict : dict[str, torch.Tensor] | ||||||
| Model inputs. | ||||||
| model : torch.nn.Module | ||||||
| Model to be used to output the predictions. | ||||||
| label : dict[str, torch.Tensor] | ||||||
| Labels. | ||||||
| natoms : int | ||||||
| The local atom number. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| model_pred: dict[str, torch.Tensor] | ||||||
| Model predictions. | ||||||
| loss: torch.Tensor | ||||||
| Loss for model to minimize. | ||||||
| more_loss: dict[str, torch.Tensor] | ||||||
| Other losses for display. | ||||||
| """ | ||||||
| model_pred = model(**input_dict) | ||||||
| coef = learning_rate / self.starter_learning_rate | ||||||
| pref_d = self.limit_pref_d + (self.start_pref_d - self.limit_pref_d) * coef | ||||||
|
|
||||||
| loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] | ||||||
| more_loss = {} | ||||||
| # more_loss['log_keys'] = [] # showed when validation on the fly | ||||||
| # more_loss['test_keys'] = [] # showed when doing dp test | ||||||
| atom_norm = 1.0 / natoms | ||||||
|
|
||||||
| if self.has_d and "density" in model_pred and "density" in label: | ||||||
| density_pred = model_pred["density"] | ||||||
| density_label = label["density"] | ||||||
| find_density = label.get("find_density", 0.0) | ||||||
| pref_d = pref_d * find_density | ||||||
| density_pred_reshape = density_pred.reshape(-1) | ||||||
| density_label_reshape = density_label.reshape(-1) | ||||||
| l2_density_loss = torch.square( | ||||||
| density_label_reshape - density_pred_reshape | ||||||
| ).mean() | ||||||
| rmse_d = l2_density_loss.sqrt() | ||||||
| more_loss["rmse_d"] = self.display_if_exist(rmse_d.detach(), find_density) | ||||||
| l1_density_loss = torch.abs( | ||||||
| density_label_reshape - density_pred_reshape | ||||||
| ).mean() | ||||||
| loss += (pref_d * l1_density_loss).to(GLOBAL_PT_FLOAT_PRECISION) | ||||||
| mae_d = l1_density_loss | ||||||
| more_loss["mae_d"] = self.display_if_exist(mae_d.detach(), find_density) | ||||||
| return model_pred, loss, more_loss | ||||||
|
|
||||||
| @property | ||||||
| def label_requirement(self) -> list[DataRequirementItem]: | ||||||
| """Return data label requirements needed for this loss calculation.""" | ||||||
| label_requirement = [] | ||||||
| label_requirement.append( | ||||||
| DataRequirementItem( | ||||||
| "grid", | ||||||
| ndof=3, | ||||||
| atomic=True, # the grid is defined for each atom, so it is atomic | ||||||
| must=True, | ||||||
| high_prec=True, | ||||||
| ) | ||||||
| ) | ||||||
| if self.has_d: | ||||||
| label_requirement.append( | ||||||
| DataRequirementItem( | ||||||
| "density", | ||||||
| ndof=1, | ||||||
| atomic=True, | ||||||
| must=False, | ||||||
| high_prec=True, | ||||||
| ) | ||||||
| ) | ||||||
| return label_requirement | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against
grid=Nonebefore density extraction.if "grid" in kwargsalso matchesgrid=None, but the backend density branch only runs whengrid is not None; this can makeresults["density"]missing and raise at runtime. Use the same predicate here (kwargs.get("grid") is not None).Suggested fix
🤖 Prompt for AI Agents