Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
968c7bb
Add deterministic execution mode for atomic operations
Apr 10, 2026
d5ef316
Fix deterministic launch edge cases
Apr 10, 2026
5ec9b25
Document deterministic options
Apr 10, 2026
72e8e3f
Add deterministic capture controls
Apr 10, 2026
48e7207
Support composite deterministic atomics
Apr 10, 2026
3170255
Refine deterministic record sizing
Apr 12, 2026
50b7ebc
Fix deterministic cache-hit launches
Apr 12, 2026
cc6720d
Benchmark deterministic graph replay
Apr 12, 2026
5b5038f
Optimize scalar deterministic reduction
Apr 12, 2026
e34a93c
Add deterministic mode levels
Apr 13, 2026
6974788
Support deterministic functions
mmacklin Apr 20, 2026
3d6d4ba
Clean deterministic ABI naming
mmacklin Apr 21, 2026
be06d95
Support deterministic struct targets
eric-heiden May 1, 2026
b010ea0
Address deterministic review comments
eric-heiden May 5, 2026
a925a72
Support deterministic sliced atomics
eric-heiden May 8, 2026
ff873d6
Update deterministic changelog ref
eric-heiden May 8, 2026
75c886e
Resolve deterministic merge conflicts
eric-heiden May 8, 2026
6e6c9e6
Resolve changelog merge conflict
eric-heiden May 8, 2026
8d74aa3
Resolve changelog conflict
eric-heiden May 8, 2026
218af6f
Fix deterministic counter edge cases
eric-heiden May 8, 2026
3833309
Avoid deterministic merge conflicts
eric-heiden May 8, 2026
768c669
Document deterministic helper flow
eric-heiden May 8, 2026
fdf8e8e
Fix deterministic option cache toggles
eric-heiden May 8, 2026
d2f6bd6
Guard integer atomics in counter passes
eric-heiden May 8, 2026
d8ce57d
Clarify deterministic comments
eric-heiden May 8, 2026
a7d8c1f
Address deterministic review cleanup
eric-heiden May 8, 2026
145bf6e
Address deterministic capture edge cases
eric-heiden May 9, 2026
28380ab
Avoid graph replay overflow sync
eric-heiden May 9, 2026
d079bf4
Clean up deterministic review nits
eric-heiden May 9, 2026
381a4aa
Avoid import-time benchmark options
eric-heiden May 9, 2026
c506742
Fix helper counter phase suppression
eric-heiden May 9, 2026
ce4ce4e
Clarify deterministic graph limits
eric-heiden May 9, 2026
85fac0c
Guard deterministic helper stores
eric-heiden May 9, 2026
2367ce1
Regenerate config docs
eric-heiden May 9, 2026
731ed6c
Address deterministic review feedback
eric-heiden May 9, 2026
f7fc45b
Guard deterministic capture limits
eric-heiden May 9, 2026
fe53ed4
Fix deterministic graph workspaces
eric-heiden May 9, 2026
2c38196
Update deterministic limitations doc
eric-heiden May 9, 2026
93a1b6d
Avoid sorting unused scatter records
eric-heiden May 9, 2026
adbb878
Fix deterministic enum test Sonar hotspot
eric-heiden May 9, 2026
a7c9335
Fix block-dependent tile codegen reuse
eric-heiden May 9, 2026
c7f4f96
Clarify deterministic scatter capture limit
eric-heiden May 9, 2026
659876e
Reduce launch benchmark overhead
eric-heiden May 9, 2026
771578a
Tighten deterministic atomic validation
eric-heiden May 9, 2026
ebe317a
Restore direct CUDA launch path
eric-heiden May 9, 2026
e2f5977
Avoid launch-time option resolution
eric-heiden May 9, 2026
4f0bfca
Document deterministic execution
eric-heiden May 9, 2026
803984c
Clarify deterministic counter slots
eric-heiden May 9, 2026
945f829
Reject counter graph capture
eric-heiden May 9, 2026
0b8fe83
Refocus determinism quick start
eric-heiden May 9, 2026
b3a79e0
Clarify boolean deterministic aliases
eric-heiden May 9, 2026
dc17c2c
Fix deterministic counter ASV
eric-heiden May 9, 2026
4b1d244
Fix deterministic benchmark capture
eric-heiden May 9, 2026
5f7ea05
Clarify determinism docs
eric-heiden May 9, 2026
18a1b51
Stabilize deterministic ASV params
eric-heiden May 9, 2026
c09fc8f
Clarify deterministic bool aliases
eric-heiden May 9, 2026
cd882e1
Add manual determinism docs
eric-heiden May 9, 2026
0abd155
Clarify deterministic reduction cost
eric-heiden May 9, 2026
40b670f
Expand determinism performance docs
eric-heiden May 9, 2026
4a153c6
Clarify determinism benchmark timing
eric-heiden May 9, 2026
c4718cc
Optimize deterministic counter writeback
eric-heiden May 10, 2026
edef2aa
Emit block tile strides symbolically
eric-heiden May 10, 2026
2e86921
Support deterministic counters in graphs
eric-heiden May 10, 2026
bc8a3ed
Support indexed deterministic counters
eric-heiden May 10, 2026
53755ea
Fix custom adjoint nondeterministic mode
eric-heiden May 10, 2026
e557947
Address deterministic review comments
eric-heiden May 11, 2026
be260a9
Remove unused deterministic counter export
eric-heiden May 11, 2026
2340d94
Clarify custom adjoint determinism
eric-heiden May 11, 2026
d4f5d83
Fix deterministic counter review nits
eric-heiden May 11, 2026
5d663b8
Guard APIC capture checks
eric-heiden May 11, 2026
39f41fc
Adapt deterministic sync to GitHub base
eric-heiden May 11, 2026
d0284f4
Align APIC capture guard with main
eric-heiden May 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

### Added

- Add deterministic execution mode for atomic operations via `wp.config.deterministic = True`.
Floating-point atomic accumulations use a scatter-sort-reduce strategy for bit-exact
reproducibility across runs. Counter/allocator atomics (where the return value is used)
use automatic two-pass execution with prefix-sum-based slot assignment. Configurable at
the global, module, and kernel level.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
- Add double-precision (`wp.float64`) support to `warp.fem`.
Precision is selected via the geometry (e.g. `scalar_type=wp.float64` on grid constructors)
and propagated automatically to function spaces, quadrature, fields, and integration kernels
Expand Down
2 changes: 2 additions & 0 deletions build_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def main(argv: list[str] | None = None) -> int:
"native/texture.cpp",
"native/mathdx.cpp",
"native/coloring.cpp",
"native/deterministic.cpp",
"native/fastcall.cpp",
]
warp_cpp_paths = [os.path.join(build_path, cpp) for cpp in cpp_sources]
Expand All @@ -533,6 +534,7 @@ def main(argv: list[str] | None = None) -> int:
else:
cuda_sources = [
"native/bvh.cu",
"native/deterministic.cu",
"native/mesh.cu",
"native/sort.cu",
"native/hashgrid.cu",
Expand Down
223 changes: 223 additions & 0 deletions design/deterministic-execution.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Deterministic Execution Mode

**Status**: Implemented

## Motivation

GPU atomic operations on floating-point arrays are inherently non-deterministic:
threads execute in unpredictable order, and since float addition is
non-associative, different execution orderings produce different rounding,
yielding different results each run. This also applies to counter/slot-allocation
patterns (``slot = wp.atomic_add(counter, 0, 1)``) where the thread-to-slot
assignment varies across runs, causing downstream writes to differ.

Customers need bit-exact reproducibility for debugging, regression testing, and
certification workflows. The manual workaround --- rewriting algorithms to use
two-pass count-scan-write patterns or sorted reductions --- is painful and
error-prone. Users want a simple toggle that makes their existing code
deterministic without algorithm rewrites.

## Requirements

| ID | Requirement | Priority | Notes |
| --- | --- | --- | --- |
| R1 | ``wp.config.deterministic = True`` makes float atomic accumulations bit-exact reproducible across runs | Must | Core value proposition |
| R2 | Counter/allocator pattern (``slot = wp.atomic_add(counter, 0, 1)``) produces deterministic slot assignments | Must | Common in compaction, particle emission |
| R3 | Both patterns work in the same kernel simultaneously | Must | Real workloads mix accumulation and allocation |
| R4 | Integer atomics with unused return values incur no overhead | Must | Already associative+commutative |
| R5 | CPU execution unaffected (already sequential/deterministic) | Must | Zero overhead on CPU |
| R6 | Per-module and per-kernel granularity via ``module_options`` | Should | Allows selective opt-in |
| R7 | Backward pass (autodiff) gradient accumulation is also deterministic | Should | Adjoint atomics are Pattern A |
| R8 | Multiple target arrays in one kernel each get independent buffers | Must | Real kernels write to N arrays |

**Non-goals**:
- ``atomic_cas``/``atomic_exch`` determinism (inherently order-dependent).
- Tile-level atomic operations (``tile_atomic_add``).
- Kernels where counter contributions depend on scratch array writes within the
same kernel (Phase 0 suppresses all side effects; documented limitation).

## User Configuration

Deterministic mode can be enabled at three scopes:

- **Global**: set ``wp.config.deterministic = True`` before module compilation.
- **Global diagnostics**: set ``wp.config.deterministic_debug = True`` to emit
debug diagnostics for deterministic scatter overflow.
- **Per shared module**: call ``wp.set_module_options({"deterministic": True})``
in the Python module that defines the kernels, just like other module-level
options such as ``enable_backward``.
- **Per kernel**: use ``@wp.kernel(deterministic=True)`` and optionally set a
minimum per-target scatter buffer capacity with
``deterministic_capacity=...``.

Like ``enable_backward``, the setting participates in module compilation and
hashing. A kernel defined in a shared module inherits that module's
``deterministic`` option; a unique-module kernel can override it independently.

## Supported Operators

Deterministic mode currently handles these atomic builtins:

- ``atomic_add``
- ``atomic_sub``
- ``atomic_min``
- ``atomic_max``

Handling depends on how the atomic is used:

- **Pattern A: accumulation, return value unused**
Examples: ``wp.atomic_add(arr, i, value)``, ``arr[i] += value``.
Floating-point ``add/sub/min/max`` are redirected through scatter-sort-reduce.
Integer ``add/sub/min/max`` with unused return values are left on the normal
atomic path because the final value is already deterministic.
- **Pattern B: counter / allocator, return value consumed**
Example: ``slot = wp.atomic_add(counter, 0, 1)``.
This is handled with the two-pass count-scan-execute path.

Operators that are not supported by deterministic mode:

- ``atomic_cas``
- ``atomic_exch``
- Tile atomics such as ``tile_atomic_add``

Bitwise integer atomics (``atomic_and``, ``atomic_or``, ``atomic_xor``) are not
transformed because their final results are already deterministic for the
unused-return case.

## Design

### Approach

Two distinct atomic usage patterns require two strategies, both transparent to
the user:

**Pattern A --- Scatter-Sort-Reduce** (accumulation, return value unused):

Instead of performing ``atomic_add`` in-place during kernel execution, each
thread writes a ``(sort_key, value)`` record to a temporary scatter buffer. The
sort key packs ``(dest_index << 32 | thread_id)``. After the kernel completes,
a CUB radix sort orders the fixed-capacity scatter buffer, then a custom CUDA
kernel walks the sorted records and accumulates each destination's values
left-to-right in a fixed (thread-id) order. Unused buffer slots are initialized
with an invalid sentinel key and sort to the end. This avoids host-side scatter
count readbacks and keeps the path compatible with CUDA graph capture.

**Pattern B --- Two-Pass Execution** (counter/allocator, return value used):

The kernel runs twice:
1. *Phase 0 (counting)*: The kernel executes with all side effects suppressed.
Counter atomics record per-thread contributions to a scratch array instead
of performing the actual atomic.
2. *Prefix sum*: ``wp.utils.array_scan(contrib, prefix, inclusive=False)``
computes deterministic per-thread offsets.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
3. *Phase 1 (execution)*: The kernel re-executes. Counter atomics return the
deterministic offset from the prefix sum. All other operations (including
Pattern A scatters) execute normally.

This mirrors the well-established count-scan-write pattern already used by
``warp/_src/marching_cubes.py`` and the FEM geometry code, but applied
automatically.

### Alternatives Considered

| Alternative | Why rejected |
| --- | --- |
| Fixed-point integer accumulation | Loses precision/range; not general for all float types |
| Per-thread output arrays | ``O(threads * output_size)`` memory; doesn't scale |
| Serialized atomics (mutex) | Prohibitively slow on GPU |
| Kahan compensated summation | Reduces error but does not guarantee determinism |
| Taint-based selective side-effect suppression for Phase 0 | Significant codegen complexity for an edge case; deferred to a future version |

### Key Implementation Details

**Atomic classification** happens in ``Adjoint._emit_deterministic_atomic()``
during the codegen build phase. The ``_det_in_assign`` flag on the ``Adjoint``
(set by ``emit_Assign``, cleared after) distinguishes Pattern B (return consumed
in an assignment like ``slot = wp.atomic_add(...)``) from Pattern A (return
discarded, as in ``arr[i] += val`` or bare ``wp.atomic_add(...)``). Only
``atomic_add``, ``atomic_sub``, ``atomic_min``, and ``atomic_max`` are
intercepted. Integer atomics with unused return values are skipped entirely
(already deterministic).

**CPU/CUDA dual compilation** is handled with ``#ifdef __CUDA_ARCH__`` guards
in the generated function body. On CUDA the scatter or phase-branching code
executes; on CPU the normal ``wp::atomic_add(...)`` call is emitted in the
``#else`` branch. This is necessary because Warp generates a single function
body that compiles for both targets.

**Hidden kernel parameters** are appended to the CUDA kernel signature by
``codegen_kernel()`` after the user arguments. Pattern B kernels get
``_wp_det_phase``, ``_wp_det_contrib_N``, ``_wp_det_prefix_N``. Pattern A
targets get ``_wp_scatter_keys_N``, ``_wp_scatter_vals_N``,
``_wp_scatter_ctr_N``, ``_wp_scatter_overflow_N``, ``_wp_scatter_cap_N``, and
``_wp_det_debug``. The launch system
(``_launch_deterministic``) allocates these buffers and appends the
corresponding ctypes params to the launch args.

**Multiple reduction buffers**: each distinct ``(target array, value type,
reduction op)`` combination gets its own scatter buffer set. Multiple call
sites with the same target and reduction op share one buffer. The
``DeterministicMeta`` dataclass on the kernel's ``Adjoint`` tracks all scatter
and counter targets discovered during codegen.

**Scatter capacity**: each scatter target uses a fixed-capacity buffer sized
from a code-generated lower bound (static records-per-thread analysis), with
``deterministic_capacity`` acting as a per-target minimum capacity floor. On
overflow, new records are truncated, a device-side overflow flag is set, and
optional diagnostics may be emitted when ``wp.config.deterministic_debug`` is
enabled.

**Counter total writeback**: after the prefix sum in Phase 0, the launch system
copies the total count (last element of the inclusive scan) back to the actual
counter array so user code that reads it post-launch sees the correct value.

**Files added/modified**:

| File | Role |
| --- | --- |
| ``warp/config.py`` | ``deterministic`` global flag |
| ``warp/_src/deterministic.py`` | Dataclasses, buffer allocation, sort-reduce orchestration |
| ``warp/_src/codegen.py`` | Atomic classification, scatter/phase codegen, hidden kernel params |
| ``warp/_src/context.py`` | Module option, ``_launch_deterministic()`` orchestrator, ctypes bindings |
| ``warp/native/deterministic.h`` | Device-side ``wp::deterministic::scatter()`` template |
| ``warp/native/deterministic.cu`` | CUB radix sort + segmented reduce kernel |
| ``warp/native/deterministic.cpp`` | CPU stubs (linker satisfaction when CUDA unavailable) |

## Testing Strategy

23 tests in ``warp/tests/test_deterministic.py`` cover:

- **Bit-exact reproducibility** (Pattern A): launch the same kernel 10 times
with ``deterministic=True``, assert ``np.array_equal`` across all runs for
float32 scatter-add, ``+=`` syntax, float64, and atomic-sub.
- **Correctness** (Pattern A): compare GPU deterministic output against a
sequential CPU reference within ``rtol=1e-4``.
- **Multiple arrays**: kernel that atomically adds to three different output
arrays simultaneously.
- **Mixed reduce ops on one array**: ``atomic_add`` and ``atomic_max`` targeting
the same destination array are reduced independently.
- **Multi-dimensional indexing**: 2D array ``atomic_add`` with row/col indices.
- **Scatter capacity accounting**: kernels with more than two deterministic
scatters per thread to the same target do not overflow a fixed heuristic
buffer.
- **Counter reproducibility** (Pattern B): ``slot = atomic_add(counter, 0, 1);
output[slot] = data[tid]`` produces identical output arrays across 10 runs.
- **Phase 0 side-effect suppression**: non-counter array writes are skipped in
the counting pass.
- **Counter correctness**: verifies counter value equals N and output is a
permutation of input.
- **Conditional counter**: stream compaction (only elements above threshold),
verifying correct count and reproducible output.
- **Mixed pattern**: both counter and accumulation in one kernel.
- **Integer passthrough**: integer ``atomic_add`` with unused return incurs no
transformation; result matches ``np.bincount``.
- **Per-module override**: ``@wp.kernel(module_options={"deterministic": True},
module="unique")`` works with global config off.
- **Kernel decorator override**: ``@wp.kernel(deterministic=True)`` works with
global config off.
- **Recorded launch support**: ``wp.launch(..., record_cmd=True)`` works for
deterministic CUDA kernels.
- **Graph capture support**: deterministic scatter launches can be captured and
replayed with CUDA graphs.
- All tests run on both CPU and CUDA where applicable. Existing
``test_atomic.py`` (158 tests) passes with zero regressions.
Loading