From 4602960eb8c78259ca38a3f5c6d5610b3f77921a Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:02:19 +0800 Subject: [PATCH 1/8] fix: try fix dpa4 compile --- deepmd/pt/model/model/sezm_model.py | 499 +++++++++++++++++++++++----- 1 file changed, 425 insertions(+), 74 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 1a35196de0..3474e0aabd 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -87,12 +87,11 @@ function: * ``core_compute`` rebuilds a compact, GPU-friendly edge list from the - padded DeePMD neighbor list (``build_edge_list_from_nlist``), with - masked dummy edges appended so the edge tensor has a non-singular - symbolic lower bound (NOTE 10). Edge vectors come from - ``index_select`` on the extended coordinate tensor, which keeps the - gradient path back to coordinates explicit and safe under symbolic - shapes (NOTE 11). + padded DeePMD neighbor list (``build_edge_list_from_nlist``), with a + single masked dummy edge appended so the edge tensor is never empty + (NOTE 10). Edge vectors come from ``index_select`` on the extended + coordinate tensor, which keeps the gradient path back to coordinates + explicit and safe under symbolic shapes (NOTE 11). * The SeZM descriptor consumes the edge list and produces per-atom features. * The fitting network predicts per-atom energy; ``apply_out_stat`` adds @@ -323,20 +322,17 @@ In eval mode we merely detach; no ``create_graph`` is requested, so the compiled kernel never has to build a backward graph. -NOTE 10 -- Tail dummy edges ---------------------------- +NOTE 10 -- Tail dummy edge +-------------------------- -``build_edge_list_from_nlist`` appends two masked edges at the end of -every batch. Real edge compaction happens via +``build_edge_list_from_nlist`` appends exactly one masked edge at the +end of every batch. Real edge compaction happens via ``torch.nonzero(valid_mask)``, whose output length is data-dependent and can be zero in sparse or single-type systems. make_fx cannot trace an "if n_edges == 0: skip" branch symbolically; without the dummy it would fall back to concrete shape specialization and break -``dynamic=True``. A pair of dummy slots also gives Inductor's batched -matmul lowering a static ``E >= 2`` edge-axis bound, avoiding -data-dependent layout guards on ``E == 1``. Each dummy's ``edge_mask`` -is ``False`` so it contributes exactly zero to every downstream sum or -gather. +``dynamic=True``. The dummy's ``edge_mask`` is ``False`` so it +contributes exactly zero to every downstream sum or gather. NOTE 11 -- ``index_select`` for coordinate gradients ---------------------------------------------------- @@ -378,9 +374,6 @@ from einops import ( rearrange, ) -from packaging.version import ( - Version, -) from torch.fx.experimental.proxy_tensor import ( make_fx, ) @@ -447,6 +440,121 @@ _dynamo_cfg.optimize_ddp = False +# --------------------------------------------------------------------------- +# Multi-task compile sharing +# --------------------------------------------------------------------------- +# Maps (structure_key..., training, do_atomic_virial, has_coord_corr) to the +# compiled callable. Tasks whose descriptor AND fitting-net first child have +# the same Python-object identity (after share_params) reuse a single compiled +# graph, avoiding N×compile-cache OOM and N DDP graph boundaries (NCCL timeout). +_SEZM_COMPILE_CACHE: dict[tuple, Any] = {} + +# Maps structure_key -> task_buf_order so every instance in the same group +# knows which buffers were promoted and in what order. +_SEZM_TASK_BUF_ORDER: dict[tuple[int, ...], tuple[str, ...]] = {} + +# Prefix namespace for promoted buffer names. +_AM_PREFIX = "am/" # atomic_model registered buffer +_FIT_PREFIX = "fit/" # fitting_net registered buffer +_FIT_ATTR_PREFIX = "fit_attr/" # fitting_net plain tensor attribute (not in _buffers) + + +def _sezm_structure_key(model: "SeZMModel") -> tuple[int, ...]: + """Return a key that is equal iff two SeZMModel instances can share a compiled graph. + + After ``share_params``, the descriptor and fitting-net module objects + themselves remain *different* Python objects per task; only their + *submodules* (``_modules`` dict entries) are replaced with shared + references. Using ``id(descriptor)`` or ``id(fitting_net)`` would + therefore always differ between tasks and defeat the cache. + + Fix: use the id of the *first named child* of each module. After + ``share_params(level=0)``, those children are the same Python objects + for all tasks in the same structure group, giving matching keys. + + NOTE: only the FIRST child is sampled, assuming "first child shared => + whole module shared" (true for level=0). Under ``share_params(level=1)`` + only ``type_embedding`` is shared; if it is the first child, two tasks + whose other descriptor weights differ would collapse to the same key and + wrongly reuse one compiled graph. If level=1 + compile is ever used, key + on all param ids instead, e.g. ``frozenset(id(p) for p in desc.parameters())``. + """ + try: + desc = model.atomic_model.descriptor + desc_id = 0 + for _, child in desc.named_children(): + desc_id = id(child) + break + if desc_id == 0: + # Descriptor has no named children (unlikely); fall back. + desc_id = id(desc) + except AttributeError: + desc_id = 0 + try: + fitting = model.atomic_model.fitting_net + for _, child in fitting.named_children(): + return (desc_id, id(child)) + return (desc_id, id(fitting)) + except AttributeError: + return (desc_id, id(model)) + + +def _get_sezm_task_buf_names(model: "SeZMModel") -> tuple[str, ...]: + """Return the ordered names of per-task buffers to promote as FX placeholders. + + Always promotes: + * ``out_bias``, ``out_std`` on ``atomic_model`` — may be replaced + out-of-place by ``model_change_out_bias``, so the compiled graph must + never bake them as constants. + * ``bias_atom_e`` on the fitting net — task-specific per-type bias that + differs across tasks after ``share_params``. + * ``case_embd`` on the fitting net — task-identity vector used for + multi-task case conditioning; stored as a plain tensor attribute. + """ + names: list[str] = [] + try: + am = model.atomic_model + for bname in ("out_bias", "out_std"): + if am._buffers.get(bname) is not None: + names.append(_AM_PREFIX + bname) + try: + fitting = am.fitting_net + for bname in ("bias_atom_e",): + if fitting._buffers.get(bname) is not None: + names.append(_FIT_PREFIX + bname) + for aname in ("case_embd",): + val = getattr(fitting, aname, None) + if val is not None and torch.is_tensor(val): + names.append(_FIT_ATTR_PREFIX + aname) + except AttributeError: + pass + except AttributeError: + pass + return tuple(names) + + +def _get_sezm_task_buf_vals( + model: "SeZMModel", + names: tuple[str, ...], +) -> tuple[torch.Tensor, ...]: + """Return the current tensor values for the given promoted-buffer names.""" + if not names: + return () + am = model.atomic_model + try: + fitting = am.fitting_net + except AttributeError: + fitting = None + vals: list[torch.Tensor] = [] + for name in names: + if name.startswith(_AM_PREFIX): + vals.append(am._buffers[name[len(_AM_PREFIX) :]]) + elif name.startswith(_FIT_PREFIX): + vals.append(fitting._buffers[name[len(_FIT_PREFIX) :]]) # type: ignore[union-attr] + elif name.startswith(_FIT_ATTR_PREFIX): + vals.append(getattr(fitting, name[len(_FIT_ATTR_PREFIX) :])) + return tuple(vals) + def _parse_optional_env_bool(var_name: str) -> bool | None: """ @@ -624,6 +732,9 @@ def __init__( # compile products instead of evicting the other mode. object.__setattr__(self, "compiled_core_compute_cache", {}) object.__setattr__(self, "compiled_dens_compute", None) + # Maps cache_key -> task_buf_order for this instance so forward() + # knows which buffers to pass and in what order. + object.__setattr__(self, "_task_buf_order_cache", {}) # Training follows `use_compile`. Evaluation/inference reads # `DP_COMPILE_INFER` at init time and falls back to eager when unset. self._env_use_compile_infer: bool | None = _parse_optional_env_bool( @@ -1010,6 +1121,14 @@ def forward_common_after_nlist( extended_coord_corr=extended_coord_corr, ) compiled_core_compute = self.compiled_core_compute_cache[cache_key] + # Read current values of per-task buffers (optimizer steps + # update them in-place; out-of-place replacements from + # model_change_out_bias are captured because we read fresh + # each call rather than caching the values at compile time). + _task_buf_vals = _get_sezm_task_buf_vals( + self, + getattr(self, "_task_buf_order_cache", {}).get(cache_key, ()), + ) with nvtx_range("SeZM/core_compute"): if extended_coord_corr is None: model_predict_lower = compiled_core_compute( @@ -1020,6 +1139,7 @@ def forward_common_after_nlist( fp, ap, charge_spin, + *_task_buf_vals, ) else: model_predict_lower = compiled_core_compute( @@ -1031,6 +1151,7 @@ def forward_common_after_nlist( ap, charge_spin, extended_coord_corr, + *_task_buf_vals, ) if ( self._core_compute_pending_compile_t0 is not None @@ -1524,6 +1645,31 @@ def trace_and_compile( mode = "train" if self.training else "eval" has_coord_corr = extended_coord_corr is not None + _compile_t0 = time.perf_counter() + + # --- Check module-level shared cache first --- + # Tasks sharing the same descriptor+fitting structure (after share_params) + # should share one compiled graph. If a sibling task already compiled, + # populate this instance's per-instance caches and return immediately. + structure_key = _sezm_structure_key(self) + cache_key = (bool(self.training), bool(do_atomic_virial), has_coord_corr) + full_cache_key = structure_key + cache_key + if full_cache_key in _SEZM_COMPILE_CACHE: + self.compiled_core_compute_cache[cache_key] = _SEZM_COMPILE_CACHE[ + full_cache_key + ] + self._task_buf_order_cache[cache_key] = _SEZM_TASK_BUF_ORDER.get( + structure_key, () + ) + log.info( + "SeZM: reusing shared compiled graph " + "(mode=%s, atomic_virial=%s, coord_corr=%s)", + mode, + do_atomic_virial, + has_coord_corr, + ) + return + log.info( "SeZM: start tracing and compiling " "(mode=%s, atomic_virial=%s, coord_corr=%s)", @@ -1531,7 +1677,71 @@ def trace_and_compile( do_atomic_virial, has_coord_corr, ) - _compile_t0 = time.perf_counter() + + # --- Detect per-task buffers to promote as FX placeholders --- + # These buffers differ across tasks in the same structure group (they are + # NOT shared by share_params) or may be replaced out-of-place after + # compilation. Passing them as explicit arguments makes the compiled + # graph reusable across all tasks in the group. + task_buf_names = _get_sezm_task_buf_names(self) + task_buf_vals_trace = _get_sezm_task_buf_vals(self, task_buf_names) + + # Resolve module references once for the buffer-patching closures. + _am_patch = self.atomic_model + try: + _fitting_patch: torch.nn.Module | None = _am_patch.fitting_net + except AttributeError: + _fitting_patch = None + + def _patch_task_bufs( + vals: tuple[torch.Tensor, ...], + ) -> dict[str, torch.Tensor | None]: + """Temporarily replace model buffers/attrs with FX proxy tensors. + + Executed at trace time inside compute_fn. make_fx records the + proxy tensors as placeholder nodes, so the compiled graph reads them + as live inputs rather than baked-in constants. The ``finally`` + block in compute_fn always calls ``_restore_task_bufs`` to leave + the model in its original state after tracing. + """ + saved: dict[str, torch.Tensor | None] = {} + for name, val in zip(task_buf_names, vals): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + saved[name] = _am_patch._buffers.get(actual) + _am_patch._buffers[actual] = val + elif name.startswith(_FIT_PREFIX): + actual = name[len(_FIT_PREFIX) :] + saved[name] = ( + _fitting_patch._buffers.get(actual) + if _fitting_patch is not None + else None + ) + if _fitting_patch is not None: + _fitting_patch._buffers[actual] = val + elif name.startswith(_FIT_ATTR_PREFIX): + actual = name[len(_FIT_ATTR_PREFIX) :] + saved[name] = getattr(_fitting_patch, actual, None) + if _fitting_patch is not None: + setattr(_fitting_patch, actual, val) + return saved + + def _restore_task_bufs( + saved: dict[str, torch.Tensor | None], + ) -> None: + """Restore original model buffers/attrs after tracing.""" + for name, orig in saved.items(): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + _am_patch._buffers[actual] = orig + elif name.startswith(_FIT_PREFIX): + actual = name[len(_FIT_PREFIX) :] + if _fitting_patch is not None: + _fitting_patch._buffers[actual] = orig + elif name.startswith(_FIT_ATTR_PREFIX): + actual = name[len(_FIT_ATTR_PREFIX) :] + if _fitting_patch is not None: + setattr(_fitting_patch, actual, orig) need_coord_grad = self.do_grad_r() or self.do_grad_c() @@ -1552,6 +1762,13 @@ def _prepare_coord_for_trace(coord: torch.Tensor) -> torch.Tensor: else: return coord.detach() + # NOTE: compute_fn accepts *task_buf_vals after the fixed tensor args. + # make_fx treats each element as a separate placeholder so the compiled + # graph reads them as live inputs every call — not baked-in constants. + # The buffer-patching trick: at trace time the proxy tensors are written + # into _buffers / __dict__ so that downstream code (apply_out_stat, + # fitting_net.forward) reads the proxies and the ops are recorded in the + # FX graph. The finally block restores original state unconditionally. if extended_coord_corr is None: def compute_fn( @@ -1562,21 +1779,27 @@ def compute_fn( fp: torch.Tensor, ap: torch.Tensor, charge_spin: torch.Tensor, + *task_buf_vals: torch.Tensor, ) -> dict[str, torch.Tensor]: - return self.core_compute( - _prepare_coord_for_trace(extended_coord), - extended_atype, - nlist, - mapping=mapping, - fparam=fp, - aparam=ap, - charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, - extra_nlist_sort=self.need_sorted_nlist_for_lower(), - ) + _saved = _patch_task_bufs(task_buf_vals) + try: + return self.core_compute( + _prepare_coord_for_trace(extended_coord), + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + finally: + _restore_task_bufs(_saved) + else: - def compute_fn( + def compute_fn( # type: ignore[misc] extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, @@ -1585,22 +1808,27 @@ def compute_fn( ap: torch.Tensor, charge_spin: torch.Tensor, extended_coord_corr: torch.Tensor, + *task_buf_vals: torch.Tensor, ) -> dict[str, torch.Tensor]: # NOTE: Spin virial uses a coordinate correction derived from the # virtual-atom displacement. Keeping it as a tensor input lets the # compiled graph stay reusable across frames. - return self.core_compute( - _prepare_coord_for_trace(extended_coord), - extended_atype, - nlist, - mapping=mapping, - fparam=fp, - aparam=ap, - charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, - extra_nlist_sort=self.need_sorted_nlist_for_lower(), - extended_coord_corr=extended_coord_corr, - ) + _saved = _patch_task_bufs(task_buf_vals) + try: + return self.core_compute( + _prepare_coord_for_trace(extended_coord), + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + extended_coord_corr=extended_coord_corr, + ) + finally: + _restore_task_bufs(_saved) # NOTE: Always trace with a fixed batch size that is free of known # symbolic-shape collisions. @@ -1615,13 +1843,71 @@ def compute_fn( # ``torch.compile(dynamic=True)`` to reject later batches whose # nf differs from the traced constant. # - # If a future code change introduces a new explicit dimension of - # this size and compile starts failing with a similar shape - # mismatch, change this constant accordingly. + # The same aliasing hazard applies to the nloc / nsel axes of the + # nlist tensor (shape [nf, nloc, nsel]). When nloc == nsel in the + # trace inputs, make_fx folds them into a single symbol and the + # compiled kernel hard-codes nloc = nsel for every future call. + # This causes wrong energy/force shapes whenever a later batch + # arrives with nloc != nsel (only reproducible when a training + # task happens to have nloc == nsel in its first batch). + # Fix: drop one local atom row from the trace nlist so that + # nlist_for_trace.shape[1] = nloc - 1 != nsel. The nsel dimension + # is untouched so format_nlist takes the same no-op case-1 path. + # + # Task-buffer aliasing: task-specific buffers (out_bias, bias_atom_e, + # case_embd) are passed as extra arguments and also have static shapes. + # If trace_nf equals any of those static dims, make_fx would unify the + # nf symbol with that static dim — baking nf == static_dim into every + # compiled kernel. We therefore collect all static buffer dims and + # increment trace_nf until it is free of all conflicts. + _reserved_dims = {1, 2, 3, 9} + _static_buf_dims: set[int] = set() + for _tbv in task_buf_vals_trace: + for _d in _tbv.shape: + if _d > 1: + _static_buf_dims.add(_d) trace_nf = 5 + while trace_nf in _reserved_dims or trace_nf in _static_buf_dims: + trace_nf += 1 + coord_for_trace = extended_coord[:1].repeat(trace_nf, 1, 1) atype_for_trace = extended_atype[:1].repeat(trace_nf, 1) nlist_for_trace = nlist[:1].repeat(trace_nf, 1, 1) + if nlist_for_trace.shape[1] == nlist_for_trace.shape[2]: + # nloc == nsel: drop one local atom row so that dim-1 (nloc-1) + # differs from dim-2 (nsel), breaking the symbolic aliasing without + # altering the nsel dimension or the format_nlist code path. + # torch.compile(dynamic=True) treats nloc as fully dynamic, so the + # compiled graph works correctly for all nloc at runtime. + nlist_for_trace = nlist_for_trace[:, :-1, :] + # NOTE: Anti-alias nloc against promoted buffer dims. + # The promoted task buffers (out_bias, bias_atom_e, case_embd) are now + # FX placeholder inputs with their own symbolic dims. If nloc at trace + # time equals ntypes (out_bias/bias_atom_e dim) or dim_case_embd + # (case_embd dim), make_fx's ShapeEnv unifies the symbols and the + # compiled graph specialises on nloc == that_value. Every training + # batch with a different nloc then fails the guard and triggers a full + # Dynamo/Inductor recompile — NCCL timeout when one rank recompiles + # while others wait at allreduce. Dropping a row from the nlist + # changes the traced nloc without touching nsel or extended-atom dims, + # so the model's internal indexing paths are unaffected. + # NOTE: Anti-alias nloc against nall. When the first batch for a task + # has nloc == nall (a non-PBC frame with no ghost atoms), make_fx's + # duck sizing unifies the nloc symbol (nlist.shape[1]) with the nall + # symbol (extended_coord.shape[1]). The compiled graph then reads the + # slice end of ``extended_coord[:, :nloc]`` / ``extended_atype[:, :nloc]`` + # from the nall source, so every later PBC batch yields energy/force + # shaped [nf, nall, ...] instead of [nf, nloc, ...]. Dropping a local + # atom row makes nloc_trace = nall_trace - 1, breaking the merge. + _nall_trace = coord_for_trace.shape[1] + _trace_nloc = nlist_for_trace.shape[1] + while _trace_nloc > 1 and ( + _trace_nloc in _static_buf_dims + or _trace_nloc in _reserved_dims + or _trace_nloc == _nall_trace + ): + nlist_for_trace = nlist_for_trace[:, :-1, :] + _trace_nloc = nlist_for_trace.shape[1] mapping_for_trace = mapping[:1].repeat(trace_nf, 1) fp_for_trace = fp[:1].repeat(trace_nf, 1) ap_for_trace = ap[:1].repeat(trace_nf, 1, 1) @@ -1637,6 +1923,11 @@ def compute_fn( ] if extended_coord_corr is not None: trace_args.append(extended_coord_corr[:1].repeat(trace_nf, 1, 1)) + # Append task-buffer values last so they map to the *task_buf_vals + # varargs in compute_fn. Their shapes are static (they don't vary + # batch-to-batch), so passing the actual tensors is correct; make_fx + # will create one placeholder per element. + trace_args.extend(task_buf_vals_trace) # NOTE: Decompose ``silu_backward`` into primitive ops. # PyTorch ships forward and first-order backward for SiLU but no @@ -1658,12 +1949,52 @@ def compute_fn( # FakeTensors, so we need concrete values to resolve their # control flow exactly once; shapes become symbolic immediately # afterwards. - traced = make_fx( - compute_fn, - tracing_mode="symbolic", - _allow_non_fake_inputs=True, - decomposition_table=decomp_table, - )(*trace_args) + # NOTE: Disable duck sizing for the symbolic trace. + # make_fx builds a ShapeEnv whose ``create_symbol`` defaults every input + # dim to ``DimDynamic.DUCK``: with duck sizing on (the default), any two + # dims sharing a concrete value at trace time are assigned the SAME sympy + # symbol. That is the root of every "dim X got the size of dim Y" bug + # here -- nloc==nall, nloc==nsel, nall==ntypes, nf==static_buf_dim all + # silently fuse two axes and bake an equality the later batches violate. + # The trim hacks above only patch the handful of pairs we anticipated; + # turning duck sizing off gives every axis an independent symbol so e.g. + # ``extended_coord[:, :nloc]`` stays parametric on nloc no matter how the + # trace batch's sizes coincide. + # + # ShapeEnv reads ``self.duck_shape`` as the global switch, but in torch + # 2.11 there is no longer a ``_config.duck_shape`` knob nor an explicit + # ``__init__`` parameter -- it only flows in as a ``**kwargs`` entry + # forwarded to the internal init (``ShapeEnv(duck_shape=False)`` works). + # make_fx constructs the ShapeEnv with no args, so the only portable hook + # is to wrap ``ShapeEnv.__init__`` on the class object and inject + # ``duck_shape=False`` for the duration of this trace. Patching the + # class method covers every reference (proxy_tensor imported the same + # class object). Restored in ``finally`` so the later + # torch.compile / inductor stage and any other ShapeEnv are unaffected. + _ss_mod = None + _orig_se_init = None + try: + import torch.fx.experimental.symbolic_shapes as _ss_mod # type: ignore[no-redef] + except Exception: + _ss_mod = None + if _ss_mod is not None and hasattr(_ss_mod, "ShapeEnv"): + _orig_se_init = _ss_mod.ShapeEnv.__init__ + + def _no_duck_shapeenv_init(self, *args, **kwargs): # type: ignore[no-untyped-def] + kwargs.setdefault("duck_shape", False) + return _orig_se_init(self, *args, **kwargs) + + _ss_mod.ShapeEnv.__init__ = _no_duck_shapeenv_init + try: + traced = make_fx( + compute_fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + decomposition_table=decomp_table, + )(*trace_args) + finally: + if _orig_se_init is not None: + _ss_mod.ShapeEnv.__init__ = _orig_se_init # NOTE: Only strip autograd-inserted detach chains in training # mode. With ``create_graph=True`` make_fx wraps every saved @@ -1735,19 +2066,42 @@ def compute_fn( # ``(training, do_atomic_virial, has_coord_corr)`` so that distinct # graph topologies coexist without evicting each other on every # ``model.eval()`` / ``model.train()`` switch. - cache_key = (bool(self.training), bool(do_atomic_virial), has_coord_corr) # NOTE: ``dynamic=True`` emits a single kernel per traced # shape symbol, so changes in ``nframes``, ``nall`` or edge # count do not trigger recompiles; and the option dict above # disables every Inductor/Triton feature that has ever # interacted badly with ``make_fx`` + double backward in # this project. - self.compiled_core_compute_cache[cache_key] = torch.compile( + compiled = torch.compile( traced, backend="inductor", dynamic=True, options=compile_options, ) + # Populate both per-instance and module-level shared caches. + # The shared cache (_SEZM_COMPILE_CACHE) lets a second task with the + # same structure key skip re-tracing and re-compiling entirely. + self.compiled_core_compute_cache[cache_key] = compiled + self._task_buf_order_cache[cache_key] = task_buf_names + _SEZM_COMPILE_CACHE[full_cache_key] = compiled + _SEZM_TASK_BUF_ORDER[structure_key] = task_buf_names + # NOTE: No dist.barrier() here. + # The barrier premise is that all ranks reach trace_and_compile + # simultaneously. That is FALSE in several trainer code paths: + # + # 1. compute_or_load_stat (training.py:417) runs on rank 0 only. + # Rank 0 compiles → calls barrier → the other N-1 ranks are not + # inside trace_and_compile at that moment → deadlock. + # + # 2. Validation at disp_freq is rank-0-only inside the rank guard; + # if DP_COMPILE_INFER is set, same deadlock. + # + # Instead we rely on compilation being symmetric during the DDP + # training loop itself: all ranks pick the same task per step (same + # random seed), so they all hit trace_and_compile for the same task + # at the same step. The compile-time gap between ranks is on the + # order of seconds while the NCCL default timeout is 30 minutes, + # so no barrier is necessary for the training-loop case. # torch.compile is lazy; the "finished" log is emitted after the # first call triggers Inductor lowering (see forward_common). # ``pending_key`` pairs with ``pending_t0`` so the log is only @@ -1778,7 +2132,8 @@ def compile_dens(self) -> None: "epilogue_fusion": False, "triton.cudagraphs": False, "shape_padding": True, - "max_fusion_size": 64, + "max_fusion_size": 8, + "triton.persistent_reductions": False, }, ), ) @@ -1824,9 +2179,8 @@ def forward_common_lower_exportable( mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, - charge_spin: torch.Tensor | None = None, - *, do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, ) -> torch.nn.Module: """Trace ``forward_common_lower`` into an exportable FX ``GraphModule``. @@ -1881,8 +2235,9 @@ def fn( mapping_: torch.Tensor | None, fparam_: torch.Tensor | None, aparam_: torch.Tensor | None, - charge_spin_: torch.Tensor | None, + *maybe_charge_spin: torch.Tensor | None, ) -> dict[str, torch.Tensor]: + charge_spin_ = maybe_charge_spin[0] if maybe_charge_spin else None return lower_fn( ext_coord, ext_atype, @@ -1901,7 +2256,7 @@ def fn( dtype=extended_coord.dtype, device=extended_coord.device, ) - trace_inputs = (*trace_inputs, charge_spin) + trace_inputs = (*trace_inputs, charge_spin) return self._trace_lower_exportable( fn, @@ -1960,10 +2315,9 @@ def build_edge_list_from_nlist( Build a compact edge list from DeePMD padded neighbor list. Edge vectors are computed via ``index_select`` on ``extended_coord`` - so they remain differentiable w.r.t. the input coordinates. Two - masked dummy edges are always appended to avoid data-dependent empty-edge - branches that ``make_fx`` cannot trace and singular edge-axis guards - in Inductor's batched matmul lowering. + so they remain differentiable w.r.t. the input coordinates. One + masked dummy edge is always appended to avoid data-dependent empty-edge + branches that ``make_fx`` cannot trace. Parameters ---------- @@ -1977,11 +2331,11 @@ def build_edge_list_from_nlist( Returns ------- edge_index - Edge indices with shape (2, E+2) where E is valid edge count. + Edge indices with shape (2, E+1) where E is valid edge count. edge_vec - Edge vectors with shape (E+2, 3). + Edge vectors with shape (E+1, 3). edge_mask - Boolean mask with shape (E+2). The trailing elements are ``False``. + Boolean mask with shape (E+1,). The trailing element is ``False``. """ nf, nloc, nsel = nlist.shape n_actual = nf * nloc @@ -2033,22 +2387,19 @@ def build_edge_list_from_nlist( valid_idx = torch.nonzero(edge_mask_actual, as_tuple=False).flatten() - # === Step 3. Compact edges + append masked dummies === - # NOTE: Always append two masked dummy edges. + # === Step 3. Compact edges + append one masked dummy === + # NOTE: Always append exactly one masked dummy edge. # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent # number of valid edges, which can be zero on sparse or # single-type systems. make_fx cannot trace an # ``if n_edges == 0: skip`` branch symbolically; without the # dummy it would fall back to concrete shape specialisation and - # break ``torch.compile(dynamic=True)`` for later batches. Two - # dummy edges keep the symbolic edge axis statically above one, - # which avoids Inductor bmm layout guards on ``E == 1``. Each + # break ``torch.compile(dynamic=True)`` for later batches. The # dummy edge copies entry 0 (any in-range index is fine) and # carries ``edge_mask=False`` so every downstream sum, gather # or scatter ignores it. - dummy_count = 2 padded_idx = torch.cat( - [valid_idx, torch.zeros(dummy_count, dtype=torch.long, device=device)] + [valid_idx, torch.zeros(1, dtype=torch.long, device=device)] ) src_sel = src_actual.index_select(0, padded_idx) dst_sel = dst_actual.index_select(0, padded_idx) @@ -2057,7 +2408,7 @@ def build_edge_list_from_nlist( edge_mask = torch.cat( [ torch.ones(valid_idx.shape[0], dtype=torch.bool, device=device), - torch.zeros(dummy_count, dtype=torch.bool, device=device), + torch.zeros(1, dtype=torch.bool, device=device), ] ) return edge_index, edge_vec_sel, edge_mask From c0565b84d98ea0692e5e3982c2fe0702704be6bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jun 2026 08:44:05 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/model/sezm_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 3474e0aabd..e5335f85f0 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -454,12 +454,12 @@ _SEZM_TASK_BUF_ORDER: dict[tuple[int, ...], tuple[str, ...]] = {} # Prefix namespace for promoted buffer names. -_AM_PREFIX = "am/" # atomic_model registered buffer -_FIT_PREFIX = "fit/" # fitting_net registered buffer +_AM_PREFIX = "am/" # atomic_model registered buffer +_FIT_PREFIX = "fit/" # fitting_net registered buffer _FIT_ATTR_PREFIX = "fit_attr/" # fitting_net plain tensor attribute (not in _buffers) -def _sezm_structure_key(model: "SeZMModel") -> tuple[int, ...]: +def _sezm_structure_key(model: SeZMModel) -> tuple[int, ...]: """Return a key that is equal iff two SeZMModel instances can share a compiled graph. After ``share_params``, the descriptor and fitting-net module objects @@ -499,7 +499,7 @@ def _sezm_structure_key(model: "SeZMModel") -> tuple[int, ...]: return (desc_id, id(model)) -def _get_sezm_task_buf_names(model: "SeZMModel") -> tuple[str, ...]: +def _get_sezm_task_buf_names(model: SeZMModel) -> tuple[str, ...]: """Return the ordered names of per-task buffers to promote as FX placeholders. Always promotes: @@ -534,7 +534,7 @@ def _get_sezm_task_buf_names(model: "SeZMModel") -> tuple[str, ...]: def _get_sezm_task_buf_vals( - model: "SeZMModel", + model: SeZMModel, names: tuple[str, ...], ) -> tuple[torch.Tensor, ...]: """Return the current tensor values for the given promoted-buffer names.""" From 86c3450cef109e879bad973fab5d909e2c88dd11 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:52:41 +0800 Subject: [PATCH 3/8] fix: import --- deepmd/pt/model/model/sezm_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index e5335f85f0..dec391c89b 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -374,6 +374,9 @@ from einops import ( rearrange, ) +from packaging.version import ( + Version, +) from torch.fx.experimental.proxy_tensor import ( make_fx, ) @@ -2179,8 +2182,9 @@ def forward_common_lower_exportable( mapping: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, charge_spin: torch.Tensor | None = None, + *, + do_atomic_virial: bool = False, ) -> torch.nn.Module: """Trace ``forward_common_lower`` into an exportable FX ``GraphModule``. From 89a6397431be1ef0f15af79d8dd9fc51e64f4d88 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Mon, 1 Jun 2026 18:01:42 +0800 Subject: [PATCH 4/8] chore:lint --- deepmd/pt/model/model/sezm_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index dec391c89b..c4ccf0dded 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -449,7 +449,7 @@ # Maps (structure_key..., training, do_atomic_virial, has_coord_corr) to the # compiled callable. Tasks whose descriptor AND fitting-net first child have # the same Python-object identity (after share_params) reuse a single compiled -# graph, avoiding N×compile-cache OOM and N DDP graph boundaries (NCCL timeout). +# graph, avoiding Nx compile-cache OOM and N DDP graph boundaries (NCCL timeout). _SEZM_COMPILE_CACHE: dict[tuple, Any] = {} # Maps structure_key -> task_buf_order so every instance in the same group @@ -1983,7 +1983,7 @@ def compute_fn( # type: ignore[misc] if _ss_mod is not None and hasattr(_ss_mod, "ShapeEnv"): _orig_se_init = _ss_mod.ShapeEnv.__init__ - def _no_duck_shapeenv_init(self, *args, **kwargs): # type: ignore[no-untyped-def] + def _no_duck_shapeenv_init(self: Any, *args: Any, **kwargs: Any) -> None: kwargs.setdefault("duck_shape", False) return _orig_se_init(self, *args, **kwargs) From 8c74af6fc02d3c4339f7a41a7548627dd8d10938 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 2 Jun 2026 10:03:52 +0800 Subject: [PATCH 5/8] fix: charge spin param passing --- deepmd/pt/model/model/sezm_model.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index c4ccf0dded..fa30553e35 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -2239,9 +2239,8 @@ def fn( mapping_: torch.Tensor | None, fparam_: torch.Tensor | None, aparam_: torch.Tensor | None, - *maybe_charge_spin: torch.Tensor | None, + charge_spin_: torch.Tensor | None, ) -> dict[str, torch.Tensor]: - charge_spin_ = maybe_charge_spin[0] if maybe_charge_spin else None return lower_fn( ext_coord, ext_atype, @@ -2252,7 +2251,6 @@ def fn( charge_spin_, ) - trace_inputs = (extended_coord, extended_atype, nlist, mapping, fparam, aparam) if self.get_dim_chg_spin() > 0: charge_spin = self.convert_charge_spin( charge_spin, @@ -2260,7 +2258,18 @@ def fn( dtype=extended_coord.dtype, device=extended_coord.device, ) - trace_inputs = (*trace_inputs, charge_spin) + # Always include the charge_spin slot (possibly None) so the traced + # module's forward signature matches the 7-tuple the freeze pipeline + # passes at runtime, regardless of whether the model is conditioned. + trace_inputs = ( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + charge_spin, + ) return self._trace_lower_exportable( fn, From 0e5500bbe9e9e80d7c7aec06845680a62df7cf22 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:51:48 +0800 Subject: [PATCH 6/8] fix: update UT --- source/tests/pt/model/test_sezm_model.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 0e682ca803..b3ebcaf2ff 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -613,10 +613,10 @@ def test_fixed_edge_geometry_matches_standard_cache(self) -> None: wigner_calc=descriptor.wigner_calc, ) - # build_edge_list_from_nlist appends masked dummy edges; + # build_edge_list_from_nlist appends exactly one masked dummy edge; # compare only the real edges before the padded tail. n_real = cache_std.src.shape[0] - self.assertEqual(edge_mask.shape[0] - n_real, 2) + self.assertEqual(edge_mask.shape[0] - n_real, 1) self.assertFalse(edge_mask[n_real:].any().item()) self.assertTrue(torch.equal(cache_std.src, cache_sparse.src[:n_real])) self.assertTrue(torch.equal(cache_std.dst, cache_sparse.dst[:n_real])) @@ -898,13 +898,15 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: msg=f"multitask force mismatch at {branch}", ) - # === Step 3. Each compiled branch owns its own compile cache; the - # shared descriptor weights must not collapse them into one. - # Step 2 ran every branch in training mode with the default + # === Step 3. Each branch keeps its own per-instance cache dict, but + # branches that share descriptor + fitting (same Python-object + # identity after share_params) reuse a single compiled callable via + # the module-level ``_SEZM_COMPILE_CACHE``. This avoids the + # N x compile-cache OOM / N DDP graph boundary cost on multitask + # runs. Step 2 ran every branch in training mode with the default # ``do_atomic_virial=False`` and no coordinate correction, so each - # per-branch cache dict - # should hold exactly that one slot, and the compiled callables - # at that slot must be distinct across branches. === + # per-branch cache should hold exactly that one slot, and the + # compiled callable at that slot must be the *same* object. === cache1 = wrapper_cmp.model["water_1"].compiled_core_compute_cache cache2 = wrapper_cmp.model["water_2"].compiled_core_compute_cache self.assertIsNot(cache1, cache2) @@ -915,7 +917,7 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: c2 = cache2[train_key] self.assertIsNotNone(c1) self.assertIsNotNone(c2) - self.assertIsNot(c1, c2) + self.assertIs(c1, c2) # === Step 4. Per-task case embedding must differentiate outputs. === out_e1 = wrapper_eager.model["water_1"](coord, atype, box=box) From cde602bbb6b4c7a0a9fde8dd2cf805d8fe3e0faa Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:17:08 +0800 Subject: [PATCH 7/8] chore: refactor aliasing compile --- deepmd/pt/model/model/sezm_model.py | 293 ++++++++++++++++------------ 1 file changed, 165 insertions(+), 128 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index fa30553e35..05dac91e7a 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -601,6 +601,67 @@ def _check_compile_torch_version() -> None: ) +def _is_prime(n: int) -> bool: + """Return True when ``n`` is a prime integer (``n >= 2``).""" + if n < 2: + return False + if n < 4: + return True + if n % 2 == 0: + return False + k = 3 + while k * k <= n: + if n % k == 0: + return False + k += 2 + return True + + +def _next_safe_prime(start: int, forbidden: set[int]) -> int: + """Return the smallest prime ``>= max(start, 5)`` not in ``forbidden``. + + Used by :meth:`SeZMModel.trace_and_compile` to choose collision-free + trace-time sizes for ``nf``, ``nall`` and ``nloc``. Primes ``>= 5`` + avoid every dim PyTorch specializes on (``1`` → broadcasting, + ``2``/``3``/``9`` → Cartesian / virial / charge_spin literals baked + into model code) and guarantee distinct values, which suppresses + make_fx's duck-shape unification without needing the + ``ShapeEnv(duck_shape=False)`` patch. + """ + n = max(start, 5) + while not _is_prime(n) or n in forbidden: + n += 1 + return n + + +def _trace_pad_dim(t: torch.Tensor, dim: int, target: int) -> torch.Tensor: + """Pad or trim ``t`` along ``dim`` so ``t.shape[dim] == target``. + + Padding duplicates the last slice along ``dim``; trimming drops + trailing slices. Used to coerce real-data trace inputs into the + prime-numbered shapes chosen by :func:`_next_safe_prime`. + + Duplicating the last slice preserves valid index values inside + index-bearing tensors (``nlist`` neighbor indices, ``mapping`` + extended-to-local indices) because the duplicated row reuses the + previously-valid row's values. Trimming likewise never invalidates + indices. Only shapes flow downstream during ``make_fx`` tracing, + so the exact replicated/trimmed values do not affect the FX graph. + """ + cur = int(t.shape[dim]) + if cur == target: + return t + if cur > target: + sl: list[slice] = [slice(None)] * t.ndim + sl[dim] = slice(None, target) + return t[tuple(sl)] + sl = [slice(None)] * t.ndim + sl[dim] = slice(-1, None) + last = t[tuple(sl)] + repeats = target - cur + return torch.cat([t, *([last] * repeats)], dim=dim) + + def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: """Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors. @@ -1833,88 +1894,92 @@ def compute_fn( # type: ignore[misc] finally: _restore_task_bufs(_saved) - # NOTE: Always trace with a fixed batch size that is free of known - # symbolic-shape collisions. + # NOTE: Choose trace shapes that are pairwise-distinct primes >= 5. # - # make_fx(tracing_mode="symbolic") replaces shapes with sympy - # symbols, but the moment a symbolic dim ends up equal to a - # *concrete* dim elsewhere in the same tensor it collapses into - # a constant and the graph specialises on that batch size. Known - # reserved dimensions include 1 (specialisation), 2 (charge/spin - # width), 3 (Cartesian coordinates), and 9 (virial tensor). Any - # of those collisions forces - # ``torch.compile(dynamic=True)`` to reject later batches whose - # nf differs from the traced constant. + # ``make_fx(tracing_mode="symbolic")`` introduces a sympy symbol per + # input dim. Two failure modes follow if those dims accidentally + # match each other or hit a PyTorch-internal "special" value: # - # The same aliasing hazard applies to the nloc / nsel axes of the - # nlist tensor (shape [nf, nloc, nsel]). When nloc == nsel in the - # trace inputs, make_fx folds them into a single symbol and the - # compiled kernel hard-codes nloc = nsel for every future call. - # This causes wrong energy/force shapes whenever a later batch - # arrives with nloc != nsel (only reproducible when a training - # task happens to have nloc == nsel in its first batch). - # Fix: drop one local atom row from the trace nlist so that - # nlist_for_trace.shape[1] = nloc - 1 != nsel. The nsel dimension - # is untouched so format_nlist takes the same no-op case-1 path. + # * Duck-shape unification: two input dims that share a concrete + # value at trace time get the SAME sympy symbol, baking an + # equality (``nloc == ntypes``, ``nloc == nall``, ...) the + # compiled graph will violate on later batches. + # * Size specialization: dims equal to ``1`` are baked as literal + # ``1`` regardless of duck-shape; values ``2``/``3``/``9`` are + # commonly literals inside the model (charge/spin width, + # Cartesian, virial) and may be unified with input symbols by + # ShapeEnv even with duck-shape off. # - # Task-buffer aliasing: task-specific buffers (out_bias, bias_atom_e, - # case_embd) are passed as extra arguments and also have static shapes. - # If trace_nf equals any of those static dims, make_fx would unify the - # nf symbol with that static dim — baking nf == static_dim into every - # compiled kernel. We therefore collect all static buffer dims and - # increment trace_nf until it is free of all conflicts. - _reserved_dims = {1, 2, 3, 9} - _static_buf_dims: set[int] = set() + # Picking pairwise-distinct primes ``>= 5`` for ``nf``, ``nall``, + # ``nloc`` rules out both failure modes in one stroke: no two + # symbols can fuse (distinct values), and no symbol can hit a + # special literal (``5+`` primes skip ``1``/``2``/``3``/``9``). + # ``nsel``, ``dim_fparam``, ``dim_aparam`` and ``dim_chg_spin`` are + # contractually fixed by the model and added to the forbidden set + # so the chosen primes never collide with them either. + _forbidden: set[int] = {1, 2, 3, 9} for _tbv in task_buf_vals_trace: for _d in _tbv.shape: if _d > 1: - _static_buf_dims.add(_d) - trace_nf = 5 - while trace_nf in _reserved_dims or trace_nf in _static_buf_dims: - trace_nf += 1 - - coord_for_trace = extended_coord[:1].repeat(trace_nf, 1, 1) - atype_for_trace = extended_atype[:1].repeat(trace_nf, 1) - nlist_for_trace = nlist[:1].repeat(trace_nf, 1, 1) - if nlist_for_trace.shape[1] == nlist_for_trace.shape[2]: - # nloc == nsel: drop one local atom row so that dim-1 (nloc-1) - # differs from dim-2 (nsel), breaking the symbolic aliasing without - # altering the nsel dimension or the format_nlist code path. - # torch.compile(dynamic=True) treats nloc as fully dynamic, so the - # compiled graph works correctly for all nloc at runtime. - nlist_for_trace = nlist_for_trace[:, :-1, :] - # NOTE: Anti-alias nloc against promoted buffer dims. - # The promoted task buffers (out_bias, bias_atom_e, case_embd) are now - # FX placeholder inputs with their own symbolic dims. If nloc at trace - # time equals ntypes (out_bias/bias_atom_e dim) or dim_case_embd - # (case_embd dim), make_fx's ShapeEnv unifies the symbols and the - # compiled graph specialises on nloc == that_value. Every training - # batch with a different nloc then fails the guard and triggers a full - # Dynamo/Inductor recompile — NCCL timeout when one rank recompiles - # while others wait at allreduce. Dropping a row from the nlist - # changes the traced nloc without touching nsel or extended-atom dims, - # so the model's internal indexing paths are unaffected. - # NOTE: Anti-alias nloc against nall. When the first batch for a task - # has nloc == nall (a non-PBC frame with no ghost atoms), make_fx's - # duck sizing unifies the nloc symbol (nlist.shape[1]) with the nall - # symbol (extended_coord.shape[1]). The compiled graph then reads the - # slice end of ``extended_coord[:, :nloc]`` / ``extended_atype[:, :nloc]`` - # from the nall source, so every later PBC batch yields energy/force - # shaped [nf, nall, ...] instead of [nf, nloc, ...]. Dropping a local - # atom row makes nloc_trace = nall_trace - 1, breaking the merge. - _nall_trace = coord_for_trace.shape[1] - _trace_nloc = nlist_for_trace.shape[1] - while _trace_nloc > 1 and ( - _trace_nloc in _static_buf_dims - or _trace_nloc in _reserved_dims - or _trace_nloc == _nall_trace - ): - nlist_for_trace = nlist_for_trace[:, :-1, :] - _trace_nloc = nlist_for_trace.shape[1] - mapping_for_trace = mapping[:1].repeat(trace_nf, 1) - fp_for_trace = fp[:1].repeat(trace_nf, 1) - ap_for_trace = ap[:1].repeat(trace_nf, 1, 1) - charge_spin_for_trace = charge_spin[:1].repeat(trace_nf, 1) + _forbidden.add(int(_d)) + # Model-contracted dims kept at their real values (changing them + # would break the model's own assertions about ``sel``, fparam / + # aparam widths, charge_spin dim). Add to forbidden so primes + # picked for free dims do not collide. + _nsel_real = int(nlist.shape[2]) + _dim_fp = int(fp.shape[1]) + _dim_ap = int(ap.shape[2]) + _dim_cs = int(charge_spin.shape[1]) + for _d in (_nsel_real, _dim_fp, _dim_ap, _dim_cs): + if _d > 1: + _forbidden.add(_d) + # Pick primes in physical order ``nf < nloc < nall``. The order + # ``trace_nloc < trace_nall`` matters: the model slices + # ``extended_atype[:, :nloc]`` to get local atoms; if + # ``trace_nloc > trace_nall`` the slice silently truncates at + # trace time, breaking the captured symbolic shape relation + # ``atype.shape[1] == nloc``. + trace_nf = _next_safe_prime(5, _forbidden) + _forbidden.add(trace_nf) + trace_nloc = _next_safe_prime(trace_nf + 1, _forbidden) + _forbidden.add(trace_nloc) + trace_nall = _next_safe_prime(trace_nloc + 1, _forbidden) + + # Build trace inputs by padding/trimming real-data tensors into + # the chosen prime shapes. ``_trace_pad_dim`` duplicates the + # last slice when padding so index-bearing tensors (``nlist`` + # neighbor indices, ``mapping`` extended-to-local indices) keep + # valid values -- the duplicated row references the same atoms + # the previous row referenced. + coord_for_trace = _trace_pad_dim(extended_coord[:1], 0, trace_nf) + coord_for_trace = _trace_pad_dim(coord_for_trace, 1, trace_nall) + atype_for_trace = _trace_pad_dim(extended_atype[:1], 0, trace_nf) + atype_for_trace = _trace_pad_dim(atype_for_trace, 1, trace_nall) + nlist_for_trace = _trace_pad_dim(nlist[:1], 0, trace_nf) + nlist_for_trace = _trace_pad_dim(nlist_for_trace, 1, trace_nloc) + # Real nlist values are in ``[-1, real_nall)`` (``-1`` marks + # padded slots, non-negative entries index into extended_coord). + # After trimming ``nall`` down to ``trace_nall`` some of those + # values can exceed ``trace_nall``, which would produce + # out-of-range gather indices in ``coord_flat.index_select(0, + # src_ext)`` during the trace pass. Clamp the upper bound to + # ``trace_nall - 1`` (the ``-1`` padding stays untouched since + # clamp only caps the high side). + nlist_for_trace = torch.clamp(nlist_for_trace, max=trace_nall - 1) + mapping_for_trace = _trace_pad_dim(mapping[:1], 0, trace_nf) + mapping_for_trace = _trace_pad_dim(mapping_for_trace, 1, trace_nall) + # Real mapping values are in ``[0, real_nloc)``. If + # ``trace_nloc < real_nloc`` they can exceed ``trace_nloc`` and + # silently propagate into ``src_local`` (used as a local-atom + # index downstream). Clamp to ``trace_nloc - 1``. + mapping_for_trace = torch.clamp( + mapping_for_trace, min=0, max=trace_nloc - 1 + ) + fp_for_trace = _trace_pad_dim(fp[:1], 0, trace_nf) + ap_for_trace = _trace_pad_dim(ap[:1], 0, trace_nf) + ap_for_trace = _trace_pad_dim(ap_for_trace, 1, trace_nloc) + charge_spin_for_trace = _trace_pad_dim(charge_spin[:1], 0, trace_nf) + trace_args = [ coord_for_trace, atype_for_trace, @@ -1925,7 +1990,9 @@ def compute_fn( # type: ignore[misc] charge_spin_for_trace, ] if extended_coord_corr is not None: - trace_args.append(extended_coord_corr[:1].repeat(trace_nf, 1, 1)) + corr_for_trace = _trace_pad_dim(extended_coord_corr[:1], 0, trace_nf) + corr_for_trace = _trace_pad_dim(corr_for_trace, 1, trace_nall) + trace_args.append(corr_for_trace) # Append task-buffer values last so they map to the *task_buf_vals # varargs in compute_fn. Their shapes are static (they don't vary # batch-to-batch), so passing the actual tensors is correct; make_fx @@ -1952,52 +2019,12 @@ def compute_fn( # type: ignore[misc] # FakeTensors, so we need concrete values to resolve their # control flow exactly once; shapes become symbolic immediately # afterwards. - # NOTE: Disable duck sizing for the symbolic trace. - # make_fx builds a ShapeEnv whose ``create_symbol`` defaults every input - # dim to ``DimDynamic.DUCK``: with duck sizing on (the default), any two - # dims sharing a concrete value at trace time are assigned the SAME sympy - # symbol. That is the root of every "dim X got the size of dim Y" bug - # here -- nloc==nall, nloc==nsel, nall==ntypes, nf==static_buf_dim all - # silently fuse two axes and bake an equality the later batches violate. - # The trim hacks above only patch the handful of pairs we anticipated; - # turning duck sizing off gives every axis an independent symbol so e.g. - # ``extended_coord[:, :nloc]`` stays parametric on nloc no matter how the - # trace batch's sizes coincide. - # - # ShapeEnv reads ``self.duck_shape`` as the global switch, but in torch - # 2.11 there is no longer a ``_config.duck_shape`` knob nor an explicit - # ``__init__`` parameter -- it only flows in as a ``**kwargs`` entry - # forwarded to the internal init (``ShapeEnv(duck_shape=False)`` works). - # make_fx constructs the ShapeEnv with no args, so the only portable hook - # is to wrap ``ShapeEnv.__init__`` on the class object and inject - # ``duck_shape=False`` for the duration of this trace. Patching the - # class method covers every reference (proxy_tensor imported the same - # class object). Restored in ``finally`` so the later - # torch.compile / inductor stage and any other ShapeEnv are unaffected. - _ss_mod = None - _orig_se_init = None - try: - import torch.fx.experimental.symbolic_shapes as _ss_mod # type: ignore[no-redef] - except Exception: - _ss_mod = None - if _ss_mod is not None and hasattr(_ss_mod, "ShapeEnv"): - _orig_se_init = _ss_mod.ShapeEnv.__init__ - - def _no_duck_shapeenv_init(self: Any, *args: Any, **kwargs: Any) -> None: - kwargs.setdefault("duck_shape", False) - return _orig_se_init(self, *args, **kwargs) - - _ss_mod.ShapeEnv.__init__ = _no_duck_shapeenv_init - try: - traced = make_fx( - compute_fn, - tracing_mode="symbolic", - _allow_non_fake_inputs=True, - decomposition_table=decomp_table, - )(*trace_args) - finally: - if _orig_se_init is not None: - _ss_mod.ShapeEnv.__init__ = _orig_se_init + traced = make_fx( + compute_fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + decomposition_table=decomp_table, + )(*trace_args) # NOTE: Only strip autograd-inserted detach chains in training # mode. With ``create_graph=True`` make_fx wraps every saved @@ -2351,7 +2378,6 @@ def build_edge_list_from_nlist( Boolean mask with shape (E+1,). The trailing element is ``False``. """ nf, nloc, nsel = nlist.shape - n_actual = nf * nloc device = extended_coord.device nall = extended_coord.shape[1] descriptor_model = self.atomic_model.descriptor @@ -2369,12 +2395,23 @@ def build_edge_list_from_nlist( # ``torch.where(valid_flat, neighbor_flat, 0)`` sanitises padded # ``-1`` entries before indexing so we never hit an out-of-range # gather; the corresponding edges are filtered out below anyway. - dst_actual = torch.arange( - n_actual, device=device, dtype=torch.long - ).repeat_interleave(nsel) + neighbor_flat = nlist.reshape(-1) + # ``dst_actual = arange(N*K) // K`` produces the same value + # sequence as ``arange(N).repeat_interleave(K)`` but its length + # is derived from ``neighbor_flat.shape[0]`` -- a single symbolic + # source shared with the ``torch.where`` below. The previous + # ``arange(nf*nloc).repeat_interleave(nsel)`` chain could + # decouple from ``nlist.numel()`` in the FX graph if any + # upstream code path ever specialized ``nloc`` at trace time; + # deriving from ``neighbor_flat.shape[0]`` makes the equality + # structural and survives any future change in trace-shape + # selection in ``trace_and_compile``. + dst_actual = ( + torch.arange(neighbor_flat.shape[0], device=device, dtype=torch.long) + // nsel + ) f_idx = dst_actual // nloc dst_local = dst_actual % nloc - neighbor_flat = nlist.reshape(-1) valid_flat = neighbor_flat >= 0 neighbor_safe = torch.where( valid_flat, neighbor_flat, torch.zeros_like(neighbor_flat) From 52bcebc1cc07a9fb599db3c1e9a57b148a6500ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:18:03 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/model/sezm_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 05dac91e7a..e007ce4fad 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -1972,9 +1972,7 @@ def compute_fn( # type: ignore[misc] # ``trace_nloc < real_nloc`` they can exceed ``trace_nloc`` and # silently propagate into ``src_local`` (used as a local-atom # index downstream). Clamp to ``trace_nloc - 1``. - mapping_for_trace = torch.clamp( - mapping_for_trace, min=0, max=trace_nloc - 1 - ) + mapping_for_trace = torch.clamp(mapping_for_trace, min=0, max=trace_nloc - 1) fp_for_trace = _trace_pad_dim(fp[:1], 0, trace_nf) ap_for_trace = _trace_pad_dim(ap[:1], 0, trace_nf) ap_for_trace = _trace_pad_dim(ap_for_trace, 1, trace_nloc)