From a2dd916d44bf8ad2898a475b0cfa9e4b7fd3e59f Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Tue, 21 Apr 2026 14:48:57 -0400 Subject: [PATCH 1/7] Adds diagnostics --- test/models/globe/test_barnes_hut_kernel.py | 189 +++++++++++++------- 1 file changed, 122 insertions(+), 67 deletions(-) diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 7b8d3a7d9c..1476faa226 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -21,6 +21,7 @@ MultiscaleKernel integration. """ +import contextlib from typing import Any, Literal import pytest @@ -48,6 +49,38 @@ # --------------------------------------------------------------------------- +@contextlib.contextmanager +def _pinned_cpu_determinism(): + """Pin intraop threads to 1 and enable deterministic algorithms for the block. + + Diagnostic stabilization for tests where BH/exact comparisons are expected + to match to within pure fp32 rearrangement error. Targets two plausible + CPU-side non-determinism sources: + + 1. ``torch.set_num_threads(1)`` removes reduction-order variance in + parallelized ops (matmul/einsum/some scatter paths) whose output can + differ by a few ULP depending on how work is sharded across cores. + CI runners and developer boxes routinely disagree on thread count. + 2. ``torch.use_deterministic_algorithms(True, warn_only=True)`` forces + deterministic implementations where available and warns (rather than + erroring) otherwise, so a newly-introduced non-deterministic op + surfaces in logs instead of breaking the suite. + + Both settings are restored on exit. This is not a fixture: callers opt + in explicitly, keeping the scope obvious and bounded. + """ + orig_threads = torch.get_num_threads() + orig_det = torch.are_deterministic_algorithms_enabled() + orig_det_warn = torch.is_deterministic_algorithms_warn_only_enabled() + try: + torch.set_num_threads(1) + torch.use_deterministic_algorithms(True, warn_only=True) + yield + finally: + torch.use_deterministic_algorithms(orig_det, warn_only=orig_det_warn) + torch.set_num_threads(orig_threads) + + def _make_bh_kernel_and_data( n_spatial_dims: int = 2, n_source_scalars: int = 0, @@ -893,77 +926,99 @@ def test_bh_nested_source_data_keys(n_dims: int): The aggregation, split_by_leaf_rank, and TensorDict.cat operations must handle this nesting correctly. - """ - torch.manual_seed(DEFAULT_SEED) - n_src, n_tgt = 30, 15 - - source_data_ranks = { - "physical": {"pressure": 0}, - "latent": {"scalars": {"0": 0, "1": 0}, "vectors": {"0": 1}}, - "normals": 1, - } - output_field_ranks = {"p": 0, "u": 1} - - common_kwargs = dict( - n_spatial_dims=n_dims, - output_field_ranks={ - k: (0 if v == "scalar" else 1) for k, v in output_field_ranks.items() - }, - source_data_ranks=source_data_ranks, - hidden_layer_sizes=[16], - ) - bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) - exact_kernel = Kernel(**common_kwargs) - exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) - bh_kernel.eval() - exact_kernel.eval() - - torch.manual_seed(DEFAULT_SEED + 1) - source_data = TensorDict( - { - "physical": TensorDict( - {"pressure": torch.randn(n_src)}, - batch_size=[n_src], - ), - "latent": TensorDict( - { - "scalars": TensorDict( - {"0": torch.randn(n_src), "1": torch.randn(n_src)}, - batch_size=[n_src], - ), - "vectors": TensorDict( - {"0": F.normalize(torch.randn(n_src, n_dims), dim=-1)}, - batch_size=[n_src], - ), - }, - batch_size=[n_src], - ), - "normals": F.normalize(torch.randn(n_src, n_dims), dim=-1), - }, - batch_size=[n_src], - ) + The body runs under :func:`_pinned_cpu_determinism` to remove thread-count + and deterministic-algorithm variance as sources of CI flakiness, and the + ``msg`` passed to :func:`torch.testing.assert_close` is a callable so the + default "Greatest absolute/relative difference" diagnostics are preserved + when a failure occurs in CI. + """ + with _pinned_cpu_determinism(): + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 30, 15 + + source_data_ranks = { + "physical": {"pressure": 0}, + "latent": {"scalars": {"0": 0, "1": 0}, "vectors": {"0": 1}}, + "normals": 1, + } + # Rank spec with integer leaves (0 = scalar, 1 = vector): passed through + # directly to the kernels so "p" stays scalar and "u" stays vector. + output_field_ranks = {"p": 0, "u": 1} + + common_kwargs = dict( + n_spatial_dims=n_dims, + output_field_ranks=output_field_ranks, + source_data_ranks=source_data_ranks, + hidden_layer_sizes=[16], + ) - data = { - "source_points": torch.randn(n_src, n_dims), - "target_points": torch.randn(n_tgt, n_dims) * 5, - "source_strengths": torch.rand(n_src) + 0.1, - "reference_length": torch.ones(()), - "source_data": source_data, - "global_data": TensorDict({}, batch_size=[]), - } + bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) + exact_kernel = Kernel(**common_kwargs) + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + bh_kernel.eval() + exact_kernel.eval() + + torch.manual_seed(DEFAULT_SEED + 1) + source_data = TensorDict( + { + "physical": TensorDict( + {"pressure": torch.randn(n_src)}, + batch_size=[n_src], + ), + "latent": TensorDict( + { + "scalars": TensorDict( + {"0": torch.randn(n_src), "1": torch.randn(n_src)}, + batch_size=[n_src], + ), + "vectors": TensorDict( + {"0": F.normalize(torch.randn(n_src, n_dims), dim=-1)}, + batch_size=[n_src], + ), + }, + batch_size=[n_src], + ), + "normals": F.normalize(torch.randn(n_src, n_dims), dim=-1), + }, + batch_size=[n_src], + ) - exact_result = exact_kernel(**data) - bh_result = bh_kernel(**data, theta=0.01) + data = { + "source_points": torch.randn(n_src, n_dims), + "target_points": torch.randn(n_tgt, n_dims) * 5, + "source_strengths": torch.rand(n_src) + 0.1, + "reference_length": torch.ones(()), + "source_data": source_data, + "global_data": TensorDict({}, batch_size=[]), + } + + exact_result = exact_kernel(**data) + bh_result = bh_kernel(**data, theta=0.01) + + for field_name in output_field_ranks: + # Callable ``msg`` preserves the default "Greatest absolute/relative + # difference" report, which a plain string ``msg`` would replace. + # Bind loop/parameter variables via default args to dodge late + # binding across iterations. + def _msg( + default: str, + field: str = field_name, + dims: int = n_dims, + ) -> str: + return ( + f"Nested keys: {field!r} not close to exact at theta=0.01 " + f"(n_dims={dims}, num_threads={torch.get_num_threads()}, " + f"torch={torch.__version__}).\n{default}" + ) - for field_name in output_field_ranks: - torch.testing.assert_close( - bh_result[field_name], - exact_result[field_name], - atol=1e-3, - rtol=1e-2, - msg=f"Nested keys: {field_name!r} not close to exact at theta=0.01", - ) + torch.testing.assert_close( + bh_result[field_name], + exact_result[field_name], + atol=1e-3, + rtol=1e-2, + msg=_msg, + ) # --------------------------------------------------------------------------- From 4809921ab47cc94bb04d50590d45712202a2707d Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Thu, 23 Apr 2026 11:31:48 -0400 Subject: [PATCH 2/7] adds diagnostics --- test/models/globe/test_barnes_hut_kernel.py | 98 +++++++++++++++++++-- 1 file changed, 89 insertions(+), 9 deletions(-) diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 1476faa226..0a0f7b1f0d 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -955,7 +955,18 @@ def test_bh_nested_source_data_keys(n_dims: int): bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) exact_kernel = Kernel(**common_kwargs) - exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=False) + + ### Invariant 1: state_dict transfer is complete and bit-exact. + # strict=True catches any new auto-registered param/buffer across + # torch versions; the post-condition catches silent value-level drift. + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True) + bh_sd, ex_sd = bh_kernel.state_dict(), exact_kernel.state_dict() + mismatched = [k for k in bh_sd if not torch.equal(bh_sd[k], ex_sd[k])] + assert not mismatched, ( + f"state_dict value mismatch after load " + f"(torch={torch.__version__}): {mismatched}" + ) + bh_kernel.eval() exact_kernel.eval() @@ -993,14 +1004,83 @@ def test_bh_nested_source_data_keys(n_dims: int): "global_data": TensorDict({}, batch_size=[]), } - exact_result = exact_kernel(**data) - bh_result = bh_kernel(**data, theta=0.01) + ### Invariant 2: per-pair pre-aggregation outputs are bit-identical. + # At theta=0.01 all pairs are near-field, so BH and Exact both call + # _evaluate_interactions on the same (target, source) pairs with + # identical weights. Capture the pre-aggregation output from each, + # reindex BH's pair ordering into Exact's row-major (t, s) order, + # and compare with a tight tolerance that reflects "same network, + # same input, same weights." If this fires, there is a genuine + # algorithmic divergence (e.g. tensordict iteration-order change + # across library versions) and the final-sum tolerance is masking + # a real bug. + captures: dict[str, dict[str, torch.Tensor]] = {} + orig_eval = Kernel._evaluate_interactions + + def _capturing_eval(tag: str): + def _patched(self, *, scalars, vectors, device): + out = orig_eval(self, scalars=scalars, vectors=vectors, device=device) + captures[tag] = {k: v.detach().clone() for k, v in out.items()} + return out + return _patched + + try: + Kernel._evaluate_interactions = _capturing_eval("exact") + exact_result = exact_kernel(**data) + Kernel._evaluate_interactions = _capturing_eval("bh") + bh_result = bh_kernel(**data, theta=0.01) + finally: + Kernel._evaluate_interactions = orig_eval + + src_tree = ClusterTree.from_points( + data["source_points"], leaf_size=DEFAULT_LEAF_SIZE, + ) + tgt_tree = ClusterTree.from_points( + data["target_points"], leaf_size=DEFAULT_LEAF_SIZE, + ) + plan = src_tree.find_dual_interaction_pairs( + target_tree=tgt_tree, theta=0.01, + ) + assert plan.n_near == n_src * n_tgt, ( + f"Expected all-near at theta=0.01, got n_near={plan.n_near} " + f"of dense={n_src * n_tgt}" + ) + + row_of_pair = plan.near_target_ids * n_src + plan.near_source_ids + inv_perm = torch.empty_like(row_of_pair) + inv_perm[row_of_pair] = torch.arange(plan.n_near) + + for field_name in output_field_ranks: + ex_pp = captures["exact"][field_name] + bh_pp = captures["bh"][field_name] + ex_flat = ex_pp.reshape(n_tgt * n_src, *ex_pp.shape[2:]) + bh_reordered = bh_pp[inv_perm] + + torch.testing.assert_close( + bh_reordered, + ex_flat, + atol=1e-6, + rtol=1e-6, + msg=lambda default, f=field_name: ( + f"BH/Exact per-pair pre-aggregation {f!r} divergence " + f"(torch={torch.__version__}). BH and Exact paths are " + f"not computing identical per-pair tensors despite " + f"identical inputs and weights.\n{default}" + ), + ) + ### Final aggregation comparison. + # The two invariant checks above guarantee that if we reach this + # point, BH and Exact computed bit-identical per-pair outputs. + # The only remaining difference is aggregation order: Exact uses + # einsum("ts,s->t", ...) while BH uses scatter_add_. For 30-term + # fp32 sums with measured |terms| <= 0.044 and cancellation ratio + # <= ~56x, the rearrangement bound is ~30 * eps * 0.044 * 56 ≈ + # 2.2e-5. A CI run (torch 2.11.0, CPU) reported 2.02e-3 abs diff + # which is ~100x that bound and remains unexplained. The 5e-3 / + # 5e-2 ceiling matches test_bh_globe_like_config and is justified + # only because the pre-checks above confirm no algorithmic bug. for field_name in output_field_ranks: - # Callable ``msg`` preserves the default "Greatest absolute/relative - # difference" report, which a plain string ``msg`` would replace. - # Bind loop/parameter variables via default args to dodge late - # binding across iterations. def _msg( default: str, field: str = field_name, @@ -1015,8 +1095,8 @@ def _msg( torch.testing.assert_close( bh_result[field_name], exact_result[field_name], - atol=1e-3, - rtol=1e-2, + atol=5e-3, + rtol=5e-2, msg=_msg, ) From a7eec602d3ac06362004ee9cb075f2320fb602bb Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Thu, 23 Apr 2026 11:37:04 -0400 Subject: [PATCH 3/7] formatting --- test/models/globe/test_barnes_hut_kernel.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 0a0f7b1f0d..39a5a0058e 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -1022,6 +1022,7 @@ def _patched(self, *, scalars, vectors, device): out = orig_eval(self, scalars=scalars, vectors=vectors, device=device) captures[tag] = {k: v.detach().clone() for k, v in out.items()} return out + return _patched try: @@ -1033,13 +1034,16 @@ def _patched(self, *, scalars, vectors, device): Kernel._evaluate_interactions = orig_eval src_tree = ClusterTree.from_points( - data["source_points"], leaf_size=DEFAULT_LEAF_SIZE, + data["source_points"], + leaf_size=DEFAULT_LEAF_SIZE, ) tgt_tree = ClusterTree.from_points( - data["target_points"], leaf_size=DEFAULT_LEAF_SIZE, + data["target_points"], + leaf_size=DEFAULT_LEAF_SIZE, ) plan = src_tree.find_dual_interaction_pairs( - target_tree=tgt_tree, theta=0.01, + target_tree=tgt_tree, + theta=0.01, ) assert plan.n_near == n_src * n_tgt, ( f"Expected all-near at theta=0.01, got n_near={plan.n_near} " @@ -1081,6 +1085,7 @@ def _patched(self, *, scalars, vectors, device): # 5e-2 ceiling matches test_bh_globe_like_config and is justified # only because the pre-checks above confirm no algorithmic bug. for field_name in output_field_ranks: + def _msg( default: str, field: str = field_name, From 1c995b42b085fc92adabfb1377917a2819d5504f Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Thu, 23 Apr 2026 12:01:13 -0400 Subject: [PATCH 4/7] Refactor BarnesHutKernel to improve source scalar handling and ensure consistent TensorDict structure. Update test tolerances for output comparisons to enhance robustness against numerical discrepancies. --- .../experimental/models/globe/field_kernel.py | 22 +++++++++++++++++-- test/models/globe/test_barnes_hut_kernel.py | 17 +++++++------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/physicsnemo/experimental/models/globe/field_kernel.py b/physicsnemo/experimental/models/globe/field_kernel.py index 7f7cda1d36..2fc3acfd45 100644 --- a/physicsnemo/experimental/models/globe/field_kernel.py +++ b/physicsnemo/experimental/models/globe/field_kernel.py @@ -1341,14 +1341,32 @@ def _gather_and_evaluate( target_positions[tgt_ids] - source_positions[src_ids] ) / reference_length - ### Flatten source scalars into one tensor, gather once. + ### Flatten source scalars into one tensor, gather once, split back. # concatenate_leaves: 1 GPU kernel (torch.cat) # [src_ids]: 1 GPU kernel (aten::index) # Total: 2 kernels instead of K (one per TensorDict leaf). + # The split-back into named leaves mirrors the vector path below + # and ensures that _evaluate_interactions sees the same nested + # TensorDict structure as the Exact (Kernel.forward) path. Without + # this, "source_scalars" would be a flat tensor here but a nested + # TensorDict in Exact, causing concatenate_leaves inside + # _evaluate_interactions to produce different column orderings when + # TensorDict leaf-iteration order changes across library versions. + src_scalar_keys = list( + source_scalars.keys(include_nested=True, leaves_only=True) + ) gathered_src_scalars = concatenate_leaves(source_scalars)[src_ids] + gathered_scalar_leaves = { + k: gathered_src_scalars[..., i] + for i, k in enumerate(src_scalar_keys) + } scalars = TensorDict( { - "source_scalars": gathered_src_scalars, + "source_scalars": TensorDict( + gathered_scalar_leaves, + batch_size=torch.Size([n_pairs]), + device=device, + ), "global_scalars": global_scalars.expand( n_pairs, *global_scalars.batch_size ), diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 39a5a0058e..97ba1601ad 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -1077,13 +1077,12 @@ def _patched(self, *, scalars, vectors, device): # The two invariant checks above guarantee that if we reach this # point, BH and Exact computed bit-identical per-pair outputs. # The only remaining difference is aggregation order: Exact uses - # einsum("ts,s->t", ...) while BH uses scatter_add_. For 30-term - # fp32 sums with measured |terms| <= 0.044 and cancellation ratio - # <= ~56x, the rearrangement bound is ~30 * eps * 0.044 * 56 ≈ - # 2.2e-5. A CI run (torch 2.11.0, CPU) reported 2.02e-3 abs diff - # which is ~100x that bound and remains unexplained. The 5e-3 / - # 5e-2 ceiling matches test_bh_globe_like_config and is justified - # only because the pre-checks above confirm no algorithmic bug. + # einsum("ts,s->t", ...) while BH uses scatter_add_. For a + # 30-term fp32 sum with measured |terms| <= 0.044 and cancellation + # ratio <= ~56x, the rearrangement bound is ~30 * eps * 0.044 * + # 56 ≈ 2.2e-5. The tolerance matches test_bh_convergence_to_exact: + # tight enough to catch real bugs, loose enough for cross-platform + # BLAS summation-order variance. for field_name in output_field_ranks: def _msg( @@ -1100,8 +1099,8 @@ def _msg( torch.testing.assert_close( bh_result[field_name], exact_result[field_name], - atol=5e-3, - rtol=5e-2, + atol=1e-4, + rtol=1e-3, msg=_msg, ) From cb177979fca8c2759618d38be533df6f040c3f1f Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Thu, 23 Apr 2026 13:06:22 -0400 Subject: [PATCH 5/7] Removes diagnostics --- test/models/globe/test_barnes_hut_kernel.py | 351 +++++++++----------- 1 file changed, 152 insertions(+), 199 deletions(-) diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 97ba1601ad..956ac563b2 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -21,7 +21,6 @@ MultiscaleKernel integration. """ -import contextlib from typing import Any, Literal import pytest @@ -49,38 +48,6 @@ # --------------------------------------------------------------------------- -@contextlib.contextmanager -def _pinned_cpu_determinism(): - """Pin intraop threads to 1 and enable deterministic algorithms for the block. - - Diagnostic stabilization for tests where BH/exact comparisons are expected - to match to within pure fp32 rearrangement error. Targets two plausible - CPU-side non-determinism sources: - - 1. ``torch.set_num_threads(1)`` removes reduction-order variance in - parallelized ops (matmul/einsum/some scatter paths) whose output can - differ by a few ULP depending on how work is sharded across cores. - CI runners and developer boxes routinely disagree on thread count. - 2. ``torch.use_deterministic_algorithms(True, warn_only=True)`` forces - deterministic implementations where available and warns (rather than - erroring) otherwise, so a newly-introduced non-deterministic op - surfaces in logs instead of breaking the suite. - - Both settings are restored on exit. This is not a fixture: callers opt - in explicitly, keeping the scope obvious and bounded. - """ - orig_threads = torch.get_num_threads() - orig_det = torch.are_deterministic_algorithms_enabled() - orig_det_warn = torch.is_deterministic_algorithms_warn_only_enabled() - try: - torch.set_num_threads(1) - torch.use_deterministic_algorithms(True, warn_only=True) - yield - finally: - torch.use_deterministic_algorithms(orig_det, warn_only=orig_det_warn) - torch.set_num_threads(orig_threads) - - def _make_bh_kernel_and_data( n_spatial_dims: int = 2, n_source_scalars: int = 0, @@ -927,183 +894,169 @@ def test_bh_nested_source_data_keys(n_dims: int): The aggregation, split_by_leaf_rank, and TensorDict.cat operations must handle this nesting correctly. - The body runs under :func:`_pinned_cpu_determinism` to remove thread-count - and deterministic-algorithm variance as sources of CI flakiness, and the - ``msg`` passed to :func:`torch.testing.assert_close` is a callable so the - default "Greatest absolute/relative difference" diagnostics are preserved - when a failure occurs in CI. + The ``msg`` passed to :func:`torch.testing.assert_close` is a callable + so the default "Greatest absolute/relative difference" diagnostics are + preserved when a failure occurs in CI. """ - with _pinned_cpu_determinism(): - torch.manual_seed(DEFAULT_SEED) - n_src, n_tgt = 30, 15 - - source_data_ranks = { - "physical": {"pressure": 0}, - "latent": {"scalars": {"0": 0, "1": 0}, "vectors": {"0": 1}}, - "normals": 1, - } - # Rank spec with integer leaves (0 = scalar, 1 = vector): passed through - # directly to the kernels so "p" stays scalar and "u" stays vector. - output_field_ranks = {"p": 0, "u": 1} - - common_kwargs = dict( - n_spatial_dims=n_dims, - output_field_ranks=output_field_ranks, - source_data_ranks=source_data_ranks, - hidden_layer_sizes=[16], - ) + torch.manual_seed(DEFAULT_SEED) + n_src, n_tgt = 30, 15 - bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) - exact_kernel = Kernel(**common_kwargs) - - ### Invariant 1: state_dict transfer is complete and bit-exact. - # strict=True catches any new auto-registered param/buffer across - # torch versions; the post-condition catches silent value-level drift. - exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True) - bh_sd, ex_sd = bh_kernel.state_dict(), exact_kernel.state_dict() - mismatched = [k for k in bh_sd if not torch.equal(bh_sd[k], ex_sd[k])] - assert not mismatched, ( - f"state_dict value mismatch after load " - f"(torch={torch.__version__}): {mismatched}" - ) + source_data_ranks = { + "physical": {"pressure": 0}, + "latent": {"scalars": {"0": 0, "1": 0}, "vectors": {"0": 1}}, + "normals": 1, + } + output_field_ranks = {"p": 0, "u": 1} - bh_kernel.eval() - exact_kernel.eval() - - torch.manual_seed(DEFAULT_SEED + 1) - source_data = TensorDict( - { - "physical": TensorDict( - {"pressure": torch.randn(n_src)}, - batch_size=[n_src], - ), - "latent": TensorDict( - { - "scalars": TensorDict( - {"0": torch.randn(n_src), "1": torch.randn(n_src)}, - batch_size=[n_src], - ), - "vectors": TensorDict( - {"0": F.normalize(torch.randn(n_src, n_dims), dim=-1)}, - batch_size=[n_src], - ), - }, - batch_size=[n_src], - ), - "normals": F.normalize(torch.randn(n_src, n_dims), dim=-1), - }, - batch_size=[n_src], - ) + common_kwargs = dict( + n_spatial_dims=n_dims, + output_field_ranks=output_field_ranks, + source_data_ranks=source_data_ranks, + hidden_layer_sizes=[16], + ) - data = { - "source_points": torch.randn(n_src, n_dims), - "target_points": torch.randn(n_tgt, n_dims) * 5, - "source_strengths": torch.rand(n_src) + 0.1, - "reference_length": torch.ones(()), - "source_data": source_data, - "global_data": TensorDict({}, batch_size=[]), - } - - ### Invariant 2: per-pair pre-aggregation outputs are bit-identical. - # At theta=0.01 all pairs are near-field, so BH and Exact both call - # _evaluate_interactions on the same (target, source) pairs with - # identical weights. Capture the pre-aggregation output from each, - # reindex BH's pair ordering into Exact's row-major (t, s) order, - # and compare with a tight tolerance that reflects "same network, - # same input, same weights." If this fires, there is a genuine - # algorithmic divergence (e.g. tensordict iteration-order change - # across library versions) and the final-sum tolerance is masking - # a real bug. - captures: dict[str, dict[str, torch.Tensor]] = {} - orig_eval = Kernel._evaluate_interactions - - def _capturing_eval(tag: str): - def _patched(self, *, scalars, vectors, device): - out = orig_eval(self, scalars=scalars, vectors=vectors, device=device) - captures[tag] = {k: v.detach().clone() for k, v in out.items()} - return out - - return _patched - - try: - Kernel._evaluate_interactions = _capturing_eval("exact") - exact_result = exact_kernel(**data) - Kernel._evaluate_interactions = _capturing_eval("bh") - bh_result = bh_kernel(**data, theta=0.01) - finally: - Kernel._evaluate_interactions = orig_eval - - src_tree = ClusterTree.from_points( - data["source_points"], - leaf_size=DEFAULT_LEAF_SIZE, - ) - tgt_tree = ClusterTree.from_points( - data["target_points"], - leaf_size=DEFAULT_LEAF_SIZE, - ) - plan = src_tree.find_dual_interaction_pairs( - target_tree=tgt_tree, - theta=0.01, - ) - assert plan.n_near == n_src * n_tgt, ( - f"Expected all-near at theta=0.01, got n_near={plan.n_near} " - f"of dense={n_src * n_tgt}" - ) + bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) + exact_kernel = Kernel(**common_kwargs) - row_of_pair = plan.near_target_ids * n_src + plan.near_source_ids - inv_perm = torch.empty_like(row_of_pair) - inv_perm[row_of_pair] = torch.arange(plan.n_near) + ### Invariant 1: state_dict transfer is complete and bit-exact. + exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True) + bh_sd, ex_sd = bh_kernel.state_dict(), exact_kernel.state_dict() + mismatched = [k for k in bh_sd if not torch.equal(bh_sd[k], ex_sd[k])] + assert not mismatched, ( + f"state_dict value mismatch after load " + f"(torch={torch.__version__}): {mismatched}" + ) - for field_name in output_field_ranks: - ex_pp = captures["exact"][field_name] - bh_pp = captures["bh"][field_name] - ex_flat = ex_pp.reshape(n_tgt * n_src, *ex_pp.shape[2:]) - bh_reordered = bh_pp[inv_perm] + bh_kernel.eval() + exact_kernel.eval() - torch.testing.assert_close( - bh_reordered, - ex_flat, - atol=1e-6, - rtol=1e-6, - msg=lambda default, f=field_name: ( - f"BH/Exact per-pair pre-aggregation {f!r} divergence " - f"(torch={torch.__version__}). BH and Exact paths are " - f"not computing identical per-pair tensors despite " - f"identical inputs and weights.\n{default}" - ), - ) + torch.manual_seed(DEFAULT_SEED + 1) + source_data = TensorDict( + { + "physical": TensorDict( + {"pressure": torch.randn(n_src)}, + batch_size=[n_src], + ), + "latent": TensorDict( + { + "scalars": TensorDict( + {"0": torch.randn(n_src), "1": torch.randn(n_src)}, + batch_size=[n_src], + ), + "vectors": TensorDict( + {"0": F.normalize(torch.randn(n_src, n_dims), dim=-1)}, + batch_size=[n_src], + ), + }, + batch_size=[n_src], + ), + "normals": F.normalize(torch.randn(n_src, n_dims), dim=-1), + }, + batch_size=[n_src], + ) - ### Final aggregation comparison. - # The two invariant checks above guarantee that if we reach this - # point, BH and Exact computed bit-identical per-pair outputs. - # The only remaining difference is aggregation order: Exact uses - # einsum("ts,s->t", ...) while BH uses scatter_add_. For a - # 30-term fp32 sum with measured |terms| <= 0.044 and cancellation - # ratio <= ~56x, the rearrangement bound is ~30 * eps * 0.044 * - # 56 ≈ 2.2e-5. The tolerance matches test_bh_convergence_to_exact: - # tight enough to catch real bugs, loose enough for cross-platform - # BLAS summation-order variance. - for field_name in output_field_ranks: - - def _msg( - default: str, - field: str = field_name, - dims: int = n_dims, - ) -> str: - return ( - f"Nested keys: {field!r} not close to exact at theta=0.01 " - f"(n_dims={dims}, num_threads={torch.get_num_threads()}, " - f"torch={torch.__version__}).\n{default}" - ) + data = { + "source_points": torch.randn(n_src, n_dims), + "target_points": torch.randn(n_tgt, n_dims) * 5, + "source_strengths": torch.rand(n_src) + 0.1, + "reference_length": torch.ones(()), + "source_data": source_data, + "global_data": TensorDict({}, batch_size=[]), + } - torch.testing.assert_close( - bh_result[field_name], - exact_result[field_name], - atol=1e-4, - rtol=1e-3, - msg=_msg, + ### Invariant 2: per-pair pre-aggregation outputs are bit-identical. + # At theta=0.01 all pairs are near-field, so BH and Exact both call + # _evaluate_interactions on the same (target, source) pairs with + # identical weights. Capture the pre-aggregation output from each, + # reindex BH's pair ordering into Exact's row-major (t, s) order, + # and compare tightly. If this fires, there is a genuine algorithmic + # divergence (e.g. tensordict iteration-order change across library + # versions) and the final-sum tolerance is masking a real bug. + captures: dict[str, dict[str, torch.Tensor]] = {} + orig_eval = Kernel._evaluate_interactions + + def _capturing_eval(tag: str): + def _patched(self, *, scalars, vectors, device): + out = orig_eval(self, scalars=scalars, vectors=vectors, device=device) + captures[tag] = {k: v.detach().clone() for k, v in out.items()} + return out + + return _patched + + try: + Kernel._evaluate_interactions = _capturing_eval("exact") + exact_result = exact_kernel(**data) + Kernel._evaluate_interactions = _capturing_eval("bh") + bh_result = bh_kernel(**data, theta=0.01) + finally: + Kernel._evaluate_interactions = orig_eval + + src_tree = ClusterTree.from_points( + data["source_points"], + leaf_size=DEFAULT_LEAF_SIZE, + ) + tgt_tree = ClusterTree.from_points( + data["target_points"], + leaf_size=DEFAULT_LEAF_SIZE, + ) + plan = src_tree.find_dual_interaction_pairs( + target_tree=tgt_tree, + theta=0.01, + ) + assert plan.n_near == n_src * n_tgt, ( + f"Expected all-near at theta=0.01, got n_near={plan.n_near} " + f"of dense={n_src * n_tgt}" + ) + + row_of_pair = plan.near_target_ids * n_src + plan.near_source_ids + inv_perm = torch.empty_like(row_of_pair) + inv_perm[row_of_pair] = torch.arange(plan.n_near) + + for field_name in output_field_ranks: + ex_pp = captures["exact"][field_name] + bh_pp = captures["bh"][field_name] + ex_flat = ex_pp.reshape(n_tgt * n_src, *ex_pp.shape[2:]) + bh_reordered = bh_pp[inv_perm] + + torch.testing.assert_close( + bh_reordered, + ex_flat, + atol=1e-6, + rtol=1e-6, + msg=lambda default, f=field_name: ( + f"BH/Exact per-pair pre-aggregation {f!r} divergence " + f"(torch={torch.__version__}). BH and Exact paths are " + f"not computing identical per-pair tensors despite " + f"identical inputs and weights.\n{default}" + ), + ) + + ### Final aggregation comparison. + # The invariant checks above guarantee that BH and Exact computed + # bit-identical per-pair outputs. The only remaining difference is + # aggregation order: einsum vs scatter_add_. Tolerance matches + # test_bh_convergence_to_exact. + for field_name in output_field_ranks: + + def _msg( + default: str, + field: str = field_name, + dims: int = n_dims, + ) -> str: + return ( + f"Nested keys: {field!r} not close to exact at theta=0.01 " + f"(n_dims={dims}, num_threads={torch.get_num_threads()}, " + f"torch={torch.__version__}).\n{default}" ) + torch.testing.assert_close( + bh_result[field_name], + exact_result[field_name], + atol=1e-4, + rtol=1e-3, + msg=_msg, + ) + # --------------------------------------------------------------------------- # Four-quadrant interaction mode tests From 7d37951452ecde2cc6bf577acd1955838b8cd538 Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Thu, 23 Apr 2026 13:13:30 -0400 Subject: [PATCH 6/7] Revises fix --- CHANGELOG.md | 3 ++ .../experimental/models/globe/field_kernel.py | 37 ++++++++----------- .../globe/utilities/tensordict_utils.py | 23 ++++++++---- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7ac7f2659..c1c799ee75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in Pangu, FengWu attention window shift for asymmetric longitudes - Fixed a bug in `mesh.sampling.find_nearest_cells`, where a mixup between L2 and L-inf norms could cause slightly incorrect nearest-neighbor assignments in highly skewed meshes. +- Fixed TensorDict key-ordering bug in GLOBE's Barnes-Hut kernel that caused + incorrect results when `tensordict >= 0.12` reordered leaves during + TensorDict construction from dict literals mixing plain and nested keys. ### Security diff --git a/physicsnemo/experimental/models/globe/field_kernel.py b/physicsnemo/experimental/models/globe/field_kernel.py index 2fc3acfd45..e6a11f1916 100644 --- a/physicsnemo/experimental/models/globe/field_kernel.py +++ b/physicsnemo/experimental/models/globe/field_kernel.py @@ -680,7 +680,10 @@ def _evaluate_interactions( basis_vector_components.append(vectors_hat["r"]) - for k in vectors.keys(include_nested=True, leaves_only=True): + for k in sorted( + vectors.keys(include_nested=True, leaves_only=True), + key=str, + ): if k == "r": continue @@ -1345,25 +1348,17 @@ def _gather_and_evaluate( # concatenate_leaves: 1 GPU kernel (torch.cat) # [src_ids]: 1 GPU kernel (aten::index) # Total: 2 kernels instead of K (one per TensorDict leaf). - # The split-back into named leaves mirrors the vector path below - # and ensures that _evaluate_interactions sees the same nested - # TensorDict structure as the Exact (Kernel.forward) path. Without - # this, "source_scalars" would be a flat tensor here but a nested - # TensorDict in Exact, causing concatenate_leaves inside - # _evaluate_interactions to produce different column orderings when - # TensorDict leaf-iteration order changes across library versions. - src_scalar_keys = list( - source_scalars.keys(include_nested=True, leaves_only=True) + # The split-back uses sorted keys matching concatenate_leaves's + # canonical column ordering so position i maps to the correct leaf. + src_scalar_keys = sorted( + source_scalars.keys(include_nested=True, leaves_only=True), + key=str, ) gathered_src_scalars = concatenate_leaves(source_scalars)[src_ids] - gathered_scalar_leaves = { - k: gathered_src_scalars[..., i] - for i, k in enumerate(src_scalar_keys) - } scalars = TensorDict( { "source_scalars": TensorDict( - gathered_scalar_leaves, + {k: gathered_src_scalars[..., i] for i, k in enumerate(src_scalar_keys)}, batch_size=torch.Size([n_pairs]), device=device, ), @@ -1380,18 +1375,16 @@ def _gather_and_evaluate( # each vector leaf separately for magnitude/direction extraction and # rotationally-equivariant basis construction. Integer indexing # along the last dimension creates non-contiguous views (zero copies). - src_vector_keys = list( - source_vectors.keys(include_nested=True, leaves_only=True) + # Sorted keys match concatenate_leaves's canonical column ordering. + src_vector_keys = sorted( + source_vectors.keys(include_nested=True, leaves_only=True), + key=str, ) gathered_src_vectors = concatenate_leaves(source_vectors)[src_ids] - gathered_vector_leaves = { - k: gathered_src_vectors[..., i] - for i, k in enumerate(src_vector_keys) - } vectors = TensorDict( { "source_vectors": TensorDict( - gathered_vector_leaves, + {k: gathered_src_vectors[..., i] for i, k in enumerate(src_vector_keys)}, batch_size=torch.Size([n_pairs, self.n_spatial_dims]), device=device, ), diff --git a/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py b/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py index e3ddb7c15e..8e939ed975 100644 --- a/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py +++ b/physicsnemo/experimental/models/globe/utilities/tensordict_utils.py @@ -73,6 +73,15 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor: :math:`(*, F_{\text{total}})` where :math:`F_{\text{total}}` is the sum of flattened features across all leaf tensors. + Leaves are sorted by ``str(key)`` before concatenation, producing a + canonical column ordering that is independent of TensorDict construction + order. This is necessary because ``TensorDict`` iteration order can + differ depending on how the object was constructed (dict literal vs + sequential ``__setitem__`` vs element-wise ops) and can change across + ``tensordict`` library versions. Sorting eliminates this as a source + of bugs in any code that relies on positional column layout (e.g. the + MLP input assembly in :meth:`Kernel._evaluate_interactions`). + Parameters ---------- td : TensorDict[str, torch.Tensor] @@ -96,14 +105,14 @@ def concatenate_leaves(td: TensorDict[str, torch.Tensor]) -> torch.Tensor: >>> result.shape torch.Size([2, 17]) """ - tensors = tuple(td.values(include_nested=True, leaves_only=True)) - if len(tensors) == 0: + items = list(td.items(include_nested=True, leaves_only=True)) + if len(items) == 0: return torch.empty(td.batch_size + torch.Size([0]), device=td.device) - else: - return torch.cat( - [t.reshape(td.batch_size + torch.Size([-1])) for t in tensors], - dim=-1, - ) + items.sort(key=lambda kv: str(kv[0])) + return torch.cat( + [t.reshape(td.batch_size + torch.Size([-1])) for _, t in items], + dim=-1, + ) class TensorsByRank(dict): From ce7b972ec950a8f7fd974de4048aa1041349e0ea Mon Sep 17 00:00:00 2001 From: Peter Sharpe Date: Thu, 23 Apr 2026 13:21:45 -0400 Subject: [PATCH 7/7] Simplifies diff --- test/models/globe/test_barnes_hut_kernel.py | 103 ++------------------ 1 file changed, 6 insertions(+), 97 deletions(-) diff --git a/test/models/globe/test_barnes_hut_kernel.py b/test/models/globe/test_barnes_hut_kernel.py index 956ac563b2..8ccfa0c61a 100644 --- a/test/models/globe/test_barnes_hut_kernel.py +++ b/test/models/globe/test_barnes_hut_kernel.py @@ -893,10 +893,6 @@ def test_bh_nested_source_data_keys(n_dims: int): The aggregation, split_by_leaf_rank, and TensorDict.cat operations must handle this nesting correctly. - - The ``msg`` passed to :func:`torch.testing.assert_close` is a callable - so the default "Greatest absolute/relative difference" diagnostics are - preserved when a failure occurs in CI. """ torch.manual_seed(DEFAULT_SEED) n_src, n_tgt = 30, 15 @@ -917,16 +913,7 @@ def test_bh_nested_source_data_keys(n_dims: int): bh_kernel = BarnesHutKernel(**common_kwargs, leaf_size=DEFAULT_LEAF_SIZE) exact_kernel = Kernel(**common_kwargs) - - ### Invariant 1: state_dict transfer is complete and bit-exact. exact_kernel.load_state_dict(bh_kernel.state_dict(), strict=True) - bh_sd, ex_sd = bh_kernel.state_dict(), exact_kernel.state_dict() - mismatched = [k for k in bh_sd if not torch.equal(bh_sd[k], ex_sd[k])] - assert not mismatched, ( - f"state_dict value mismatch after load " - f"(torch={torch.__version__}): {mismatched}" - ) - bh_kernel.eval() exact_kernel.eval() @@ -964,97 +951,19 @@ def test_bh_nested_source_data_keys(n_dims: int): "global_data": TensorDict({}, batch_size=[]), } - ### Invariant 2: per-pair pre-aggregation outputs are bit-identical. - # At theta=0.01 all pairs are near-field, so BH and Exact both call - # _evaluate_interactions on the same (target, source) pairs with - # identical weights. Capture the pre-aggregation output from each, - # reindex BH's pair ordering into Exact's row-major (t, s) order, - # and compare tightly. If this fires, there is a genuine algorithmic - # divergence (e.g. tensordict iteration-order change across library - # versions) and the final-sum tolerance is masking a real bug. - captures: dict[str, dict[str, torch.Tensor]] = {} - orig_eval = Kernel._evaluate_interactions - - def _capturing_eval(tag: str): - def _patched(self, *, scalars, vectors, device): - out = orig_eval(self, scalars=scalars, vectors=vectors, device=device) - captures[tag] = {k: v.detach().clone() for k, v in out.items()} - return out - - return _patched - - try: - Kernel._evaluate_interactions = _capturing_eval("exact") - exact_result = exact_kernel(**data) - Kernel._evaluate_interactions = _capturing_eval("bh") - bh_result = bh_kernel(**data, theta=0.01) - finally: - Kernel._evaluate_interactions = orig_eval - - src_tree = ClusterTree.from_points( - data["source_points"], - leaf_size=DEFAULT_LEAF_SIZE, - ) - tgt_tree = ClusterTree.from_points( - data["target_points"], - leaf_size=DEFAULT_LEAF_SIZE, - ) - plan = src_tree.find_dual_interaction_pairs( - target_tree=tgt_tree, - theta=0.01, - ) - assert plan.n_near == n_src * n_tgt, ( - f"Expected all-near at theta=0.01, got n_near={plan.n_near} " - f"of dense={n_src * n_tgt}" - ) - - row_of_pair = plan.near_target_ids * n_src + plan.near_source_ids - inv_perm = torch.empty_like(row_of_pair) - inv_perm[row_of_pair] = torch.arange(plan.n_near) - - for field_name in output_field_ranks: - ex_pp = captures["exact"][field_name] - bh_pp = captures["bh"][field_name] - ex_flat = ex_pp.reshape(n_tgt * n_src, *ex_pp.shape[2:]) - bh_reordered = bh_pp[inv_perm] - - torch.testing.assert_close( - bh_reordered, - ex_flat, - atol=1e-6, - rtol=1e-6, - msg=lambda default, f=field_name: ( - f"BH/Exact per-pair pre-aggregation {f!r} divergence " - f"(torch={torch.__version__}). BH and Exact paths are " - f"not computing identical per-pair tensors despite " - f"identical inputs and weights.\n{default}" - ), - ) + exact_result = exact_kernel(**data) + bh_result = bh_kernel(**data, theta=0.01) - ### Final aggregation comparison. - # The invariant checks above guarantee that BH and Exact computed - # bit-identical per-pair outputs. The only remaining difference is - # aggregation order: einsum vs scatter_add_. Tolerance matches - # test_bh_convergence_to_exact. for field_name in output_field_ranks: - - def _msg( - default: str, - field: str = field_name, - dims: int = n_dims, - ) -> str: - return ( - f"Nested keys: {field!r} not close to exact at theta=0.01 " - f"(n_dims={dims}, num_threads={torch.get_num_threads()}, " - f"torch={torch.__version__}).\n{default}" - ) - torch.testing.assert_close( bh_result[field_name], exact_result[field_name], atol=1e-4, rtol=1e-3, - msg=_msg, + msg=lambda default, f=field_name: ( + f"Nested keys: {f!r} not close to exact at theta=0.01 " + f"(n_dims={n_dims}, torch={torch.__version__}).\n{default}" + ), )