Skip to content

fix: try fix dpa4 compile#5483

Open
anyangml wants to merge 8 commits into
deepmodeling:masterfrom
anyangml:fix/dpa4-multitask-compile
Open

fix: try fix dpa4 compile#5483
anyangml wants to merge 8 commits into
deepmodeling:masterfrom
anyangml:fix/dpa4-multitask-compile

Conversation

@anyangml
Copy link
Copy Markdown
Collaborator

@anyangml anyangml commented Jun 1, 2026

Summary by CodeRabbit

  • Bug Fixes

    • Fixed edge-list building behavior for improved consistency
  • Performance

    • Optimized multi-task model compilation efficiency through shared caching mechanisms
  • Changes

    • Updated model export signature to consistently include charge_spin parameter

Copilot AI review requested due to automatic review settings June 1, 2026 08:43
@anyangml anyangml marked this pull request as draft June 1, 2026 08:43
@dosubot dosubot Bot added the bug label Jun 1, 2026
@github-actions github-actions Bot added the Python label Jun 1, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 1, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

SeZMModel's compiled energy path now shares compiled callables across tasks with compatible descriptor and fitting modules. Per-task buffers are promoted to FX placeholders and passed as runtime inputs, reducing compile count. Trace shapes use prime dimensions for symbolic collision-avoidance. The edge-list builder simplifies from two dummy edges to one. Export signatures include charge_spin in the 7-tuple. Tests confirm callable reuse and edge counts.

Changes

Multi-task compile cache with shared graphs and export adjustments

Layer / File(s) Summary
Cache infrastructure and helper functions
deepmd/pt/model/model/sezm_model.py
Module-level caches keyed by (structure_key, compile_mode) enable compiled-graph reuse. Helpers detect task-specific buffers (out_bias, out_std, bias_atom_e, case_embd), derive structure keys from first named child modules, and fetch current per-instance buffer values.
Trace shape hardening with prime dimensions
deepmd/pt/model/model/sezm_model.py
New utilities compute collision-free prime dimensions for frame count, local rows, and all atoms. Trace inputs are padded or trimmed to these prime sizes by duplicating the last slice, avoiding symbolic-dimension aliasing with static buffer dimensions.
trace_and_compile refactor with buffer detection and patching
deepmd/pt/model/model/sezm_model.py
Method checks module-level shared cache for early-return hit, computes structure/compile keys, detects promoted buffer names, and defines patch/restore closures that temporarily swap buffers into the model during tracing so the FX graph reads them as live inputs instead of constants.
Compute function closures with patching and prime-based shapes
deepmd/pt/model/model/sezm_model.py
Both coordinate-correction code paths patch task buffers at trace time with finally-block restore. The extended_coord_corr path additionally computes forbidden dimensions from record size, selects prime frame/row/all dimensions, and constructs padded/trimmed trace inputs with clamped indices.
Trace input construction with promoted buffer tensors
deepmd/pt/model/model/sezm_model.py
Trace input argument list appends per-task buffer tensor values after fixed tensor args so make_fx creates separate FX placeholders for each promoted element rather than baking them as constants.
Cache population and compiler tuning
deepmd/pt/model/model/sezm_model.py
After torch.compile, populates both shared module-level and per-instance caches with compiled callable and buffer-order metadata. Removes prior barrier assumptions with notes on distributed deadlock risk, and reduces compile_dens max_fusion_size.
Runtime buffer values in compiled forward path
deepmd/pt/model/model/sezm_model.py
During compiled ener forward, reads current per-task promoted buffer values for the active compile slot, then supplies them as extra positional varargs to the compiled callable, eliminating per-task recompilation.
Export path charge_spin threading
deepmd/pt/model/model/sezm_model.py
forward_common_lower_exportable adjusts traced signature to include charge_spin in the 7-tuple (converting None when needed), matching the runtime freeze pipeline's expected signature shape.
Multitask compile cache test coverage
source/tests/pt/model/test_sezm_model.py
Test assertions verify that compiled callables stored in different branches' per-instance caches are the same object (not distinct), confirming shared callable reuse while branch-specific dict instances remain different.

Edge list simplification from E+2 to E+1

Layer / File(s) Summary
Edge list implementation and documentation
deepmd/pt/model/model/sezm_model.py
Edge compaction now appends exactly one masked dummy edge with updated pad-index construction. Index arithmetic for coordinate gradients derives dst_actual from neighbor_flat shape. Docstrings and return-shape documentation updated from E+2 to E+1 conventions across multiple locations.
Edge list test expectations
source/tests/pt/model/test_sezm_model.py
Fixed-edge-geometry test assertions updated to expect single padded tail element instead of two, aligning test validation with the E+1 dummy-edge padding policy.

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5457: Both PRs refactor the FX/compiled forward and trace paths to enable safe multi-task reuse by keying compiled-graph caches on model structure identity and promoting task-specific buffers to dynamic inputs/varargs during tracing.

Suggested reviewers

  • wanghan-iapcm
  • njzjz-bot
  • njzjz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'fix: try fix dpa4 compile' is vague and generic, using non-descriptive language ('try fix') that doesn't convey meaningful information about the actual changes. Replace with a more specific title that describes the main change, such as 'fix: implement multi-task compile sharing for SeZM models' or 'fix: optimize SeZM compiled path for shared descriptor+fitting structures'.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 82.61% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

val = getattr(fitting, aname, None)
if val is not None and torch.is_tensor(val):
names.append(_FIT_ATTR_PREFIX + aname)
except AttributeError:
names.append(_FIT_ATTR_PREFIX + aname)
except AttributeError:
pass
except AttributeError:
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR attempts to improve/repair the PyTorch-compiled execution path for the SeZM/DPA4 model, primarily by reducing recompiles/OOM in multi-task setups and addressing symbolic-shape tracing issues in make_fx.

Changes:

  • Add module-level compile sharing and promote selected per-task buffers (e.g., out_bias, bias_atom_e, case_embd) as FX inputs to enable compiled-graph reuse across shared-parameter tasks.
  • Add additional symbolic-shape anti-aliasing logic for trace inputs and temporarily disable ShapeEnv duck sizing during tracing.
  • Change edge-list construction to append a single masked dummy edge (instead of two) and adjust related documentation/behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/model/model/sezm_model.py
Comment on lines 2181 to 2184
aparam: torch.Tensor | None = None,
charge_spin: torch.Tensor | None = None,
*,
do_atomic_virial: bool = False,
charge_spin: torch.Tensor | None = None,
) -> torch.nn.Module:
Comment thread deepmd/pt/model/model/sezm_model.py Outdated
Comment on lines +2390 to 2394
# === Step 3. Compact edges + append one masked dummy ===
# NOTE: Always append exactly one masked dummy edge.
# ``torch.nonzero(edge_mask_actual)`` produces a data-dependent
# number of valid edges, which can be zero on sparse or
# single-type systems. make_fx cannot trace an
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 1, 2026

Codecov Report

❌ Patch coverage is 91.19171% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.37%. Comparing base (967e525) to head (52bcebc).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/model/model/sezm_model.py 91.19% 17 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5483      +/-   ##
==========================================
+ Coverage   81.34%   81.37%   +0.02%     
==========================================
  Files         868      868              
  Lines       96373    96739     +366     
  Branches     4233     4233              
==========================================
+ Hits        78399    78725     +326     
- Misses      16675    16714      +39     
- Partials     1299     1300       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@anyangml anyangml marked this pull request as ready for review June 2, 2026 10:29
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/model/model/sezm_model.py (1)

2735-2744: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clear the shared compile cache when the ener head is reset.

This branch clears only self.compiled_core_compute_cache, but Line 1721 can immediately repopulate it from _SEZM_COMPILE_CACHE when the structure key is unchanged. That bypasses the retrace promised by this method and can resurrect a callable traced against the pre-reset head.

Suggested change
         else:
+            structure_key = _sezm_structure_key(self)
+            stale_keys = [
+                key
+                for key in _SEZM_COMPILE_CACHE
+                if key[: len(structure_key)] == structure_key
+            ]
+            for key in stale_keys:
+                _SEZM_COMPILE_CACHE.pop(key, None)
+            _SEZM_TASK_BUF_ORDER.pop(structure_key, None)
             self._core_compute_pending_compile_t0 = None
             self._core_compute_pending_compile_key = None
             # Drop every compile slot so the next forward retraces against the
             # reinitialised fitting head.
             self.compiled_core_compute_cache.clear()
+            self._task_buf_order_cache.clear()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/sezm_model.py` around lines 2735 - 2744, The ener-head
reset branch currently clears only instance cache (compiled_core_compute_cache)
but leaves the shared cache (_SEZM_COMPILE_CACHE) intact so a subsequent lookup
(see code that repopulates from _SEZM_COMPILE_CACHE using the structure key) can
resurrect callables traced against the old head; update the else branch that
resets the ener head to also invalidate the shared compile cache for the same
structure key: when you set _core_compute_pending_compile_key to None and call
compiled_core_compute_cache.clear(), also remove any entries in
_SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the
shared cache entirely) so retracing is forced (operate on the same key variable
used to populate _SEZM_COMPILE_CACHE in the forward path).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 465-483: The current _sezm_structure_key function risks returning
identical keys for different tasks because it only samples the first child of
descriptor and fitting_net; update it to either (preferred) derive the key from
all shared parameter objects (e.g., use a frozenset of ids for all parameters
that are shared between tasks) or (safer/alternative) detect whether the entire
descriptor and fitting_net stacks are fully shared and raise an error if not, so
partial sharing cannot collapse different tasks into the same
_SEZM_COMPILE_CACHE entry; locate the logic in _sezm_structure_key and use
SeZMModel.descriptor / SeZMModel.fitting_net plus the model.share_params
semantics to compute the full parameter-based key or to assert full-stack
sharing before returning a key for cache reuse.
- Line 1772: The loop using zip(task_buf_names, vals) can silently truncate if
the iterables differ; change it to zip(task_buf_names, vals, strict=True) to
make mismatched lengths raise an error. Locate the loop that iterates over
task_buf_names and vals (the line currently written as "for name, val in
zip(task_buf_names, vals):") and update it to include strict=True so any length
mismatch is detected immediately.

---

Outside diff comments:
In `@deepmd/pt/model/model/sezm_model.py`:
- Around line 2735-2744: The ener-head reset branch currently clears only
instance cache (compiled_core_compute_cache) but leaves the shared cache
(_SEZM_COMPILE_CACHE) intact so a subsequent lookup (see code that repopulates
from _SEZM_COMPILE_CACHE using the structure key) can resurrect callables traced
against the old head; update the else branch that resets the ener head to also
invalidate the shared compile cache for the same structure key: when you set
_core_compute_pending_compile_key to None and call
compiled_core_compute_cache.clear(), also remove any entries in
_SEZM_COMPILE_CACHE that correspond to the previous structure key (or clear the
shared cache entirely) so retracing is forced (operate on the same key variable
used to populate _SEZM_COMPILE_CACHE in the forward path).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b86f7a23-46a3-4ec5-b8a9-d4a97ba0f001

📥 Commits

Reviewing files that changed from the base of the PR and between 967e525 and 52bcebc.

📒 Files selected for processing (2)
  • deepmd/pt/model/model/sezm_model.py
  • source/tests/pt/model/test_sezm_model.py

Comment on lines +465 to +483
def _sezm_structure_key(model: SeZMModel) -> tuple[int, ...]:
"""Return a key that is equal iff two SeZMModel instances can share a compiled graph.

After ``share_params``, the descriptor and fitting-net module objects
themselves remain *different* Python objects per task; only their
*submodules* (``_modules`` dict entries) are replaced with shared
references. Using ``id(descriptor)`` or ``id(fitting_net)`` would
therefore always differ between tasks and defeat the cache.

Fix: use the id of the *first named child* of each module. After
``share_params(level=0)``, those children are the same Python objects
for all tasks in the same structure group, giving matching keys.

NOTE: only the FIRST child is sampled, assuming "first child shared =>
whole module shared" (true for level=0). Under ``share_params(level=1)``
only ``type_embedding`` is shared; if it is the first child, two tasks
whose other descriptor weights differ would collapse to the same key and
wrongly reuse one compiled graph. If level=1 + compile is ever used, key
on all param ids instead, e.g. ``frozenset(id(p) for p in desc.parameters())``.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Guard shared compile reuse to fully shared descriptor/fitting stacks.

This key samples only the first descriptor/fitting child, so a partially shared setup can collapse two different tasks into the same _SEZM_COMPILE_CACHE entry. In that case the reused callable still closes over the first task’s unshared modules/parameters, so the second task runs the wrong weights instead of just missing the cache. Either derive the key from all shared parameter objects or fail fast unless the whole descriptor and fitting stack is shared.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/sezm_model.py` around lines 465 - 483, The current
_sezm_structure_key function risks returning identical keys for different tasks
because it only samples the first child of descriptor and fitting_net; update it
to either (preferred) derive the key from all shared parameter objects (e.g.,
use a frozenset of ids for all parameters that are shared between tasks) or
(safer/alternative) detect whether the entire descriptor and fitting_net stacks
are fully shared and raise an error if not, so partial sharing cannot collapse
different tasks into the same _SEZM_COMPILE_CACHE entry; locate the logic in
_sezm_structure_key and use SeZMModel.descriptor / SeZMModel.fitting_net plus
the model.share_params semantics to compute the full parameter-based key or to
assert full-stack sharing before returning a key for cache reuse.

the model in its original state after tracing.
"""
saved: dict[str, torch.Tensor | None] = {}
for name, val in zip(task_buf_names, vals):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Show the exact code around the referenced line
sed -n '1755,1790p' deepmd/pt/model/model/sezm_model.py | nl -ba | sed -n '1,80p'

# 2) Check whether ruff is available, then run ruff check for the file
if ! command -v ruff >/dev/null 2>&1; then
  pip install -q ruff
fi

ruff --version
ruff check deepmd/pt/model/model/sezm_model.py || true

# 3) (Optional) see if there are other zip(... ) patterns that might also be relevant
rg -n "zip\(" deepmd/pt/model/model/sezm_model.py | head -n 50

Repository: deepmodeling/deepmd-kit

Length of output: 110


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="deepmd/pt/model/model/sezm_model.py"
START=1755
END=1790

# 1) Show exact code around the referenced line (cat -n works without nl)
sed -n "${START},${END}p" "$FILE" | cat -n

# 2) Run ruff check (install if needed)
if ! command -v ruff >/dev/null 2>&1; then
  pip install -q ruff
fi

ruff --version
ruff check "$FILE" || true

# 3) Find all zip( occurrences in the file (quick scan)
rg -n "zip\(" "$FILE" || true

Repository: deepmodeling/deepmd-kit

Length of output: 24126


Make the promoted-buffer zip strict.

  • Prevents silent truncation/mismatched placeholder mapping if task_buf_names and vals ever differ in length.
  • Fixes Ruff B905 (zip() without strict=) in deepmd/pt/model/model/sezm_model.py.
Suggested change
-            for name, val in zip(task_buf_names, vals):
+            for name, val in zip(task_buf_names, vals, strict=True):

As per coding guidelines, **/*.py: install linter and run ruff check . and ruff format . before committing changes or CI will fail.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for name, val in zip(task_buf_names, vals):
for name, val in zip(task_buf_names, vals, strict=True):
🧰 Tools
🪛 Ruff (0.15.15)

[warning] 1772-1772: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/sezm_model.py` at line 1772, The loop using
zip(task_buf_names, vals) can silently truncate if the iterables differ; change
it to zip(task_buf_names, vals, strict=True) to make mismatched lengths raise an
error. Locate the loop that iterates over task_buf_names and vals (the line
currently written as "for name, val in zip(task_buf_names, vals):") and update
it to include strict=True so any length mismatch is detected immediately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants