From dc218b408e74c3477b7bd40a3342e9c307ac7416 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Wed, 1 Apr 2026 11:41:14 -0500 Subject: [PATCH] Rename abstract_with_sharding to as_abstract --- flax/nnx/__init__.py | 1 + flax/nnx/spmd.py | 5 ++++- tests/nnx/spmd_test.py | 6 +++--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 84a994dd1..2c4dff148 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -153,6 +153,7 @@ from .spmd import get_named_sharding as get_named_sharding from .spmd import with_partitioning as with_partitioning from .spmd import get_abstract_model as get_abstract_model +from .spmd import as_abstract as as_abstract from .spmd import abstract_with_sharding as abstract_with_sharding from .statelib import FlatState as FlatState from .statelib import State as State diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index c7914a4f6..ce91a2624 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -22,6 +22,7 @@ ) import jax from jax.sharding import PartitionSpec +from flax.nnx.deprecations import deprecated A = tp.TypeVar('A') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) @@ -182,7 +183,7 @@ def get_abstract_model(init_fn, mesh, *, graph: bool | None = None): return gdef, abs_state -def abstract_with_sharding( +def as_abstract( tree: A, graph: bool | None = None ) -> A: """Add sharding information to abstract Variables. @@ -238,3 +239,5 @@ def add_sharding(_path, x): return abs_var return x return graphlib.map(add_sharding, tree, graph=graph) + +abstract_with_sharding = deprecated(as_abstract) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 88d091a5c..8902ea6d6 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -518,7 +518,7 @@ def test_get_abstract_with_abstract_mesh(self): kernel_metadata={'out_sharding': ('a', 'b')}, ) ) - abs_model = nnx.abstract_with_sharding(abs_model) + abs_model = nnx.as_abstract(abs_model) self.assertIsInstance(abs_model.kernel, nnx.Param) self.assertEqual(abs_model.kernel.sharding.spec, P('a', 'b')) @@ -555,7 +555,7 @@ def __init__(self): ) abs_model = nnx.eval_shape(lambda: Model()) - abs_model = nnx.abstract_with_sharding(abs_model) + abs_model = nnx.as_abstract(abs_model) self.assertEqual(abs_model.p1.kernel.sharding.spec, P('a', 'b')) self.assertEqual(abs_model.p1.kernel.sharding.mesh, mesh1) @@ -564,7 +564,7 @@ def __init__(self): def test_get_abstract_no_sharding_metadata(self): abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0))) - abs_model = nnx.abstract_with_sharding(abs_model) + abs_model = nnx.as_abstract(abs_model) self.assertIsInstance(abs_model.kernel, nnx.Param) self.assertIsNone(