fix: try fix dpa4 compile#5483
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughSeZMModel's compiled energy path now shares compiled callables across tasks with compatible descriptor and fitting modules. Per-task buffers are promoted to FX placeholders and passed as runtime inputs, reducing compile count. Trace shapes use prime dimensions for symbolic collision-avoidance. The edge-list builder simplifies from two dummy edges to one. Export signatures include charge_spin in the 7-tuple. Tests confirm callable reuse and edge counts. ChangesMulti-task compile cache with shared graphs and export adjustments
Edge list simplification from E+2 to E+1
🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| val = getattr(fitting, aname, None) | ||
| if val is not None and torch.is_tensor(val): | ||
| names.append(_FIT_ATTR_PREFIX + aname) | ||
| except AttributeError: |
| names.append(_FIT_ATTR_PREFIX + aname) | ||
| except AttributeError: | ||
| pass | ||
| except AttributeError: |
There was a problem hiding this comment.
Pull request overview
This PR attempts to improve/repair the PyTorch-compiled execution path for the SeZM/DPA4 model, primarily by reducing recompiles/OOM in multi-task setups and addressing symbolic-shape tracing issues in make_fx.
Changes:
- Add module-level compile sharing and promote selected per-task buffers (e.g.,
out_bias,bias_atom_e,case_embd) as FX inputs to enable compiled-graph reuse across shared-parameter tasks. - Add additional symbolic-shape anti-aliasing logic for trace inputs and temporarily disable
ShapeEnvduck sizing during tracing. - Change edge-list construction to append a single masked dummy edge (instead of two) and adjust related documentation/behavior.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| aparam: torch.Tensor | None = None, | ||
| charge_spin: torch.Tensor | None = None, | ||
| *, | ||
| do_atomic_virial: bool = False, | ||
| charge_spin: torch.Tensor | None = None, | ||
| ) -> torch.nn.Module: |
| # === Step 3. Compact edges + append one masked dummy === | ||
| # NOTE: Always append exactly one masked dummy edge. | ||
| # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent | ||
| # number of valid edges, which can be zero on sparse or | ||
| # single-type systems. make_fx cannot trace an |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5483 +/- ##
==========================================
+ Coverage 81.34% 81.37% +0.02%
==========================================
Files 868 868
Lines 96373 96739 +366
Branches 4233 4233
==========================================
+ Hits 78399 78725 +326
- Misses 16675 16714 +39
- Partials 1299 1300 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)
2735-2744:⚠️ Potential issue | 🟠 Major | ⚡ Quick winClear the shared compile cache when the
enerhead is reset.This branch clears only
self.compiled_core_compute_cache, but Line 1721 can immediately repopulate it from_SEZM_COMPILE_CACHEwhen the structure key is unchanged. That bypasses the retrace promised by this method and can resurrect a callable traced against the pre-reset head.Suggested change
else: + structure_key = _sezm_structure_key(self) + stale_keys = [ + key + for key in _SEZM_COMPILE_CACHE + if key[: len(structure_key)] == structure_key + ] + for key in stale_keys: + _SEZM_COMPILE_CACHE.pop(key, None) + _SEZM_TASK_BUF_ORDER.pop(structure_key, None) self._core_compute_pending_compile_t0 = None self._core_compute_pending_compile_key = None # Drop every compile slot so the next forward retraces against the # reinitialised fitting head. self.compiled_core_compute_cache.clear() + self._task_buf_order_cache.clear()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/sezm_model.py` around lines 2735 - 2744, The ener-head reset branch currently clears only instance cache (compiled_core_compute_cache) but leaves the shared cache (_SEZM_COMPILE_CACHE) intact so a subsequent lookup (see code that repopulates from _SEZM_COMPILE_CACHE using the structure key) can resurrect callables traced against the old head; update the else branch that resets the ener head to also invalidate the shared compile cache for the same structure key: when you set _core_compute_pending_compile_key to None and call compiled_core_compute_cache.clear(), also remove any entries in _SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the shared cache entirely) so retracing is forced (operate on the same key variable used to populate _SEZM_COMPILE_CACHE in the forward path).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 465-483: The current _sezm_structure_key function risks returning
identical keys for different tasks because it only samples the first child of
descriptor and fitting_net; update it to either (preferred) derive the key from
all shared parameter objects (e.g., use a frozenset of ids for all parameters
that are shared between tasks) or (safer/alternative) detect whether the entire
descriptor and fitting_net stacks are fully shared and raise an error if not, so
partial sharing cannot collapse different tasks into the same
_SEZM_COMPILE_CACHE entry; locate the logic in _sezm_structure_key and use
SeZMModel.descriptor / SeZMModel.fitting_net plus the model.share_params
semantics to compute the full parameter-based key or to assert full-stack
sharing before returning a key for cache reuse.
- Line 1772: The loop using zip(task_buf_names, vals) can silently truncate if
the iterables differ; change it to zip(task_buf_names, vals, strict=True) to
make mismatched lengths raise an error. Locate the loop that iterates over
task_buf_names and vals (the line currently written as "for name, val in
zip(task_buf_names, vals):") and update it to include strict=True so any length
mismatch is detected immediately.
---
Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 2735-2744: The ener-head reset branch currently clears only
instance cache (compiled_core_compute_cache) but leaves the shared cache
(_SEZM_COMPILE_CACHE) intact so a subsequent lookup (see code that repopulates
from _SEZM_COMPILE_CACHE using the structure key) can resurrect callables traced
against the old head; update the else branch that resets the ener head to also
invalidate the shared compile cache for the same structure key: when you set
_core_compute_pending_compile_key to None and call
compiled_core_compute_cache.clear(), also remove any entries in
_SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the
shared cache entirely) so retracing is forced (operate on the same key variable
used to populate _SEZM_COMPILE_CACHE in the forward path).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: b86f7a23-46a3-4ec5-b8a9-d4a97ba0f001
📒 Files selected for processing (2)
deepmd/pt/model/model/sezm_model.pysource/tests/pt/model/test_sezm_model.py
| def _sezm_structure_key(model: SeZMModel) -> tuple[int, ...]: | ||
| """Return a key that is equal iff two SeZMModel instances can share a compiled graph. | ||
|
|
||
| After ``share_params``, the descriptor and fitting-net module objects | ||
| themselves remain *different* Python objects per task; only their | ||
| *submodules* (``_modules`` dict entries) are replaced with shared | ||
| references. Using ``id(descriptor)`` or ``id(fitting_net)`` would | ||
| therefore always differ between tasks and defeat the cache. | ||
|
|
||
| Fix: use the id of the *first named child* of each module. After | ||
| ``share_params(level=0)``, those children are the same Python objects | ||
| for all tasks in the same structure group, giving matching keys. | ||
|
|
||
| NOTE: only the FIRST child is sampled, assuming "first child shared => | ||
| whole module shared" (true for level=0). Under ``share_params(level=1)`` | ||
| only ``type_embedding`` is shared; if it is the first child, two tasks | ||
| whose other descriptor weights differ would collapse to the same key and | ||
| wrongly reuse one compiled graph. If level=1 + compile is ever used, key | ||
| on all param ids instead, e.g. ``frozenset(id(p) for p in desc.parameters())``. |
There was a problem hiding this comment.
Guard shared compile reuse to fully shared descriptor/fitting stacks.
This key samples only the first descriptor/fitting child, so a partially shared setup can collapse two different tasks into the same _SEZM_COMPILE_CACHE entry. In that case the reused callable still closes over the first task’s unshared modules/parameters, so the second task runs the wrong weights instead of just missing the cache. Either derive the key from all shared parameter objects or fail fast unless the whole descriptor and fitting stack is shared.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/model/sezm_model.py` around lines 465 - 483, The current
_sezm_structure_key function risks returning identical keys for different tasks
because it only samples the first child of descriptor and fitting_net; update it
to either (preferred) derive the key from all shared parameter objects (e.g.,
use a frozenset of ids for all parameters that are shared between tasks) or
(safer/alternative) detect whether the entire descriptor and fitting_net stacks
are fully shared and raise an error if not, so partial sharing cannot collapse
different tasks into the same _SEZM_COMPILE_CACHE entry; locate the logic in
_sezm_structure_key and use SeZMModel.descriptor / SeZMModel.fitting_net plus
the model.share_params semantics to compute the full parameter-based key or to
assert full-stack sharing before returning a key for cache reuse.
| the model in its original state after tracing. | ||
| """ | ||
| saved: dict[str, torch.Tensor | None] = {} | ||
| for name, val in zip(task_buf_names, vals): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Show the exact code around the referenced line
sed -n '1755,1790p' deepmd/pt/model/model/sezm_model.py | nl -ba | sed -n '1,80p'
# 2) Check whether ruff is available, then run ruff check for the file
if ! command -v ruff >/dev/null 2>&1; then
pip install -q ruff
fi
ruff --version
ruff check deepmd/pt/model/model/sezm_model.py || true
# 3) (Optional) see if there are other zip(... ) patterns that might also be relevant
rg -n "zip\(" deepmd/pt/model/model/sezm_model.py | head -n 50Repository: deepmodeling/deepmd-kit
Length of output: 110
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="deepmd/pt/model/model/sezm_model.py"
START=1755
END=1790
# 1) Show exact code around the referenced line (cat -n works without nl)
sed -n "${START},${END}p" "$FILE" | cat -n
# 2) Run ruff check (install if needed)
if ! command -v ruff >/dev/null 2>&1; then
pip install -q ruff
fi
ruff --version
ruff check "$FILE" || true
# 3) Find all zip( occurrences in the file (quick scan)
rg -n "zip\(" "$FILE" || trueRepository: deepmodeling/deepmd-kit
Length of output: 24126
Make the promoted-buffer zip strict.
- Prevents silent truncation/mismatched placeholder mapping if
task_buf_namesandvalsever differ in length. - Fixes Ruff B905 (
zip()withoutstrict=) indeepmd/pt/model/model/sezm_model.py.
Suggested change
- for name, val in zip(task_buf_names, vals):
+ for name, val in zip(task_buf_names, vals, strict=True):As per coding guidelines, **/*.py: install linter and run ruff check . and ruff format . before committing changes or CI will fail.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for name, val in zip(task_buf_names, vals): | |
| for name, val in zip(task_buf_names, vals, strict=True): |
🧰 Tools
🪛 Ruff (0.15.15)
[warning] 1772-1772: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/model/sezm_model.py` at line 1772, The loop using
zip(task_buf_names, vals) can silently truncate if the iterables differ; change
it to zip(task_buf_names, vals, strict=True) to make mismatched lengths raise an
error. Locate the loop that iterates over task_buf_names and vals (the line
currently written as "for name, val in zip(task_buf_names, vals):") and update
it to include strict=True so any length mismatch is detected immediately.
Summary by CodeRabbit
Bug Fixes
Performance
Changes