diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index b2d107647..377d3189c 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -34,6 +34,7 @@ from mujoco_warp._src.types import Model from mujoco_warp._src.types import mat23 from mujoco_warp._src.types import mat63 +from mujoco_warp._src.types import vec5 from mujoco_warp._src.warp_util import event_scope wp.set_module_options({"enable_backward": False}) @@ -748,6 +749,216 @@ def _narrowphase(m: Model, d: Data, ctx: CollisionContext): flex_narrowphase(m, d) +# Maximum geomcollisionid packed into sort key. Primitive box<>box generates at +# most 8 contacts, so larger geomcollisionid values need not contribute to the +# deterministic ordering key. +_CONTACT_SORT_GCID_MAX = 8 + + +@wp.kernel +def _compute_contact_sort_keys( + # Model: + ngeom: int, + # Data in: + contact_geom_in: wp.array[wp.vec2i], + contact_worldid_in: wp.array[int], + contact_geomcollisionid_in: wp.array[int], + nacon_in: wp.array[int], + # In: + gcid_max: int, + # Out: + sort_keys_out: wp.array[int], + sort_indices_out: wp.array[int], +): + """Compute composite sort keys for deterministic contact ordering.""" + cid = wp.tid() + sort_indices_out[cid] = cid + if cid >= nacon_in[0]: + sort_keys_out[cid] = 2147483647 # INT_MAX: inactive contacts sort to end + return + geom = contact_geom_in[cid] + wid = contact_worldid_in[cid] + gcid = wp.min(contact_geomcollisionid_in[cid], gcid_max - 1) + sort_keys_out[cid] = ((wid * ngeom + geom[0]) * ngeom + geom[1]) * gcid_max + gcid + + +@wp.kernel +def _permute_contacts( + # Data in: + nacon_in: wp.array[int], + # In: + perm_in: wp.array[int], + src_dist_in: wp.array[float], + src_pos_in: wp.array[wp.vec3], + src_frame_in: wp.array[wp.mat33], + src_includemargin_in: wp.array[float], + src_friction_in: wp.array[vec5], + src_solref_in: wp.array[wp.vec2], + src_solreffriction_in: wp.array[wp.vec2], + src_solimp_in: wp.array[vec5], + src_dim_in: wp.array[int], + src_geom_in: wp.array[wp.vec2i], + src_flex_in: wp.array[wp.vec2i], + src_vert_in: wp.array[wp.vec2i], + src_worldid_in: wp.array[int], + src_type_in: wp.array[int], + src_gcid_in: wp.array[int], + src_efc_in: wp.array2d[int], + # Out: + dst_dist_out: wp.array[float], + dst_pos_out: wp.array[wp.vec3], + dst_frame_out: wp.array[wp.mat33], + dst_includemargin_out: wp.array[float], + dst_friction_out: wp.array[vec5], + dst_solref_out: wp.array[wp.vec2], + dst_solreffriction_out: wp.array[wp.vec2], + dst_solimp_out: wp.array[vec5], + dst_dim_out: wp.array[int], + dst_geom_out: wp.array[wp.vec2i], + dst_flex_out: wp.array[wp.vec2i], + dst_vert_out: wp.array[wp.vec2i], + dst_worldid_out: wp.array[int], + dst_type_out: wp.array[int], + dst_gcid_out: wp.array[int], + dst_efc_out: wp.array2d[int], +): + """Permute contact fields using sorted indices.""" + cid = wp.tid() + if cid >= nacon_in[0]: + return + src = perm_in[cid] + dst_dist_out[cid] = src_dist_in[src] + dst_pos_out[cid] = src_pos_in[src] + dst_frame_out[cid] = src_frame_in[src] + dst_includemargin_out[cid] = src_includemargin_in[src] + dst_friction_out[cid] = src_friction_in[src] + dst_solref_out[cid] = src_solref_in[src] + dst_solreffriction_out[cid] = src_solreffriction_in[src] + dst_solimp_out[cid] = src_solimp_in[src] + dst_dim_out[cid] = src_dim_in[src] + dst_geom_out[cid] = src_geom_in[src] + dst_flex_out[cid] = src_flex_in[src] + dst_vert_out[cid] = src_vert_in[src] + dst_worldid_out[cid] = src_worldid_in[src] + dst_type_out[cid] = src_type_in[src] + dst_gcid_out[cid] = src_gcid_in[src] + for j in range(src_efc_in.shape[1]): + dst_efc_out[cid, j] = src_efc_in[src, j] + + +def _sort_contacts(m: Model, d: Data): + """Sort contacts by (worldid, geom0, geom1, geomcollisionid) for determinism.""" + if d.naconmax == 0: + return + + # Check for sort-key overflow. Fall back to no-gcid key if needed. + gcid_max = _CONTACT_SORT_GCID_MAX + if d.nworld * m.ngeom * m.ngeom * gcid_max > 2**31 - 1: + gcid_max = 1 + + # Allocate sort buffers (radix_sort_pairs needs 2x capacity for internal use). + sort_keys = wp.empty(2 * d.naconmax, dtype=int) + sort_indices = wp.empty(2 * d.naconmax, dtype=int) + + # Step 1: Compute sort keys and initialise indices to identity. + wp.launch( + _compute_contact_sort_keys, + dim=d.naconmax, + inputs=[ + m.ngeom, + d.contact.geom, + d.contact.worldid, + d.contact.geomcollisionid, + d.nacon, + gcid_max, + ], + outputs=[sort_keys, sort_indices], + ) + + # Step 2: Stable radix sort on keys, carrying indices. + wp.utils.radix_sort_pairs(sort_keys, sort_indices, d.naconmax) + + # TODO(team): investigate a single kernel that copies all contact fields to scratch. + # Step 3: Copy contact fields to temporary buffers. + tmp_dist = wp.empty_like(d.contact.dist) + tmp_pos = wp.empty_like(d.contact.pos) + tmp_frame = wp.empty_like(d.contact.frame) + tmp_includemargin = wp.empty_like(d.contact.includemargin) + tmp_friction = wp.empty_like(d.contact.friction) + tmp_solref = wp.empty_like(d.contact.solref) + tmp_solreffriction = wp.empty_like(d.contact.solreffriction) + tmp_solimp = wp.empty_like(d.contact.solimp) + tmp_dim = wp.empty_like(d.contact.dim) + tmp_geom = wp.empty_like(d.contact.geom) + tmp_flex = wp.empty_like(d.contact.flex) + tmp_vert = wp.empty_like(d.contact.vert) + tmp_worldid = wp.empty_like(d.contact.worldid) + tmp_type = wp.empty_like(d.contact.type) + tmp_gcid = wp.empty_like(d.contact.geomcollisionid) + tmp_efc = wp.empty_like(d.contact.efc_address) + + wp.copy(tmp_dist, d.contact.dist) + wp.copy(tmp_pos, d.contact.pos) + wp.copy(tmp_frame, d.contact.frame) + wp.copy(tmp_includemargin, d.contact.includemargin) + wp.copy(tmp_friction, d.contact.friction) + wp.copy(tmp_solref, d.contact.solref) + wp.copy(tmp_solreffriction, d.contact.solreffriction) + wp.copy(tmp_solimp, d.contact.solimp) + wp.copy(tmp_dim, d.contact.dim) + wp.copy(tmp_geom, d.contact.geom) + wp.copy(tmp_flex, d.contact.flex) + wp.copy(tmp_vert, d.contact.vert) + wp.copy(tmp_worldid, d.contact.worldid) + wp.copy(tmp_type, d.contact.type) + wp.copy(tmp_gcid, d.contact.geomcollisionid) + wp.copy(tmp_efc, d.contact.efc_address) + + # Step 4: Gather-permute from temp buffers back into contact arrays. + wp.launch( + _permute_contacts, + dim=d.naconmax, + inputs=[ + d.nacon, + sort_indices, + tmp_dist, + tmp_pos, + tmp_frame, + tmp_includemargin, + tmp_friction, + tmp_solref, + tmp_solreffriction, + tmp_solimp, + tmp_dim, + tmp_geom, + tmp_flex, + tmp_vert, + tmp_worldid, + tmp_type, + tmp_gcid, + tmp_efc, + ], + outputs=[ + d.contact.dist, + d.contact.pos, + d.contact.frame, + d.contact.includemargin, + d.contact.friction, + d.contact.solref, + d.contact.solreffriction, + d.contact.solimp, + d.contact.dim, + d.contact.geom, + d.contact.flex, + d.contact.vert, + d.contact.worldid, + d.contact.type, + d.contact.geomcollisionid, + d.contact.efc_address, + ], + ) + + @event_scope def collision(m: Model, d: Data): """Runs the full collision detection pipeline. @@ -782,5 +993,8 @@ def collision(m: Model, d: Data): _narrowphase(m, d, ctx) + if m.opt.deterministic: + _sort_contacts(m, d) + if m.callback.contactfilter: m.callback.contactfilter(m, d) diff --git a/mujoco_warp/_src/determinism_test.py b/mujoco_warp/_src/determinism_test.py new file mode 100644 index 000000000..5ece990c7 --- /dev/null +++ b/mujoco_warp/_src/determinism_test.py @@ -0,0 +1,206 @@ +# Copyright 2026 The Newton Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GPU determinism (contact sorting).""" + +import numpy as np +import warp as wp +from absl.testing import absltest +from absl.testing import parameterized + +import mujoco_warp as mjw +from mujoco_warp import test_data +from mujoco_warp._src import collision_driver + +_NSTEPS = 10 +_CONTACT_FIELDS = ( + "dist", + "pos", + "frame", + "includemargin", + "friction", + "solref", + "solreffriction", + "solimp", + "dim", + "geom", + "flex", + "vert", + "efc_address", + "worldid", + "type", + "geomcollisionid", +) + + +def _run_and_collect_contacts(path, nworld, nsteps, deterministic): + """Run simulation and return contact geom arrays from last step.""" + _, _, m, d = test_data.fixture(path=path, nworld=nworld) + m.opt.deterministic = deterministic + for _ in range(nsteps): + mjw.step(m, d) + nacon = d.nacon.numpy()[0] + return { + "nacon": nacon, + "geom": d.contact.geom.numpy()[:nacon].copy(), + "dist": d.contact.dist.numpy()[:nacon].copy(), + "pos": d.contact.pos.numpy()[:nacon].copy(), + "frame": d.contact.frame.numpy()[:nacon].copy(), + "dim": d.contact.dim.numpy()[:nacon].copy(), + "worldid": d.contact.worldid.numpy()[:nacon].copy(), + "geomcollisionid": d.contact.geomcollisionid.numpy()[:nacon].copy(), + } + + +def _copy_contact_fields(d): + """Return copies of every contact array.""" + return {field: getattr(d.contact, field).numpy().copy() for field in _CONTACT_FIELDS} + + +def _write_contact_fields(d, contact_fields): + """Write full contact arrays back to device memory.""" + for field, values in contact_fields.items(): + arr = getattr(d.contact, field) + wp.copy(arr, wp.array(values, dtype=arr.dtype, device=arr.device)) + + +def _permute_active_contacts(contact_fields, nacon, perm): + """Return a copy with the active contacts permuted by `perm`.""" + permuted = {field: values.copy() for field, values in contact_fields.items()} + for field, values in permuted.items(): + values[:nacon] = values[perm] + return permuted + + +def _sorted_contact_order(contact_fields, nacon): + """Return stable sorted indices for the active contacts.""" + geom = contact_fields["geom"] + worldid = contact_fields["worldid"] + geomcollisionid = contact_fields["geomcollisionid"] + return sorted( + range(nacon), + key=lambda idx: ( + int(worldid[idx]), + int(geom[idx, 0]), + int(geom[idx, 1]), + int(geomcollisionid[idx]), + ), + ) + + +class ContactSortDeterminismTest(parameterized.TestCase): + """Tests that contact sorting produces deterministic contact ordering.""" + + @parameterized.parameters( + ("collision.xml", 1), + ("collision.xml", 4), + ("humanoid/humanoid.xml", 1), + ("humanoid/humanoid.xml", 4), + ) + def test_contact_ordering_deterministic(self, path, nworld): + """Contacts are bitwise identical across multiple runs.""" + nruns = 3 + results = [_run_and_collect_contacts(path, nworld, _NSTEPS, True) for _ in range(nruns)] + + # Verify contacts were generated. + self.assertGreater(results[0]["nacon"], 0, f"No contacts for {path}") + + for run in range(1, nruns): + self.assertEqual(results[0]["nacon"], results[run]["nacon"]) + np.testing.assert_array_equal( + results[0]["geom"], + results[run]["geom"], + err_msg=f"Contact geom ordering differs: run 0 vs run {run}", + ) + + @parameterized.parameters( + ("collision.xml", 1), + ("humanoid/humanoid.xml", 1), + ) + def test_contact_fields_deterministic(self, path, nworld): + """All contact fields are bitwise identical across runs.""" + nruns = 3 + results = [_run_and_collect_contacts(path, nworld, _NSTEPS, True) for _ in range(nruns)] + + self.assertGreater(results[0]["nacon"], 0) + + for run in range(1, nruns): + self.assertEqual(results[0]["nacon"], results[run]["nacon"]) + for field in ("dist", "pos", "frame", "geom", "dim", "worldid", "geomcollisionid"): + np.testing.assert_array_equal( + results[0][field], + results[run][field], + err_msg=f"{field} differs: run 0 vs run {run}", + ) + + def test_contacts_sorted_by_geom(self): + """Contacts are sorted by (worldid, geom0, geom1) after deterministic step.""" + result = _run_and_collect_contacts("collision.xml", 1, _NSTEPS, True) + + nacon = result["nacon"] + self.assertGreater(nacon, 1) + + geom = result["geom"] + worldid = result["worldid"] + + # Verify sorted: (worldid, geom0, geom1) is non-decreasing. + for i in range(1, nacon): + key_prev = (worldid[i - 1], geom[i - 1, 0], geom[i - 1, 1]) + key_curr = (worldid[i], geom[i, 0], geom[i, 1]) + self.assertLessEqual( + key_prev, + key_curr, + f"Contacts not sorted at index {i}: {key_prev} > {key_curr}", + ) + + def test_sort_contacts_reorders_mixed_contacts(self): + """Sorting restores deterministic contact order after contacts are mixed.""" + _, _, m, d = test_data.fixture(path="collision.xml", nworld=4) + m.opt.deterministic = False + + mjw.forward(m, d) + + nacon = d.nacon.numpy()[0] + self.assertGreaterEqual(nacon, 5) + + original = _copy_contact_fields(d) + perm = np.concatenate((np.arange(1, nacon, 2), np.arange(0, nacon, 2))) + self.assertFalse(np.array_equal(perm, np.arange(nacon))) + + mixed = _permute_active_contacts(original, nacon, perm) + _write_contact_fields(d, mixed) + + expected_order = _sorted_contact_order(mixed, nacon) + expected = _permute_active_contacts(mixed, nacon, expected_order) + + collision_driver._sort_contacts(m, d) + + actual = _copy_contact_fields(d) + self.assertEqual(d.nacon.numpy()[0], nacon) + + for field in _CONTACT_FIELDS: + np.testing.assert_array_equal( + actual[field][:nacon], + expected[field][:nacon], + err_msg=f"{field} was not permuted into deterministic order", + ) + + def test_deterministic_flag_default_false(self): + """The deterministic flag defaults to False.""" + _, _, m, _ = test_data.fixture(path="collision.xml") + self.assertFalse(m.opt.deterministic) + + +if __name__ == "__main__": + absltest.main() diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 0c5a1f658..c8ce72a34 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -197,6 +197,7 @@ def _check_friction(name: str, id_: int, condim: int, friction, checks): opt.contact_sensor_maxmatch = mjm.numeric_data[mjm.numeric_adr[contact_sensor_maxmatch_id]] else: opt.contact_sensor_maxmatch = 64 + opt.deterministic = False # place opt on device for f in dataclasses.fields(types.Option): @@ -2527,6 +2528,7 @@ def override_model(model: types.Model | mujoco.MjModel, overrides: dict[str, Any "opt.ls_parallel", "opt.graph_conditional", "opt.contact_sensor_maxmatch", + "opt.deterministic", } mj_only_fields = {"opt.jacobian"} diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 7a9dc2a59..7a89792c9 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -740,6 +740,9 @@ class Option: zeros out the contacts at each step) contact_sensor_maxmatch: max number of contacts considered by contact sensor matching criteria contacts matched after this value is exceded will be ignored + deterministic: enable deterministic contact ordering after narrowphase + TODO update this description as more parts of the + simulation pipeline gain optional deterministic results """ timestep: array("*", float) @@ -770,6 +773,7 @@ class Option: graph_conditional: bool run_collision_detection: bool contact_sensor_maxmatch: int + deterministic: bool @dataclasses.dataclass