Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,25 @@ def _make_dp_loader_set(
# LMDB path: single string → LmdbDataset
if isinstance(training_systems, str) and is_lmdb(training_systems):
auto_prob = training_dataset_params.get("auto_prob", None)
mixed_batch = training_dataset_params.get("mixed_batch", False)
train_data_single = LmdbDataset(
training_systems,
model_params_single["type_map"],
training_dataset_params["batch_size"],
mixed_batch=mixed_batch,
auto_prob_style=auto_prob,
)
if (
validation_systems is not None
and isinstance(validation_systems, str)
and is_lmdb(validation_systems)
):
val_mixed_batch = validation_dataset_params.get("mixed_batch", False)
validation_data_single = LmdbDataset(
validation_systems,
model_params_single["type_map"],
validation_dataset_params["batch_size"],
mixed_batch=val_mixed_batch,
)
elif validation_systems is not None:
validation_data_single = _make_dp_loader_set(
Expand Down
93 changes: 62 additions & 31 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,37 @@ def forward(
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
# Normalization exponent controls loss scaling with system size:
# - norm_exp=2 (intensive_ener_virial=True): loss uses 1/N² scaling, making it independent of system size
# - norm_exp=1 (intensive_ener_virial=False, legacy): loss uses 1/N scaling, which varies with system size

# Detect mixed batch format
is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None

atom_norms = None
if is_mixed_batch:
ptr = input_dict["ptr"]
natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes]
atom_norms = 1.0 / natoms_per_frame.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
atom_norm = None
else:
atom_norm = 1.0 / natoms
Comment thread
coderabbitai[bot] marked this conversation as resolved.
norm_exp = 2 if self.intensive_ener_virial else 1

def get_frame_norm(value: torch.Tensor) -> torch.Tensor:
assert atom_norms is not None
return atom_norms.to(device=value.device, dtype=value.dtype).view(
[-1] + [1] * (value.dim() - 1)
)

def weighted_mean(value: torch.Tensor, power: int = 1) -> torch.Tensor:
if atom_norms is None:
assert atom_norm is not None
return value.mean() * (atom_norm**power)
return (value * get_frame_norm(value) ** power).mean()

def normalized_rmse(diff: torch.Tensor) -> torch.Tensor:
if atom_norms is None:
assert atom_norm is not None
return torch.mean(torch.square(diff)).sqrt() * atom_norm
return torch.mean(torch.square(diff * get_frame_norm(diff))).sqrt()
if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
energy_label = label["energy"]
Expand All @@ -256,35 +282,37 @@ def forward(
energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1)
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
diff_e = energy_pred - energy_label
if self.loss_func == "mse":
l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label))
square_ener_diff = torch.square(diff_e)
l2_ener_loss = torch.mean(square_ener_diff)
if not self.inference:
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
if not self.use_huber:
loss += atom_norm**norm_exp * (pref_e * l2_ener_loss)
loss += pref_e * weighted_mean(square_ener_diff, norm_exp)
else:
energy_norm = (
atom_norm if atom_norms is None else get_frame_norm(energy_pred)
)
l_huber_loss = custom_huber_loss(
atom_norm * energy_pred,
atom_norm * energy_label,
energy_norm * energy_pred,
energy_norm * energy_label,
delta=self._huber_delta_energy,
)
loss += pref_e * l_huber_loss
rmse_e = l2_ener_loss.sqrt() * atom_norm
rmse_e = normalized_rmse(diff_e)
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
elif self.loss_func == "mae":
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
)
loss += atom_norm * (pref_e * l1_ener_loss)
abs_ener_diff = torch.abs(diff_e)
mae_e = weighted_mean(abs_ener_diff)
loss += pref_e * mae_e
more_loss["mae_e"] = self.display_if_exist(
l1_ener_loss.detach() * atom_norm,
mae_e.detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
Expand All @@ -293,9 +321,9 @@ def forward(
f"Loss type {self.loss_func} is not implemented for energy loss."
)
if mae:
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
mae_e = weighted_mean(torch.abs(diff_e))
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
mae_e_all = torch.mean(torch.abs(diff_e))
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)
Expand Down Expand Up @@ -439,41 +467,44 @@ def forward(
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
if self.loss_func == "mse":
l2_virial_loss = torch.mean(torch.square(diff_v))
square_virial_diff = torch.square(diff_v)
l2_virial_loss = torch.mean(square_virial_diff)
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
if not self.use_huber:
loss += atom_norm**norm_exp * (pref_v * l2_virial_loss)
loss += pref_v * weighted_mean(square_virial_diff, norm_exp)
else:
virial = model_pred["virial"].reshape(-1, 9)
virial_label = label["virial"].reshape(-1, 9)
virial_norm = (
atom_norm if atom_norms is None else get_frame_norm(virial)
)
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
(virial_norm * virial).reshape(-1),
(virial_norm * virial_label).reshape(-1),
delta=self._huber_delta_virial,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
rmse_v = normalized_rmse(diff_v)
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
)
elif self.loss_func == "mae":
l1_virial_loss = F.l1_loss(
label["virial"].reshape(-1),
model_pred["virial"].reshape(-1),
reduction="mean",
)
loss += atom_norm * (pref_v * l1_virial_loss)
abs_virial_diff = torch.abs(diff_v)
mae_v = weighted_mean(abs_virial_diff)
loss += pref_v * mae_v
more_loss["mae_v"] = self.display_if_exist(
l1_virial_loss.detach() * atom_norm,
mae_v.detach(),
find_virial,
)
else:
raise NotImplementedError(
f"Loss type {self.loss_func} is not implemented for virial loss."
)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
mae_v = weighted_mean(torch.abs(diff_v))
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
Expand Down
145 changes: 145 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,151 @@ def forward_atomic(
)
return fit_ret

def forward_common_atomic_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
extended_ptr: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
aparam : torch.Tensor | None
Atomic parameters [total_atoms, nda].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result_dict : dict[str, torch.Tensor]
Model predictions in flat format.
"""
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

if (
hasattr(self.fitting_net, "get_dim_fparam")
and self.fitting_net.get_dim_fparam() > 0
and fparam is None
):
default_fparam_tensor = self.fitting_net.get_default_fparam()
assert default_fparam_tensor is not None
fparam_input_for_des = torch.tile(
default_fparam_tensor.to(device=extended_coord.device).unsqueeze(0),
[ptr.numel() - 1, 1],
)
else:
fparam_input_for_des = fparam

# Descriptor and fitting both consume the flat atom layout.
descriptor_out = self.descriptor.forward_flat(
extended_coord,
extended_atype,
extended_batch,
nlist,
mapping,
batch,
ptr,
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

descriptor = descriptor_out.get("descriptor")
rot_mat = descriptor_out.get("rot_mat")
g2 = descriptor_out.get("g2")
h2 = descriptor_out.get("h2")

if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor.detach())

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
atype = extended_atype[central_ext_index]
else:
atype = extended_atype[central_ext_index]

fit_ret = self.fitting_net.forward_flat(
descriptor,
atype,
batch,
ptr,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
fit_ret = self.apply_out_stat(fit_ret, atype)

atom_mask = self.make_atom_mask(atype).to(torch.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0)

for kk in fit_ret.keys():
out_shape = fit_ret[kk].shape
out_shape2 = 1
for ss in out_shape[1:]:
out_shape2 *= ss
fit_ret[kk] = (
fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None]
).view(out_shape)
fit_ret["mask"] = atom_mask

if self.enable_eval_fitting_last_layer_hook:
if "middle_output" in fit_ret:
self.eval_fitting_last_layer_list.append(
fit_ret.pop("middle_output").detach()
)

return fit_ret

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
Expand Down
Loading