Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
from .spmd import get_named_sharding as get_named_sharding
from .spmd import with_partitioning as with_partitioning
from .spmd import get_abstract_model as get_abstract_model
from .spmd import as_abstract as as_abstract
from .spmd import abstract_with_sharding as abstract_with_sharding
from .statelib import FlatState as FlatState
from .statelib import State as State
Expand Down
5 changes: 4 additions & 1 deletion flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
import jax
from jax.sharding import PartitionSpec
from flax.nnx.deprecations import deprecated

A = tp.TypeVar('A')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
Expand Down Expand Up @@ -182,7 +183,7 @@ def get_abstract_model(init_fn, mesh, *, graph: bool | None = None):
return gdef, abs_state


def abstract_with_sharding(
def as_abstract(
tree: A, graph: bool | None = None
) -> A:
"""Add sharding information to abstract Variables.
Expand Down Expand Up @@ -238,3 +239,5 @@ def add_sharding(_path, x):
return abs_var
return x
return graphlib.map(add_sharding, tree, graph=graph)

abstract_with_sharding = deprecated(as_abstract)
6 changes: 3 additions & 3 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_get_abstract_with_abstract_mesh(self):
kernel_metadata={'out_sharding': ('a', 'b')},
)
)
abs_model = nnx.abstract_with_sharding(abs_model)
abs_model = nnx.as_abstract(abs_model)

self.assertIsInstance(abs_model.kernel, nnx.Param)
self.assertEqual(abs_model.kernel.sharding.spec, P('a', 'b'))
Expand Down Expand Up @@ -555,7 +555,7 @@ def __init__(self):
)

abs_model = nnx.eval_shape(lambda: Model())
abs_model = nnx.abstract_with_sharding(abs_model)
abs_model = nnx.as_abstract(abs_model)

self.assertEqual(abs_model.p1.kernel.sharding.spec, P('a', 'b'))
self.assertEqual(abs_model.p1.kernel.sharding.mesh, mesh1)
Expand All @@ -564,7 +564,7 @@ def __init__(self):

def test_get_abstract_no_sharding_metadata(self):
abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))
abs_model = nnx.abstract_with_sharding(abs_model)
abs_model = nnx.as_abstract(abs_model)

self.assertIsInstance(abs_model.kernel, nnx.Param)
self.assertIsNone(
Expand Down
Loading