Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions source/api_cc/src/DeepSpinPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,43 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
int nloc = nall_real - nghost_real;
int nframes = 1;

// Build spin tensor for real atoms using bkw_map
std::vector<VALUETYPE> dspin(static_cast<size_t>(nall_real) * 3);
for (int ii = 0; ii < nall_real; ++ii) {
// Phantom-atom padding for the empty-subdomain corner case
// (``nloc_real == 0``). Multi-rank spin MD can land a rank with zero
// real local atoms when atoms migrate to other subdomains. The
// with-comm AOTI artifact, traced with ``nloc_min=1`` and lowered by
// inductor with an even stricter ``nloc >= 2`` runtime-check
// (silently bypassed because ``AOTI_RUNTIME_CHECK_INPUTS`` is unset by
// default), then SIGFPEs at runtime with an "integer divide by zero"
// inside inductor-generated shape arithmetic that uses ``nloc`` as a
// divisor. The failure is intermittent because inductor re-codegens
// across runs and only some compiles emit the offending divide.
//
// Fix: prepend two phantom atoms with no neighbours so the AOTI graph
// runs with ``nloc == 2``. The phantoms have an empty nlist row and
// therefore contribute zero atomic energy / force / virial, preserving
// the physically-correct "this rank has no real atoms" semantics.
// ``nlocal`` in the comm tensors is set to ``2`` so border_op writes
// received ghost features past the phantom slots; outputs are stripped
// of the phantom prefix before being scattered back to LAMMPS atoms
// via ``select_map``.
const int phantom_n = (nloc_real == 0 && nall_real > 0) ? 2 : 0;
Comment thread
wanghan-iapcm marked this conversation as resolved.
Comment thread
wanghan-iapcm marked this conversation as resolved.
if (phantom_n > 0) {
dcoord.insert(dcoord.begin(), static_cast<size_t>(phantom_n) * 3,
static_cast<VALUETYPE>(0));
datype.insert(datype.begin(), static_cast<size_t>(phantom_n), 0);
nall_real += phantom_n;
nloc_real = phantom_n;
nloc = nall_real - nghost_real;
}

// Build spin tensor for real atoms using bkw_map (skip phantom prefix
// which keeps zero spin).
std::vector<VALUETYPE> dspin(static_cast<size_t>(nall_real) * 3,
static_cast<VALUETYPE>(0));
for (int ii = phantom_n; ii < nall_real; ++ii) {
for (int dd = 0; dd < 3; ++dd) {
dspin[static_cast<size_t>(ii) * 3 + dd] =
spin[static_cast<size_t>(bkw_map[ii]) * 3 + dd];
spin[static_cast<size_t>(bkw_map[ii - phantom_n]) * 3 + dd];
}
}

Expand Down Expand Up @@ -445,11 +476,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();

// Rebuild mapping tensor
// Rebuild mapping tensor. Phantom slots (when phantom_n > 0) get
// identity entries — they index into their own row and never appear
// in any other atom's nlist (their nlist rows are all -1 below).
if (lmp_list.mapping) {
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]];
for (int ii = 0; ii < phantom_n; ii++) {
mapping[ii] = ii;
}
for (int ii = phantom_n; ii < nall_real; ii++) {
mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii - phantom_n]]];
}
Comment thread
wanghan-iapcm marked this conversation as resolved.
mapping_tensor =
torch::from_blob(mapping.data(), {1, nall_real}, int_option)
Expand All @@ -472,8 +508,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
}

// Flatten raw nlist — the .pt2 model sorts by distance on-device.
// Phantom rows (all -1) are prepended below so the AOTI graph sees
// nloc == phantom_n + nloc_real_orig instead of 0.
firstneigh_tensor =
createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device);
if (phantom_n > 0) {
auto phantom_rows = torch::full(
{1, phantom_n, nnei}, static_cast<std::int64_t>(-1),
torch::TensorOptions().dtype(torch::kInt64).device(device));
firstneigh_tensor = torch::cat({phantom_rows, firstneigh_tensor}, 1);
}
}

// Build fparam/aparam tensors
Expand Down Expand Up @@ -566,6 +610,23 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
ener.assign(flat_energy_.data_ptr<ENERGYTYPE>(),
flat_energy_.data_ptr<ENERGYTYPE>() + flat_energy_.numel());

// Zero the reduced energy on an empty rank. Phantoms have constant
// atomic outputs (per-type bias + zero-neighbour MLP) that flow into
// ``energy_redu`` -- and on the spin path the SpinModel doubles atoms
// so the bias contribution appears for both real and spin phantom
// halves; subtracting only the real-half exposed by
// ``output_map["energy"]`` after the ``[:, :nloc]`` slice leaves the
// spin-half leaking into the MPI-reduced LAMMPS total. The physical
// contribution of a rank with no real local atoms is zero by
// definition, so just clear ``ener`` directly.
//
// Forces, force_mag, and virial are unaffected because phantom atomic
// outputs are coord-independent (no neighbours) so their derivatives
// are zero -- no analogous correction is needed.
if (phantom_n > 0) {
std::fill(ener.begin(), ener.end(), static_cast<ENERGYTYPE>(0));
}

// Extract force: energy_derv_r (nf, nall, 1, 3) -> (nf, nall, 3)
torch::Tensor force_tensor =
output_map["energy_derv_r"].squeeze(-2).view({-1}).to(floatType);
Expand All @@ -588,6 +649,17 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

// Strip the phantom prefix (see phantom-atom padding comment near
// ``select_real_atoms_coord``) so the ``bkw_map`` lookup below sees
// only the real / ghost atoms it was built for. The phantom slots
// carry zero forces because their nlist rows were all -1 — they
// produce no neighbour contributions, so dropping them is exact.
if (phantom_n > 0) {
dforce.erase(dforce.begin(), dforce.begin() + phantom_n * 3);
dforce_mag.erase(dforce_mag.begin(), dforce_mag.begin() + phantom_n * 3);
nall_real -= phantom_n;
}

// bkw map: map force from real atoms back to full atom list
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
force_mag.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
Expand All @@ -612,6 +684,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());

// Strip the phantom prefix from atomic outputs as well (see force
// block above). Phantom slots carry zero atomic energy / virial
// because their nlist rows were all -1.
if (phantom_n > 0) {
datom_energy.erase(datom_energy.begin(),
datom_energy.begin() + phantom_n);
datom_virial.erase(datom_virial.begin(),
datom_virial.begin() + phantom_n * 9);
}

atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
select_map<VALUETYPE>(atom_energy, datom_energy, bkw_map, 1, nframes,
Expand Down
Loading