-
Notifications
You must be signed in to change notification settings - Fork 521
Warp determinism #1355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mmacklin
wants to merge
72
commits into
NVIDIA:main
Choose a base branch
from
mmacklin:warp-deterministic
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+7,808
−99
Draft
Warp determinism #1355
Changes from 4 commits
Commits
Show all changes
72 commits
Select commit
Hold shift + click to select a range
968c7bb
Add deterministic execution mode for atomic operations
d5ef316
Fix deterministic launch edge cases
5ec9b25
Document deterministic options
72e8e3f
Add deterministic capture controls
48e7207
Support composite deterministic atomics
3170255
Refine deterministic record sizing
50b7ebc
Fix deterministic cache-hit launches
cc6720d
Benchmark deterministic graph replay
5b5038f
Optimize scalar deterministic reduction
e34a93c
Add deterministic mode levels
6974788
Support deterministic functions
mmacklin 3d6d4ba
Clean deterministic ABI naming
mmacklin be06d95
Support deterministic struct targets
eric-heiden b010ea0
Address deterministic review comments
eric-heiden a925a72
Support deterministic sliced atomics
eric-heiden ff873d6
Update deterministic changelog ref
eric-heiden 75c886e
Resolve deterministic merge conflicts
eric-heiden 6e6c9e6
Resolve changelog merge conflict
eric-heiden 8d74aa3
Resolve changelog conflict
eric-heiden 218af6f
Fix deterministic counter edge cases
eric-heiden 3833309
Avoid deterministic merge conflicts
eric-heiden 768c669
Document deterministic helper flow
eric-heiden fdf8e8e
Fix deterministic option cache toggles
eric-heiden d2f6bd6
Guard integer atomics in counter passes
eric-heiden d8ce57d
Clarify deterministic comments
eric-heiden a7d8c1f
Address deterministic review cleanup
eric-heiden 145bf6e
Address deterministic capture edge cases
eric-heiden 28380ab
Avoid graph replay overflow sync
eric-heiden d079bf4
Clean up deterministic review nits
eric-heiden 381a4aa
Avoid import-time benchmark options
eric-heiden c506742
Fix helper counter phase suppression
eric-heiden ce4ce4e
Clarify deterministic graph limits
eric-heiden 85fac0c
Guard deterministic helper stores
eric-heiden 2367ce1
Regenerate config docs
eric-heiden 731ed6c
Address deterministic review feedback
eric-heiden f7fc45b
Guard deterministic capture limits
eric-heiden fe53ed4
Fix deterministic graph workspaces
eric-heiden 2c38196
Update deterministic limitations doc
eric-heiden 93a1b6d
Avoid sorting unused scatter records
eric-heiden adbb878
Fix deterministic enum test Sonar hotspot
eric-heiden a7c9335
Fix block-dependent tile codegen reuse
eric-heiden c7f4f96
Clarify deterministic scatter capture limit
eric-heiden 659876e
Reduce launch benchmark overhead
eric-heiden 771578a
Tighten deterministic atomic validation
eric-heiden ebe317a
Restore direct CUDA launch path
eric-heiden e2f5977
Avoid launch-time option resolution
eric-heiden 4f0bfca
Document deterministic execution
eric-heiden 803984c
Clarify deterministic counter slots
eric-heiden 945f829
Reject counter graph capture
eric-heiden 0b8fe83
Refocus determinism quick start
eric-heiden b3a79e0
Clarify boolean deterministic aliases
eric-heiden dc17c2c
Fix deterministic counter ASV
eric-heiden 4b1d244
Fix deterministic benchmark capture
eric-heiden 5f7ea05
Clarify determinism docs
eric-heiden 18a1b51
Stabilize deterministic ASV params
eric-heiden c09fc8f
Clarify deterministic bool aliases
eric-heiden cd882e1
Add manual determinism docs
eric-heiden 0abd155
Clarify deterministic reduction cost
eric-heiden 40b670f
Expand determinism performance docs
eric-heiden 4a153c6
Clarify determinism benchmark timing
eric-heiden c4718cc
Optimize deterministic counter writeback
eric-heiden edef2aa
Emit block tile strides symbolically
eric-heiden 2e86921
Support deterministic counters in graphs
eric-heiden bc8a3ed
Support indexed deterministic counters
eric-heiden 53755ea
Fix custom adjoint nondeterministic mode
eric-heiden e557947
Address deterministic review comments
eric-heiden be260a9
Remove unused deterministic counter export
eric-heiden 2340d94
Clarify custom adjoint determinism
eric-heiden d4f5d83
Fix deterministic counter review nits
eric-heiden 5d663b8
Guard APIC capture checks
eric-heiden 39f41fc
Adapt deterministic sync to GitHub base
eric-heiden d0284f4
Align APIC capture guard with main
eric-heiden File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,223 @@ | ||
| # Deterministic Execution Mode | ||
|
|
||
| **Status**: Implemented | ||
|
|
||
| ## Motivation | ||
|
|
||
| GPU atomic operations on floating-point arrays are inherently non-deterministic: | ||
| threads execute in unpredictable order, and since float addition is | ||
| non-associative, different execution orderings produce different rounding, | ||
| yielding different results each run. This also applies to counter/slot-allocation | ||
| patterns (``slot = wp.atomic_add(counter, 0, 1)``) where the thread-to-slot | ||
| assignment varies across runs, causing downstream writes to differ. | ||
|
|
||
| Customers need bit-exact reproducibility for debugging, regression testing, and | ||
| certification workflows. The manual workaround --- rewriting algorithms to use | ||
| two-pass count-scan-write patterns or sorted reductions --- is painful and | ||
| error-prone. Users want a simple toggle that makes their existing code | ||
| deterministic without algorithm rewrites. | ||
|
|
||
| ## Requirements | ||
|
|
||
| | ID | Requirement | Priority | Notes | | ||
| | --- | --- | --- | --- | | ||
| | R1 | ``wp.config.deterministic = True`` makes float atomic accumulations bit-exact reproducible across runs | Must | Core value proposition | | ||
| | R2 | Counter/allocator pattern (``slot = wp.atomic_add(counter, 0, 1)``) produces deterministic slot assignments | Must | Common in compaction, particle emission | | ||
| | R3 | Both patterns work in the same kernel simultaneously | Must | Real workloads mix accumulation and allocation | | ||
| | R4 | Integer atomics with unused return values incur no overhead | Must | Already associative+commutative | | ||
| | R5 | CPU execution unaffected (already sequential/deterministic) | Must | Zero overhead on CPU | | ||
| | R6 | Per-module and per-kernel granularity via ``module_options`` | Should | Allows selective opt-in | | ||
| | R7 | Backward pass (autodiff) gradient accumulation is also deterministic | Should | Adjoint atomics are Pattern A | | ||
| | R8 | Multiple target arrays in one kernel each get independent buffers | Must | Real kernels write to N arrays | | ||
|
|
||
| **Non-goals**: | ||
| - ``atomic_cas``/``atomic_exch`` determinism (inherently order-dependent). | ||
| - Tile-level atomic operations (``tile_atomic_add``). | ||
| - Kernels where counter contributions depend on scratch array writes within the | ||
| same kernel (Phase 0 suppresses all side effects; documented limitation). | ||
|
|
||
| ## User Configuration | ||
|
|
||
| Deterministic mode can be enabled at three scopes: | ||
|
|
||
| - **Global**: set ``wp.config.deterministic = True`` before module compilation. | ||
| - **Global diagnostics**: set ``wp.config.deterministic_debug = True`` to emit | ||
| debug diagnostics for deterministic scatter overflow. | ||
| - **Per shared module**: call ``wp.set_module_options({"deterministic": True})`` | ||
| in the Python module that defines the kernels, just like other module-level | ||
| options such as ``enable_backward``. | ||
| - **Per kernel**: use ``@wp.kernel(deterministic=True)`` and optionally set a | ||
| minimum per-target scatter buffer capacity with | ||
| ``deterministic_capacity=...``. | ||
|
|
||
| Like ``enable_backward``, the setting participates in module compilation and | ||
| hashing. A kernel defined in a shared module inherits that module's | ||
| ``deterministic`` option; a unique-module kernel can override it independently. | ||
|
|
||
| ## Supported Operators | ||
|
|
||
| Deterministic mode currently handles these atomic builtins: | ||
|
|
||
| - ``atomic_add`` | ||
| - ``atomic_sub`` | ||
| - ``atomic_min`` | ||
| - ``atomic_max`` | ||
|
|
||
| Handling depends on how the atomic is used: | ||
|
|
||
| - **Pattern A: accumulation, return value unused** | ||
| Examples: ``wp.atomic_add(arr, i, value)``, ``arr[i] += value``. | ||
| Floating-point ``add/sub/min/max`` are redirected through scatter-sort-reduce. | ||
| Integer ``add/sub/min/max`` with unused return values are left on the normal | ||
| atomic path because the final value is already deterministic. | ||
| - **Pattern B: counter / allocator, return value consumed** | ||
| Example: ``slot = wp.atomic_add(counter, 0, 1)``. | ||
| This is handled with the two-pass count-scan-execute path. | ||
|
|
||
| Operators that are not supported by deterministic mode: | ||
|
|
||
| - ``atomic_cas`` | ||
| - ``atomic_exch`` | ||
| - Tile atomics such as ``tile_atomic_add`` | ||
|
|
||
| Bitwise integer atomics (``atomic_and``, ``atomic_or``, ``atomic_xor``) are not | ||
| transformed because their final results are already deterministic for the | ||
| unused-return case. | ||
|
|
||
| ## Design | ||
|
|
||
| ### Approach | ||
|
|
||
| Two distinct atomic usage patterns require two strategies, both transparent to | ||
| the user: | ||
|
|
||
| **Pattern A --- Scatter-Sort-Reduce** (accumulation, return value unused): | ||
|
|
||
| Instead of performing ``atomic_add`` in-place during kernel execution, each | ||
| thread writes a ``(sort_key, value)`` record to a temporary scatter buffer. The | ||
| sort key packs ``(dest_index << 32 | thread_id)``. After the kernel completes, | ||
| a CUB radix sort orders the fixed-capacity scatter buffer, then a custom CUDA | ||
| kernel walks the sorted records and accumulates each destination's values | ||
| left-to-right in a fixed (thread-id) order. Unused buffer slots are initialized | ||
| with an invalid sentinel key and sort to the end. This avoids host-side scatter | ||
| count readbacks and keeps the path compatible with CUDA graph capture. | ||
|
|
||
| **Pattern B --- Two-Pass Execution** (counter/allocator, return value used): | ||
|
|
||
| The kernel runs twice: | ||
| 1. *Phase 0 (counting)*: The kernel executes with all side effects suppressed. | ||
| Counter atomics record per-thread contributions to a scratch array instead | ||
| of performing the actual atomic. | ||
| 2. *Prefix sum*: ``wp.utils.array_scan(contrib, prefix, inclusive=False)`` | ||
| computes deterministic per-thread offsets. | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
| 3. *Phase 1 (execution)*: The kernel re-executes. Counter atomics return the | ||
| deterministic offset from the prefix sum. All other operations (including | ||
| Pattern A scatters) execute normally. | ||
|
|
||
| This mirrors the well-established count-scan-write pattern already used by | ||
| ``warp/_src/marching_cubes.py`` and the FEM geometry code, but applied | ||
| automatically. | ||
|
|
||
| ### Alternatives Considered | ||
|
|
||
| | Alternative | Why rejected | | ||
| | --- | --- | | ||
| | Fixed-point integer accumulation | Loses precision/range; not general for all float types | | ||
| | Per-thread output arrays | ``O(threads * output_size)`` memory; doesn't scale | | ||
| | Serialized atomics (mutex) | Prohibitively slow on GPU | | ||
| | Kahan compensated summation | Reduces error but does not guarantee determinism | | ||
| | Taint-based selective side-effect suppression for Phase 0 | Significant codegen complexity for an edge case; deferred to a future version | | ||
|
|
||
| ### Key Implementation Details | ||
|
|
||
| **Atomic classification** happens in ``Adjoint._emit_deterministic_atomic()`` | ||
| during the codegen build phase. The ``_det_in_assign`` flag on the ``Adjoint`` | ||
| (set by ``emit_Assign``, cleared after) distinguishes Pattern B (return consumed | ||
| in an assignment like ``slot = wp.atomic_add(...)``) from Pattern A (return | ||
| discarded, as in ``arr[i] += val`` or bare ``wp.atomic_add(...)``). Only | ||
| ``atomic_add``, ``atomic_sub``, ``atomic_min``, and ``atomic_max`` are | ||
| intercepted. Integer atomics with unused return values are skipped entirely | ||
| (already deterministic). | ||
|
|
||
| **CPU/CUDA dual compilation** is handled with ``#ifdef __CUDA_ARCH__`` guards | ||
| in the generated function body. On CUDA the scatter or phase-branching code | ||
| executes; on CPU the normal ``wp::atomic_add(...)`` call is emitted in the | ||
| ``#else`` branch. This is necessary because Warp generates a single function | ||
| body that compiles for both targets. | ||
|
|
||
| **Hidden kernel parameters** are appended to the CUDA kernel signature by | ||
| ``codegen_kernel()`` after the user arguments. Pattern B kernels get | ||
| ``_wp_det_phase``, ``_wp_det_contrib_N``, ``_wp_det_prefix_N``. Pattern A | ||
| targets get ``_wp_scatter_keys_N``, ``_wp_scatter_vals_N``, | ||
| ``_wp_scatter_ctr_N``, ``_wp_scatter_overflow_N``, ``_wp_scatter_cap_N``, and | ||
| ``_wp_det_debug``. The launch system | ||
| (``_launch_deterministic``) allocates these buffers and appends the | ||
| corresponding ctypes params to the launch args. | ||
|
|
||
| **Multiple reduction buffers**: each distinct ``(target array, value type, | ||
| reduction op)`` combination gets its own scatter buffer set. Multiple call | ||
| sites with the same target and reduction op share one buffer. The | ||
| ``DeterministicMeta`` dataclass on the kernel's ``Adjoint`` tracks all scatter | ||
| and counter targets discovered during codegen. | ||
|
|
||
| **Scatter capacity**: each scatter target uses a fixed-capacity buffer sized | ||
| from a code-generated lower bound (static records-per-thread analysis), with | ||
| ``deterministic_capacity`` acting as a per-target minimum capacity floor. On | ||
| overflow, new records are truncated, a device-side overflow flag is set, and | ||
| optional diagnostics may be emitted when ``wp.config.deterministic_debug`` is | ||
| enabled. | ||
|
|
||
| **Counter total writeback**: after the prefix sum in Phase 0, the launch system | ||
| copies the total count (last element of the inclusive scan) back to the actual | ||
| counter array so user code that reads it post-launch sees the correct value. | ||
|
|
||
| **Files added/modified**: | ||
|
|
||
| | File | Role | | ||
| | --- | --- | | ||
| | ``warp/config.py`` | ``deterministic`` global flag | | ||
| | ``warp/_src/deterministic.py`` | Dataclasses, buffer allocation, sort-reduce orchestration | | ||
| | ``warp/_src/codegen.py`` | Atomic classification, scatter/phase codegen, hidden kernel params | | ||
| | ``warp/_src/context.py`` | Module option, ``_launch_deterministic()`` orchestrator, ctypes bindings | | ||
| | ``warp/native/deterministic.h`` | Device-side ``wp::deterministic::scatter()`` template | | ||
| | ``warp/native/deterministic.cu`` | CUB radix sort + segmented reduce kernel | | ||
| | ``warp/native/deterministic.cpp`` | CPU stubs (linker satisfaction when CUDA unavailable) | | ||
|
|
||
| ## Testing Strategy | ||
|
|
||
| 23 tests in ``warp/tests/test_deterministic.py`` cover: | ||
|
|
||
| - **Bit-exact reproducibility** (Pattern A): launch the same kernel 10 times | ||
| with ``deterministic=True``, assert ``np.array_equal`` across all runs for | ||
| float32 scatter-add, ``+=`` syntax, float64, and atomic-sub. | ||
| - **Correctness** (Pattern A): compare GPU deterministic output against a | ||
| sequential CPU reference within ``rtol=1e-4``. | ||
| - **Multiple arrays**: kernel that atomically adds to three different output | ||
| arrays simultaneously. | ||
| - **Mixed reduce ops on one array**: ``atomic_add`` and ``atomic_max`` targeting | ||
| the same destination array are reduced independently. | ||
| - **Multi-dimensional indexing**: 2D array ``atomic_add`` with row/col indices. | ||
| - **Scatter capacity accounting**: kernels with more than two deterministic | ||
| scatters per thread to the same target do not overflow a fixed heuristic | ||
| buffer. | ||
| - **Counter reproducibility** (Pattern B): ``slot = atomic_add(counter, 0, 1); | ||
| output[slot] = data[tid]`` produces identical output arrays across 10 runs. | ||
| - **Phase 0 side-effect suppression**: non-counter array writes are skipped in | ||
| the counting pass. | ||
| - **Counter correctness**: verifies counter value equals N and output is a | ||
| permutation of input. | ||
| - **Conditional counter**: stream compaction (only elements above threshold), | ||
| verifying correct count and reproducible output. | ||
| - **Mixed pattern**: both counter and accumulation in one kernel. | ||
| - **Integer passthrough**: integer ``atomic_add`` with unused return incurs no | ||
| transformation; result matches ``np.bincount``. | ||
| - **Per-module override**: ``@wp.kernel(module_options={"deterministic": True}, | ||
| module="unique")`` works with global config off. | ||
| - **Kernel decorator override**: ``@wp.kernel(deterministic=True)`` works with | ||
| global config off. | ||
| - **Recorded launch support**: ``wp.launch(..., record_cmd=True)`` works for | ||
| deterministic CUDA kernels. | ||
| - **Graph capture support**: deterministic scatter launches can be captured and | ||
| replayed with CUDA graphs. | ||
| - All tests run on both CPU and CUDA where applicable. Existing | ||
| ``test_atomic.py`` (158 tests) passes with zero regressions. | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.