Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def forward_common_atomic(
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype)

for kk in ret_dict.keys():
for kk in list(ret_dict.keys()):
if kk.startswith("_"):
continue
out_shape = ret_dict[kk].shape
out_shape2 = 1
for ss in out_shape[2:]:
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
Comment thread
wanghan-iapcm marked this conversation as resolved.
extended_coord = extended_coord.detach().clone().requires_grad_(True)

# Handle default chg_spin if descriptor supports it
if self.add_chg_spin_ebd and charge_spin is None:
Expand Down Expand Up @@ -302,6 +302,8 @@ def forward_atomic(
fparam=fparam,
aparam=aparam,
)
if self.do_grad_r() or self.do_grad_c():
fit_ret["_force_coord"] = extended_coord
if self.enable_eval_fitting_last_layer_hook:
assert "middle_output" in fit_ret, (
"eval_fitting_last_layer not supported for this fitting net!"
Expand Down
6 changes: 4 additions & 2 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def forward_atomic(
the result dict, defined by the fitting net output def.
"""
nframes, nloc, nnei = nlist.shape
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord = extended_coord.detach().clone().requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
Expand Down Expand Up @@ -304,6 +304,8 @@ def forward_atomic(
dim=0,
),
} # (nframes, nloc, 1)
if self.do_grad_r() or self.do_grad_c():
fit_ret["_force_coord"] = extended_coord
return fit_ret

def apply_out_stat(
Expand Down
7 changes: 5 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def forward_atomic(
) -> dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
Comment thread
wanghan-iapcm marked this conversation as resolved.
extended_coord = extended_coord.detach().clone().requires_grad_(True)

# this will mask all -1 in the nlist
mask = nlist >= 0
Expand Down Expand Up @@ -313,6 +313,9 @@ def forward_atomic(
dim=-1,
).unsqueeze(-1)

if self.do_grad_r() or self.do_grad_c():
atomic_energy = atomic_energy + 0.0 * extended_coord.sum()[..., None, None]
return {"energy": atomic_energy, "_force_coord": extended_coord}
return {"energy": atomic_energy}

def _pair_tabulated_inter(
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def forward_common_lower(
cc_ext, _, fp, ap, input_prec = self._input_type_cast(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
del fparam, aparam
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext,
extended_atype,
Expand All @@ -316,10 +316,11 @@ def forward_common_lower(
comm_dict=comm_dict,
charge_spin=charge_spin,
)
force_coord = atomic_ret.pop("_force_coord", cc_ext)
model_predict = fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
force_coord,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
Expand Down
28 changes: 28 additions & 0 deletions source/tests/pt/model/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,34 @@ def test_self_consistency(self) -> None:
to_numpy_array(ret1["energy"]),
)

def test_forward_common_atomic_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE)

coord = to_torch_tensor(self.coord_ext).requires_grad_(True)
coord_view = coord.view(self.nf, self.nall, 3)
args = [
coord_view,
to_torch_tensor(self.atype_ext),
to_torch_tensor(self.nlist),
]
atomic_ret = md0.forward_atomic(*args)
ret = md0.forward_common_atomic(*args)

self.assertIn("_force_coord", atomic_ret)
self.assertIn("energy", ret)
Comment thread
wanghan-iapcm marked this conversation as resolved.

def test_dp_consistency(self) -> None:
nf, nloc, nnei = self.nlist.shape
ds = DPDescrptSeA(
Expand Down
35 changes: 35 additions & 0 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,41 @@ def test_self_consistency(self) -> None:
atol=self.atol,
)

def test_forward_lower_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = EnergyFittingNet(
self.nt,
ds.get_dim_out(),
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE)

coord_ext, atype_ext, _ = extend_coord_with_ghosts(
to_torch_tensor(self.coord),
to_torch_tensor(self.atype),
to_torch_tensor(self.cell),
self.rcut,
)
nlist = build_neighbor_list(
coord_ext,
atype_ext,
self.nloc,
self.rcut,
self.sel,
distinguish_types=(not md0.mixed_types()),
)
coord_view = coord_ext.requires_grad_(True).view(self.nf, -1, 3)

ret = md0.forward_lower(coord_view, atype_ext, nlist, do_atomic_virial=True)

self.assertIn("extended_force", ret)
self.assertIn("virial", ret)

def test_dp_consistency(self) -> None:
nf, nloc = self.atype.shape
nfp, nap = 2, 3
Expand Down
Loading