Skip to content
79 changes: 72 additions & 7 deletions source/api_cc/src/DeepSpinPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,43 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
int nloc = nall_real - nghost_real;
int nframes = 1;

// Build spin tensor for real atoms using bkw_map
std::vector<VALUETYPE> dspin(static_cast<size_t>(nall_real) * 3);
for (int ii = 0; ii < nall_real; ++ii) {
// Phantom-atom padding for the empty-subdomain corner case
// (``nloc_real == 0``). Multi-rank spin MD can land a rank with zero
// real local atoms when atoms migrate to other subdomains. The
// with-comm AOTI artifact, traced with ``nloc_min=1`` and lowered by
// inductor with an even stricter ``nloc >= 2`` runtime-check
// (silently bypassed because ``AOTI_RUNTIME_CHECK_INPUTS`` is unset by
// default), then SIGFPEs at runtime with an "integer divide by zero"
// inside inductor-generated shape arithmetic that uses ``nloc`` as a
// divisor. The failure is intermittent because inductor re-codegens
// across runs and only some compiles emit the offending divide.
//
// Fix: prepend two phantom atoms with no neighbours so the AOTI graph
// runs with ``nloc == 2``. The phantoms have an empty nlist row and
// therefore contribute zero atomic energy / force / virial, preserving
// the physically-correct "this rank has no real atoms" semantics.
// ``nlocal`` in the comm tensors is set to ``2`` so border_op writes
// received ghost features past the phantom slots; outputs are stripped
// of the phantom prefix before being scattered back to LAMMPS atoms
// via ``select_map``.
const int phantom_n = (nloc_real == 0 && nall_real > 0) ? 2 : 0;
Comment thread
wanghan-iapcm marked this conversation as resolved.
Comment thread
wanghan-iapcm marked this conversation as resolved.
if (phantom_n > 0) {
dcoord.insert(dcoord.begin(), static_cast<size_t>(phantom_n) * 3,
static_cast<VALUETYPE>(0));
datype.insert(datype.begin(), static_cast<size_t>(phantom_n), 0);
nall_real += phantom_n;
nloc_real = phantom_n;
nloc = nall_real - nghost_real;
}

// Build spin tensor for real atoms using bkw_map (skip phantom prefix
// which keeps zero spin).
std::vector<VALUETYPE> dspin(static_cast<size_t>(nall_real) * 3,
static_cast<VALUETYPE>(0));
for (int ii = phantom_n; ii < nall_real; ++ii) {
for (int dd = 0; dd < 3; ++dd) {
dspin[static_cast<size_t>(ii) * 3 + dd] =
spin[static_cast<size_t>(bkw_map[ii]) * 3 + dd];
spin[static_cast<size_t>(bkw_map[ii - phantom_n]) * 3 + dd];
}
}

Expand Down Expand Up @@ -445,11 +476,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();

// Rebuild mapping tensor
// Rebuild mapping tensor. Phantom slots (when phantom_n > 0) get
// identity entries — they index into their own row and never appear
// in any other atom's nlist (their nlist rows are all -1 below).
if (lmp_list.mapping) {
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]];
for (int ii = 0; ii < phantom_n; ii++) {
mapping[ii] = ii;
}
for (int ii = phantom_n; ii < nall_real; ii++) {
mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii - phantom_n]]];
}
Comment thread
wanghan-iapcm marked this conversation as resolved.
mapping_tensor =
torch::from_blob(mapping.data(), {1, nall_real}, int_option)
Expand All @@ -472,8 +508,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
}

// Flatten raw nlist — the .pt2 model sorts by distance on-device.
// Phantom rows (all -1) are prepended below so the AOTI graph sees
// nloc == phantom_n + nloc_real_orig instead of 0.
firstneigh_tensor =
createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device);
if (phantom_n > 0) {
auto phantom_rows = torch::full(
{1, phantom_n, nnei}, static_cast<std::int64_t>(-1),
torch::TensorOptions().dtype(torch::kInt64).device(device));
firstneigh_tensor = torch::cat({phantom_rows, firstneigh_tensor}, 1);
}
}

// Build fparam/aparam tensors
Expand Down Expand Up @@ -588,6 +632,17 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());

// Strip the phantom prefix (see phantom-atom padding comment near
// ``select_real_atoms_coord``) so the ``bkw_map`` lookup below sees
// only the real / ghost atoms it was built for. The phantom slots
// carry zero forces because their nlist rows were all -1 — they
// produce no neighbour contributions, so dropping them is exact.
if (phantom_n > 0) {
dforce.erase(dforce.begin(), dforce.begin() + phantom_n * 3);
dforce_mag.erase(dforce_mag.begin(), dforce_mag.begin() + phantom_n * 3);
nall_real -= phantom_n;
}

// bkw map: map force from real atoms back to full atom list
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
force_mag.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
Expand All @@ -612,6 +667,16 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
cpu_atom_virial_.data_ptr<VALUETYPE>(),
cpu_atom_virial_.data_ptr<VALUETYPE>() + cpu_atom_virial_.numel());

// Strip the phantom prefix from atomic outputs as well (see force
// block above). Phantom slots carry zero atomic energy / virial
// because their nlist rows were all -1.
if (phantom_n > 0) {
datom_energy.erase(datom_energy.begin(),
datom_energy.begin() + phantom_n);
datom_virial.erase(datom_virial.begin(),
datom_virial.begin() + phantom_n * 9);
}

atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
select_map<VALUETYPE>(atom_energy, datom_energy, bkw_map, 1, nframes,
Expand Down
Loading