diff --git a/docs_nnx/api_reference/flax.nnx/rnglib.rst b/docs_nnx/api_reference/flax.nnx/rnglib.rst index 73ffa3840..f9be57e0e 100644 --- a/docs_nnx/api_reference/flax.nnx/rnglib.rst +++ b/docs_nnx/api_reference/flax.nnx/rnglib.rst @@ -11,3 +11,4 @@ rnglib .. autofunction:: split_rngs .. autofunction:: fork_rngs .. autofunction:: reseed +.. autofunction:: with_rngs diff --git a/flax/configurations.py b/flax/configurations.py index 183b6c5ba..413695568 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -300,11 +300,14 @@ def static_int_env(varname: str, default: int | None) -> int | None: ) nnx_graph_mode = bool_flag( name='nnx_graph_mode', - default=True, + default=False, help='Whether NNX APIs default to graph-mode (True) or tree-mode (False).', ) nnx_graph_updates = bool_flag( - name='nnx_graph_updates', - default=True, - help='Whether graph-mode uses dynamic (True) or simple (False) graph traversal.', -) \ No newline at end of file + name='nnx_graph_updates', + default=False, + help=( + 'Whether graph-mode uses dynamic (True) or simple (False) graph' + ' traversal.' + ), +) diff --git a/flax/nnx/compat.py b/flax/nnx/compat.py index 78a8a3432..9f513ef7b 100644 --- a/flax/nnx/compat.py +++ b/flax/nnx/compat.py @@ -37,6 +37,7 @@ flatten = functools.partial(_graphlib.flatten, graph=True) iter_graph = functools.partial(_graphlib.iter_graph, graph=True) recursive_map = functools.partial(_graphlib.recursive_map, graph=True) +cached_partial = functools.partial(_graphlib.cached_partial, graph=True, graph_updates=True) # module view = functools.partial(_module.view, graph=True) @@ -45,8 +46,8 @@ iter_children = functools.partial(_module.iter_children, graph=True) # type: ignore[has-type] # rnglib -split_rngs = functools.partial(_rnglib.split_rngs, graph=True) -fork_rngs = functools.partial(_rnglib.fork_rngs, graph=True) +split_rngs = functools.partial(_rnglib.split_rngs, graph=True, graph_updates=True) +fork_rngs = functools.partial(_rnglib.fork_rngs, graph=True, graph_updates=True) reseed = functools.partial(_rnglib.reseed, graph=True) backup_keys = functools.partial(_rnglib.backup_keys, graph=True) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 3bf687aae..9d4cda33e 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -123,7 +123,7 @@ def check_consistent_aliasing2( value_id = id(value) node_id_to_variable[value_id] = value # If prefix is a TreeState (e.g. from nnx.prefix(graph=True)), - # extract the actual prefix value for this variable using local_path. + # extract the actual prefix value for this Variable using local_path. if isinstance(prefix, TreeState): prefix_fn = prefix.prefix_fn.value if not callable(prefix_fn): @@ -200,7 +200,7 @@ def broadcast_prefix2( ) -> tuple[list[KeyPath], list[tp.Any]]: paths: list[KeyPath] = [] leaves: list[tp.Any] = [] - num_leaves = lambda t: jax.tree_util.tree_structure(t).num_leaves + num_leaves = lambda t: jax.tree_util.tree_structure(t, is_leaf=is_leaf).num_leaves def add_leaves(path, x, subtree): n = num_leaves(subtree) paths.extend([path] * n) @@ -215,10 +215,14 @@ def broadcast_prefix_map( *rest: tp.Any, is_leaf: tp.Callable[[tp.Any], bool] | None = None, ) -> tp.Any: - paths, prefix_leaves = broadcast_prefix2(prefix_tree, full_tree, is_leaf=is_leaf) - leaves, treedef = jax.tree_util.tree_flatten(full_tree, is_leaf=is_leaf) - full_prefix_tree = treedef.unflatten(prefix_leaves) - return jax.tree.map_with_path(f, full_prefix_tree, full_tree, *rest, is_leaf=is_leaf) + _, prefix_leaves = broadcast_prefix2(prefix_tree, full_tree, is_leaf=is_leaf) + full_leaves_with_path, treedef = jax.tree.flatten_with_path(full_tree, is_leaf=is_leaf) + rest_flat = [treedef.flatten_up_to(r) for r in rest] + out_leaves = [] + for (path, full_leaf), p_leaf, *r_leaves in zip(full_leaves_with_path, prefix_leaves, *rest_flat): + out_leaf = f(path, p_leaf, full_leaf, *r_leaves) + out_leaves.append(out_leaf) + return jax.tree.unflatten(treedef, out_leaves) class GraphDefState(struct.PyTreeNode): @@ -557,6 +561,69 @@ def replace_at(t: tuple, index: int, value: tp.Any) -> tuple: for i, x in enumerate(t) ) + +def slice_at(t: tuple, index: int | None) -> tuple[tp.Any, tuple]: + if index is None: + return None, t + return t[index], t[:index] + t[index + 1 :] + + +def insert_at(t: tuple, index: int | None, value: tp.Any) -> tuple: + if index is None: + return t + xs = list(t) + xs.insert(index, value) + return tuple(xs) + + +def find(t: tuple, value: tp.Any) -> int | None: + return next((i for i, x in enumerate(t) if x == value), None) + + +@jax.tree_util.register_static +@dataclasses.dataclass(frozen=True, slots=True) +class ExtractIndex: + index: int + + +def extract( + f: tp.Callable[[jax.tree_util.KeyPath, tp.Any, tp.Any], bool], + prefix: tp.Any, + tree: tp.Any, + *, + is_leaf: tp.Callable[[tp.Any], bool] | None = None, +) -> tuple[tp.Any, list[tp.Any]]: + extracted: list[tp.Any] = [] + def _leaf_fn(path: jax.tree_util.KeyPath, prefix_leaf: tp.Any, leaf: tp.Any): + if f(path, prefix_leaf, leaf): + idx = len(extracted) + extracted.append(leaf) + return ExtractIndex(idx) + return leaf + + full_prefix = jax.tree.broadcast(prefix, tree, is_leaf=is_leaf) + new_tree = jax.tree.map_with_path(_leaf_fn, full_prefix, tree, is_leaf=is_leaf) + return new_tree, extracted + + +def insert( + tree: tp.Any, + extracted: list[tp.Any], + is_leaf: tp.Callable[[tp.Any], bool] | None = None, +) -> tp.Any: + if is_leaf is None: + _is_leaf = lambda x: isinstance(x, ExtractIndex) + else: + _is_leaf = lambda x: isinstance(x, ExtractIndex) or is_leaf(x) + + def _leaf_fn(leaf: tp.Any): + if isinstance(leaf, ExtractIndex): + return extracted[leaf.index] + return leaf + + return jax.tree.map(_leaf_fn, tree, is_leaf=_is_leaf) + + def updates_and_snapshot(args: A) -> tuple[A, A]: is_leaf = lambda x: isinstance(x, variablelib.Variable) leaves, treedef = jax.tree.flatten(args, is_leaf=is_leaf) @@ -613,7 +680,8 @@ def check_no_aliases( f' - {seen_path_str}\n' f' - {path_str}\n\n' f'nnx.{fn_name} with graph_updates=False does not support ' - 'returning input Variables as outputs. ' + 'Variable aliasing (duplicate inputs, duplicate outputs, or ' + 'input Variables returned as outputs). ' f'Consider the following options:\n\n' f'1. Remove the duplicate Variables.\n' f'2. Create new Variables via nnx.clone() and use those instead.\n' @@ -816,9 +884,9 @@ def forward(m1, m2, x): filters = list(filter_map.keys()) def prefix_fn(path, leaf): - for predicate, value in predicates: + for predicate, _prefix in predicates: if predicate(path, leaf): - return value + return _prefix raise ValueError( f'No filter matched leaf at path {path!r} with value {leaf!r}. ' f'Filters: {filters}' diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 15e0ecc8a..450c8c2ce 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -100,6 +100,8 @@ def _check_valid_pytree( @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) + + class NoUpdate: ... @@ -108,6 +110,8 @@ class NoUpdate: ... @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, slots=True) + + class Repeated: ... @@ -1576,7 +1580,12 @@ def static_cache(static_cache: tp.MutableMapping[tp.Any, StaticCache]): ) -def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bool | None = None): +def _cached_partial( + f: tp.Callable[..., tp.Any], + *cached_args, + graph: bool | None = None, + graph_updates: bool | None = None, +): """Create a partial from a NNX transformed function alog with some cached input arguments and reduces the python overhead by caching the traversal of NNX graph nodes. This is useful for speed up function that are called repeatedly with the same subset of inputs e.g. a @@ -1625,10 +1634,12 @@ def _cached_partial(f: tp.Callable[..., tp.Any], *cached_args, graph: bool | Non """ if graph is None: graph = set_graph_mode.current_value() - if not graph: + if graph_updates is None: + graph_updates = set_graph_updates.current_value() + + if not graph or not graph_updates: raise ValueError( - 'cached_partial is a graph-mode-only API and does not support ' - 'tree-mode (graph=False).' + 'cached_partial is a graph-mode-only API and requires graph_updates=True.' ) cache: tp.MutableMapping[tp.Any, StaticCache] = PythonRefMap() # type: ignore original_ref_index: RefMap = RefMap() diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 137b1ee5f..9df3f065e 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -883,7 +883,7 @@ def __call__( # we use split_rngs with splits=1 and squeeze=True to get unique rngs # every time RNN is called @nnx.split_rngs(splits=1, only=self.broadcast_rngs, squeeze=True) - @nnx.scan( + @nnx.compat.scan( in_axes=(state_axes, iteration.Carry, time_axis), out_axes=(iteration.Carry, (0, time_axis)) if slice_carry diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 1e4576a1f..070b2fa77 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -804,8 +804,76 @@ def __enter__(self): def __exit__(self, *args): restore_rngs(self) -def with_rngs(tree, split=None, fork=None, only=True, graph=False): + +@tp.overload +def with_rngs( + node: A, + /, + *, + split: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...]] + | int + | tuple[int, ...] + | None + ) = None, + fork: filterlib.Filter | tp.Sequence[filterlib.Filter] | None = None, + broadcast: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...]] + | int + | tuple[int, ...] + | None + ) = None, + only: filterlib.Filter = True, + graph: bool | None = None, + graph_updates: bool | None = None, +) -> A: ... + + +@tp.overload +def with_rngs( + *, + split: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...]] + | int + | tuple[int, ...] + | None + ) = None, + fork: filterlib.Filter | tp.Sequence[filterlib.Filter] | None = None, + broadcast: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...]] + | int + | tuple[int, ...] + | None + ) = None, + only: filterlib.Filter = True, + graph: bool | None = None, + graph_updates: bool | None = None, +) -> tp.Callable[[F], F]: ... + + +def with_rngs( + node: tp.Any = MISSING, + /, + *, + split: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...]] + | int + | tuple[int, ...] + | None + ) = None, + fork: filterlib.Filter | tp.Sequence[filterlib.Filter] | None = None, + broadcast: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...]] + | int + | tuple[int, ...] + | None + ) = None, + only: filterlib.Filter = True, + graph: bool | None = None, + graph_updates: bool | None = None, +) -> tp.Any: """Returns a copy of ``tree`` with ``RngStream`` objects replaced according to + ``split`` and ``fork`` rules. ``split`` controls which streams are **split** — after splitting, each call @@ -815,22 +883,27 @@ def with_rngs(tree, split=None, fork=None, only=True, graph=False): the parent counter. Streams that match neither rule are returned unchanged. Args: - tree: A pytree that may contain ``RngStream`` objects (e.g. an ``Rngs`` + node: A pytree that may contain ``RngStream`` objects (e.g. an ``Rngs`` instance, a module, or any nested structure). - split: Specifies which streams to split and into what shape. Can be: - - * An ``int`` or ``tuple[int, ...]`` — split *all* streams into this - shape, equivalent to ``{...: split}``. - * A :class:`~flax.nnx.filterlib.Filter`-keyed mapping where each value - is an ``int`` or ``tuple[int, ...]``. The first matching filter wins. - - fork: A :class:`~flax.nnx.filterlib.Filter` selecting which streams not - already handled by ``split`` should be forked. Pass ``...`` to fork all - remaining streams. - graph: If ``True``, uses graph-mode which supports the full - NNX feature set including shared references. If ``False``, uses - tree-mode which treats Modules as regular JAX pytrees, avoiding - the overhead of the graph protocol. + split: Specifies which streams to split and into what shape. Can be: * An + ``int`` or ``tuple[int, ...]`` — split *all* streams into this shape, + equivalent to ``{...: split}``. * A + :class:`~flax.nnx.filterlib.Filter`-keyed mapping where each value is an + ``int`` or ``tuple[int, ...]``. The first matching filter wins. + fork: A :class:`~flax.nnx.filterlib.Filter`, a sequence of filters, or + ``None`` selecting which streams not already handled by ``split`` should + be forked. Pass ``...`` to fork all remaining streams. + broadcast: Specifies which streams to broadcast and into what shape. Can + be: * An ``int`` or ``tuple[int, ...]`` — broadcast *all* streams into + this shape, equivalent to ``{...: broadcast}``. * A + :class:`~flax.nnx.filterlib.Filter`-keyed mapping where each value is an + ``int`` or ``tuple[int, ...]``. The first matching filter wins. + only: A :class:`~flax.nnx.filterlib.Filter` selecting which streams to + process. Pass ``True`` (default) to process all streams. + graph: If ``True``, uses graph-mode which supports the full NNX feature set + including shared references. If ``False``, uses tree-mode which treats + Modules as regular JAX pytrees, avoiding the overhead of the graph + protocol. Returns: A new tree of the same structure as ``tree`` with ``RngStream`` objects @@ -841,7 +914,7 @@ def with_rngs(tree, split=None, fork=None, only=True, graph=False): >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) - >>> new_rngs = nnx.with_rngs(rngs, split=4) + >>> new_rngs = nnx.with_rngs(rngs, split=4, graph=False) >>> new_rngs.params.key.shape (4,) >>> new_rngs.dropout.key.shape @@ -850,7 +923,9 @@ def with_rngs(tree, split=None, fork=None, only=True, graph=False): Example — split some streams, fork the rest:: >>> rngs = nnx.Rngs(params=0, dropout=1) - >>> new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + >>> new_rngs = nnx.with_rngs( + ... rngs, split={'params': 4}, fork=nnx.Not('params'), graph=False + ... ) >>> new_rngs.params.key.shape (4,) >>> new_rngs.dropout.key.shape # forked: scalar key, advanced counter @@ -862,97 +937,173 @@ def with_rngs(tree, split=None, fork=None, only=True, graph=False): >>> new_rngs = nnx.with_rngs(rngs, split={ ... 'params': 4, # split params into 4 keys ... ...: (2, 4), # split anything else into 2×4 keys - ... }) + ... }, graph=False) >>> new_rngs.params.key.shape (4,) >>> new_rngs.noise.key.shape (2, 4) - """ + if graph is None: + graph = graphlib.set_graph_mode.current_value() + if graph_updates is None: + graph_updates = graphlib.set_graph_updates.current_value() + + if graph and graph_updates: + raise NotImplementedError( + 'graph=True and graph_updates=True is not supported for `with_rngs`' + ) + + if isinstance(node, Missing): + + def with_rngs_decorator(f: F) -> F: + @functools.wraps(f) + def with_rngs_wrapper(*args, **kwargs): + args, kwargs = with_rngs( + (args, kwargs), + split=split, + fork=fork, + broadcast=broadcast, + only=only, + graph=graph, + graph_updates=False, + ) + return f(*args, **kwargs) + + return tp.cast(F, with_rngs_wrapper) + + return with_rngs_decorator # type: ignore[bad-return-type] + if split is None: split = {} elif isinstance(split, (int, tuple)): split = {...: split} - if isinstance(fork, str) or not isinstance(fork, tp.Sequence): + + if broadcast is None: + broadcast = {} + elif isinstance(broadcast, (int, tuple)): + broadcast = {...: broadcast} + + if fork is None: + fork = [] + elif isinstance(fork, str) or not isinstance(fork, tp.Sequence): fork = [fork] + split_predicates = [(k, filterlib.to_predicate(k), v) for k, v in split.items()] + broadcast_predicates = [(k, filterlib.to_predicate(k), v) for k, v in broadcast.items()] fork_predicates = [(p, filterlib.to_predicate(p)) for p in fork] only_predicate = filterlib.to_predicate(only) - def f(path, val): + def update_rngs(path, val): if isinstance(val, RngStream) and only_predicate(path, val): results = {} for (filter, predicate, num_splits) in split_predicates: if predicate(path, val): results['split'] = (filter, num_splits) break + for (filter, predicate, num_broadcasts) in broadcast_predicates: + if predicate(path, val): + results['broadcast'] = (filter, num_broadcasts) + break for (filter, predicate) in fork_predicates: if predicate(path, val): results['fork'] = (filter,) break + if len(results) > 1: - fork_filter = results['fork'][0] - if fork_filter not in (..., True): + specific_matches = [r for r, info in results.items() if info[0] not in (..., True)] + if len(specific_matches) > 1: rule_descriptions = '\n'.join(f' - {rule} matches filter {info[0]!r}' for rule, info in results.items()) raise ValueError( f"RngStream at path {path} matches multiple rules:\n{rule_descriptions}" ) + if 'split' in results: return val.split(results['split'][1]) + if 'broadcast' in results: + return val.broadcast(results['broadcast'][1]) if 'fork' in results: return val.fork() return val - return graphlib.recursive_map(f, tree, graph=graph) + return graphlib.recursive_map(update_rngs, node, graph=graph) + @tp.overload def split_rngs( - node: tp.Any, - /, - *, - splits: int | tuple[int, ...], - only: filterlib.Filter = ..., - squeeze: bool = False, - graph: tp.Literal[True] | None = None, -) -> SplitBackups: ... + node: tp.Any, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, + graph: tp.Literal[True] | None = None, + graph_updates: tp.Literal[True] | None = None, +) -> SplitBackups: + ... + + @tp.overload def split_rngs( - node: A, - /, - *, - splits: int | tuple[int, ...], - only: filterlib.Filter = ..., - squeeze: bool = False, - graph: tp.Literal[False], -) -> A: ... + node: A, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, + graph: tp.Literal[False], + graph_updates: bool | None = None, +) -> A: + ... + + @tp.overload def split_rngs( - *, - splits: int | tuple[int, ...], - only: filterlib.Filter = ..., - squeeze: bool = False, - graph: bool | None = None, -) -> tp.Callable[[F], F]: ... + node: A, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, + graph: tp.Literal[True] | None, + graph_updates: tp.Literal[False], +) -> A: + ... + + +@tp.overload def split_rngs( - node: tp.Any = MISSING, - /, - *, - splits: int | tuple[int, ...], - only: filterlib.Filter = ..., - squeeze: bool = False, - graph: bool | None = None, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, + graph: bool | None = None, + graph_updates: bool | None = None, +) -> tp.Callable[[F], F]: + ... + +def split_rngs( + node: tp.Any = MISSING, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, + graph: bool | None = None, + graph_updates: bool | None = None, ) -> SplitBackups | tp.Any | tp.Callable[[F], F]: """Splits the (nested) Rng states of the given node. Args: node: the base node containing the rng states to split. - splits: an integer or tuple of integers specifying the - shape of the split rng keys. + splits: an integer or tuple of integers specifying the shape of the split + rng keys. only: a Filter selecting which rng states to split. - graph: If ``True`` (default), uses graph-mode which supports the full - NNX feature set including shared references. If ``False``, uses - tree-mode which treats Modules as regular JAX pytrees, avoiding - the overhead of the graph protocol. + graph: If ``True`` (default), uses graph-mode which supports the full NNX + feature set including shared references. If ``False``, uses tree-mode + which treats Modules as regular JAX pytrees, avoiding the overhead of the + graph protocol. + graph_updates: If ``True``, applies the splits in-place on the node. If + ``False``, returns a new node with split rng states. Returns: A SplitBackups iterable if ``node`` is provided, otherwise a @@ -1034,27 +1185,35 @@ def split_rngs( >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape () - - """ if graph is None: graph = graphlib.set_graph_mode.current_value() + if graph_updates is None: + graph_updates = graphlib.set_graph_updates.current_value() if isinstance(node, Missing): def split_rngs_decorator(f: F) -> F: @functools.wraps(f) def split_rngs_wrapper(*args, **kwargs): - if graph: + if graph and graph_updates: with split_rngs( - (args, kwargs), splits=splits, only=only, squeeze=squeeze, - graph=True, + (args, kwargs), + splits=splits, + only=only, + squeeze=squeeze, + graph=True, + graph_updates=True, ): return f(*args, **kwargs) else: args, kwargs = split_rngs( - (args, kwargs), splits=splits, only=only, squeeze=squeeze, - graph=False, + (args, kwargs), + splits=splits, + only=only, + squeeze=squeeze, + graph=graph, + graph_updates=False, ) return f(*args, **kwargs) @@ -1065,23 +1224,30 @@ def split_rngs_wrapper(*args, **kwargs): if squeeze and splits != 1: raise ValueError('squeeze=True is only supported for splits=1') - if graph: - return _graph_split_rngs( - node, splits=splits, only=only, squeeze=squeeze, + if graph and graph_updates: + return _graph_updates_split_rngs( + node, + splits=splits, + only=only, + squeeze=squeeze, ) else: - return _tree_split_rngs( - node, splits=splits, only=only, squeeze=squeeze, + return _simple_split_rngs( + node, + splits=splits, + only=only, + squeeze=squeeze, + graph=graph, ) -def _graph_split_rngs( - node: tp.Any, - /, - *, - splits: int | tuple[int, ...], - only: filterlib.Filter = ..., - squeeze: bool = False, +def _graph_updates_split_rngs( + node: tp.Any, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, ) -> SplitBackups: predicate = filterlib.to_predicate(only) backups: list[StreamBackup] = [] @@ -1109,13 +1275,14 @@ def _graph_split_rngs( return SplitBackups(backups) -def _tree_split_rngs( - node: tp.Any, - /, - *, - splits: int | tuple[int, ...], - only: filterlib.Filter = ..., - squeeze: bool = False, +def _simple_split_rngs( + node: tp.Any, + /, + *, + splits: int | tuple[int, ...], + only: filterlib.Filter = ..., + squeeze: bool = False, + graph: bool, ) -> tp.Any: predicate = filterlib.to_predicate(only) @@ -1141,34 +1308,88 @@ def _split_stream(path, node): ) return node - return graphlib.recursive_map(_split_stream, node, graph=False) + return graphlib.recursive_map(_split_stream, node, graph=graph) + + +def _graph_updates_fork_rngs( + node: tp.Any, + /, + *, + predicate_splits: tp.Mapping[tp.Callable, tp.Any], + graph: bool, +) -> SplitBackups: + backups: list[StreamBackup] = [] + for path, stream in graphlib.iter_graph(node, graph=graph): + for predicate, splits in predicate_splits.items(): + if ( + isinstance(stream, RngStream) + and predicate((*path, 'key'), stream.key) + and predicate((*path, 'count'), stream.count) + ): + forked_stream = stream.fork(split=splits) + # backup the original stream state + backups.append((stream, stream.key[...], stream.count[...])) + # apply the forked key and count to the original stream + stream.key.set_value(forked_stream.key.get_value()) + stream.count.set_value(forked_stream.count.get_value()) + + return SplitBackups(backups) + + +def _simple_fork_rngs( + node: tp.Any, + /, + *, + predicate_splits: tp.Mapping[tp.Callable, tp.Any], + graph: bool, +) -> tp.Any: + def _fork_stream(path, node): + if isinstance(node, RngStream): + for predicate, splits in predicate_splits.items(): + if predicate((*path, 'key'), node.key) and predicate( + (*path, 'count'), node.count + ): + return node.fork(split=splits) + return node + + return graphlib.recursive_map(_fork_stream, node, graph=graph) + @tp.overload def fork_rngs( - node: tp.Any, - /, - *, - split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] - | int - | None = None, - graph: bool | None = None, -) -> SplitBackups: ... + node: tp.Any, + /, + *, + split: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None + ) = None, + graph: bool | None = None, + graph_updates: bool | None = None, +) -> SplitBackups: + ... + + @tp.overload def fork_rngs( - *, - split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] - | int - | None = None, - graph: bool | None = None, -) -> tp.Callable[[F], F]: ... + *, + split: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None + ) = None, + graph: bool | None = None, + graph_updates: bool | None = None, +) -> tp.Callable[[F], F]: + ... + + def fork_rngs( - node: tp.Any = MISSING, - /, - *, - split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] - | int - | None = None, - graph: bool | None = None, + node: tp.Any = MISSING, + /, + *, + split: ( + tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None + ) = None, + graph: bool | None = None, + graph_updates: bool | None = None, ) -> SplitBackups | tp.Callable[[F], F]: """Forks the (nested) Rng states of the given node. @@ -1211,12 +1432,25 @@ def fork_rngs( ... model = nnx.Linear(2, 3, rngs=rngs) """ + if graph is None: + graph = graphlib.set_graph_mode.current_value() + if graph_updates is None: + graph_updates = graphlib.set_graph_updates.current_value() + if isinstance(node, Missing): def fork_rngs_decorator(f: F) -> F: @functools.wraps(f) def fork_rngs_wrapper(*args, **kwargs): - with fork_rngs((args, kwargs), split=split): + if graph and graph_updates: + with fork_rngs( + (args, kwargs), split=split, graph=True, graph_updates=True + ): + return f(*args, **kwargs) + else: + args, kwargs = fork_rngs( + (args, kwargs), split=split, graph=graph, graph_updates=False + ) return f(*args, **kwargs) return tp.cast(F, fork_rngs_wrapper) @@ -1231,22 +1465,15 @@ def fork_rngs_wrapper(*args, **kwargs): predicate_splits = { filterlib.to_predicate(k): v for k, v in split.items() } - backups: list[StreamBackup] = [] - for path, stream in graphlib.iter_graph(node, graph=graph): - for predicate, splits in predicate_splits.items(): - if ( - isinstance(stream, RngStream) - and predicate((*path, 'key'), stream.key) - and predicate((*path, 'count'), stream.count) - ): - forked_stream = stream.fork(split=splits) - # backup the original stream state - backups.append((stream, stream.key[...], stream.count[...])) - # apply the forked key and count to the original stream - stream.key.set_value(forked_stream.key.get_value()) - stream.count.set_value(forked_stream.count.get_value()) - return SplitBackups(backups) + if graph and graph_updates: + return _graph_updates_fork_rngs( + node, predicate_splits=predicate_splits, graph=graph + ) + else: + return _simple_fork_rngs( + node, predicate_splits=predicate_splits, graph=graph + ) def backup_keys(node: tp.Any, /, *, graph: bool | None = None): diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index da110c297..3897f62d7 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -23,7 +23,7 @@ from flax import nnx from flax.nnx import filterlib from flax.nnx.pytreelib import Pytree -from flax.nnx.variablelib import Variable +from flax.nnx.variablelib import Param, Variable M = tp.TypeVar('M', bound=nnx.Module) F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) @@ -229,7 +229,7 @@ class ModelAndOptimizer(Optimizer[M]): Use :class:`Optimizer` instead. """ - def __init__(self, model: M, tx: optax.GradientTransformation, *, wrt: filterlib.Filter = nnx.Param): + def __init__(self, model: M, tx: optax.GradientTransformation, *, wrt: filterlib.Filter = Param): super().__init__(model, tx, wrt=wrt) self.model = model diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 9ea5ff3cc..07c2867de 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -79,7 +79,13 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('grad', args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases( + 'grad', + args=updates[0], + kwargs=updates[1], + out=out, + check_can_update=['out'], + ) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: @@ -572,7 +578,9 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('vjp', args=updates, out=out) + extract.check_no_aliases( + 'vjp', args=updates, out=out, check_can_update=['out'] + ) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out @@ -738,7 +746,9 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('jvp', args=updates, out=out) + extract.check_no_aliases( + 'jvp', args=updates, out=out, check_can_update=['out'] + ) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out @@ -915,7 +925,9 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases('custom_vjp', args=updates, out=out) + extract.check_no_aliases( + 'custom_vjp', args=updates, out=out, check_can_update=['out'] + ) diff_prefix = tuple( i not in self.nondiff_argnums for i in range(len(args)) ) @@ -955,7 +967,9 @@ def __call__(self, *args): if self.graph: out = extract.to_tree2(out) residual = extract.to_tree2(residual) - extract.check_no_aliases('custom_vjp', args=updates, out=out) + extract.check_no_aliases( + 'custom_vjp', args=updates, out=out, check_can_update=['out'] + ) updates = extract.mask_variable_updates(updates, snapshot) return (out, updates), residual @@ -974,6 +988,7 @@ def __call__(self, *args): if self.graph: nondiff = extract.from_tree2(nondiff) residual = extract.from_tree2(residual) + out_g = extract.from_tree2(out_g) result = self.bwd(*nondiff, residual, out_g) if self.graph: result = extract.to_tree2(result) diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index d9798098a..5f9cac348 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -146,6 +146,8 @@ def __call__(self, *pure_args, **pure_kwargs): @tp.overload + + def jit( *, in_shardings: tp.Any = None, @@ -162,6 +164,8 @@ def jit( graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: ... @tp.overload + + def jit( fun: tp.Callable[P, R], *, @@ -434,8 +438,6 @@ def _flatten_to_partial_state( return PartialState(treedef=treedef, leaves=leaves) - - @dataclasses.dataclass(eq=False) class SimpleJitFn: f: tp.Callable[..., tp.Any] @@ -457,7 +459,13 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) - extract.check_no_aliases('jit', args=args_updates, kwargs=kwargs_updates, out=out) + extract.check_no_aliases( + 'jit', + args=args_updates, + kwargs=kwargs_updates, + out=out, + check_can_update=['out'], + ) def donated_arg(jax_path, prefix, c, s): path = graphlib.jax_to_nnx_path(jax_path) return path[0] in self.donate_argnums or extract.variable_changed(c, s) @@ -1281,7 +1289,9 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_specs) - extract.check_no_aliases('shard_map', args=updates, out=out) + extract.check_no_aliases( + 'shard_map', args=updates, out=out, check_can_update=['out'] + ) updates = extract.mask_variable_updates(updates, snapshot) return out, updates diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 67ec87740..dc8bb136c 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -140,7 +140,9 @@ def wrapper(*in_args, **in_kwargs): out = f(*args) if graph: out = extract.to_tree2(out, prefix=out_axes) - extract.check_no_aliases('transform_metadata', args=updates, out=out) + extract.check_no_aliases( + 'transform_metadata', args=updates, out=out, check_can_update=['out'] + ) _apply_axis_fn(args, in_axes, metadata, spmd.add_axis) _apply_axis_fn(out, out_axes, metadata, spmd.add_axis) updates = extract.mask_variable_updates(updates, snapshot) @@ -280,7 +282,13 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) - extract.check_no_aliases('vmap', args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases( + 'vmap', + args=updates[0], + kwargs=updates[1], + out=out, + check_can_update=['out'], + ) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -302,7 +310,13 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) - extract.check_no_aliases('pmap', args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases( + 'pmap', + args=updates[0], + kwargs=updates[1], + out=out, + check_can_update=['out'], + ) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -907,6 +921,48 @@ def _check_out_axes(out_axes): f'Got StateAxes({{{filter}: Carry}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) + + +def _validate_scan_axes(in_axes, out_axes): + is_axis_leaf = lambda x: x is None or x is Carry + + in_carry_count = 0 + for path, leaf in jax.tree_util.tree_leaves_with_path( + in_axes, is_leaf=is_axis_leaf + ): + if leaf is Carry: + in_carry_count += 1 + if len(path) > 1: + raise ValueError( + 'Carry must be a top-level argument, it cannot be nested. ' + f'Found Carry inside in_axes at path {jax.tree_util.keystr(path)}' + ) + if in_carry_count > 1: + raise ValueError('Found multiple Carry definitions in in_axes') + + out_carry_count = 0 + for path, leaf in jax.tree_util.tree_leaves_with_path( + out_axes, is_leaf=is_axis_leaf + ): + if leaf is Carry: + out_carry_count += 1 + if len(path) > 1: + raise ValueError( + 'Carry must be a top-level argument, it cannot be nested. ' + f'Found Carry inside out_axes at path {jax.tree_util.keystr(path)}' + ) + if out_carry_count > 1: + raise ValueError('Found multiple Carry definitions in out_axes') + + in_has_carry = in_carry_count > 0 + out_has_carry = out_carry_count > 0 + if in_has_carry != out_has_carry: + raise ValueError( + 'If one of in_axes or out_axes has Carry, the other must also ' + f'have Carry. Got {in_axes=}, {out_axes=}' + ) + + def _check_carry_same_references(carry_arg, carry_arg_out): def check_carry_same_references(key_path, arg, out): if ( @@ -1369,43 +1425,56 @@ def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args - def __call__(self, *args): - updates, snapshot = extract.updates_and_snapshot(args) - updates = extract.mask_at(updates, self.carry_arg_index) - snapshot = extract.mask_at(snapshot, self.carry_arg_index) + def __call__(self, full_carry: tp.Any, x_args: tp.Any): + carry, broadcasts = full_carry + updates, snapshot = extract.updates_and_snapshot(x_args) + x_args = extract.insert( + x_args, + broadcasts, + is_leaf=lambda x: isinstance(x, variablelib.Variable), + ) + + if self.graph: + x_args = extract.from_tree2(x_args) + carry = extract.from_tree2(carry) + + # Reconstruct full args if self.carry_arg_index is not None: - carry_in = args[self.carry_arg_index] + args = extract.insert_at(x_args, self.carry_arg_index, carry) else: - carry_in = None - if self.graph: - args = extract.from_tree2(args) + args = x_args out = self.f(*args) if self.graph: - out = extract.to_tree2(out, prefix=self.out_axes) + # check consistent aliasing, temporarily convert `out` to tree + # to check aliasing, but the real tree convertion is done later + check_out = extract.to_tree2(out, prefix=self.out_axes) + else: + check_out = out - if self.carry_out_index is not None: - carry_out = out[self.carry_out_index] if self.out_is_tuple else out - extract.check_same_variables(carry_in, carry_out, 'scan') + extract.check_no_aliases( + 'scan', args=updates, out=check_out, check_can_update=['out'] + ) + updates = extract.mask_variable_updates(updates, snapshot) + if self.carry_arg_index is not None: # has carry + if self.out_is_tuple: + carry_out, ys = extract.slice_at(out, self.carry_out_index) + else: + carry_out = out + ys = None + extract.check_same_variables(carry, carry_out, 'scan') + else: + ys = out + carry_out = None - def keep_fn(path, prefix, cur, snap): - changed = extract.variable_changed(cur, snap) - if prefix is None and changed: - raise ValueError( - f'Broadcast (None axis) Variable at {jax.tree_util.keystr(path)} ' - 'was mutated during scan. Only Carry and scanned Variables can be ' - 'updated.' - ) - return changed + if self.graph: + # convert the carry to tree separately to ensure a consistent + # graph structure for the carry in and carry out + carry_out = extract.to_tree2(carry_out) + ys = extract.to_tree2(ys) - extract.check_no_aliases('scan', args=updates, out=out) - updates = extract.mask_variable_updates( - updates, snapshot, prefix=self.in_axes, keep_fn=keep_fn, - ) - if self.out_is_tuple: - return (*out, updates) - return (out, updates) + return (carry_out, broadcasts), (ys, updates) @tp.overload @@ -1594,11 +1663,40 @@ def forward(x, model): transform_metadata=transform_metadata, ) +def _move_axis(move_fn, axes, tree): + def move_axis_leaf(path, ax, leaf): + if isinstance(leaf, extract.Mask): + return leaf + if isinstance(leaf, variablelib.Variable): + return extract.broadcast_prefix_map( + move_axis_leaf, ax, leaf + ) + if isinstance(leaf, extract.ExtractIndex): + assert ax is None + return leaf + if leaf is None: + return leaf + assert isinstance(ax, int) + if ax != 0: + return move_fn(leaf, ax) + return leaf + + return extract.broadcast_prefix_map( + move_axis_leaf, + axes, + tree, + is_leaf=lambda x: isinstance(x, extract.Mask) + or isinstance(x, variablelib.Variable) + or x is None + or isinstance(x, extract.ExtractIndex), + ) + def _simple_scan( f, f_unbound, *, graph, in_axes, out_axes, length, reverse, unroll, _split_transpose, ): + # TODO: do this inside check_prefix if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' @@ -1611,84 +1709,112 @@ def _simple_scan( 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('scan') ) + _check_out_axes(out_axes) if graph: extract.check_prefix(in_axes, 'in_axes', 'scan') extract.check_prefix(out_axes, 'out_axes', 'scan') + _validate_scan_axes(in_axes, out_axes) + out_is_tuple = isinstance(out_axes, tuple) + was_carry = in_axes is Carry if in_axes is Carry: in_axes = (Carry,) if isinstance(in_axes, tuple): - carry_arg_index = next( - (i for i, ax in enumerate(in_axes) if ax is Carry), None - ) - updates_out_axes = extract.mask_at(in_axes, carry_arg_index) + carry_arg_index = extract.find(in_axes, Carry) + _, sliced_in_axes = extract.slice_at(in_axes, carry_arg_index) else: carry_arg_index = None - updates_out_axes = in_axes + sliced_in_axes = in_axes if isinstance(out_axes, tuple): - carry_out_index = next( - (i for i, ax in enumerate(out_axes) if ax is Carry), None - ) + carry_out_index = extract.find(out_axes, Carry) + _, sliced_out_axes = extract.slice_at(out_axes, carry_out_index) else: carry_out_index = None + sliced_out_axes = out_axes simple_scan_fn = SimpleScanFn( - f_unbound, graph=graph, - in_axes=in_axes, out_axes=out_axes, - out_is_tuple=out_is_tuple, - carry_arg_index=carry_arg_index, - carry_out_index=carry_out_index, + f_unbound, graph=graph, + in_axes=in_axes, out_axes=out_axes, + out_is_tuple=out_is_tuple, + carry_arg_index=carry_arg_index, + carry_out_index=carry_out_index, ) - if out_is_tuple: - augmented_out_axes = (*out_axes, updates_out_axes) - else: - augmented_out_axes = (out_axes, updates_out_axes) - @functools.wraps(f) def simple_scan_wrapper(*args): args = resolve_kwargs(f, args, {}) + if was_carry and len(args) != 1: + raise ValueError( + 'When in_axes=Carry, the function must take exactly one argument, ' + f'got {len(args)} arguments.' + ) if graph: - args = extract.to_tree2(args, prefix=in_axes) + # check consistent aliasing, temporarily convert args to tree + # to check aliasing, but the real tree convertion is done later + check_args = extract.to_tree2(args, prefix=in_axes) + else: + check_args = args - extract.check_no_aliases('scan', args=args) + extract.check_no_aliases('scan', args=check_args) + carry, x_args = extract.slice_at(args, carry_arg_index) + + if graph: + # convert the carry to tree separately to ensure a consistent + # graph structure for the carry in and carry out + carry = extract.to_tree2(carry) + x_args = extract.to_tree2(x_args) + + def extract_broadcasts(path, prefix_leaf, leaf): + return leaf is not None and ( + prefix_leaf is None + or ( + isinstance(prefix_leaf, variablelib.Variable) + and prefix_leaf.get_value() is None + ) + ) + x_args, broadcasts = extract.extract( + extract_broadcasts, sliced_in_axes, x_args, + is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable), + ) - result = pure_jax_fancy_scan( - simple_scan_fn, - *args, - length=length, - reverse=reverse, - unroll=unroll, - _split_transpose=_split_transpose, - in_axes=in_axes, - out_axes=augmented_out_axes, + x_args_transposed = _move_axis( + lambda leaf, ax: jnp.moveaxis(leaf, ax, 0), + sliced_in_axes, x_args, ) - if out_is_tuple: - n = len(out_axes) - out = result[:n] - updates = result[n] - else: - out, updates = result + (carry_out, final_broadcasts), (ys, updates) = jax.lax.scan( + simple_scan_fn, + (carry, broadcasts), + x_args_transposed, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + ) + + ys, updates = _move_axis( + lambda leaf, ax: jnp.moveaxis(leaf, 0, ax), + (sliced_out_axes, sliced_in_axes), + (ys, updates), + ) - masked_args = extract.mask_at(args, carry_arg_index) - extract.apply_variable_updates(masked_args, updates) + extract.apply_variable_updates(x_args, updates) + extract.apply_variable_updates(broadcasts, final_broadcasts) + carry = extract.update_carry_variables(carry, carry_out) + + if graph: + ys = extract.from_tree2(ys) + carry = extract.from_tree2(carry) if carry_arg_index is not None: - carry_in = args[carry_arg_index] - carry_out = ( - out[carry_out_index] if out_is_tuple else out - ) - extract.update_carry_variables(carry_in, carry_out) if out_is_tuple: - out = extract.replace_at(out, carry_out_index, carry_in) + out = extract.insert_at(ys, carry_out_index, carry) else: - out = carry_in - - if graph: - out = extract.from_tree2(out) + out = carry + else: + out = ys return out @@ -1822,174 +1948,6 @@ def scan_wrapper(*args, **kwargs): return scan_wrapper -def pure_jax_fancy_scan( - f, - *args, - length: int | None = None, - reverse: bool = False, - unroll: int | bool = 1, - _split_transpose: bool = False, - in_axes: tp.Any = (Carry, 0), - out_axes: tp.Any = (Carry, 0), -): - if in_axes is Carry: - in_axes = (Carry,) - is_axis_leaf = lambda x: x is None or x is Carry - - if isinstance(in_axes, tuple): - for i, ax in enumerate(in_axes): - if ax is Carry or ax is None or isinstance(ax, int): - continue - for leaf in jax.tree.leaves(ax, is_leaf=is_axis_leaf): - if leaf is Carry: - raise ValueError( - 'Carry must be a top-level argument, it cannot be nested. ' - f'Found Carry inside in_axes[{i}]={ax}' - ) - - if isinstance(out_axes, tuple): - for i, ax in enumerate(out_axes): - if ax is Carry or ax is None or isinstance(ax, int): - continue - for path, leaf in jax.tree_util.tree_leaves_with_path( - ax, is_leaf=is_axis_leaf, - ): - if leaf is Carry: - raise ValueError( - 'Carry must be a top-level argument, it cannot be nested. ' - f'Found Carry at out_axes[{i}]{jax.tree_util.keystr(path)}' - ) - - in_has_carry = in_axes is Carry or ( - isinstance(in_axes, tuple) and Carry in in_axes - ) - out_has_carry = out_axes is Carry or ( - isinstance(out_axes, tuple) and Carry in out_axes - ) - if in_has_carry != out_has_carry: - raise ValueError( - 'If one of in_axes or out_axes has Carry, the other must also ' - f'have Carry. Got {in_axes=}, {out_axes=}' - ) - - - args_flat, args_treedef = jax.tree.flatten(args) - _, in_axes_flat = extract.broadcast_prefix2( - in_axes, args, is_leaf=is_axis_leaf, - ) - - carry_indices: list[int] = [] - broadcast_indices: list[int] = [] - scan_indices: list[int] = [] - scan_in_axes: list[int] = [] - - carry_leaves: list[tp.Any] = [] - broadcast_leaves: list[tp.Any] = [] - scan_leaves: list[tp.Any] = [] - - for i, (leaf, ax) in enumerate(zip(args_flat, in_axes_flat, strict=True)): - if ax is Carry: - carry_indices.append(i) - carry_leaves.append(leaf) - elif ax is None: - broadcast_indices.append(i) - broadcast_leaves.append(leaf) - elif isinstance(ax, int): - scan_indices.append(i) - scan_in_axes.append(ax) - if ax != 0: - leaf = jnp.moveaxis(leaf, ax, 0) - scan_leaves.append(leaf) - else: - raise ValueError(f'Invalid in_axes leaf value: {ax}') - - n_in = len(args_flat) - out_info: list[tuple[ - jax.tree_util.PyTreeDef, list[int], list[int], list[int], - ]] = [] - - in_broadcast = jax.tree.map(lambda x: x, broadcast_leaves) - - def body_fn(carry_state, scan_x): - flat = [None] * n_in - for idx, j in enumerate(carry_indices): - flat[j] = carry_state[idx] - for idx, j in enumerate(broadcast_indices): - flat[j] = in_broadcast[idx] - if scan_x is not None: - for idx, j in enumerate(scan_indices): - flat[j] = scan_x[idx] - - reconstructed = args_treedef.unflatten(flat) - out = f(*reconstructed) - - out_flat, out_treedef = jax.tree.flatten(out) - out_axes_paths, out_axes_flat = extract.broadcast_prefix2( - out_axes, out, is_leaf=is_axis_leaf, - ) - - if not out_info: - out_carry_idx = [] - out_scan_idx = [] - out_scan_axes = [] - out_broadcast_idx = [] - for j, oax in enumerate(out_axes_flat): - if oax is Carry: - out_carry_idx.append(j) - elif oax is None: - out_broadcast_idx.append(j) - elif isinstance(oax, int): - out_scan_idx.append(j) - out_scan_axes.append(oax) - else: - raise ValueError(f'Invalid out_axes leaf value: {oax}') - if out_broadcast_idx: - broadcast_paths = [ - jax.tree_util.keystr(out_axes_paths[j]) for j in out_broadcast_idx - ] - broadcast_str = "\n\n ".join(broadcast_paths) - raise ValueError( - 'Scan does not support broadcast outputs (None axis). The following ' - f'output leaves are broadcast:\n\n {broadcast_str}\n' - ) - out_info.append( - (out_treedef, out_carry_idx, out_scan_idx, out_scan_axes), - ) - - oci = out_info[0][1] - osi = out_info[0][2] - new_carry = [out_flat[j] for j in oci] - new_ys = [out_flat[j] for j in osi] - - return new_carry, new_ys - - final_carry, stacked_ys = jax.lax.scan( - body_fn, - carry_leaves, - scan_leaves if scan_leaves else None, - length=length, - reverse=reverse, - unroll=unroll, - _split_transpose=_split_transpose, - ) - - out_treedef, out_carry_idx, out_scan_idx, out_scan_axes = ( - out_info[0] - ) - n_out = out_treedef.num_leaves - out_flat: list[tp.Any] = [None] * n_out - for idx, j in enumerate(out_carry_idx): - out_flat[j] = final_carry[idx] - for idx, j in enumerate(out_scan_idx): - y = stacked_ys[idx] - ax = out_scan_axes[idx] - if ax != 0: - y = jnp.moveaxis(y, 0, ax) - out_flat[j] = y - - return out_treedef.unflatten(out_flat) - - # ------------------------------- # while_loop # ------------------------------- diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index 0d79b9f96..30c5ddcad 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -430,8 +430,8 @@ class MLP(bridge.Module): dim: int num_layers: int def setup(self): - @nnx.split_rngs(splits=self.num_layers) - @nnx.vmap( + @nnx.compat.split_rngs(splits=self.num_layers) + @nnx.compat.vmap( in_axes=(nnx.StateAxes({nnx.RngState: 0, ...: None}),), axis_size=self.num_layers, transform_metadata={nnx.PARTITION_NAME: None}, @@ -442,8 +442,8 @@ def create_block(parent): create_block(self) def __call__(self, x): - @nnx.split_rngs(splits=self.num_layers) - @nnx.scan( + @nnx.compat.split_rngs(splits=self.num_layers) + @nnx.compat.scan( in_axes=(0, nnx.Carry), out_axes=nnx.Carry, transform_metadata={nnx.PARTITION_NAME: None}, @@ -486,7 +486,7 @@ def forward(self, x): return self.aaa(x) def __call__(self, x): - forward = nnx.remat(self.__class__.forward) + forward = nnx.compat.remat(self.__class__.forward) return forward(self, x) model = Top() diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index b01b77127..31fa89f59 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -73,7 +73,7 @@ def test_unflatten(self): a = nnx.Dict(a=1, b=nnx.Param(2)) g = nnx.List([a, 3, a, nnx.Param(4)]) - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) g = nnx.merge(graphdef, state) assert g[0] is g[2] @@ -112,7 +112,7 @@ def test_unflatten_pure_dict(self): a = nnx.Dict(a=1, b=nnx.Param(2)) g = nnx.List([a, 3, a, nnx.Param(4)]) - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) pure_state = nnx.to_pure_dict(state) g = nnx.merge(graphdef, pure_state) @@ -123,7 +123,7 @@ def test_unflatten_pytree(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) g = nnx.merge(graphdef, state) assert g[0] is not g[2] @@ -132,7 +132,7 @@ def test_unflatten_empty(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) g = nnx.List([a, 3, a, nnx.Param(4)]) - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) with self.assertRaisesRegex(ValueError, 'Incorrect number of leaves'): nnx.graphlib.unflatten(graphdef, nnx.State({})) @@ -154,7 +154,7 @@ def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} g = [a, 3, a, nnx.Param(jnp.array(4))] - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) state[0]['b'][...] = 3 nnx.update(g, state) @@ -166,7 +166,7 @@ def test_update_from_pure_dict(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} g = [a, 3, a, nnx.Param(jnp.array(4))] - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) pure_state = nnx.to_pure_dict(state) pure_state[0]['b'] = jnp.array(3) @@ -196,7 +196,7 @@ def test_shared_variables(self): v = nnx.Param(1) g = [v, v] - graphdef, state = nnx.split(g) + graphdef, state = nnx.compat.split(g) assert len(nnx.to_flat_state(state)) == 1 @@ -214,7 +214,7 @@ def __init__(self, *, rngs: nnx.Rngs) -> None: self.baz.kernel = self.bar.kernel node = Foo(rngs=nnx.Rngs(0)) - graphdef, state = nnx.split(node) + graphdef, state = nnx.compat.split(node) assert len(nnx.to_flat_state(state)) == 3 # 2 bias + 1 kernel @@ -247,7 +247,7 @@ def __call__(self, x): return self.linear_out(x) model = Encoder(rngs=nnx.Rngs(0)) - graphdef, state = nnx.split(model) + graphdef, state = nnx.compat.split(model) assert len(nnx.to_flat_state(state)) == 1 @@ -284,7 +284,7 @@ def __init__(self): self.b = p m = Foo() - graphdef, state = nnx.split(m) + graphdef, state = nnx.compat.split(m) assert isinstance(m.a, nnx.Param) assert isinstance(m.b, nnx.Param) @@ -541,8 +541,8 @@ def __init__(self, dout: int, rngs: nnx.Rngs): self.rngs = rngs def __call__(self, x): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(0, None), axis_size=5) + @nnx.compat.split_rngs(splits=5) + @nnx.compat.vmap(in_axes=(0, None), axis_size=5) def vmap_fn(inner, x): return inner(x) @@ -835,7 +835,7 @@ def f(*pure_args): y = 0 self.assertIs(args[0], args[2]['b']) - for path, m in nnx.iter_graph(args): + for path, m in nnx.compat.iter_graph(args): if isinstance(m, Foo): self.assertEqual(m.a.shape, ()) self.assertEqual(m.b.shape, ()) @@ -958,7 +958,7 @@ def test_jit_pytree_of_variables(self): v2 = nnx.Param(jnp.array(2)) vs = [v1, v1, v2] - @nnx.jit + @nnx.compat.jit def f(vs): self.assertIs(vs[0], vs[1]) self.assertIsNot(vs[0], vs[2]) @@ -979,7 +979,7 @@ def __init__(self, var): var = nnx.Param(1) foo = Foo(var) - @nnx.jit + @nnx.compat.jit def increment_var(var, foo): self.assertIs(var, foo.var) var[...] += 1 @@ -1041,7 +1041,7 @@ def __init__(self): m = Foo() - @nnx.jit + @nnx.compat.jit def f(m): m.a += 1 self.assertEqual(m.b, 'yes') @@ -1124,7 +1124,7 @@ def test_iter_graph(self): root.f = var0 root.g = arr1 - nodes = [node for _, node in nnx.iter_graph(root)] + nodes = [node for _, node in nnx.compat.iter_graph(root)] count = lambda e: sum(node is e for node in nodes) # All internal nodes must be visited exactly once. @@ -1150,7 +1150,7 @@ def test_cached_partial_docstring_example(self): model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) - @nnx.jit + @nnx.compat.jit def train_step(model, optimizer, x, y): def loss_fn(model): return jnp.mean((model(x) - y) ** 2) @@ -1159,7 +1159,7 @@ def loss_fn(model): optimizer.update(model, grads) return loss - cached_train_step = nnx.cached_partial(train_step, model, optimizer) + cached_train_step = nnx.compat.cached_partial(train_step, model, optimizer) for step in range(2): x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) @@ -1196,7 +1196,7 @@ def inc_d(path, node): node.d += 1 return node - bar2 = nnx.recursive_map(inc_d, bar) + bar2 = nnx.compat.recursive_map(inc_d, bar) self.assertIs(bar2[0], bar2[2]) self.assertEqual(bar2[0].d, 11) self.assertEqual(bar2[1].d, 21) @@ -1219,7 +1219,7 @@ def swap(path, node): node = Foo(-node.d) return node - bar2 = nnx.recursive_map(swap, bar) + bar2 = nnx.compat.recursive_map(swap, bar) self.assertIs(bar2[0], bar2[2]) self.assertEqual(bar2[0].d, -10) self.assertEqual(bar2[1].d, -20) diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 6f16842cb..e82d37144 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -65,7 +65,7 @@ def loss_fn(model): y_pred = model(x) return jnp.mean((y_pred - y) ** 2) - loss, grads = nnx.value_and_grad(loss_fn)(model) + loss, grads = nnx.compat.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates return loss @@ -105,9 +105,9 @@ def __call__(self, x): x = self.block2(x) return x - @nnx.jit + @nnx.compat.jit def train_step(model: Model, x, y): - @nnx.grad + @nnx.compat.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) @@ -116,7 +116,7 @@ def loss_fn(model: Model): nnx.update( model, jax.tree.map( - lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + lambda w, g: w - 0.1 * g, nnx.compat.state(model, nnx.Param), grads ), ) @@ -155,9 +155,9 @@ def __call__(self, x): x = self.block2(x) return x - @nnx.jit + @nnx.compat.jit def train_step(model: Model, x, y): - @nnx.grad + @nnx.compat.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) @@ -166,7 +166,7 @@ def loss_fn(model: Model): nnx.update( model, jax.tree.map( - lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + lambda w, g: w - 0.1 * g, nnx.compat.state(model, nnx.Param), grads ), ) @@ -174,7 +174,7 @@ def loss_fn(model: Model): x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) - new_model = nnx.view(model, use_running_average=False) + new_model = nnx.compat.view(model, use_running_average=False) for _i in range(3): train_step(model, x, y) @@ -210,7 +210,7 @@ def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): model = nnx.merge(graphdef, state) model.set_attributes(use_running_average=False, graph=True) - @nnx.grad + @nnx.compat.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) @@ -219,14 +219,14 @@ def loss_fn(model: Model): nnx.update( model, jax.tree.map( - lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + lambda w, g: w - 0.1 * g, nnx.compat.state(model, nnx.Param), grads ), ) - return nnx.split(model) + return nnx.compat.split(model) graphdef: nnx.GraphDef[Model] - graphdef, state = nnx.split(Model(rngs=nnx.Rngs(0))) + graphdef, state = nnx.compat.split(Model(rngs=nnx.Rngs(0))) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) @@ -267,9 +267,9 @@ def __call__(self, x): @jax.jit def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): model = nnx.merge(graphdef, state) - new_model = nnx.view(model, use_running_average=False, graph=True) + new_model = nnx.compat.view(model, use_running_average=False) - @nnx.grad + @nnx.compat.grad def loss_fn(model: Model): y_pred = model(x) return jnp.mean((y - y_pred) ** 2) @@ -278,14 +278,14 @@ def loss_fn(model: Model): nnx.update( new_model, jax.tree.map( - lambda w, g: w - 0.1 * g, nnx.state(new_model, nnx.Param), grads + lambda w, g: w - 0.1 * g, nnx.compat.state(new_model, nnx.Param), grads ), ) - return nnx.split(new_model) + return nnx.compat.split(new_model) graphdef: nnx.GraphDef[Model] - graphdef, state = nnx.split(Model(rngs=nnx.Rngs(0))) + graphdef, state = nnx.compat.split(Model(rngs=nnx.Rngs(0))) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) @@ -301,8 +301,7 @@ def loss_fn(model: Model): assert model.block1.linear.bias is model.block2.linear.bias assert model.block1.bn is not model.block2.bn - @parameterized.parameters(True, False) - def test_stateful_example(self, graph_mode): + def test_stateful_example(self): class State(nnx.Variable[A]): pass @@ -323,20 +322,64 @@ def __call__(self, x): y = model(x) assert model.count[...] == 1 - @nnx.jit(graph=graph_mode) + @nnx.jit(graph=True, graph_updates=True) def train_step(model, x, y): def loss_fn(model): y_pred = model(x) return jax.numpy.mean((y_pred - y) ** 2) # compute gradient - grads: nnx.State = nnx.grad(loss_fn)(model) + grad_fn = nnx.grad(loss_fn, graph=True, graph_updates=True) + grads: nnx.State = grad_fn(model) # SGD update nnx.update( - model, - jax.tree.map( - lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads - ), + model, + jax.tree.map( + lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + ), + ) + + # execute the training step + train_step(model, x, y) + assert model.count[...] == 2 + + @parameterized.parameters(True, False) + def test_stateful_example_functional(self, graph): + class State(nnx.Variable[A]): + pass + + class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = State(jnp.array(0)) + + def __call__(self, x): + self.count[...] += 1 + return x @ self.w + self.b[None] + + model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) + # forward pass + x = jnp.ones((8, 12)) + y = model(x) + assert model.count[...] == 1 + + @nnx.jit(graph=graph, graph_updates=False) + def train_step(model, x, y): + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + model = nnx.merge(graphdef, params, nondiff) + return ((model(x) - y) ** 2).mean() + + # compute gradient + grad_fn = nnx.grad(loss_fn, graph=graph, graph_updates=False) + grads = grad_fn(params, nondiff) + # SGD update + nnx.update( + model, + jax.tree.map(lambda w, g: w - 0.1 * g, params, grads), ) # execute the training step diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 2706fa53a..be350e8a9 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -451,7 +451,7 @@ def test_shared_module(self): m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) - m3 = nnx.merge(*nnx.split(m2)) + m3 = nnx.merge(*nnx.split(m2, graph=True)) assert m3['x'] is m3['y'] assert m3['x']['a'] is m3['y']['a'] @@ -465,7 +465,7 @@ def __init__(self): m = Foo() - graphdef, state = nnx.split(m) + graphdef, state = nnx.split(m, graph=True) assert len(state) == 1 m2 = nnx.merge(graphdef, state) @@ -484,9 +484,9 @@ def f(graphdef: nnx.GraphDef[nnx.Dict], state: nnx.State): assert m['a'][0] is m['b'] assert m['a'][1] is not m['b'] - return nnx.split(m) + return nnx.split(m, graph=True) - graphdef, state = f(*nnx.split(m)) + graphdef, state = f(*nnx.split(m, graph=True)) m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] @@ -553,7 +553,7 @@ def test_deref_number_of_fields(self): } ) - graphdef, p = nnx.split(m) + graphdef, p = nnx.split(m, graph=True) assert len(nnx.to_flat_state(p)) == 2 assert len(jax.tree_util.tree_leaves(p)) == 2 @@ -615,12 +615,12 @@ def test_cached_partial(self): model = SowMod(nnx.Rngs(42)) x = jnp.ones((2, 4)) - @nnx.jit + @nnx.compat.jit def train_step(model, x): out, intermediates = nnx.capture(model, nnx.Intermediate)(x) return out, intermediates - train_step_fn = nnx.cached_partial(train_step, model) + train_step_fn = nnx.compat.cached_partial(train_step, model) train_step_fn(x) def test_update_static_state_submodules(self): @@ -1247,7 +1247,7 @@ def __call__(self, x, *, rngs: nnx.Rngs): foo = Foo(c=1.0, rngs=nnx.Rngs(0)) - graphdef, state = nnx.split(foo) + graphdef, state = nnx.split(foo, graph=True) assert isinstance(graphdef.nodes[0], nnx.graphlib.NodeDef | nnx.graphlib.NodeRef) assert isinstance(state, nnx.State) diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index 570d6910b..5875d31d3 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -447,7 +447,7 @@ def test_simple_jit(self): m1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) m_out1 = None - @nnx.jit + @nnx.compat.jit def f(m2): nonlocal m_out1 m_out1 = nnx.vars_as(nnx.Linear(1, 1, rngs=nnx.Rngs(0)), ref=True) @@ -466,7 +466,7 @@ class Foo(nnx.Pytree): m1 = Foo(a=jax.new_ref(1)) - @nnx.jit + @nnx.compat.jit def f(m2: Foo): m2.a[...] += 1 return m2 diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index 93fc0f121..318e4cf8f 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -99,7 +99,7 @@ def test_sharding_propagation(self): @parameterized.product( module_cls=[nnx.Linear, Model], - jit_decorator=[lambda f: f, nnx.jit, jax.jit], + jit_decorator=[lambda f: f, nnx.compat.jit, jax.jit], optimizer=[optax.sgd, optax.adam], ) def test_jit(self, module_cls, jit_decorator, optimizer): @@ -146,7 +146,7 @@ def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): @parameterized.product( module_cls=[nnx.Linear, Model], - jit_decorator=[lambda f: f, nnx.jit, jax.jit], + jit_decorator=[lambda f: f, nnx.compat.jit, jax.jit], optimizer=[optax.lbfgs], ) def test_jit_linesearch(self, module_cls, jit_decorator, optimizer): @@ -190,7 +190,7 @@ def jax_jit_train_step(graphdef, state, x, y): initial_loss = loss_fn(state.model, x, y) def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): - grads = nnx.grad(loss_fn)(optimizer.model, x, y) + grads = nnx.compat.grad(loss_fn)(optimizer.model, x, y) optimizer.update( grads, grad=grads, value=initial_loss, value_fn=loss_fn_split ) @@ -247,12 +247,19 @@ def test_wrt_update(self, variable): rngs=nnx.Rngs(1), ) state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) - prev_variables, prev_other_variables = nnx.clone(nnx.state(model, variable, ...)) + prev_variables, prev_other_variables = nnx.clone( + nnx.state(model, variable, ...), graph=True + ) x = jnp.ones((1, 4)) y = jnp.ones((1, 10)) loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() - grad_fn = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable)) + grad_fn = nnx.grad( + loss_fn, + argnums=nnx.DiffState(0, variable), + graph=True, + graph_updates=True, + ) def step(): grads = grad_fn(model, x, y) @@ -278,7 +285,7 @@ def step(): @parameterized.parameters( {'variable': nnx.Param}, - # {'variable': nnx.LoRAParam}, + {'variable': nnx.LoRAParam}, {'variable': (nnx.Param, nnx.LoRAParam)}, ) def test_wrt_update_linesearch(self, variable): @@ -294,15 +301,22 @@ def test_wrt_update_linesearch(self, variable): rngs=nnx.Rngs(1), ) state = nnx.Optimizer(model, optax.lbfgs(), wrt=variable) - prev_variables, prev_other_variables = nnx.clone(nnx.state(model, variable, ...)) + prev_variables, prev_other_variables = nnx.clone( + nnx.state(model, variable, ...), graph=True + ) x = jnp.ones((1, 4)) y = jnp.ones((1, 10)) loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() - grad_fn = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable)) - graphdef = nnx.graphdef(model) - loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y) + grad_fn = nnx.grad( + loss_fn, + argnums=nnx.DiffState(0, variable), + graph=True, + graph_updates=True, + ) + graphdef, _, other_variables = nnx.compat.split(model, variable, ...) + loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state, other_variables), x, y) def step(): grads = grad_fn(model, x, y) @@ -328,105 +342,99 @@ def step(): assert_equal, prev_other_variables, other_variables ) - def test_update_returns_updates(self): - """Test that Optimizer.update returns the updates PyTree.""" - model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - optimizer = nnx.Optimizer(model, optax.sgd(0.1), wrt=nnx.Param) - - def loss_fn(model): - params = nnx.state(model) - loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params)) - return loss - - grads = nnx.grad(loss_fn)(model) - - # Call update and capture return value - updates = optimizer.update(model, grads) - - # Verify updates is not None - self.assertIsNotNone(updates) - - # Verify updates structure matches params structure - params = nnx.pure(nnx.state(model, nnx.Param)) - - def check_structure(path, update_val, param_val): - self.assertEqual(update_val.shape, param_val.shape) - self.assertEqual(update_val.dtype, param_val.dtype) - - jax.tree.map_with_path(check_structure, updates, params) - - def test_updates_match_param_changes(self): - """Test that returned updates equal the actual parameter changes.""" - model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - # Use SGD with lr=1.0 for simplicity (updates = -grads) - optimizer = nnx.Optimizer(model, optax.sgd(1.0), wrt=nnx.Param) - - # Get initial params as pure arrays - initial_params = nnx.pure(nnx.state(model, nnx.Param)) - - def loss_fn(model): - params = nnx.state(model) - loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params)) - return loss - - grads = nnx.grad(loss_fn)(model) - - # Get updates - updates = optimizer.update(model, grads) - - # Get new params as pure arrays - new_params = nnx.pure(nnx.state(model, nnx.Param)) - - # Verify: new_params = initial_params + updates (within optax.apply_updates) - def check_update(old, update, new): - expected_new = old + update - np.testing.assert_allclose(new, expected_new, rtol=1e-5) + @parameterized.parameters( + {'variable': nnx.Param}, + # {'variable': nnx.LoRAParam}, + {'variable': (nnx.Param, nnx.LoRAParam)}, + ) + def test_wrt_update_linesearch_functional(self, variable): + in_features = 4 + out_features = 10 + model = nnx.LoRA( + in_features=in_features, + lora_rank=2, + out_features=out_features, + base_module=Model( + in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0) + ), + rngs=nnx.Rngs(1), + ) + optimizer = nnx.Optimizer(model, optax.lbfgs(), wrt=variable) + graphdef, params, nondiff = nnx.split(model, variable, ..., graph=True) + prev_params, prev_nondiff = nnx.clone( + (params, nondiff), graph=True + ) - jax.tree.map(check_update, initial_params, updates, new_params) + x = jnp.ones((1, 4)) + y = jnp.ones((1, 10)) + def loss_fn(params, nondiff): + model = nnx.merge(graphdef, params, nondiff) + return ((model(x) - y) ** 2).mean() - def test_model_and_optimizer_returns_updates(self): - """Test that ModelAndOptimizer.update also returns updates.""" - model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - state = nnx.ModelAndOptimizer(model, optax.sgd(0.1)) + grad_fn = nnx.grad(loss_fn, graph=True, graph_updates=True) + loss_fn_split = lambda params: loss_fn(params, nondiff) - def loss_fn(model): - params = nnx.state(model) - loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params)) - return loss + def step(): + grads = grad_fn(params, nondiff) + initial_loss = loss_fn(params, nondiff) + optimizer.update( + model, grads, grad=grads, value_fn=loss_fn_split, value=initial_loss + ) + self.assertTrue(loss_fn(params, nondiff) < initial_loss) - grads = nnx.grad(loss_fn)(model) + # Since lora_b is initialized to zeros by default, the gradient flow to lora_a + # will be zeroed out in first call. Thus, run the step twice to make sure + # lora_a is updated. + for _ in range(2): + step() - # Call update and capture return value - updates = state.update(grads) + jax.tree.map_with_path(assert_not_equal, prev_params, params) + jax.tree.map_with_path(assert_equal, prev_nondiff, nondiff) - # Verify updates is not None - self.assertIsNotNone(updates) - # Verify updates has the expected structure - params = nnx.pure(nnx.state(model, nnx.Param)) - def check_structure(path, update_val, param_val): - self.assertEqual(update_val.shape, param_val.shape) + @parameterized.parameters( + {'variable': nnx.Param}, + {'variable': nnx.LoRAParam}, + {'variable': (nnx.Param, nnx.LoRAParam)}, + ) + def test_wrt_update_functional(self, variable): + in_features = 4 + out_features = 10 + model = nnx.LoRA( + in_features=in_features, + lora_rank=2, + out_features=out_features, + base_module=Model( + in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0) + ), + rngs=nnx.Rngs(1), + ) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) + graphdef, params, nondiff = nnx.split(model, variable, ..., graph=True) + prev_params, prev_nondiff = nnx.clone((params, nondiff), graph=True) - jax.tree.map_with_path(check_structure, updates, params) + x = jnp.ones((1, 4)) + y = jnp.ones((1, 10)) - def test_update_backward_compatible(self): - """Test that existing code ignoring return value still works.""" - model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - optimizer = nnx.Optimizer(model, optax.adam(0.1), wrt=nnx.Param) + def loss_fn(params, nondiff): + model = nnx.merge(graphdef, params, nondiff) + return ((model(x) - y) ** 2).mean() - def loss_fn(model): - params = nnx.state(model) - loss = sum(jnp.sum(x**2) for x in jax.tree.leaves(params)) - return loss + grad_fn = nnx.grad(loss_fn, graph=True, graph_updates=True) - grads = nnx.grad(loss_fn)(model) + def step(): + grads = grad_fn(params, nondiff) + initial_loss = loss_fn(params, nondiff) + optimizer.update(model, grads) + self.assertTrue(loss_fn(params, nondiff) < initial_loss) - # Existing code pattern - ignore return value - optimizer.update(model, grads) # Should not error + for _ in range(2): + step() - # Verify update still worked - self.assertEqual(optimizer.step[...], 1) + jax.tree.map_with_path(assert_not_equal, prev_params, params) + jax.tree.map_with_path(assert_equal, prev_nondiff, nondiff) if __name__ == '__main__': absltest.main() + diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index ccde8544e..a7edc2bc2 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -193,11 +193,16 @@ def __call__(self, x): np.testing.assert_allclose(y1, y2) - @parameterized.parameters(True, False) - def test_split_rngs(self, graph): + @parameterized.product( + graph=[True, False], + graph_updates=[True, False], + ) + def test_split_rngs(self, graph: bool, graph_updates: bool): rngs = nnx.Rngs(params=0, dropout=1) - result = nnx.split_rngs(rngs, splits=5, graph=graph) - if graph: + result = nnx.split_rngs( + rngs, splits=5, graph=graph, graph_updates=graph_updates + ) + if graph and graph_updates: self.assertEqual(rngs.params.key.shape, (5,)) self.assertEqual(rngs['dropout'].key.shape, (5,)) nnx.restore_rngs(result) @@ -209,13 +214,43 @@ def test_split_rngs(self, graph): self.assertEqual(result.params.key.shape, (5,)) self.assertEqual(result['dropout'].key.shape, (5,)) - @parameterized.parameters(True, False) - def test_fork_rngs(self, graph): + @parameterized.product( + graph=[True, False], + graph_updates=[True, False], + ) + def test_fork_rngs(self, graph: bool, graph_updates: bool): rngs = nnx.Rngs(params=0, dropout=1) - backups = nnx.fork_rngs(rngs, graph=graph) - new_key = rngs.params.key.copy() - nnx.restore_rngs(backups) - self.assertNotEqual(rngs.params.key, new_key) + original_key = rngs.params.key[...].copy() + + result = nnx.fork_rngs(rngs, graph=graph, graph_updates=graph_updates) + + # jax.random.key_data is needed because np.allclose raises a TypeError + # when implicitly converting PRNGKey dtypes to NumPy arrays. + def _get_data(x): + arr = ( + x[...] + if hasattr(x, '__getitem__') and not isinstance(x, jax.Array) + else x + ) + return jax.random.key_data(arr) + + if graph and graph_updates: + self.assertFalse( + np.allclose(_get_data(rngs.params.key[...]), _get_data(original_key)) + ) + nnx.restore_rngs(result) + np.testing.assert_allclose( + _get_data(rngs.params.key[...]), _get_data(original_key) + ) + else: + np.testing.assert_allclose( + _get_data(rngs.params.key[...]), _get_data(original_key) + ) + self.assertFalse( + np.allclose( + _get_data(result.params.key[...]), _get_data(original_key) + ) + ) def test_random_helpers(self): rngs = nnx.Rngs(0, params=1) @@ -230,98 +265,130 @@ def test_random_helpers(self): x_nnx = rngs.lecun_normal()((2, 3)) x_jax = jax.nn.initializers.lecun_normal()( - jax.random.fold_in(jax.random.key(0), 1), (2, 3) + jax.random.fold_in(jax.random.key(0), 1), (2, 3) ) np.testing.assert_allclose(x_nnx, x_jax) x_nnx = rngs.params.lecun_uniform()((2, 3)) x_jax = jax.nn.initializers.lecun_uniform()( - jax.random.fold_in(jax.random.key(1), 1), (2, 3) + jax.random.fold_in(jax.random.key(1), 1), (2, 3) ) np.testing.assert_allclose(x_nnx, x_jax) - def test_split_int_splits_all_streams(self): + @parameterized.parameters(True, False) + def test_split_int_splits_all_streams(self, graph): rngs = nnx.Rngs(params=0, dropout=1) - new_rngs = nnx.with_rngs(rngs, split=4) + new_rngs = nnx.with_rngs(rngs, split=4, graph=graph, graph_updates=False) self.assertEqual(new_rngs.params.key.shape, (4,)) self.assertEqual(new_rngs['dropout'].key.shape, (4,)) - def test_split_tuple_splits_all_streams(self): + @parameterized.parameters(True, False) + def test_split_tuple_splits_all_streams(self, graph): rngs = nnx.Rngs(params=0, dropout=1) - new_rngs = nnx.with_rngs(rngs, split=(2, 3)) + new_rngs = nnx.with_rngs( + rngs, split=(2, 3), graph=graph, graph_updates=False + ) self.assertEqual(new_rngs.params.key.shape, (2, 3)) self.assertEqual(new_rngs['dropout'].key.shape, (2, 3)) - def test_fork_forks_all_streams(self): + @parameterized.parameters(True, False) + def test_fork_forks_all_streams(self, graph): rngs = nnx.Rngs(params=0, dropout=1) original_params_key = rngs.params.key[...] original_dropout_key = rngs['dropout'].key[...] - new_rngs = nnx.with_rngs(rngs, fork=...) + new_rngs = nnx.with_rngs(rngs, fork=..., graph=graph, graph_updates=False) # Forked keys are scalar and differ from originals self.assertEqual(new_rngs.params.key.shape, ()) self.assertEqual(new_rngs['dropout'].key.shape, ()) - self.assertFalse(jnp.array_equal(new_rngs.params.key[...], original_params_key)) - self.assertFalse(jnp.array_equal(new_rngs['dropout'].key[...], original_dropout_key)) + self.assertFalse( + jnp.array_equal(new_rngs.params.key[...], original_params_key) + ) + self.assertFalse( + jnp.array_equal(new_rngs['dropout'].key[...], original_dropout_key) + ) - def test_split_mapping_applies_per_filter(self): + @parameterized.parameters(True, False) + def test_split_mapping_applies_per_filter(self, graph): rngs = nnx.Rngs(params=0, dropout=1, noise=2) - new_rngs = nnx.with_rngs(rngs, split={'params': 4, ...: (2, 3)}) + new_rngs = nnx.with_rngs( + rngs, split={'params': 4, ...: (2, 3)}, graph=graph, graph_updates=False + ) self.assertEqual(new_rngs.params.key.shape, (4,)) self.assertEqual(new_rngs['dropout'].key.shape, (2, 3)) self.assertEqual(new_rngs.noise.key.shape, (2, 3)) - def test_split_mapping_first_matching_filter_wins(self): + @parameterized.parameters(True, False) + def test_split_mapping_first_matching_filter_wins(self, graph): rngs = nnx.Rngs(params=0, dropout=1) # 'params' filter comes before '...' so it should match first - new_rngs = nnx.with_rngs(rngs, split={'params': 4, ...: 8}) + new_rngs = nnx.with_rngs( + rngs, split={'params': 4, ...: 8}, graph=graph, graph_updates=False + ) self.assertEqual(new_rngs.params.key.shape, (4,)) self.assertEqual(new_rngs['dropout'].key.shape, (8,)) - def test_split_some_fork_rest(self): + @parameterized.parameters(True, False) + def test_split_some_fork_rest(self, graph): rngs = nnx.Rngs(params=0, dropout=1) - new_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + new_rngs = nnx.with_rngs( + rngs, split={'params': 4}, fork=..., graph=graph, graph_updates=False + ) self.assertEqual(new_rngs.params.key.shape, (4,)) # dropout not matched by split → forked (scalar) self.assertEqual(new_rngs['dropout'].key.shape, ()) - def test_original_base_key_not_replaced(self): + @parameterized.parameters(True, False) + def test_original_base_key_not_replaced(self, graph): # nnx.with_rngs advances the original stream's counter (consuming one step to # derive the new keys) but does not replace the original's base key. rngs = nnx.Rngs(params=0, dropout=1) original_key_var = rngs.params.key - nnx.with_rngs(rngs, split=4) + nnx.with_rngs(rngs, split=4, graph=graph, graph_updates=False) self.assertIs(rngs.params.key, original_key_var) self.assertEqual(rngs.params.key.shape, ()) - def test_unmatched_streams_returned_unchanged(self): + @parameterized.parameters(True, False) + def test_unmatched_streams_returned_unchanged(self, graph): rngs = nnx.Rngs(params=0, dropout=1) # Only fork 'params'; 'dropout' matches neither split nor fork - new_rngs = nnx.with_rngs(rngs, fork='params') + new_rngs = nnx.with_rngs( + rngs, fork='params', graph=graph, graph_updates=False + ) self.assertIsNot(new_rngs['dropout'], rngs['dropout']) # new tree, but... - self.assertTrue(jnp.array_equal(new_rngs['dropout'].key[...], rngs['dropout'].key[...])) + self.assertTrue( + jnp.array_equal(new_rngs['dropout'].key[...], rngs['dropout'].key[...]) + ) self.assertEqual(new_rngs['dropout'].key.shape, ()) - def test_split_and_fork_same_stream_raises(self): + @parameterized.parameters(True, False) + def test_split_and_fork_same_stream_raises(self, graph): rngs = nnx.Rngs(params=0, dropout=1) - with self.assertRaisesRegex(ValueError, re.compile(r"multiple rules")): - nnx.with_rngs(rngs, split={'params': 4}, fork='params') + with self.assertRaisesRegex(ValueError, re.compile(r'multiple rules')): + nnx.with_rngs( + rngs, + split={'params': 4}, + fork='params', + graph=graph, + graph_updates=False, + ) - def test_works_on_plain_pytree(self): + @parameterized.parameters(True, False) + def test_works_on_plain_pytree(self, graph): params_stream = nnx.RngStream(0, tag='params') dropout_stream = nnx.RngStream(1, tag='dropout') tree = {'a': params_stream, 'b': dropout_stream} - new_tree = nnx.with_rngs(tree, split=4) + new_tree = nnx.with_rngs(tree, split=4, graph=graph, graph_updates=False) self.assertEqual(new_tree['a'].key.shape, (4,)) self.assertEqual(new_tree['b'].key.shape, (4,)) @@ -354,6 +421,36 @@ def test_rngs_broadcast(self): self.assertEqual(broadcasted_rngs_mapped.params.key.shape, (5,)) self.assertEqual(broadcasted_rngs_mapped.dropout.key.shape, ()) + @parameterized.parameters(True, False) + def test_with_rngs_decorator(self, graph): + rngs = nnx.Rngs(params=0, dropout=1) + + @nnx.with_rngs(split=4, graph=graph, graph_updates=False) + def f(r): + return r.params.key.shape, r['dropout'].key.shape + + params_shape, dropout_shape = f(rngs) + self.assertEqual(params_shape, (4,)) + self.assertEqual(dropout_shape, (4,)) + + @parameterized.parameters(True, False) + def test_with_rngs_broadcast(self, graph): + rngs = nnx.Rngs(params=0, dropout=1) + + # Test int broadcast + new_rngs = nnx.with_rngs( + rngs, broadcast=5, graph=graph, graph_updates=False + ) + self.assertEqual(new_rngs.params.key.shape, (5,)) + self.assertEqual(new_rngs.dropout.key.shape, (5,)) + + # Test mapping broadcast + new_rngs_mapped = nnx.with_rngs( + rngs, broadcast={'params': 5}, graph=graph, graph_updates=False + ) + self.assertEqual(new_rngs_mapped.params.key.shape, (5,)) + self.assertEqual(new_rngs_mapped['dropout'].key.shape, ()) + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 88d091a5c..27474cc1a 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -152,10 +152,11 @@ def test_add_remove_axis_in_transform(self): kadds, kremoves, badds, bremoves = [], [], [], [] class MLP(nnx.Module): - @nnx.split_rngs(splits=5) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) @nnx.vmap( in_axes=(0, 0), transform_metadata={nnx.PARTITION_NAME: 'layers', 'nickname': 'nick'}, + graph=True, graph_updates=True, ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( @@ -178,7 +179,8 @@ def __init__(self, rngs: nnx.Rngs): @nnx.scan( in_axes=(0, nnx.Carry), - transform_metadata={nnx.PARTITION_NAME: 'layers'} + transform_metadata={nnx.PARTITION_NAME: 'layers'}, + graph=True, graph_updates=True, ) def __call__(self, x: jax.Array): x = self.linear(x) @@ -396,7 +398,7 @@ def __init__(self, rngs): ('batch', 'model'), axis_types=(jax.sharding.AxisType.Auto,) * len(('batch', 'model')), ) - gdef, abs_state = nnx.get_abstract_model(lambda: Foo(nnx.Rngs(0)), mesh) + gdef, abs_state = nnx.compat.get_abstract_model(lambda: Foo(nnx.Rngs(0)), mesh) assert len(jax.tree.leaves(abs_state)) == 1 assert jax.tree.leaves(abs_state)[0].sharding.is_equivalent_to( NamedSharding(mesh, P(None, 'model')), ndim=2) @@ -453,7 +455,7 @@ def __init__(self): self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b"), "mesh": mesh1}) self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("c", "d"), "mesh": mesh2}) - abs_model = nnx.eval_shape(lambda: Model()) + abs_model = nnx.abstract_with_sharding(nnx.eval_shape(lambda: Model())) assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding) assert abs_model.p1.kernel.sharding.mesh.axis_names == mesh1.axis_names assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b") @@ -468,7 +470,7 @@ def __init__(self): mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) with jax.set_mesh(mesh): - abs_model = nnx.eval_shape(lambda: Model()) + abs_model = nnx.abstract_with_sharding(nnx.eval_shape(lambda: Model())) assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding) assert abs_model.linear.kernel.sharding.mesh.axis_names == mesh.axis_names assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b") diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index c0f020956..c05e64b94 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -22,7 +22,6 @@ from absl.testing import absltest from absl.testing import parameterized from flax import nnx -from flax.nnx.transforms.iteration import pure_jax_fancy_scan from flax.nnx.transforms import general import jax from jax.experimental import checkify, mesh_utils @@ -33,23 +32,37 @@ class TestJIT(parameterized.TestCase): - def test_jit(self): + + def test_jit_graph_updates(self): m = nnx.Dict(a=nnx.Param(1)) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def g(m: nnx.Dict): m.a = 2 return 1.0 out = g(m) - assert m.a == 2 - assert out == 1.0 + self.assertEqual(m.a, 2) + self.assertEqual(out, 1.0) + + def test_jit_graph_updates_functional(self): + m = nnx.Dict(a=nnx.Param(1)) + + @nnx.jit(graph=True, graph_updates=False) + def g(m: nnx.Dict): + m.a.set_value(2) + return 1.0 + + out = g(m) + + self.assertEqual(m.a, 2) + self.assertEqual(out, 1.0) def test_mutable_array_input_output(self): m = jax.new_ref(jnp.array(1.0)) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(m: jax.Ref): m[...] += 1.0 m2 = jax.new_ref(jnp.array(10.0)) @@ -89,7 +102,7 @@ def test_jit_on_init(self): n = 0 class Foo(nnx.Module): - @nnx.jit(static_argnums=(1, 2)) + @nnx.jit(static_argnums=(1, 2), graph=True, graph_updates=True) def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): nonlocal n n += 1 @@ -111,6 +124,34 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): m = Foo(2, 3, rngs=nnx.Rngs(0)) assert n == 1 + def test_jit_on_init_functional(self): + n = 0 + + class Foo(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.normal(key, shape=(din, dout))) + self.din = din + self.dout = dout + + @nnx.jit(static_argnums=(0, 1), graph=True, graph_updates=False) + def create_foo(din: int, dout: int, *, rngs: nnx.Rngs): + nonlocal n + n += 1 + return Foo(din, dout, rngs=rngs) + + m = create_foo(2, 3, rngs=nnx.Rngs(0)) + assert n == 1 + assert m.w.shape == (2, 3) + assert m.din == 2 + assert m.dout == 3 + assert isinstance(m.din, int) + assert isinstance(m.dout, int) + assert isinstance(m.w[...], jax.Array) + + m = create_foo(2, 3, rngs=nnx.Rngs(0)) + assert n == 1 + @parameterized.parameters( (True, True), (True, False), (False, False), ) @@ -144,7 +185,7 @@ def __call__(self, x: jax.Array) -> jax.Array: y = m(jnp.ones((1, 2))) assert n == 1 - def test_cached_unflatten(self): + def test_graph_updates_unflatten(self): n = 0 class Foo(nnx.Module): @@ -152,7 +193,7 @@ def __init__(self, *, rngs: nnx.Rngs): self.a = nnx.Linear(2, 2, rngs=rngs) self.b = nnx.BatchNorm(2, rngs=rngs) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(m: Foo): nonlocal n n += 1 @@ -170,9 +211,9 @@ def f(m: Foo): f(m) - assert n == 1 - assert m.a is b - assert m.b is a + self.assertEqual(n, 1) + self.assertIs(m.a, b) + self.assertIs(m.b, a) np.testing.assert_allclose(a_kernel, a.kernel[...]) np.testing.assert_allclose(a_bias, a.bias[...]) np.testing.assert_allclose(b_scale, b.scale[...]) @@ -182,21 +223,62 @@ def f(m: Foo): f(m) - assert n == 2 - assert m.a is a - assert m.b is b + self.assertEqual(n, 2) + self.assertIs(m.a, a) + self.assertIs(m.b, b) + + f(m) + + self.assertEqual(n, 2) + self.assertIs(m.a, b) + self.assertIs(m.b, a) + + f(m) + + self.assertEqual(n, 2) + self.assertIs(m.a, a) + self.assertIs(m.b, b) + + def test_graph_updates_unflatten_functional(self): + n = 0 + + class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Param(jnp.array(1)) + self.b = nnx.Param(jnp.array(2)) + + @nnx.jit(graph=True, graph_updates=False) + def f(m: Foo): + nonlocal n + n += 1 + a_val = m.a.get_value() + b_val = m.b.get_value() + m.a.set_value(b_val) + m.b.set_value(a_val) + + m = Foo() + a = m.a + b = m.b + + f(m) + + self.assertEqual(n, 1) + self.assertIs(m.a, a) + self.assertIs(m.b, b) + self.assertEqual(m.a, 2) + self.assertEqual(m.b, 1) f(m) - assert n == 2 - assert m.a is b - assert m.b is a + self.assertEqual(n, 1) # Should NOT retrace + self.assertEqual(m.a, 1) + self.assertEqual(m.b, 2) f(m) - assert n == 2 - assert m.a is a - assert m.b is b + self.assertEqual(n, 1) + self.assertEqual(m.a, 2) + self.assertEqual(m.b, 1) @parameterized.parameters( (True, True), (True, False), (False, False), @@ -219,7 +301,7 @@ def f_bwd(res, g): jax_out = jax.jit(f)(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])) assert (nnx_out == jax_out).all() - def test_cached_unflatten_same_type(self): + def test_graph_updates_same_type(self): n = 0 class Foo(nnx.Module): @@ -227,7 +309,7 @@ def __init__(self, *, rngs: nnx.Rngs): self.a = nnx.Linear(2, 2, rngs=rngs) self.b = nnx.Linear(2, 2, rngs=rngs) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(m: Foo): nonlocal n n += 1 @@ -239,17 +321,51 @@ def f(m: Foo): f(m) - assert n == 1 - assert m.a is b - assert m.b is a + self.assertEqual(n, 1) + self.assertIs(m.a, b) + self.assertIs(m.b, a) f(m) - assert n == 1 - assert m.a is a - assert m.b is b + self.assertEqual(n, 1) + self.assertIs(m.a, a) + self.assertIs(m.b, b) + + @parameterized.parameters(True, False) + def test_graph_updates_same_type_functional(self, graph): + n = 0 + + class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Variable(1.0) + self.b = nnx.Variable(2.0) + + @nnx.jit(graph=graph, graph_updates=False) + def f(m: Foo): + nonlocal n + n += 1 + # Functional swap using set_value + a_val = m.a.get_value() + b_val = m.b.get_value() + m.a.set_value(b_val) + m.b.set_value(a_val) + + m = Foo() + + f(m) + + self.assertEqual(n, 1) + self.assertEqual(m.a, 2.0) + self.assertEqual(m.b, 1.0) + + f(m) + + self.assertEqual(n, 1) + self.assertEqual(m.a, 1.0) + self.assertEqual(m.b, 2.0) - def test_objects_in_pytree(self): + @parameterized.parameters(True, False) + def test_objects_in_pytree(self, graph_updates): n = 0 class Foo(nnx.Module): @@ -260,7 +376,7 @@ def __init__(self, *, rngs: nnx.Rngs): class FooDict(tp.TypedDict): foo: Foo - @nnx.jit + @nnx.jit(graph=True, graph_updates=graph_updates) def f(tree: tuple[FooDict]): nonlocal n n += 1 @@ -273,23 +389,27 @@ def f(tree: tuple[FooDict]): f(({'foo': m},)) - assert n == 1 - assert m.a is b - assert m.b is a + self.assertEqual(n, 1) + if graph_updates: + self.assertIs(m.a, b) + self.assertIs(m.b, a) + else: + self.assertIs(m.a, a) + self.assertIs(m.b, b) f(({'foo': m},)) - assert n == 1 - assert m.a is a - assert m.b is b + self.assertEqual(n, 1) + self.assertIs(m.a, a) + self.assertIs(m.b, b) - def test_cached_unflatten_swap_variables(self): + def test_graph_updates_swap_variables(self): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(1) self.b = nnx.Param(2) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(m: Foo): m.a, m.b = m.b, m.a @@ -299,17 +419,41 @@ def f(m: Foo): f(m) - assert m.a is b - assert m.b is a + self.assertIs(m.a, b) + self.assertIs(m.b, a) + + def test_graph_updates_swap_variables_functional(self): + class Foo(nnx.Module): + def __init__(self): + self.a = nnx.Param(1) + self.b = nnx.Param(2) + + @nnx.jit(graph=True, graph_updates=False) + def f(m: Foo): + a_val = m.a.get_value() + b_val = m.b.get_value() + m.a.set_value(b_val) + m.b.set_value(a_val) + + m = Foo() + a = m.a + b = m.b + + f(m) + + self.assertIs(m.a, a) # References should NOT change in functional mode + self.assertIs(m.b, b) + self.assertEqual(m.a, 2) # Values should be swapped + self.assertEqual(m.b, 1) - def test_cached_unflatten_add_self_reference(self): + def test_graph_updates_add_self_reference(self): n = 0 class Foo(nnx.Module): def __init__(self): self.ref: tp.Optional[Foo] = nnx.data(None) # type: ignore[name-error] - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(m: Foo): nonlocal n n += 1 @@ -319,27 +463,27 @@ def f(m: Foo): f(m) - assert n == 1 - assert m.ref is m + self.assertEqual(n, 1) + self.assertIs(m.ref, m) f(m) - assert n == 2 - assert m.ref is m + self.assertEqual(n, 2) + self.assertIs(m.ref, m) f(m) - assert n == 2 - assert m.ref is m + self.assertEqual(n, 2) + self.assertIs(m.ref, m) - def test_cached_unflatten_ref_in_output(self): + def test_graph_updates_ref_in_output(self): n = 0 class Foo(nnx.Module): def __init__(self): self.ref: tp.Optional[Foo] = nnx.data(None) # type: ignore[name-error] - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(m: Foo): nonlocal n n += 1 @@ -350,21 +494,21 @@ def f(m: Foo): m2 = f(m) - assert n == 1 - assert m.ref is m - assert m2 is m + self.assertEqual(n, 1) + self.assertIs(m.ref, m) + self.assertIs(m2, m) m2 = f(m) - assert n == 2 - assert m.ref is m - assert m2 is m + self.assertEqual(n, 2) + self.assertIs(m.ref, m) + self.assertIs(m2, m) m2 = f(m) - assert n == 2 - assert m.ref is m - assert m2 is m + self.assertEqual(n, 2) + self.assertIs(m.ref, m) + self.assertIs(m2, m) def test_apply_shardings(self): n_devices = max(jax.local_device_count() // 2, 1) @@ -387,7 +531,7 @@ def sharding(*args): self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) - @nnx.jit(in_shardings=(state_sharding,)) + @nnx.jit(in_shardings=(state_sharding,), graph=True, graph_updates=True) def constrain_object(m): pass @@ -395,17 +539,48 @@ def constrain_object(m): self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) + @parameterized.parameters(True, False) + def test_apply_shardings_functional(self, graph): + n_devices = max(jax.local_device_count() // 2, 1) + devices = mesh_utils.create_device_mesh( + (n_devices, jax.local_device_count() // n_devices) + ) + mesh = jax.sharding.Mesh(devices, ('a', 'b')) + + def sharding(*args): + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)) + + m_dummy = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + + state_sharding = nnx.prefix( + m_dummy, + { + nnx.PathContains('kernel'): sharding('a', 'b'), + nnx.PathContains('bias'): sharding('b'), + }, + graph=graph, + ) + + @nnx.jit(out_shardings=state_sharding, graph=graph, graph_updates=False) + def constrain_object(): + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + return m + + m = constrain_object() + + self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) + def test_cache_args(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def f(cached_m: nnx.Linear, m: nnx.Linear): self.assertIsNot(cached_m, m) self.assertIs(cached_m.kernel, m.kernel) self.assertIs(cached_m.bias, m.bias) return cached_m - cached_f = nnx.cached_partial(f, m) + cached_f = nnx.compat.cached_partial(f, m) cached_m = cached_f(m) self.assertIsNot(m, cached_m) @@ -416,6 +591,18 @@ def f(cached_m: nnx.Linear, m: nnx.Linear): cached_m2 = cached_f(m) self.assertIs(cached_m, cached_m2) + @parameterized.parameters(True, False) + def test_cache_args_functional(self, graph_mode): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + + def f(m, x): + return m(x) + + f_jit = nnx.jit_partial(f, m, graph=graph_mode, graph_updates=False) + x = jnp.ones((1, 2)) + y = f_jit(x) + self.assertEqual(y.shape, (1, 3)) + @parameterized.parameters( (True, True), (True, False), (False, False), ) @@ -508,8 +695,10 @@ def sharding(*args): self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) @nnx.jit( - in_shardings=(state_sharding, None), - **static_args, + in_shardings=(state_sharding, None), + graph=True, + graph_updates=True, + **static_args, ) def constrain_object(m, scale: float, static_arg1: bool, static_arg2: bool): new_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('b', 'a')) @@ -519,6 +708,55 @@ def constrain_object(m, scale: float, static_arg1: bool, static_arg2: bool): constrain_object(m, 0.5, True, True) self.assertEqual(m.kernel.sharding.spec, jax.sharding.PartitionSpec("a", "b")) + @parameterized.parameters( + { + 'graph': True, + 'static_args': {'static_argnums': (1, 2)}, + }, + { + 'graph': False, + 'static_args': {'static_argnums': (1, 2)}, + }, + { + 'graph': True, + 'static_args': {'static_argnames': ('static_arg1', 'static_arg2')}, + }, + { + 'graph': False, + 'static_args': {'static_argnames': ('static_arg1', 'static_arg2')}, + }, + ) + def test_with_sharding_and_static_args_functional(self, graph, static_args): + n_devices = max(jax.local_device_count() // 2, 1) + devices = mesh_utils.create_device_mesh( + (n_devices, jax.local_device_count() // n_devices) + ) + mesh = jax.sharding.Mesh(devices, ('a', 'b')) + + def sharding(*args): + return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*args)) + + m_dummy = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + + state_sharding = nnx.prefix( + m_dummy, + { + nnx.PathContains('kernel'): sharding('a', 'b'), + nnx.PathContains('bias'): sharding('b'), + }, + graph=graph, + ) + + @nnx.jit(out_shardings=state_sharding, graph=graph, graph_updates=False, **static_args) + def constrain_object(scale: float, static_arg1: bool, static_arg2: bool): + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + new_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('b', 'a')) + m.kernel.set_value(jax.lax.with_sharding_constraint(m.kernel.value, new_sharding)) + return m + + m = constrain_object(0.5, True, True) + self.assertEqual(m.kernel.sharding.spec, jax.sharding.PartitionSpec("a", "b")) + class TestTreeJIT(parameterized.TestCase): @parameterized.parameters( @@ -616,7 +854,7 @@ def test_tree_jit_no_input_output_aliasing(self): def f(v): return v - with self.assertRaisesRegex(ValueError, 'does not support returning input Variables as outputs'): + with self.assertRaisesRegex(ValueError, 'does not support Variable aliasing'): f(v) def test_tree_jit_no_shared_variable_refs(self): @@ -926,7 +1164,8 @@ def loss_fn(model): loss = train_step_fn(x, y) self.assertIsInstance(loss, jax.Array) - def test_jit_partial_shared_variable(self): + @parameterized.parameters(True, False) + def test_jit_partial_shared_variable(self, graph): v = nnx.Param(jnp.array(1.0)) class Container(nnx.Module): @@ -940,18 +1179,25 @@ def f(c1, c2, x): c1.v[...] += x return c1.v[...] + c2.v[...] - f_jit = nnx.jit_partial(f, c1, c2, graph=True, graph_updates=False) + f_jit = nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) + if not graph: + with self.assertRaisesRegex(ValueError, 'Duplicate Param'): + f_jit(jnp.array(1.0)) + return + y = f_jit(jnp.array(1.0)) np.testing.assert_allclose(y, 4.0) np.testing.assert_allclose(v[...], 2.0) - def test_jit_inconsistent_aliasing(self): + @parameterized.parameters(True, False) + def test_jit_inconsistent_aliasing(self, graph_updates): v = nnx.Param(jnp.array(1.0)) P = jax.sharding.PartitionSpec @nnx.jit( in_shardings=(P(), P('x')), - graph=True, graph_updates=False, + graph=True, + graph_updates=graph_updates, ) def f(a, b): return a[...] + b[...] @@ -976,7 +1222,7 @@ def test_eval_shape(self, graph, graph_updates): def test_eval_shape_mutable_array(self): with nnx.var_defaults(hijax=True): - abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) + abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0)), graph=True, graph_updates=True) self.assertIsInstance(abs_model, nnx.Linear) self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) self.assertEqual(abs_model.kernel.shape, (1, 2)) @@ -1056,7 +1302,13 @@ def test_basic_shardmap(self): self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) - @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None) + @nnx.shard_map( + mesh=mesh, + in_specs=(state_sharding,), + out_specs=None, + graph=True, + graph_updates=True, + ) def f(m: nnx.Linear): self.assertEqual( m.kernel.shape, (m.in_features, m.out_features // n_devices) @@ -1067,6 +1319,38 @@ def f(m: nnx.Linear): self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) + @parameterized.parameters(True, False) + def test_basic_shardmap_functional(self, graph): + n_devices = jax.local_device_count() + devices = mesh_utils.create_device_mesh((n_devices,)) + mesh = jax.sharding.Mesh(devices, ('a',)) + PS = jax.sharding.PartitionSpec + + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + state_sharding = nnx.prefix( + m, + { + nnx.PathContains('kernel'): PS(None, 'a'), + nnx.PathContains('bias'): PS(), + }, + graph=graph, + ) + + @nnx.shard_map( + mesh=mesh, + in_specs=(state_sharding,), + out_specs=None, + graph=graph, + graph_updates=False, + ) + def f(m: nnx.Linear): + self.assertEqual( + m.kernel.shape, (m.in_features, m.out_features // n_devices) + ) + self.assertEqual(m.bias.shape, (m.out_features,)) + + f(m) + @parameterized.parameters( (True, True), (True, False), (False, False), ) @@ -1117,7 +1401,14 @@ def test_from_state(self): self.assertNotIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) - @nnx.shard_map(mesh=mesh, in_specs=(state_sharding,), out_specs=None) + @nnx.shard_map( + mesh=mesh, + in_specs=(state_sharding,), + out_specs=None, + graph=True, + graph_updates=True, + ) + def f(m: nnx.Linear): self.assertEqual( m.kernel.shape, (m.in_features, m.out_features // n_devices) @@ -1129,6 +1420,37 @@ def f(m: nnx.Linear): self.assertIsInstance(m.kernel.sharding, jax.sharding.NamedSharding) self.assertIsInstance(m.bias.sharding, jax.sharding.NamedSharding) + @parameterized.parameters(True, False) + def test_from_state_functional(self, graph): + n_devices = jax.local_device_count() + devices = mesh_utils.create_device_mesh((n_devices,)) + mesh = jax.sharding.Mesh(devices, ('a',)) + PS = jax.sharding.PartitionSpec + + state_spec = nnx.State({ + 'kernel': PS(None, 'a'), + 'bias': PS(), + }) + + m = nnx.Linear(16, 32, rngs=nnx.Rngs(0)) + graphdef, s = nnx.split(m) + + @nnx.shard_map( + mesh=mesh, + in_specs=(state_spec,), + out_specs=None, + graph=graph, + graph_updates=False, + ) + def f(s): + m = nnx.merge(graphdef, s) + self.assertEqual( + m.kernel.shape, (m.in_features, m.out_features // n_devices) + ) + self.assertEqual(m.bias.shape, (m.out_features,)) + + f(s) + @parameterized.parameters( (True, True), (True, False), (False, False), ) @@ -1165,6 +1487,7 @@ def f(m, x): self.assertIsInstance(m.bias.sharding, jax.sharding.NamedSharding) def test_simple_tensor_parallel(self): + n_devices = jax.local_device_count() P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) @@ -1191,13 +1514,69 @@ def path_ends_with(path_suffix): ) @nnx.shard_map( - mesh=mesh, in_specs=(model_sharding, P(None)), out_specs=P(None) + mesh=mesh, + in_specs=(model_sharding, P(None)), + out_specs=P(None), + graph=True, + graph_updates=True, + ) + def f(m, x): + self.assertEqual(m.linear1.kernel.shape, (2, 64 // n_devices)) + self.assertEqual(m.linear2.kernel.shape, (64 // n_devices, 3)) + y = m(x) + return jax.lax.psum(y, 'model') + + y = f(m, x) + self.assertEqual(y.shape, (32, 3)) + self.assertEqual(y.sharding.spec, P(None)) + + @parameterized.parameters(True, False) + def test_simple_tensor_parallel_functional(self, graph): + n_devices = jax.local_device_count() + P = jax.sharding.PartitionSpec + + mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) + + class MLP(nnx.Module): + + def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) + self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) + + def __call__(self, x): + return self.linear2(jax.nn.relu(self.linear1(x))) + + m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) + x = jnp.ones((32, 2)) + + def path_ends_with(path_suffix): + return lambda path, value: path[-len(path_suffix) :] == path_suffix + + model_sharding = nnx.prefix( + m, + { + path_ends_with(('linear1', 'kernel')): P(None, 'model'), + path_ends_with(('linear2', 'kernel')): P('model', None), + }, + graph=graph, + ) + + @nnx.shard_map( + mesh=mesh, + in_specs=(model_sharding, P(None)), + out_specs=P(None), + graph=graph, + graph_updates=False, ) def f(m, x): + self.assertEqual(m.linear1.kernel.shape, (2, 64 // n_devices)) + self.assertEqual(m.linear2.kernel.shape, (64 // n_devices, 3)) y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) + self.assertEqual(y.shape, (32, 3)) + self.assertEqual(y.sharding.spec, P(None)) @parameterized.parameters( (True, True), (True, False), (False, False), @@ -1254,7 +1633,12 @@ def f(w, count): self.assertEqual(y.shape, (8, 4)) np.testing.assert_allclose(w[...], jnp.zeros((8, 4))) - def test_shardmap_shared_variable(self): + @parameterized.parameters( + (True, True), + (True, False), + (False, False), + ) + def test_shardmap_shared_variable(self, graph, graph_updates): P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) @@ -1269,15 +1653,19 @@ def __init__(self, v): @nnx.shard_map( mesh=mesh, in_specs=(P(), P(), P()), out_specs=P(), - graph=True, graph_updates=True, + graph=graph, graph_updates=graph_updates, ) def f(c1, c2, x): c1.v[...] += x return c1.v[...] + c2.v[...] - y = f(c1, c2, jnp.array(1.0)) - np.testing.assert_allclose(y, 4.0) - np.testing.assert_allclose(v[...], 2.0) + if not graph and not graph_updates: + with self.assertRaises(ValueError): + f(c1, c2, jnp.array(1.0)) + else: + y = f(c1, c2, jnp.array(1.0)) + np.testing.assert_allclose(y, 4.0) + np.testing.assert_allclose(v[...], 2.0) @parameterized.parameters( (True, True), (True, False), (False, False), @@ -1306,7 +1694,11 @@ def f(m, x): y = f(m, jnp.array(3.0)) self.assertEqual(m.count[...], 2) - def test_shard_map_inconsistent_aliasing(self): + @parameterized.parameters( + (True, True), + (True, False), + ) + def test_shard_map_inconsistent_aliasing(self, graph, graph_updates): v = nnx.Param(jnp.array(1.0)) P = jax.sharding.PartitionSpec mesh = jax.sharding.Mesh(jax.devices(), ('x',)) @@ -1315,7 +1707,7 @@ def test_shard_map_inconsistent_aliasing(self): mesh=mesh, in_specs=(P(), P('x')), out_specs=P(), - graph=True, graph_updates=False, + graph=graph, graph_updates=graph_updates, ) def f(a, b): return a[...] + b[...] @@ -1325,7 +1717,9 @@ def f(a, b): class TestGrad(parameterized.TestCase): - def test_grad(self): + + @parameterized.parameters(True, False) + def test_grad(self, graph_updates: bool): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) @@ -1336,12 +1730,14 @@ def test_grad(self): d=5.0, ) - @nnx.grad def f(m: nnx.Dict): # sum all params return m['a'][0][...] + m['a'][1][...] + m['b'][...] - grads = f(m) + f_grad = nnx.grad(f, graph=True, graph_updates=graph_updates) + grads = f_grad(m) + if not graph_updates: + grads = nnx.compat.state(grads) assert m.a[0] is m.b assert isinstance(grads, nnx.State) @@ -1360,7 +1756,8 @@ def f(m: nnx.Dict): assert m['c'] == 7 assert m['d'] == 5.0 - def test_grad_with_multiple_ref_types(self): + @parameterized.parameters(True, False) + def test_grad_with_multiple_ref_types(self, graph_updates: bool): m = nnx.Dict( a=nnx.List([nnx.Param(jnp.array(10.0)), nnx.BatchStat(jnp.array(20.0))]), b=nnx.Param(jnp.array(10.0)), @@ -1368,25 +1765,41 @@ def test_grad_with_multiple_ref_types(self): d=5.0, ) - @nnx.grad def f(m: nnx.Dict): # sum all params return m.a[0] + m.a[1] + m.b - grads = f(m) + f_grad = nnx.grad(f, graph=True, graph_updates=graph_updates) + grads = f_grad(m) - assert isinstance(grads, nnx.State) - assert grads['a'][0][...] == 1.0 - assert issubclass(type(grads['a'][0]), nnx.Param) - assert len(grads) == 2 + self.assertEqual(grads['a'][0][...], 1.0) + self.assertTrue(issubclass(type(grads['a'][0]), nnx.Param)) + + if graph_updates: + self.assertIsInstance(grads, nnx.State) + self.assertEqual(len(grads), 2) + else: + # In tree mode (graph_updates=False), nnx.grad treats the Module as a regular + # pytree and differentiates all array leaves, including non-Param variables. + self.assertIsInstance(grads, nnx.Dict) + self.assertEqual(len(grads), 4) + self.assertEqual(grads['a'][1][...], 1.0) + self.assertTrue(issubclass(type(grads['a'][1]), nnx.BatchStat)) + + if not graph_updates: + grads = nnx.state(grads) nnx.update(m, grads) - assert m.a[0][...] == 1.0 - assert m.a[1][...] == 20.0 - assert m.b[...] == 1.0 - assert m.c == 7 - assert m.d == 5.0 + self.assertEqual(m.a[0][...], 1.0) + self.assertEqual(m.b[...], 1.0) + self.assertEqual(m.c, 7) + self.assertEqual(m.d, 5.0) + + if graph_updates: + self.assertEqual(m.a[1][...], 20.0) + else: + self.assertEqual(m.a[1][...], 1.0) def test_grad_with_type_predicate(self): m = nnx.Dict( @@ -1396,61 +1809,116 @@ def test_grad_with_type_predicate(self): d=5.0, ) - @nnx.grad(argnums=nnx.DiffState(0, nnx.BatchStat)) + @nnx.compat.grad(argnums=nnx.DiffState(0, nnx.BatchStat)) def f(m: nnx.Dict): # sum all params return m.a[0] + m.a[1] + m.b grads = f(m) - assert isinstance(grads, nnx.State) - assert grads['a'][1][...] == 1.0 - assert issubclass(type(grads['a'][1]), nnx.BatchStat) - assert len(grads) == 1 + self.assertIsInstance(grads, nnx.State) + self.assertEqual(grads['a'][1][...], 1.0) + self.assertTrue(issubclass(type(grads['a'][1]), nnx.BatchStat)) + self.assertEqual(len(grads), 1) + + nnx.update(m, grads) + + self.assertEqual(m.a[0][...], 10.0) + self.assertEqual(m.a[1][...], 1.0) + self.assertEqual(m.b[...], 10.0) + self.assertEqual(m.c, 7) + self.assertEqual(m.d, 5.0) + + def test_grad_functional(self): + m = nnx.Dict( + a=nnx.List( + [nnx.Param(jnp.array(10.0)), nnx.BatchStat(jnp.array(20.0))] + ), + b=nnx.Param(jnp.array(10.0)), + c=7, + d=5.0, + ) + + graphdef, batch_stats, rest = nnx.split(m, nnx.BatchStat, ...) + + @nnx.grad + def f(batch_stats: nnx.State, rest: nnx.State): + m_inner = nnx.merge(graphdef, batch_stats, rest) + # sum all params + return m_inner.a[0] + m_inner.a[1] + m_inner.b + + grads = f(batch_stats, rest) + + self.assertIsInstance(grads, nnx.State) + self.assertEqual(grads['a'][1][...], 1.0) + self.assertTrue(issubclass(type(grads['a'][1]), nnx.BatchStat)) + self.assertEqual(len(grads), 1) nnx.update(m, grads) - assert m.a[0][...] == 10.0 - assert m.a[1][...] == 1.0 - assert m.b[...] == 10.0 - assert m.c == 7 - assert m.d == 5.0 + self.assertEqual(m.a[0][...], 10.0) + self.assertEqual(m.a[1][...], 1.0) + self.assertEqual(m.b[...], 10.0) + self.assertEqual(m.c, 7) + self.assertEqual(m.d, 5.0) - def test_multiple_inputs(self): + @parameterized.parameters(True, False) + def test_multiple_inputs(self, graph_updates: bool): rngs = nnx.Rngs(0) m = nnx.Linear(2, 3, rngs=rngs) loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) - grad_fn = nnx.grad(loss_fn) + grad_fn = nnx.grad(loss_fn, graph=True, graph_updates=graph_updates) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) grads = grad_fn(m, x, y) + if not graph_updates: + grads = nnx.state(grads) assert 'kernel' in grads assert grads['kernel'].shape == (2, 3) assert 'bias' in grads assert grads['bias'].shape == (3,) - @parameterized.parameters( - { - 'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), - 'argnums': (0, 1), - }, - { - 'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), - 'argnums': (1, 3), - }, + @parameterized.named_parameters( + { + 'testcase_name': '0_1_updates_True', + 'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), + 'argnums': (0, 1), + 'graph_updates': True, + }, + { + 'testcase_name': '0_1_updates_False', + 'loss_fn': lambda m1, m2, x, y: jnp.mean((m2(m1(x)) - y) ** 2), + 'argnums': (0, 1), + 'graph_updates': False, + }, + { + 'testcase_name': '1_3_updates_True', + 'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), + 'argnums': (1, 3), + 'graph_updates': True, + }, + { + 'testcase_name': '1_3_updates_False', + 'loss_fn': lambda x, m1, y, m2: jnp.mean((m2(m1(x)) - y) ** 2), + 'argnums': (1, 3), + 'graph_updates': False, + }, ) - def test_multiple_graph_nodes(self, loss_fn, argnums): + def test_multiple_graph_nodes(self, loss_fn, argnums, graph_updates: bool): rngs = nnx.Rngs(0) m1 = nnx.Linear(2, 3, rngs=rngs) m2 = nnx.Linear(3, 3, rngs=rngs) - grad_fn = nnx.grad(loss_fn, argnums=argnums) + grad_fn = nnx.compat.grad(loss_fn, argnums=argnums) if graph_updates else nnx.grad(loss_fn, argnums=argnums) x = jax.random.uniform(rngs(), (1, 2)) y = jnp.ones((1, 3)) inputs = [x, y] inputs.insert(argnums[0], m1) inputs.insert(argnums[1], m2) grads_m1, grads_m2 = grad_fn(*inputs) + if not graph_updates: + grads_m1 = nnx.state(grads_m1) + grads_m2 = nnx.state(grads_m2) assert 'kernel' in grads_m1 assert grads_m1['kernel'].shape == (2, 3) @@ -1468,7 +1936,7 @@ def test_multiple_args(self): m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) - @nnx.grad(argnums=(m1_diffstate, m2_diffstate)) + @nnx.compat.grad(argnums=(m1_diffstate, m2_diffstate)) def loss_fn(m1: nnx.Linear, m2: nnx.Linear): return jnp.mean(m1.kernel * m2.kernel) + jnp.mean(m1.bias * m2.bias) @@ -1486,7 +1954,7 @@ def test_multiple_args_in_pytrees(self): m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) - @nnx.grad(argnums=(m1_diffstate, m2_diffstate)) + @nnx.compat.grad(argnums=(m1_diffstate, m2_diffstate)) def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): return jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias @@ -1506,7 +1974,7 @@ def test_value_and_grad_multiple_args_in_pytrees(self): m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) - @nnx.value_and_grad(argnums=(m1_diffstate, m2_diffstate)) + @nnx.compat.value_and_grad(argnums=(m1_diffstate, m2_diffstate)) def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): return jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias @@ -1527,7 +1995,10 @@ def test_value_and_grad_with_aux(self): m1_diffstate = nnx.DiffState(0, nnx.PathContains('kernel')) m2_diffstate = nnx.DiffState(1, nnx.PathContains('bias')) - @nnx.value_and_grad(argnums=(m1_diffstate, m2_diffstate), has_aux=True) + @nnx.value_and_grad( + argnums=(m1_diffstate, m2_diffstate), has_aux=True, + graph=True, graph_updates=True, + ) def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): loss = jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias @@ -1538,7 +2009,6 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): (loss, m3), (grads_m1, grads_m2) = loss_fn([m1], [m2]) - self.assertEqual(m1.kernel[...], -1.0) self.assertEqual(loss.shape, ()) self.assertIsInstance(m3, nnx.Linear) self.assertIn('kernel', grads_m1[0]) @@ -1546,33 +2016,75 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): self.assertNotIn('kernel', grads_m2[0]) self.assertIn('bias', grads_m2[0]) - def test_variables_in_grad(self): + def test_value_and_grad_with_aux_functional(self): + m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + m2 = nnx.Linear(2, 3, rngs=nnx.Rngs(1)) + + def diff_filter(path, x): + if path: + first = path[0] + idx = getattr(first, 'idx', first) + if idx == 0: + return nnx.PathContains('kernel')(path, x) + if idx == 1: + return nnx.PathContains('bias')(path, x) + return False + + graphdef, diff, nondiff = nnx.split((m1, m2), diff_filter, ...) + + @nnx.value_and_grad(has_aux=True, graph=False) + def loss_fn(diff: nnx.State, nondiff: nnx.State): + m1, m2 = nnx.merge(graphdef, diff, nondiff) + + loss = jnp.mean(m1.kernel * m2.kernel) + jnp.mean(m1.bias * m2.bias) + m1.kernel.set_value(jnp.array(-1.0)) + m3 = nnx.Linear(2, 3, rngs=nnx.Rngs(2)) + return loss, m3 + + (loss, m3), grads = loss_fn(diff, nondiff) + + self.assertEqual(m1.kernel[...], -1.0) + self.assertEqual(loss.shape, ()) + self.assertIsInstance(m3, nnx.Linear) + self.assertIn(0, grads) + self.assertIn('kernel', grads[0]) + self.assertNotIn('bias', grads[0]) + self.assertIn(1, grads) + self.assertNotIn('kernel', grads[1]) + self.assertIn('bias', grads[1]) + + @parameterized.parameters(True, False) + def test_variables_in_grad(self, graph_updates: bool): p1 = nnx.Param(10.0) p2 = nnx.Param(20.0) m = dict(a=[p1, p2], b=p1) - @nnx.grad + @nnx.grad(graph=True, graph_updates=graph_updates) def f(m: dict): return m['a'][0] + m['a'][1] + m['b'] grads = f(m) - assert m['a'][0] is m['b'] - assert isinstance(grads, dict) - assert issubclass(type(grads['a'][0]), nnx.Variable) - assert grads['a'][1][...] == 1.0 - assert issubclass(type(grads['a'][1]), nnx.Variable) - assert len(jax.tree.leaves(grads)) == 2 + self.assertIs(m['a'][0], m['b']) + self.assertIsInstance(grads, dict) + self.assertIsInstance(grads['a'][0], nnx.Variable) + self.assertEqual(grads['a'][1][...], 1.0) + self.assertIsInstance(grads['a'][1], nnx.Variable) - jax.tree.map( - nnx.update, m, grads, is_leaf=lambda x: isinstance(x, nnx.Variable) - ) + if graph_updates: + self.assertEqual(len(jax.tree.leaves(grads)), 2) + self.assertIsInstance(grads['b'], nnx.State) + else: + self.assertEqual(len(jax.tree.leaves(grads)), 3) + self.assertIsInstance(grads['b'], nnx.Variable) - assert m['a'][0] is m['b'] - assert m['a'][0][...] == 2.0 - assert m['a'][1][...] == 1.0 - assert m['b'][...] == 2.0 + nnx.update(m, nnx.state(grads, graph=True)) + + self.assertIs(m['a'][0], m['b']) + self.assertEqual(m['a'][0][...], 2.0) + self.assertEqual(m['a'][1][...], 1.0) + self.assertEqual(m['b'][...], 2.0) @parameterized.parameters( (True, True), @@ -1723,19 +2235,14 @@ def f_bwd(state, res, g): self.assertEqual(state[...], -1.0) self.assertEqual(y.shape, (1, 1)) - @parameterized.parameters( - (True, True), - (True, False), - (False, False), - ) - def test_jax_example(self, graph, graph_updates): + def test_jax_example_graph_updates(self): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) + @nnx.custom_vjp(graph=True, graph_updates=True) def f(m: Foo): m.z += 1 return jnp.sin(m.x) * m.y # type: ignore @@ -1747,37 +2254,60 @@ def f_fwd(m: Foo): def f_bwd(res, g): cos_x, sin_x, m = res - if graph and graph_updates: - (m_g,), out_g = g - self.assertIsInstance(m_g, nnx.State) - m_g['x'][...] = cos_x * out_g * m.y - m_g['y'][...] = sin_x * out_g - return (m_g,) - else: - out_g = g - m_g = nnx.clone(m) - m_g.x[...] = cos_x * out_g * m.y - m_g.y[...] = sin_x * out_g - return (m_g,) + (m_g,), out_g = g + self.assertIsInstance(m_g, nnx.State) + m_g['x'][...] = cos_x * out_g * m.y + m_g['y'][...] = sin_x * out_g + return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - if graph and graph_updates: - grads = nnx.grad(f, argnums=nnx.DiffState(0, ...))(m) - self.assertIsInstance(grads, nnx.State) - else: - grads = nnx.grad( - f, graph=graph, graph_updates=graph_updates, - )(m) - self.assertIsInstance(grads, Foo) + grads = nnx.grad( + f, argnums=nnx.DiffState(0, ...), graph=True, graph_updates=True + )(m) + self.assertIsInstance(grads, nnx.State) + np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore - if graph and graph_updates: - self.assertEqual(m.z, 1) - else: - self.assertEqual(m.z, 0) + self.assertEqual(m.z, 1) + + @parameterized.parameters(True, False) + def test_jax_example_functional(self, graph): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + @nnx.custom_vjp(graph=graph, graph_updates=False) + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore + return y, res + + def f_bwd(res, g): + cos_x, sin_x, m = res + out_g = g + m_g = nnx.clone(m) + m_g.x[...] = cos_x * out_g * m.y + m_g.y[...] = sin_x * out_g + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + grads = nnx.grad(f, graph=graph, graph_updates=False)(m) + self.assertIsInstance(grads, Foo) + np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 0) def test_diff_state(self): @dataclasses.dataclass @@ -1789,7 +2319,7 @@ class Foo(nnx.Module): x_in_path = nnx.PathContains('x') diff_state = nnx.DiffState(0, x_in_path) - @nnx.custom_vjp(nondiff_argnums=(diff_state,)) + @nnx.custom_vjp(nondiff_argnums=(diff_state,), graph=True, graph_updates=True) def f(m: Foo): m.z += 1 return jnp.sin(m.x) * m.y # type: ignore @@ -1815,25 +2345,113 @@ def f_bwd(res, g): m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - grad: nnx.State = nnx.grad(f, argnums=nnx.DiffState(0, x_in_path))(m) + grad: nnx.State = nnx.grad(f, argnums=nnx.DiffState(0, x_in_path), graph=True, graph_updates=True)(m) np.testing.assert_allclose(grad['x'][...], jnp.cos(1.0) * 2.0) # type: ignore self.assertEqual(m.z, 1) + @parameterized.parameters(True, False) + def test_diff_state_functional(self, graph): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + x_in_path = nnx.PathContains('x') + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + graphdef, state_x, state_rest = nnx.split(m, x_in_path, ...) + + @nnx.custom_vjp(nondiff_argnums=(1,), graph=graph, graph_updates=False) + def f(state_x, state_rest): + m = nnx.merge(graphdef, state_x, state_rest) + m.z += 1 + y = jnp.sin(m.x) * m.y + return y + + def f_fwd(state_x, state_rest): + y = f(state_x, state_rest) + m = nnx.merge(graphdef, state_x, state_rest) + res = (jnp.cos(m.x), state_x) + return y, res + + def f_bwd(state_rest, res, g): + cos_x, state_x = res + m = nnx.merge(graphdef, state_x, state_rest) + state_x_g = nnx.clone(state_x) + state_x_g.x[...] = cos_x * g * m.y + return (state_x_g,) + + f.defvjp(f_fwd, f_bwd) + + def loss_fn(state_x, state_rest): + y = f(state_x, state_rest) + return y + + grad_fn = nnx.grad(loss_fn, argnums=0, graph=graph, graph_updates=False) + grad_x = grad_fn(state_x, state_rest) + + self.assertIsInstance(grad_x, nnx.State) + np.testing.assert_allclose(grad_x['x'][...], jnp.cos(1.0) * 2.0) + self.assertEqual(m.z, 0) + + def test_jax_example_with_remat_graph_updates(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + @nnx.custom_vjp(graph=True, graph_updates=True) + @nnx.remat(graph=True, graph_updates=True) + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore + return y, res + + def f_bwd(res, g): + cos_x, sin_x, m = res + (m_g,), out_g = g + self.assertIsInstance(m_g, nnx.State) + m_g['x'][...] = cos_x * out_g * m.y + m_g['y'][...] = sin_x * out_g + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + @nnx.jit(graph=True, graph_updates=True) + def loss_fn(m): + return f(m) + + grads = nnx.grad( + loss_fn, argnums=nnx.DiffState(0, ...), graph=True, graph_updates=True + )(m) + self.assertIsInstance(grads, nnx.State) + + np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) + @parameterized.parameters( - (True, True), - (True, False), - (False, False), + (True,), + (False,), ) - def test_jax_example_with_remat(self, graph, graph_updates): + def test_jax_example_with_remat_functional(self, graph): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp(graph=graph, graph_updates=graph_updates) - @nnx.remat(graph=graph, graph_updates=graph_updates) + @nnx.custom_vjp(graph=graph, graph_updates=False) + @nnx.remat(graph=graph, graph_updates=False) def f(m: Foo): m.z += 1 return jnp.sin(m.x) * m.y # type: ignore @@ -1845,41 +2463,25 @@ def f_fwd(m: Foo): def f_bwd(res, g): cos_x, sin_x, m = res - if graph and graph_updates: - (m_g,), out_g = g - self.assertIsInstance(m_g, nnx.State) - m_g['x'][...] = cos_x * out_g * m.y - m_g['y'][...] = sin_x * out_g - return (m_g,) - else: - out_g = g - m_g = jax.tree.map(lambda x: x, m) - m_g.x[...] = cos_x * out_g * m.y - m_g.y[...] = sin_x * out_g - return (m_g,) + out_g = g + m_g = jax.tree.map(lambda x: x, m) + m_g.x[...] = cos_x * out_g * m.y + m_g.y[...] = sin_x * out_g + return (m_g,) f.defvjp(f_fwd, f_bwd) m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - @nnx.jit(graph=graph, graph_updates=graph_updates) + @nnx.jit(graph=graph, graph_updates=False) def loss_fn(m): return f(m) - if graph and graph_updates: - grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) - self.assertIsInstance(grads, nnx.State) - else: - grads = nnx.grad( - loss_fn, graph=graph, graph_updates=graph_updates, - )(m) - self.assertIsInstance(grads, Foo) + grads = nnx.grad(loss_fn, graph=graph, graph_updates=False)(m) + self.assertIsInstance(grads, Foo) np.testing.assert_allclose(grads.x[...], jnp.cos(1.0) * 2.0) # type: ignore np.testing.assert_allclose(grads.y[...], jnp.sin(1.0)) # type: ignore - if graph and graph_updates: - self.assertEqual(m.z, 1) - else: - self.assertEqual(m.z, 0) + self.assertEqual(m.z, 0) def test_two_args(self): @dataclasses.dataclass @@ -1888,7 +2490,7 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp + @nnx.custom_vjp(graph=True, graph_updates=True) def f(m1: Foo, m2: Foo): m1.z += 1 y = jnp.sin(m1.x) * m1.y # type: ignore @@ -1925,7 +2527,10 @@ def loss_fn(m1, m2): m1_grad: nnx.State m2_grad: nnx.State m1_grad, m2_grad = nnx.grad( - loss_fn, argnums=(nnx.DiffState(0, ...), nnx.DiffState(1, ...)) + loss_fn, + argnums=(nnx.DiffState(0, ...), nnx.DiffState(1, ...)), + graph=True, + graph_updates=True, )(m1, m2) np.testing.assert_allclose(m1_grad['x'][...], jnp.cos(1.0) * 2.0) # type: ignore @@ -1934,29 +2539,82 @@ def loss_fn(m1, m2): np.testing.assert_allclose(m2_grad['x'][...], 4.0) # type: ignore np.testing.assert_allclose(m2_grad['y'][...], 3.0) # type: ignore - @parameterized.parameters( - (True, True), - (True, False), - (False, False), - ) - def test_non_diff_args(self, graph, graph_updates): + @parameterized.parameters(True, False) + def test_two_args_functional(self, graph): @dataclasses.dataclass class Foo(nnx.Module): x: nnx.Param[jax.Array] y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp( - nondiff_argnums=(0, 2), graph=graph, graph_updates=graph_updates, - ) - def f(a, m: Foo, b): - self.assertEqual(a, 1) - self.assertEqual(b, 2) - m.z += 1 - return jnp.sin(m.x) * m.y # type: ignore + @nnx.custom_vjp(graph=graph, graph_updates=False) + def f(m1: Foo, m2: Foo): + m1.z += 1 + y = jnp.sin(m1.x) * m1.y # type: ignore + return y, nnx.clone(m2) - def f_fwd(a, m: Foo, b): - self.assertEqual(a, 1) + def f_fwd(m1: Foo, m2: Foo): + y, m2_out = f(m1, m2) + res = (jnp.cos(m1.x), jnp.sin(m1.x), m1, m2) + return (y, m2_out), res + + def f_bwd(res, g): + y_g, m2_g = g + cos_x, sin_x, m1, m2 = res + + m1_g = nnx.clone(m1) + m1_g.x[...] = cos_x * y_g * m1.y + m1_g.y[...] = sin_x * y_g + + return m1_g, m2_g + + f.defvjp(f_fwd, f_bwd) + + m1 = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + m2 = Foo(nnx.Param(jnp.array(3.0)), nnx.Param(jnp.array(4.0)), 0) + + def loss_fn(m1, m2): + y, m2 = f(m1, m2) + return y + m2.x * m2.y + + m1_grad, m2_grad = nnx.grad( + loss_fn, + argnums=(0, 1), + graph=graph, + graph_updates=False, + )(m1, m2) + + self.assertIsInstance(m1_grad, Foo) + self.assertIsInstance(m2_grad, Foo) + np.testing.assert_allclose(m1_grad.x[...], jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(m1_grad.y[...], jnp.sin(1.0)) # type: ignore + self.assertEqual(m1.z, 0) + np.testing.assert_allclose(m2_grad.x[...], 4.0) # type: ignore + np.testing.assert_allclose(m2_grad.y[...], 3.0) # type: ignore + + @parameterized.parameters( + (True, True), + (True, False), + (False, False), + ) + def test_non_diff_args(self, graph, graph_updates): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + @nnx.custom_vjp( + nondiff_argnums=(0, 2), graph=graph, graph_updates=graph_updates, + ) + def f(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore + + def f_fwd(a, m: Foo, b): + self.assertEqual(a, 1) self.assertEqual(b, 2) y = f(a, m, b) res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore @@ -1989,7 +2647,9 @@ def loss_fn(m): return f(a, m, b) if graph and graph_updates: - grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) + grads = nnx.grad( + loss_fn, argnums=nnx.DiffState(0, ...), graph=True, graph_updates=True + )(m) self.assertIsInstance(grads, nnx.State) else: grads = nnx.grad( @@ -2012,7 +2672,7 @@ def __init__(self, x, y): self.x = nnx.Param(x) self.y = nnx.Param(y) - @nnx.custom_vjp + @nnx.custom_vjp(graph=True, graph_updates=True) def f(m: Foo): return jnp.sin(m.x) * m.y # type: ignore @@ -2028,7 +2688,38 @@ def f_bwd(res, g): f.defvjp(f_fwd, f_bwd) m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) - grads = nnx.grad(f)(m) + grads = nnx.grad(f, graph=True, graph_updates=True)(m) + + @parameterized.parameters(True, False) + def test_docs_example_functional(self, graph): + import jax.numpy as jnp + from flax import nnx + + class Foo(nnx.Module): + def __init__(self, x, y): + self.x = nnx.Param(x) + self.y = nnx.Param(y) + + @nnx.custom_vjp(graph=graph, graph_updates=False) + def f(m: Foo): + return jnp.sin(m.x) * m.y # type: ignore + + def f_fwd(m: Foo): + return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore + + def f_bwd(res, g): + cos_x, sin_x, m = res + out_g = g + m_g = nnx.clone(m) + m_g.x[...] = cos_x * out_g * m.y + m_g.y[...] = sin_x * out_g + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) + grads = nnx.grad(f, graph=graph, graph_updates=False)(m) + self.assertIsInstance(grads, Foo) @parameterized.parameters( {'use_custom_vjp': False}, @@ -2066,10 +2757,10 @@ def linear_bwd(res, g): return (m_g, x_grad) if use_custom_vjp: - linear = nnx.custom_vjp(linear) + linear = nnx.custom_vjp(linear, graph=True, graph_updates=True) linear.defvjp(linear_fwd, linear_bwd) - @nnx.jit + @nnx.jit(graph=True, graph_updates=True) def loss_fn(x, mod): y = linear(mod, x) return y.mean() @@ -2077,7 +2768,7 @@ def loss_fn(x, mod): mod = MyLinear(10, 5, rngs=nnx.Rngs(0)) self.assertEqual(mod.n[...], 0) x = jax.random.normal(jax.random.key(0), (10,)) - loss, grad = nnx.value_and_grad(loss_fn)(x, mod) + loss, grad = nnx.value_and_grad(loss_fn, graph=True, graph_updates=True)(x, mod) self.assertEqual(loss.shape, ()) self.assertEqual(grad.shape, (10,)) self.assertEqual(mod.n[...], 1) @@ -2275,7 +2966,6 @@ def f_bwd(v_nondiff, res, g): with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing'): f(v, v) - def test_custom_vjp_diff_arg_mutation_error(self): @nnx.custom_vjp(graph=True, graph_updates=False) def f(m): @@ -2590,6 +3280,141 @@ def forward_block(_, block: Block, x: jax.Array): assert y.shape == (5, 1, 3) assert out is None + def test_broadcast_args(self): + def scale_cumsum(carry, scale, x): + carry = carry + x * scale + return carry, carry + + final_carry, _ = nnx.scan( + scale_cumsum, + in_axes=(nnx.Carry, None, 0), + out_axes=(nnx.Carry, 0), + graph=False, + )(jnp.array(0.0), jnp.array(2.0), jnp.arange(5.0)) + np.testing.assert_allclose(final_carry, 20.0) + + def test_no_carry_all_scanned(self): + def double(x): + return x * 2 + + ys = nnx.scan( + double, in_axes=0, out_axes=0, graph=False + )(jnp.arange(5.0)) + np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) + + def test_pytree_prefix_in_axes(self): + def fn(carry, x): + carry = carry + x['a'] + x['b'] + return carry, carry + + xs = {'a': jnp.arange(3.0), 'b': jnp.array(1.0)} + final_carry, _ = nnx.scan( + fn, + in_axes=(nnx.Carry, {'a': 0, 'b': None}), + out_axes=(nnx.Carry, 0), + graph=False, + )(jnp.array(0.0), xs) + np.testing.assert_allclose(final_carry, 6.0) + + def test_nested_carry_rejected(self): + with self.assertRaises(ValueError): + nnx.scan( + lambda x: x, + in_axes=({'a': nnx.Carry},), + out_axes=nnx.Carry, + graph=False, + )({'a': jnp.array(1.0)}) + + @parameterized.parameters(True, False) + def test_broadcast_out_axes_rejected(self, graph): + with self.assertRaises(ValueError): + nnx.scan( + lambda c, x: (c, x), + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, None), + graph=graph, + graph_updates=False, + )(jnp.array(0.0), jnp.arange(3.0)) + + def test_none_broadcast_input(self): + def fn(carry, _unused, x): + carry = carry + x + return carry, carry + + final_carry, _ = nnx.scan( + fn, + in_axes=(nnx.Carry, None, 0), + out_axes=(nnx.Carry, 0), + graph=False, + )(jnp.array(0.0), None, jnp.arange(3.0)) + np.testing.assert_allclose(final_carry, 3.0) + + def test_none_nested_in_arg(self): + def fn(carry, x): + carry = carry + x['a'] + return carry, carry + + xs = {'a': jnp.arange(3.0), 'b': None} + final_carry, _ = nnx.scan( + fn, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + graph=False, + )(jnp.array(0.0), xs) + np.testing.assert_allclose(final_carry, 3.0) + + def test_nested_carry_in_out_axes_rejected(self): + with self.assertRaises(ValueError): + nnx.scan( + lambda c, x: (c, x), + in_axes=(nnx.Carry, 0), + out_axes=({'a': nnx.Carry},), + graph=False, + )(jnp.array(0.0), jnp.arange(3.0)) + + def test_carry_in_in_axes_only_rejected(self): + with self.assertRaises(ValueError): + nnx.scan( + lambda c, x: (c + x,), + in_axes=(nnx.Carry, 0), + out_axes=(0,), + graph=False, + )(jnp.array(0.0), jnp.arange(3.0)) + + def test_carry_in_out_axes_only_rejected(self): + with self.assertRaises(ValueError): + nnx.scan( + lambda x: x, + in_axes=(0,), + out_axes=nnx.Carry, + graph=False, + )(jnp.arange(3.0)) + + def test_non_tuple_carry_only(self): + def f(carry): + return carry + 1.0 + + result = nnx.scan( + f, + in_axes=nnx.Carry, + out_axes=nnx.Carry, + length=5, + graph=False, + )(jnp.array(0.0)) + np.testing.assert_allclose(result, 5.0) + + def test_non_tuple_scan_only(self): + def f(x): + return x * 2 + + ys = nnx.scan( + f, + in_axes=0, + out_axes=0, + graph=False, + )(jnp.arange(5.0)) + np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) + @parameterized.parameters(True, False) def test_variables_in_scan(self, graph_updates): def block_init(din, dout, rngs): @@ -2707,7 +3532,8 @@ class Foo(nnx.Module): foo = Foo(n=nnx.BatchStat(0)) @nnx.scan(in_axes=nnx.Carry, out_axes=nnx.Carry, length=3) - def loop(foo: Foo, x): ... + def loop(foo: Foo, x): + ... with self.assertRaisesRegex( ValueError, @@ -2715,7 +3541,8 @@ def loop(foo: Foo, x): ... ): loop(foo, 0) - def test_all_carry_new_reference_error(self): + @parameterized.parameters(True, False) + def test_all_carry_new_reference_error(self, graph_updates): class Foo(nnx.Module): def __init__(self, n: nnx.BatchStat[int]): self.n = n @@ -2723,16 +3550,24 @@ def __init__(self, n: nnx.BatchStat[int]): xs = jnp.arange(3) foo = Foo(n=nnx.BatchStat(0)) - @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0)) + @nnx.scan( + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + graph=True, + graph_updates=graph_updates, + ) def loop(foo: Foo, x): x = x + 1 foo = Foo(nnx.BatchStat(foo.n[...] + 1)) # new reference return foo, x - with self.assertRaisesRegex( - ValueError, - 'Carry references must be the same between iterations', - ): + msg = ( + 'Carry references must be the same between iterations' + if graph_updates + else 'scan Variable identity must be preserved across iterations' + ) + + with self.assertRaisesRegex(ValueError, msg): loop(foo, xs) @parameterized.parameters(True, False) @@ -2756,9 +3591,9 @@ def loop(foo: Foo, x): np.testing.assert_allclose(foo.n[...], jnp.arange(1, 4)) def test_all_broadcast(self): + @nnx.dataclass class Foo(nnx.Module): - def __init__(self, n: nnx.BatchStat[int]): - self.n = n + n: nnx.BatchStat[int] xs = jnp.array(1) foo = Foo(n=nnx.BatchStat(2)) @@ -2772,50 +3607,69 @@ def loop(foo: Foo, x): np.testing.assert_allclose(ys, 3) self.assertEqual(ys.shape, (4,)) - def test_input_output_carry_mismatch_error(self): - with self.assertRaisesRegex( - ValueError, - 'If one of in_axes or out_axes has Carry, the other must also have Carry', - ): + @parameterized.parameters(True, False) + def test_input_output_carry_mismatch_error(self, graph_updates): + with nnx.set_graph_updates(graph_updates): + with self.assertRaisesRegex( + ValueError, + 'If one of in_axes or out_axes has Carry, the other must also have' + ' Carry', + ): - @nnx.scan(in_axes=0, out_axes=(nnx.Carry, 0)) - def loop(a, b): ... + @nnx.scan(in_axes=0, out_axes=(nnx.Carry, 0)) + def loop(a, b): + ... - with self.assertRaisesRegex( - ValueError, - 'If one of in_axes or out_axes has Carry, the other must also have Carry', - ): + with self.assertRaisesRegex( + ValueError, + 'If one of in_axes or out_axes has Carry, the other must also have' + ' Carry', + ): - @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0) - def loop(a, b): ... + @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=0) + def loop(a, b): + ... - def test_double_carry_error(self): + @parameterized.parameters(True, False) + def test_double_carry_error(self, graph_updates): with self.assertRaisesRegex( ValueError, 'Found multiple Carry definitions', ): - @nnx.scan(in_axes=(nnx.Carry, nnx.Carry)) - def loop(a, b): ... + @nnx.scan(in_axes=(nnx.Carry, nnx.Carry), graph_updates=graph_updates) + def loop(a, b): + ... - def test_broadcast_in_output_error(self): + @parameterized.parameters(True, False) + def test_broadcast_in_output_error(self, graph_updates): with self.assertRaisesRegex( ValueError, 'Cannot broadcast output state', ): - @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None)) - def loop(a, b): ... + @nnx.scan( + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, None), + graph_updates=graph_updates, + ) + def loop(a, b): + ... + def test_broadcast_in_output_state_axes_error(self): with self.assertRaisesRegex( ValueError, 'Cannot broadcast output state. Got StateAxes', ): @nnx.scan( - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, nnx.StateAxes({...: None})) + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, nnx.StateAxes({...: None})), + graph=True, + graph_updates=True, ) - def loop(a, b): ... + def loop(a, b): + ... @parameterized.parameters( (True, False), (False, False), @@ -2880,13 +3734,13 @@ def test_out_axes(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5, graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.BatchStat(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=(nnx.Carry, 1, 2)) + @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=(nnx.Carry, 1, 2), graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) @@ -2905,17 +3759,47 @@ def __call__(self, x: jax.Array): assert y1.shape == (1, 5, 3) assert y2.shape == (1, 3, 5) + @parameterized.parameters(True, False) + def test_out_axes_functional(self, graph_mode): + class MLP(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.BatchStat(jnp.ones((2,))) + + def __call__(self, x: jax.Array): + x = self.linear(x) + x = nnx.gelu(x) + return x, x, x + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + return MLP(rngs=rngs) + + @nnx.scan(in_axes=(0, nnx.Carry), out_axes=(nnx.Carry, 1, 2), graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + module = create(nnx.Rngs(0)) + + x = jnp.ones((1, 3)) + c, y1, y2 = forward(module, x) + + assert c.shape == (1, 3) + assert y1.shape == (1, 5, 3) + assert y2.shape == (1, 3, 5) + def test_in_axes_simple(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): - @nnx.vmap(in_axes=(state_axes, 0)) + @nnx.vmap(in_axes=(state_axes, 0), graph=True, graph_updates=True) def __init__(self, key: jax.Array): rngs = nnx.Rngs(key) self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry) + @nnx.scan(in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry, graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = nnx.gelu(x) @@ -2929,17 +3813,46 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) + @parameterized.parameters(True, False) + def test_in_axes_simple_functional(self, graph_mode): + class MLP(nnx.Module): + def __init__(self, key: jax.Array): + rngs = nnx.Rngs(key) + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array): + x = self.linear(x) + x = nnx.gelu(x) + return x + + @nnx.vmap(in_axes=0, graph=graph_mode, graph_updates=False) + def create(key): + return MLP(key=key) + + @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry, graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + key = jax.random.split(jax.random.key(0), 5) + module = create(key) + + x = jnp.ones((1, 3)) + y = forward(module, x) + + assert y.shape == (1, 3) + def test_in_axes(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState, nnx.Intermediate): 0, ...: None}) class MLP(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes)) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry, 0)) + @nnx.scan(in_axes=(state_axes, nnx.Carry, 0), graph=True, graph_updates=True) def __call__( self, x: jax.Array, a: jax.Array ) -> tp.Tuple[jax.Array, None]: @@ -2965,18 +3878,51 @@ def __call__( assert intermediates['data'][0].shape == (5, 1, 3) + @parameterized.parameters(True, False) + def test_in_axes_functional(self, graph_mode): + class MLP(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array, a: jax.Array) -> tp.Tuple[jax.Array, None]: + assert x.shape == a.shape + x = x + a + x = self.linear(x) + x = nnx.gelu(x) + self.sow(nnx.Intermediate, "data", x) + return x, None + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + return MLP(rngs=rngs) + + @nnx.scan(in_axes=(0, nnx.Carry, 0), graph=graph_mode, graph_updates=False) + def forward(module, x, a): + return module(x, a) + + module = create(nnx.Rngs(0)) + + x = jnp.ones((1, 3)) + a = jnp.ones((5, 1, 3)) + y, out = forward(module, x, a) + + assert y.shape == (1, 3) + assert out is None + def test_in_axes_broadcast(self): test = self state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes)) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.node = nnx.BatchStat(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry, 0, None)) + @nnx.scan(in_axes=(state_axes, nnx.Carry, 0, None), graph=True, graph_updates=True) def __call__( self, x: jax.Array, a: jax.Array, b: jax.Array ) -> tp.Tuple[jax.Array, None]: @@ -3001,51 +3947,121 @@ def __call__( self.assertEqual(y.shape, (1, 3)) self.assertIsNone(out) - def test_complex(self): - state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + @parameterized.parameters(True, False) + def test_in_axes_broadcast_functional(self, graph_mode): + test = self class MLP(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes)) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) - self.bn = nnx.BatchNorm(3, rngs=rngs) - self.dropout = nnx.Dropout(0.5, rngs=rngs) - self.node = nnx.Variable(jnp.ones((2,))) + self.node = nnx.BatchStat(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry)) - def __call__(self, x: jax.Array): + def __call__(self, x: jax.Array, a: jax.Array, b: jax.Array) -> tp.Tuple[jax.Array, None]: + test.assertEqual(x.shape, a.shape) + test.assertEqual(x.shape, b.shape) + x = x + a + b x = self.linear(x) - x = self.bn(x) - x = self.dropout(x) x = nnx.gelu(x) return x, None - module = MLP(rngs=nnx.Rngs(0)) - module.set_attributes(deterministic=False, use_running_average=False) + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + return MLP(rngs=rngs) - assert module.linear.kernel.shape == (5, 3, 3) - assert module.linear.bias.shape == (5, 3) - assert module.node.shape == (2,) + @nnx.scan(in_axes=(0, nnx.Carry, 0, None), graph=graph_mode, graph_updates=False) + def forward(module, x, a, b): + return module(x, a, b) + + module = create(nnx.Rngs(0)) x = jnp.ones((1, 3)) - y, _ = module(x) + a = jnp.ones((5, 1, 3)) + b = jnp.ones((1, 3)) + y, out = forward(module, x, a, b) - assert y.shape == (1, 3) + self.assertEqual(y.shape, (1, 3)) + self.assertIsNone(out) + + def test_complex(self): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class MLP(nnx.Module): + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), graph=True, graph_updates=True) + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + @nnx.scan(in_axes=(state_axes, nnx.Carry), graph=True, graph_updates=True) + def __call__(self, x: jax.Array): + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + module = MLP(rngs=nnx.Rngs(0)) + module.set_attributes(deterministic=False, use_running_average=False) + + assert module.linear.kernel.shape == (5, 3, 3) + assert module.linear.bias.shape == (5, 3) + assert module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, _ = module(x) + + assert y.shape == (1, 3) + + @parameterized.parameters(True, False) + def test_complex_functional(self, graph_mode): + class MLP(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.BatchStat(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + m = MLP(rngs=rngs) + m.set_attributes(deterministic=False, use_running_average=False) + return m + + @nnx.scan(in_axes=(0, nnx.Carry), graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + module = create(nnx.Rngs(0)) + + x = jnp.ones((1, 3)) + y, _ = forward(module, x) + + assert y.shape == (1, 3) def test_complex_view(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class MLP(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes)) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry)) + @nnx.scan(in_axes=(state_axes, nnx.Carry), graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) @@ -3065,20 +4081,53 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) + @parameterized.parameters(True, False) + def test_complex_view_functional(self, graph_mode): + class MLP(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + return MLP(rngs=rngs) + + @nnx.scan(in_axes=(0, nnx.Carry), graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + module = create(nnx.Rngs(0)) + new_module = nnx.view(module, deterministic=False, use_running_average=False) + + x = jnp.ones((1, 3)) + y, _ = forward(new_module, x) + + assert y.shape == (1, 3) + def test_complex_broadcast_dropout(self): state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) class MLP(nnx.Module): - @nnx.split_rngs(splits=5, only='params') - @nnx.vmap(in_axes=(state_axes, state_axes)) + @nnx.split_rngs(splits=5, only='params', graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.split_rngs(splits=5, only='params') - @nnx.scan(in_axes=(state_axes, nnx.Carry)) + @nnx.split_rngs(splits=5, only='params', graph=True, graph_updates=True) + @nnx.scan(in_axes=(state_axes, nnx.Carry), graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) @@ -3098,20 +4147,59 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) + @parameterized.parameters((False,)) + def test_complex_broadcast_dropout_functional(self, graph_mode): + + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = nnx.gelu(self.dropout(self.bn(self.linear(x)))) + return x, None + + rngs = nnx.Rngs(params=0, dropout=1) + dummy_module = Block(rngs) + (module_axes, rngs_axes) = nnx.prefix((dummy_module, rngs), {(nnx.Param, 'params', 'dropout'): 0, ...: None}, graph=graph_mode) + + @nnx.with_rngs(split={'params': 5}, replicate={'dropout': 5}, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=(rngs_axes,), out_axes=module_axes, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs: nnx.Rngs): + return Block(rngs=rngs) + + module = create(rngs) + + @nnx.scan(in_axes=(module_axes, nnx.Carry), graph=graph_mode, graph_updates=False) + def forward(module, x): + module = nnx.with_attributes(module, deterministic=False, use_running_average=False) + return module(x) + + assert module.linear.kernel.shape == (5, 3, 3) + assert module.linear.bias.shape == (5, 3) + assert module.node.shape == (2,) + + x = jnp.ones((1, 3)) + y, _ = forward(module, x) + + assert y.shape == (1, 3) + def test_complex_broadcast_dropout_view(self): state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) class MLP(nnx.Module): - @nnx.split_rngs(splits=5, only='params') - @nnx.vmap(in_axes=(state_axes, state_axes)) + @nnx.split_rngs(splits=5, only='params', graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.split_rngs(splits=5, only='params') - @nnx.scan(in_axes=(state_axes, nnx.Carry)) + @nnx.split_rngs(splits=5, only='params', graph=True, graph_updates=True) + @nnx.scan(in_axes=(state_axes, nnx.Carry), graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) @@ -3135,8 +4223,8 @@ def test_complex_decorator(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class Block(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5, graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.d = 3 self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -3144,7 +4232,7 @@ def __init__(self, rngs: nnx.Rngs): self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry)) + @nnx.scan(in_axes=(state_axes, nnx.Carry), graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) @@ -3166,12 +4254,82 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) assert out is None + @parameterized.parameters(True, False) + def test_complex_decorator_functional(self, graph_mode): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.d = 3 + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + m = Block(rngs=rngs) + m.set_attributes(deterministic=False, use_running_average=False) + return m + + @nnx.scan(in_axes=(0, nnx.Carry), graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + module = create(nnx.Rngs(0)) + + x = jnp.ones((1, 3)) + y, _ = forward(module, x) + + assert y.shape == (1, 3) + + @parameterized.parameters(True, False) + def test_complex_broadcast_dropout_view_functional(self, graph_mode): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.d = 3 + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + m = Block(rngs=rngs) + m.set_attributes(deterministic=False, use_running_average=False) + return m + + @nnx.scan(in_axes=(0, nnx.Carry), graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + module = create(nnx.Rngs(0)) + + x = jnp.ones((1, 3)) + y, _ = forward(module, x) + + assert y.shape == (1, 3) + def test_complex_decorator_view(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class Block(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5, graph=True, graph_updates=True) def __init__(self, rngs: nnx.Rngs): self.d = 3 self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -3179,7 +4337,7 @@ def __init__(self, rngs: nnx.Rngs): self.dropout = nnx.Dropout(0.5, rngs=rngs) self.node = nnx.Variable(jnp.ones((2,))) - @nnx.scan(in_axes=(state_axes, nnx.Carry)) + @nnx.scan(in_axes=(state_axes, nnx.Carry), graph=True, graph_updates=True) def __call__(self, x: jax.Array): x = self.linear(x) x = self.bn(x) @@ -3201,6 +4359,41 @@ def __call__(self, x: jax.Array): assert y.shape == (1, 3) assert out is None + @parameterized.parameters(True, False) + def test_complex_decorator_view_functional(self, graph_mode): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.d = 3 + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: + x = self.linear(x) + x = self.bn(x) + x = self.dropout(x) + x = nnx.gelu(x) + return x, None + + @nnx.split_rngs(splits=5, graph=graph_mode, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph_mode, graph_updates=False) + def create(rngs): + return Block(rngs=rngs) + + @nnx.scan(in_axes=(0, nnx.Carry), graph=graph_mode, graph_updates=False) + def forward(module, x): + return module(x) + + module = create(nnx.Rngs(0)) + new_module = nnx.view(module, deterministic=False, use_running_average=False) + + assert new_module.d == 3 + + x = jnp.ones((1, 3)) + y, _ = forward(new_module, x) + + assert y.shape == (1, 3) def test_scan_with_sharding(self): test = self @@ -3208,10 +4401,13 @@ def test_scan_with_sharding(self): transform_metadata = {nnx.PARTITION_NAME: 'layers'} class MLP(nnx.Module): - @nnx.split_rngs(splits=5) + + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) @nnx.vmap( - in_axes=(state_axes, state_axes), - transform_metadata=transform_metadata, + in_axes=(state_axes, state_axes), + transform_metadata=transform_metadata, + graph=True, + graph_updates=True, ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear( @@ -3227,7 +4423,10 @@ def __init__(self, rngs: nnx.Rngs): ) @nnx.scan( - in_axes=(state_axes, nnx.Carry), transform_metadata=transform_metadata + in_axes=(state_axes, nnx.Carry), + transform_metadata=transform_metadata, + graph=True, + graph_updates=True, ) def __call__(self, x: jax.Array): x = self.linear(x) @@ -3258,62 +4457,157 @@ def __call__(self, x: jax.Array): self.assertEqual(m.linear.bias.shape, (5, 3)) self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) - def test_cache_tracing_simple(self): + @parameterized.parameters(True, False) + def test_scan_with_sharding_functional(self, graph): + test = self + + class MLP(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + 3, + 3, + kernel_init=nnx.with_metadata( + nnx.initializers.lecun_normal(), out_sharding=('din', 'dout') + ), + bias_init=nnx.with_metadata( + nnx.initializers.zeros_init(), out_sharding=('dout',) + ), + rngs=rngs, + ) + + def __call__(self, x: jax.Array): + x = self.linear(x) + return x + + mesh = jax.make_mesh( + (1, 1, 1), + ('layers', 'din', 'dout'), + axis_types=(jax.sharding.AxisType.Auto,) * 3, + ) + + @nnx.vmap( + in_axes=0, + out_axes=0, + graph=graph, + graph_updates=False, + ) + @nnx.transform_metadata( + in_axes=0, + out_axes=0, + partition='layers', + graph=graph, + ) + def init_fn(rngs): + return MLP(rngs) + + with jax.set_mesh(mesh): + m = init_fn(nnx.Rngs(0).split(5)) + + # verify shapes outside + self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.bias.shape, (5, 3)) + self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) + + x = jnp.ones((1, 3)) + + @nnx.scan( + in_axes=(0, nnx.Carry), + out_axes=nnx.Carry, + graph=graph, + graph_updates=False, + ) + @nnx.transform_metadata( + in_axes=(0, nnx.Carry), + out_axes=nnx.Carry, + partition='layers', + graph=graph, + ) + def call_fn(m, x): + y = m(x) + # test sharding layer axes is not present inside scan + test.assertEqual(m.linear.kernel.shape, (3, 3)) + test.assertEqual(m.linear.kernel.out_sharding, ('din', 'dout')) + test.assertEqual(m.linear.bias.shape, (3,)) + test.assertEqual(m.linear.bias.out_sharding, ('dout',)) + return y + + with jax.set_mesh(mesh): + y = call_fn(m, x) + + # verify shapes after call + self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(m.linear.kernel.out_sharding, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.bias.shape, (5, 3)) + self.assertEqual(m.linear.bias.out_sharding, ('layers', 'dout')) + + @parameterized.parameters(True, False) + def test_cache_tracing_simple(self, graph_updates): n = 0 x = jnp.arange(5) count = jnp.array(0) - @nnx.scan + @nnx.scan(graph_updates=graph_updates) def f(count, x): nonlocal n n += 1 return count + 1, x**2 count, y = f(count, x) - assert n == 1 - assert count == 5 + self.assertEqual(n, 1) + self.assertEqual(count, 5) np.testing.assert_allclose(y, x**2) count, y = f(count, x) - assert n == 1 - assert count == 10 + self.assertEqual(n, 1) + self.assertEqual(count, 10) - def test_cache_tracing_object(self): + @parameterized.parameters(True, False) + def test_cache_tracing_object(self, graph_updates): n = 0 x = jnp.arange(5) count = jnp.array(0) class Foo(nnx.Pytree): - @nnx.split_rngs(splits=5) - @nnx.vmap(axis_size=5) def __init__(self, rngs: nnx.Rngs): - self.x = nnx.Param(jax.random.normal(rngs(), shape=(3,))) + self.x = nnx.Param(rngs.normal((3,))) + + @nnx.split_rngs(splits=5) + @nnx.vmap(graph_updates=graph_updates) + def create_foo(rngs: nnx.Rngs): + return Foo(rngs) - foo = Foo(rngs=nnx.Rngs(0)) - assert foo.x.shape == (5, 3) + foo = create_foo(nnx.Rngs(0)) + self.assertEqual(foo.x.shape, (5, 3)) - @nnx.scan(in_axes=(nnx.Carry, 0, 0)) + @nnx.scan(in_axes=(nnx.Carry, 0, 0), graph_updates=graph_updates) def f(count, x, foo): nonlocal n n += 1 - assert foo.x.shape == (3,) + self.assertEqual(foo.x.shape, (3,)) return count + 1, x**2 count, y = f(count, x, foo) - assert n == 1 - assert count == 5 + self.assertEqual(n, 1) + self.assertEqual(count, 5) np.testing.assert_allclose(y, x**2) count, y = f(count, x, foo) - assert n == 1 - assert count == 10 + self.assertEqual(n, 1) + self.assertEqual(count, 10) def test_scan_broadcast_keys(self): params_key = jax.random.split(jax.random.key(0), 3) rngs = nnx.Rngs(params=params_key, dropout=1) state_axes = nnx.StateAxes({'params': 0, ...: None}) - @nnx.scan(in_axes=(nnx.Carry, state_axes), length=3) + @nnx.scan( + in_axes=(nnx.Carry, state_axes), + length=3, + graph=True, + graph_updates=True, + ) def f(_, rngs: nnx.Rngs): param_key = rngs.params() dropout_key = rngs.dropout() @@ -3326,6 +4620,26 @@ def f(_, rngs: nnx.Rngs): assert jnp.equal(dropout_keys[0], dropout_keys[1]) assert jnp.equal(dropout_keys[1], dropout_keys[2]) + @parameterized.parameters(True, False) + def test_scan_broadcast_keys_functional(self, graph): + rngs = nnx.Rngs(params=0, dropout=1).split({'params': 3}).replicate({'dropout': 3}) + rngs_axes = nnx.prefix(rngs, {...: 0}, graph=graph) + + @nnx.scan( + in_axes=(rngs_axes,), out_axes=0, graph=graph, graph_updates=False + ) + def f(rngs): + param_key = rngs.params() + dropout_key = rngs.dropout() + return param_key, dropout_key + + param_keys, dropout_keys = f(rngs) + + assert jnp.not_equal(param_keys[0], param_keys[1]) + assert jnp.not_equal(param_keys[1], param_keys[2]) + assert jnp.equal(dropout_keys[0], dropout_keys[1]) + assert jnp.equal(dropout_keys[1], dropout_keys[2]) + def test_rnn_example(self): class RNNCell(nnx.Module): def __init__(self, input_size, hidden_size, rngs): @@ -3350,7 +4664,54 @@ def initial_state(self, batch_size: int): def rnn_forward(cell: RNNCell, x: jax.Array): carry = cell.initial_state(x.shape[0]) - @nnx.scan(in_axes=(state_axes, nnx.Carry, 1), out_axes=(nnx.Carry, 1)) + @nnx.scan( + in_axes=(state_axes, nnx.Carry, 1), + out_axes=(nnx.Carry, 1), + graph=True, + graph_updates=True, + ) + def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: + return cell(carry, x) + + _, y = unroll(cell, carry, x) + return y + + x = jnp.ones((16, 10, 20)) + y = rnn_forward(cell, x) + + @parameterized.parameters(True, False) + def test_rnn_example_functional(self, graph): + class RNNCell(nnx.Module): + + def __init__(self, input_size, hidden_size, rngs): + self.linear = nnx.Linear( + hidden_size + input_size, hidden_size, rngs=rngs + ) + self.drop = nnx.Dropout(0.1, rngs=rngs) + self.hidden_size = hidden_size + + def __call__(self, carry, x) -> tuple[jax.Array, jax.Array]: + carry = self.drop(carry) # recurrent dropout + x = nnx.relu(self.linear(jnp.concatenate([carry, x], axis=-1))) + return x, x + + def initial_state(self, batch_size: int): + return jnp.zeros((batch_size, self.hidden_size)) + + cell = RNNCell(20, 20, nnx.Rngs(params=0, dropout=1)) + + cell_axes = nnx.prefix(cell, {'dropout': 0, ...: None}, graph=graph) + + def rnn_forward(cell: RNNCell, x: jax.Array): + carry = cell.initial_state(x.shape[0]) + + @nnx.with_rngs(replicate={'dropout': 10}) + @nnx.scan( + in_axes=(cell_axes, nnx.Carry, 1), + out_axes=(nnx.Carry, 1), + graph=graph, + graph_updates=False, + ) def unroll(cell: RNNCell, carry, x) -> tuple[jax.Array, jax.Array]: return cell(carry, x) @@ -3389,7 +4750,7 @@ def _step2(self, state: tuple[CarryAsPytree, jax.Array, CarryAsPytree]): state[2].data = new_data2 return (state[0], out, state[2]) - @nnx.jit(static_argnames=("method")) + @nnx.jit(static_argnames=("method"), graph=True, graph_updates=True) def __call__(self, state, method): state_axes = nnx.StateAxes({nnx.Intermediate: 0, ...: nnx.Carry}) state_final = nnx.scan( @@ -3397,6 +4758,8 @@ def __call__(self, state, method): in_axes=(state_axes, nnx.Carry), out_axes=nnx.Carry, length=self.num_steps, + graph=True, + graph_updates=True, )(self, state) return state_final @@ -3427,28 +4790,106 @@ def __call__(self, state, method): intermediates['data2'][0], 11.0 + jnp.arange(num_steps) ) - def test_broadcast_variable_mutation_rejected(self): - v = nnx.Variable(jnp.array(1.0)) + @parameterized.parameters(True, False) + def test_carry_pytree_sow_functional(self, graph): + class CarryAsPytree(nnx.Pytree): + def __init__(self, data: jax.Array): + self.data = data + + class Model(nnx.Module): + def __init__(self, num_steps): + self.fc = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + self.num_steps = num_steps + + def _step(self, state): + new_data = state.data + 1 + self.sow(nnx.Intermediate, "data", new_data) + state.data = new_data + return state + + def _step2(self, state: tuple[CarryAsPytree, jax.Array, CarryAsPytree]): + out = self.fc(state[1]) + + new_data1 = state[0].data + 1 + new_data2 = state[2].data + 1 + + self.sow(nnx.Intermediate, "data1", new_data1) + self.sow(nnx.Intermediate, "data2", new_data2) + + state[0].data = new_data1 + state[2].data = new_data2 + return (state[0], out, state[2]) + + @nnx.jit(static_argnames=("method"), graph=graph, graph_updates=False) + def __call__(self, state, method): + state_axes = nnx.prefix( + self, {nnx.Intermediate: 0, ...: None}, graph=graph + ) + state_final = nnx.scan( + method, + in_axes=(state_axes, nnx.Carry), + out_axes=nnx.Carry, + length=self.num_steps, + graph=graph, + graph_updates=False, + )(self, state) + + return state_final + + num_steps = 5 + model = Model(num_steps=num_steps) + carry = CarryAsPytree(data=jnp.array(0.0)) + carry_final, intermediates = nnx.capture(model, nnx.Intermediate)(carry, method=Model._step) + self.assertEqual(carry_final.data, num_steps) + np.testing.assert_array_equal( + intermediates['data'][0], 1.0 + jnp.arange(num_steps) + ) + + carry = ( + CarryAsPytree(data=jnp.array(0.0)), + jnp.ones((num_steps, 10)), + CarryAsPytree(data=jnp.array(10.0)) + ) + + carry_final, intermediates = nnx.capture(model, nnx.Intermediate)(carry, method=Model._step2) + + self.assertEqual(carry_final[0].data, num_steps) + self.assertEqual(carry_final[2].data, 10 + num_steps) + np.testing.assert_array_equal( + intermediates['data1'][0], 1.0 + jnp.arange(num_steps) + ) + np.testing.assert_array_equal( + intermediates['data2'][0], 11.0 + jnp.arange(num_steps) + ) + + def test_broadcast_variable_mutation(self): + v = nnx.Variable(jnp.array(1)) @nnx.scan( in_axes=(None, nnx.Carry, 0), graph=False, graph_updates=False, ) def fn(v, carry, x): - v[...] = v[...] + 1.0 + v[...] = v[...] + 1 return carry + x, carry - with self.assertRaisesRegex(ValueError, 'Broadcast.*mutated'): - fn(v, jnp.array(0.0), jnp.arange(3.0)) + carry, ys = fn(v, jnp.array(0), jnp.arange(3)) + # v is broadcast (None axis), mutated each iteration: 1 -> 2 -> 3 -> 4 + self.assertEqual(v[...], 4) + # step 0: carry=0, x=0 -> carry=0, y=0 + # step 1: carry=0, x=1 -> carry=1, y=0 + # step 2: carry=1, x=2 -> carry=3, y=1 + np.testing.assert_allclose(carry, 3) + np.testing.assert_allclose(ys, jnp.array([0, 0, 1])) def test_broadcast_out_axes_rejected(self): - @nnx.scan( - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None), - graph=False, graph_updates=False, - ) - def fn(carry, x): - return carry + x, jnp.zeros(3) - with self.assertRaisesRegex(ValueError, 'broadcast'): + @nnx.scan( + in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None), + graph=False, graph_updates=False, + ) + def fn(carry, x): + return carry + x, jnp.zeros(3) + fn(jnp.array(0.0), jnp.arange(3.0)) def test_scan_inconsistent_aliasing(self): @@ -3473,9 +4914,113 @@ def test_scan_input_output_aliasing(self): def f(carry): return carry - with self.assertRaisesRegex(ValueError, 'does not support returning input Variables as outputs'): + with self.assertRaisesRegex(ValueError, 'does not support Variable aliasing'): f(v) + def test_scan_carry_and_scan(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + final_carry, ys = nnx.scan( + cumsum, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + graph=False, + )(jnp.array(0.0), jnp.arange(5.0)) + np.testing.assert_allclose(final_carry, 10.0) + np.testing.assert_allclose(ys, jnp.array([0., 1., 3., 6., 10.])) + + def test_scan_pytree_carry(self): + def dict_scan(carry, x): + carry = {'a': carry['a'] + x['a'], 'b': carry['b'] + x['b']} + return carry, carry + + xs = {'a': jnp.arange(3.0), 'b': jnp.ones(3)} + init = {'a': jnp.array(0.0), 'b': jnp.array(0.0)} + final_carry, _ = nnx.scan( + dict_scan, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + graph=False, + )(init, xs) + np.testing.assert_allclose(final_carry['a'], 3.0) + np.testing.assert_allclose(final_carry['b'], 3.0) + + def test_scan_reverse(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + final_carry, _ = nnx.scan( + cumsum, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + reverse=True, + graph=False, + )(jnp.array(0.0), jnp.arange(5.0)) + np.testing.assert_allclose(final_carry, 10.0) + + def test_scan_axis_1(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + x = jnp.arange(10.0).reshape((2, 5)) + final_carry, ys = nnx.scan( + cumsum, + in_axes=(nnx.Carry, 1), + out_axes=(nnx.Carry, 1), + graph=False, + )(jnp.zeros(2), x) + np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) + expected_ys = jnp.array([ + [0., 1., 3., 6., 10.], + [5., 11., 18., 26., 35.] + ]) + np.testing.assert_allclose(ys, expected_ys) + + def test_scan_axis_negative_1(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + x = jnp.arange(10.0).reshape((2, 5)) + final_carry, ys = nnx.scan( + cumsum, + in_axes=(nnx.Carry, -1), + out_axes=(nnx.Carry, -1), + graph=False, + )(jnp.zeros(2), x) + np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) + expected_ys = jnp.array([ + [0., 1., 3., 6., 10.], + [5., 11., 18., 26., 35.] + ]) + np.testing.assert_allclose(ys, expected_ys) + + def test_scan_different_in_out_axes(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + x = jnp.arange(10.0).reshape((2, 5)) + final_carry, ys = nnx.scan( + cumsum, + in_axes=(nnx.Carry, 1), + out_axes=(nnx.Carry, 0), + graph=False, + )(jnp.zeros(2), x) + np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) + expected_ys = jnp.array([ + [0., 5.], + [1., 11.], + [3., 18.], + [6., 26.], + [10., 35.] + ]) + np.testing.assert_allclose(ys, expected_ys) + class TestRemat(parameterized.TestCase): @parameterized.parameters( @@ -3539,25 +5084,71 @@ def test_remat_with_scan_decorator(self): state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) class ScanLinear(nnx.Module): - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes, state_axes), axis_size=5) + + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap( + in_axes=(state_axes, state_axes), + axis_size=5, + graph=True, + graph_updates=True, + ) def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) - @nnx.scan(in_axes=(state_axes, nnx.Carry)) - @nnx.remat - def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: - x = self.linear(x) - return x, None + @nnx.scan( + in_axes=(state_axes, nnx.Carry), + out_axes=nnx.Carry, + graph=True, + graph_updates=True, + ) + @nnx.remat(graph=True, graph_updates=True) + def __call__(self, x: jax.Array) -> jax.Array: + return self.linear(x) m = ScanLinear(nnx.Rngs(0)) assert m.linear.kernel.shape == (5, 3, 3) assert m.linear.bias.shape == (5, 3) - y, _ = m(jnp.ones((1, 3))) + y = m(jnp.ones((1, 3))) assert y.shape == (1, 3) + @parameterized.parameters(True, False) + def test_remat_with_scan_decorator_functional(self, graph): + state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) + + class ScanLinear(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + @nnx.split_rngs(splits=5, graph=graph, graph_updates=False) + @nnx.vmap(in_axes=0, axis_size=5, graph=graph, graph_updates=False) + def create(rngs): + return nnx.Linear(3, 3, rngs=rngs) + + self.linear = create(rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + @nnx.scan( + in_axes=(0, nnx.Carry), + out_axes=nnx.Carry, + graph=graph, + graph_updates=False, + ) + @nnx.remat(graph=graph, graph_updates=False) + def forward(linear, x): + return linear(x) + + return forward(self.linear, x) + + m = ScanLinear(nnx.Rngs(0)) + + assert m.linear.kernel.shape == (5, 3, 3) + + x = jnp.ones((2, 3)) + y = m(x) + + assert y.shape == (2, 3) + @parameterized.parameters( (True, False), (False, False), @@ -3717,9 +5308,12 @@ def forward(model, x): y = forward(model, x) assert y.shape == (5, 1, 3) - def test_basic(self): + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_basic(self, graph, graph_updates): @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=0, out_axes=0, axis_size=5) + @nnx.vmap(in_axes=0, out_axes=0, axis_size=5, graph=graph, graph_updates=graph_updates) def create_block(rngs: nnx.Rngs): return nnx.Linear(2, 3, rngs=rngs) @@ -3730,7 +5324,7 @@ def create_block(rngs: nnx.Rngs): self.assertEqual(block.kernel.shape, (5, 2, 3)) self.assertEqual(rngs.default.count[...], 1) - @nnx.vmap(in_axes=(0, 1), out_axes=1) + @nnx.vmap(in_axes=(0, 1), out_axes=1, graph=graph, graph_updates=graph_updates) def forward(block: nnx.Linear, x): self.assertEqual(block.kernel.shape, (2, 3)) self.assertEqual(block.bias.shape, (3,)) @@ -3742,9 +5336,12 @@ def forward(block: nnx.Linear, x): self.assertEqual(y.shape, (3, 5)) - def test_basic_variables(self): + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_basic_variables(self, graph, graph_updates): @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=0, out_axes=0, axis_size=5) + @nnx.vmap(in_axes=0, out_axes=0, axis_size=5, graph=graph, graph_updates=graph_updates) def create_block(rngs: nnx.Rngs): w = nnx.Param(jax.random.normal(rngs(), (2, 3))) b = nnx.Param(jax.random.normal(rngs(), (3,))) @@ -3757,7 +5354,7 @@ def create_block(rngs: nnx.Rngs): self.assertEqual(b.shape, (5, 3)) self.assertEqual(rngs.default.count[...], 1) - @nnx.vmap(in_axes=(0, 0, 1), out_axes=1) + @nnx.vmap(in_axes=(0, 0, 1), out_axes=1, graph=graph, graph_updates=graph_updates) def forward(w, b, x): self.assertEqual(w.shape, (2, 3)) self.assertEqual(b.shape, (3,)) @@ -3782,48 +5379,122 @@ def __call__(self, x: jax.Array) -> jax.Array: return x @nnx.vmap( - in_axes=0, - out_axes=nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), + in_axes=0, + out_axes=nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), + graph=True, + graph_updates=True, ) def create_block(rngs: nnx.Rngs): - rngs = nnx.clone(rngs) return Block(rngs) rngs = nnx.Rngs(0) initial_key = rngs.default.key[...] - backups = nnx.split_rngs(rngs, splits=5) + backups = nnx.split_rngs(rngs, splits=5, graph=True, graph_updates=True) module = create_block(rngs) nnx.restore_rngs(backups) - assert rngs.default.count[...] == 1 - assert rngs.default.key[...] == initial_key - assert not jnp.allclose( - module.linear.kernel[0], - module.linear.kernel[1], + self.assertEqual(rngs.default.count[...], 1) + self.assertEqual(rngs.default.key[...], initial_key) + self.assertFalse( + jnp.allclose( + module.linear.kernel[0], + module.linear.kernel[1], + ) ) - assert module.linear.kernel.shape == (5, 3, 3) - assert module.linear.bias.shape == (5, 3) + self.assertEqual(module.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(module.linear.bias.shape, (5, 3)) x = jnp.ones((5, 1, 3)) @nnx.vmap( - in_axes=(nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), 0), + in_axes=(nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}), 0), + graph=True, + graph_updates=True, ) def forward_block(module, x): return module(x) - backups = nnx.split_rngs(rngs, splits=5) + backups = nnx.split_rngs(rngs, splits=5, graph=True, graph_updates=True) y = forward_block(module, x) nnx.restore_rngs(backups) - assert y.shape == (5, 1, 3) - assert rngs.default.count[...] == 2 - assert rngs.default.key[...] == initial_key + self.assertEqual(y.shape, (5, 1, 3)) + self.assertEqual(rngs.default.count[...], 2) + self.assertEqual(rngs.default.key[...], initial_key) y2 = forward_block(module, x) - assert not jnp.allclose(y, y2) + self.assertFalse(jnp.allclose(y, y2)) + + @parameterized.parameters(True, False) + def test_state_axes_functional(self, graph): + + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key[...] + + vec_filter = (nnx.Param, nnx.RngState) + unb_filter = ... + + abs_block = nnx.eval_shape(lambda: Block(nnx.Rngs(0))) + model_axes = nnx.prefix( + abs_block, {vec_filter: 0, unb_filter: None}, graph=graph + ) + + @nnx.with_rngs(split=5, graph=graph, graph_updates=False) + @nnx.vmap( + in_axes=0, + out_axes=model_axes, + graph=graph, + graph_updates=False, + ) + def create_block_functional(rngs): + return Block(rngs) + + module = create_block_functional(rngs) + + self.assertEqual(rngs.default.count[...], 1) + self.assertTrue(np.all(rngs.default.key[...] == initial_key)) + + self.assertFalse( + jnp.allclose( + module.linear.kernel[0], + module.linear.kernel[1], + ) + ) + self.assertEqual(module.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(module.linear.bias.shape, (5, 3)) + + x = jnp.ones((5, 1, 3)) + + @nnx.vmap( + in_axes=(model_axes, 0), + out_axes=0, + graph=graph, + graph_updates=False, + ) + def forward_block_functional(module, x): + return module(x) + + y = forward_block_functional(module, x) + + self.assertEqual(y.shape, (5, 1, 3)) + self.assertEqual(rngs.default.count[...], 1) + self.assertTrue(np.all(rngs.default.key[...] == initial_key)) + + y2 = forward_block_functional(module, x) + self.assertFalse(jnp.allclose(y, y2)) def test_split_rngs_context_manager(self): class Block(nnx.Module): @@ -3839,7 +5510,12 @@ def __call__(self, x: jax.Array) -> jax.Array: state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) - @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + @nnx.vmap( + in_axes=(state_axes,), + out_axes=state_axes, + graph=True, + graph_updates=True, + ) def create_block(rngs: nnx.Rngs): return Block(rngs) @@ -3859,7 +5535,7 @@ def create_block(rngs: nnx.Rngs): x = jnp.ones((5, 1, 3)) - @nnx.vmap(in_axes=(state_axes, 0)) + @nnx.vmap(in_axes=(state_axes, 0), graph=True, graph_updates=True) def forward_block(module, x): return module(x) @@ -3872,6 +5548,67 @@ def forward_block(module, x): assert not jnp.allclose(y, y2) + @parameterized.parameters(True, False) + def test_split_rngs_context_manager_functional(self, graph): + class Block(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + vec_filter = (nnx.Param, nnx.RngState) + unb_filter = ... + + abs_block = nnx.eval_shape(lambda: Block(nnx.Rngs(0))) + model_axes = nnx.prefix( + abs_block, {vec_filter: 0, unb_filter: None}, graph=graph + ) + + @nnx.vmap( + in_axes=(0,), out_axes=model_axes, graph=graph, graph_updates=False + ) + def create_block_functional(rngs): + return Block(rngs) + + rngs = nnx.Rngs(0) + + module = create_block_functional(rngs.split(5)) + + self.assertEqual(module.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(module.linear.bias.shape, (5, 3)) + self.assertFalse( + jnp.allclose(module.linear.kernel[0], module.linear.kernel[1]) + ) + + x = jnp.ones((5, 1, 3)) + + model_axes = nnx.prefix( + module, {vec_filter: 0, unb_filter: None}, graph=graph + ) + + @nnx.vmap( + in_axes=(model_axes, 0), + out_axes=0, + graph=graph, + graph_updates=False, + ) + def forward_block_functional(module, x): + return module(x) + + y = forward_block_functional(module, x) + + self.assertEqual(y.shape, (5, 1, 3)) + + y2 = forward_block_functional(module, x) + + self.assertFalse(jnp.allclose(y, y2)) + def test_split_rngs_decorator(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): @@ -3886,8 +5623,13 @@ def __call__(self, x: jax.Array) -> jax.Array: state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) - @nnx.split_rngs(splits=5) - @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap( + in_axes=(state_axes,), + out_axes=state_axes, + graph=True, + graph_updates=True, + ) def create_block(rngs: nnx.Rngs): return Block(rngs) @@ -3905,23 +5647,188 @@ def create_block(rngs: nnx.Rngs): assert module.linear.kernel.shape == (5, 3, 3) assert module.linear.bias.shape == (5, 3) - x = jnp.ones((5, 1, 3)) + x = jnp.ones((5, 1, 3)) + + @nnx.vmap(in_axes=(state_axes, 0), graph=True, graph_updates=True) + def forward_block(module, x): + self.assertEqual(x.shape, (1, 3)) + return module(x) + + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + assert rngs.default.key[...] == initial_key + + y2 = forward_block(module, x) + + assert not jnp.allclose(y, y2) + + @parameterized.parameters(True, False) + def test_split_rngs_decorator_functional(self, graph): + graph_updates = False + + class Block(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + vec_filter = (nnx.Param, nnx.RngState) + unb_filter = ... + + abs_block = nnx.eval_shape(lambda: Block(nnx.Rngs(0))) + model_axes = nnx.prefix( + abs_block, {vec_filter: 0, unb_filter: None}, graph=graph + ) + + @nnx.split_rngs(splits=5, graph=graph, graph_updates=graph_updates) + @nnx.vmap( + in_axes=(0,), + out_axes=model_axes, + graph=graph, + graph_updates=graph_updates, + ) + def create_block_functional(rngs): + return Block(rngs) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key[...] + + module = create_block_functional(rngs) + + if graph and graph_updates: + self.assertEqual(rngs.default.count[...], 1) + self.assertTrue(jnp.allclose(rngs.default.key[...], initial_key)) + + self.assertEqual(module.linear.kernel.shape, (5, 3, 3)) + self.assertEqual(module.linear.bias.shape, (5, 3)) + self.assertFalse( + jnp.allclose(module.linear.kernel[0], module.linear.kernel[1]) + ) + + x = jnp.ones((5, 1, 3)) + + @nnx.vmap( + in_axes=(model_axes, 0), + out_axes=0, + graph=graph, + graph_updates=graph_updates, + ) + def forward_block_functional(module, x): + return module(x) + + y = forward_block_functional(module, x) + + self.assertEqual(y.shape, (5, 1, 3)) + + y2 = forward_block_functional(module, x) + + self.assertFalse(jnp.allclose(y, y2)) + + def test_state_axes_simple(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) + + @nnx.split_rngs(splits=5, only='dropout') + @nnx.vmap( + in_axes=(state_axes,), + out_axes=state_axes, + graph=True, + graph_updates=True, + ) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(params=0, dropout=1) + module = create_block(rngs) + + assert module.linear.kernel.shape == (2, 3) + assert module.bn.scale.shape == (3,) + assert module.bn.mean.shape == (5, 3) + + @nnx.vmap( + in_axes=(state_axes, 0), out_axes=0, graph=True, graph_updates=True + ) + def forward_block(module, x): + return module(x) + + x = jnp.ones((5, 1, 2)) + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + + @parameterized.parameters(True, False) + def test_state_axes_simple_functional(self, graph): + graph_updates = False + + class Block(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.bn = nnx.BatchNorm(3, rngs=rngs) + self.dropout = nnx.Dropout(0.1, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return nnx.relu(self.dropout(self.bn(self.linear(x)))) + + rngs = nnx.Rngs(params=0, dropout=1).split({'dropout': 5}) + + vec_filter = (nnx.BatchStat, 'dropout') + unb_filter = ... + + abs_block = nnx.eval_shape(lambda: Block(nnx.Rngs(params=0, dropout=1))) + model_axes = nnx.prefix( + abs_block, {vec_filter: 0, unb_filter: None}, graph=graph + ) + + @nnx.vmap( + in_axes=(None, 0), + out_axes=model_axes, + graph=graph, + graph_updates=graph_updates, + ) + def create_block_functional(params, dropout): + return Block(nnx.Rngs(params=params, dropout=dropout)) + + module = create_block_functional(rngs.params, rngs.dropout) - @nnx.vmap(in_axes=(state_axes, 0)) - def forward_block(module, x): - self.assertEqual(x.shape, (1, 3)) - return module(x) + self.assertEqual(module.linear.kernel.shape, (2, 3)) + self.assertEqual(module.bn.scale.shape, (3,)) + self.assertEqual(module.bn.mean.shape, (5, 3)) - y = forward_block(module, x) + initial_mean = module.bn.mean[...] - assert y.shape == (5, 1, 3) - assert rngs.default.key[...] == initial_key + @nnx.vmap( + in_axes=(model_axes, 0), + out_axes=0, + graph=graph, + graph_updates=graph_updates, + ) + def forward_block_functional(module, x): + return module(x) - y2 = forward_block(module, x) + x = jnp.ones((5, 1, 2)) + y = forward_block_functional(module, x) - assert not jnp.allclose(y, y2) + self.assertEqual(y.shape, (5, 1, 3)) + # Verify that updates were tracked and applied + self.assertFalse(jnp.allclose(initial_mean, module.bn.mean[...])) - def test_state_axes_simple(self): + def test_split_rngs_decorator_simple(self): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) @@ -3933,29 +5840,46 @@ def __call__(self, x: jax.Array) -> jax.Array: state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) - @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) + @nnx.split_rngs(splits=5, only='dropout', graph=True, graph_updates=True) + @nnx.vmap( + in_axes=(state_axes,), + out_axes=state_axes, + graph=True, + graph_updates=True, + ) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(params=0, dropout=1) - nnx.split_rngs(rngs, splits=5, only='dropout') module = create_block(rngs) assert module.linear.kernel.shape == (2, 3) assert module.bn.scale.shape == (3,) assert module.bn.mean.shape == (5, 3) + assert module.dropout.rngs is not None + self.assertEqual(module.dropout.rngs.key.shape, (5,)) - @nnx.vmap(in_axes=(state_axes, 0), out_axes=0) - def forward_block(module, x): + @nnx.vmap( + in_axes=(state_axes, 0), + out_axes=0, + graph=True, + graph_updates=True, + ) + def forward_block(module: Block, x): + assert module.dropout.rngs is not None + self.assertEqual(module.dropout.rngs.key.shape, ()) return module(x) x = jnp.ones((5, 1, 2)) y = forward_block(module, x) + assert module.dropout.rngs is not None + self.assertEqual(module.dropout.rngs.key.shape, (5,)) assert y.shape == (5, 1, 3) - def test_split_rngs_decorator_simple(self): + @parameterized.parameters(True, False) + def test_split_rngs_decorator_simple_functional(self, graph): class Block(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) @@ -3965,38 +5889,60 @@ def __init__(self, rngs: nnx.Rngs): def __call__(self, x: jax.Array) -> jax.Array: return nnx.relu(self.dropout(self.bn(self.linear(x)))) - state_axes = nnx.StateAxes({(nnx.BatchStat, 'dropout'): 0, ...: None}) + vec_filter = (nnx.BatchStat, 'dropout') + unb_filter = ... + + rngs = nnx.Rngs(params=0, dropout=1) + abs_block = nnx.eval_shape(lambda: Block(nnx.clone(rngs))) + model_axes = nnx.prefix( + abs_block, {vec_filter: 0, unb_filter: None}, graph=graph + ) + rngs_axes = nnx.prefix(rngs, {'dropout': 0, ...: None}, graph=graph) @nnx.split_rngs(splits=5, only='dropout') - @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) - def create_block(rngs: nnx.Rngs): + @nnx.vmap( + in_axes=(rngs_axes,), + out_axes=model_axes, + graph=graph, + graph_updates=False, + ) + def create_block_functional(rngs): return Block(rngs) - rngs = nnx.Rngs(params=0, dropout=1) - - module = create_block(rngs) + module = create_block_functional(rngs) - assert module.linear.kernel.shape == (2, 3) - assert module.bn.scale.shape == (3,) - assert module.bn.mean.shape == (5, 3) - assert module.dropout.rngs is not None + self.assertEqual(module.linear.kernel.shape, (2, 3)) + self.assertEqual(module.bn.scale.shape, (3,)) + self.assertEqual(module.bn.mean.shape, (5, 3)) + self.assertIsNotNone(module.dropout.rngs) self.assertEqual(module.dropout.rngs.key.shape, (5,)) - @nnx.vmap(in_axes=(state_axes, 0), out_axes=0) - def forward_block(module: Block, x): - assert module.dropout.rngs is not None + @nnx.vmap( + in_axes=(model_axes, 0), + out_axes=0, + graph=graph, + graph_updates=False, + ) + def forward_block_functional(module, x): + self.assertIsNotNone(module.dropout.rngs) self.assertEqual(module.dropout.rngs.key.shape, ()) return module(x) x = jnp.ones((5, 1, 2)) - y = forward_block(module, x) + y = forward_block_functional(module, x) - assert module.dropout.rngs is not None + self.assertIsNotNone(module.dropout.rngs) self.assertEqual(module.dropout.rngs.key.shape, (5,)) - assert y.shape == (5, 1, 3) + self.assertEqual(y.shape, (5, 1, 3)) - def test_state_axes_super_simple(self): + @parameterized.parameters( + (True, True), + (True, False), + (False, False), + ) + def test_state_axes_super_simple(self, graph, graph_updates): class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) self.bn = nnx.BatchNorm(3, rngs=rngs) @@ -4005,27 +5951,36 @@ def __init__(self, rngs: nnx.Rngs): def __call__(self, x: jax.Array) -> jax.Array: return nnx.relu(self.dropout(self.bn(self.linear(x)))) - @nnx.vmap(in_axes=0, out_axes=0) + @nnx.split_rngs(splits=5, graph=graph, graph_updates=graph_updates) + @nnx.vmap( + in_axes=0, + out_axes=0, + graph=graph, + graph_updates=graph_updates, + ) def create_block(rngs: nnx.Rngs): return Block(rngs) rngs = nnx.Rngs(0) - nnx.split_rngs(rngs, splits=5) - module = create_block(rngs) - assert module.linear.kernel.shape == (5, 2, 3) - assert module.bn.scale.shape == (5, 3) - assert module.bn.mean.shape == (5, 3) + self.assertEqual(module.linear.kernel.shape, (5, 2, 3)) + self.assertEqual(module.bn.scale.shape, (5, 3)) + self.assertEqual(module.bn.mean.shape, (5, 3)) - @nnx.vmap(in_axes=(0, 0), out_axes=0) + @nnx.vmap( + in_axes=(0, 0), + out_axes=0, + graph=graph, + graph_updates=graph_updates, + ) def forward_block(module, x): return module(x) x = jnp.ones((5, 1, 2)) y = forward_block(module, x) - assert y.shape == (5, 1, 3) + self.assertEqual(y.shape, (5, 1, 3)) def test_replicate(self): din = 3 @@ -4044,8 +5999,10 @@ def create_block(rngs: nnx.Rngs): state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None}) - @nnx.split_rngs(splits=5) - @partial(nnx.vmap, in_axes=(state_axes, 0), out_axes=0) + @nnx.split_rngs(splits=5, graph=True, graph_updates=True) + @nnx.vmap( + in_axes=(state_axes, 0), out_axes=0, graph=True, graph_updates=True + ) def forward_block(module: Block, x): return module(x) @@ -4053,55 +6010,120 @@ def forward_block(module: Block, x): module = create_block(rngs) initial_key = module.dropout.rngs.key[...] - assert module.dropout.rngs.count[...] == 0 - assert module.linear.kernel.shape == (din, dout) - assert module.linear.bias.shape == (dout,) + self.assertEqual(module.dropout.rngs.count[...], 0) + self.assertEqual(module.linear.kernel.shape, (din, dout)) + self.assertEqual(module.linear.bias.shape, (dout,)) x = jnp.ones((5, 1, din)) y = forward_block(module, x) - assert y.shape == (5, 1, dout) - assert module.dropout.rngs.count[...] == 1 + self.assertEqual(y.shape, (5, 1, dout)) + self.assertEqual(module.dropout.rngs.count[...], 1) - assert not jnp.allclose(y[0], y[1]) + self.assertFalse(jnp.allclose(y[0], y[1])) y2 = forward_block(module, x) - # dropout is working! - assert not jnp.allclose(y, y2) + self.assertFalse(jnp.allclose(y, y2)) - assert module.dropout.rngs.key[...] == initial_key + self.assertTrue(np.all(module.dropout.rngs.key[...] == initial_key)) + + @parameterized.parameters(True, False) + def test_replicate_functional(self, graph): + din = 3 + dout = 10 + + class Block(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + vec_filter = nnx.RngState + unb_filter = ... + + rngs = nnx.Rngs(0) + module = create_block(rngs) - def test_consistent_aliasing_inputs(self): + self.assertEqual(module.dropout.rngs.count[...], 0) + self.assertEqual(module.linear.kernel.shape, (din, dout)) + + x = jnp.ones((5, 1, din)) + + module = nnx.split_rngs(module, splits=5, graph=graph, graph_updates=False) + model_axes = nnx.prefix( + module, {vec_filter: 0, unb_filter: None}, graph=graph + ) + + @nnx.vmap( + in_axes=(model_axes, 0), + out_axes=0, + graph=graph, + graph_updates=False, + ) + def forward_block_functional(module, x): + return module(x) + + y = forward_block_functional(module, x) + + self.assertEqual(y.shape, (5, 1, dout)) + self.assertFalse(jnp.allclose(y[0], y[1])) + + y2 = forward_block_functional(module, x) + + self.assertFalse(jnp.allclose(y, y2)) + + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_consistent_aliasing_inputs(self, graph, graph_updates): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.zeros((5, 5))) m = Foo() - @nnx.vmap(in_axes=(0, 1)) + @nnx.vmap(in_axes=(0, 1), graph=graph, graph_updates=graph_updates) def f(m1, m2): pass - with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): + error_msg = ( + 'Inconsistent aliasing detected' if graph else 'Duplicate Param' + ) + with self.assertRaisesRegex(ValueError, error_msg): f(m, m) - def test_consistent_aliasing_input_output(self): + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_consistent_aliasing_input_output(self, graph, graph_updates): class Foo(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.zeros((2, 3))) m = Foo() - @partial(nnx.vmap, in_axes=0, out_axes=1) + @nnx.vmap(in_axes=0, out_axes=1, graph=graph, graph_updates=graph_updates) def f(m): return m - with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'): + error_msg = ( + 'Inconsistent aliasing detected' if graph_updates else 'Duplicate Param' + ) + with self.assertRaisesRegex(ValueError, error_msg): m2 = f(m) - def test_consistent_aliasing_shared(self): + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_consistent_aliasing_shared(self, graph, graph_updates): class Shared(nnx.Module): def __init__(self): self.a = nnx.Param(jnp.zeros((3, 3))) @@ -4114,15 +6136,16 @@ def __init__(self, shared: Shared): m1 = Foo(shared) m2 = Foo(shared) - @nnx.vmap(in_axes=(0, 1)) + @nnx.vmap(in_axes=(0, 1), graph=graph, graph_updates=graph_updates) def f(m1, m2): pass - with self.assertRaisesRegex( - ValueError, - r'Inconsistent aliasing detected([\s\S]*)Param([\s\S]*)a:' - r' 0([\s\S]*)a: 1', - ): + error_msg = ( + r'Inconsistent aliasing detected([\s\S]*)Param([\s\S]*)a:' + r' 0([\s\S]*)a: 1' + if graph else 'Duplicate Param' + ) + with self.assertRaisesRegex(ValueError, error_msg): f(m1, m2) def test_equivalent_state_axes_mapping(self): @@ -4131,12 +6154,27 @@ def test_equivalent_state_axes_mapping(self): sa1 = nnx.StateAxes({...: 0}) sa2 = nnx.StateAxes({nnx.Param: 0}) - @nnx.vmap(in_axes=(0, sa1, sa2)) + @nnx.vmap(in_axes=(0, sa1, sa2), graph=True, graph_updates=True) def f(m1, m2, m3): pass f(m, m, m) + def test_equivalent_state_axes_mapping_functional(self): + m = nnx.eval_shape(lambda: nnx.Linear(3, 3, rngs=nnx.Rngs(0))) + + sa1, sa2, sa3 = nnx.prefix((m, m, m), {...: 0}, graph=True) + + @nnx.vmap(out_axes=(sa1, sa2, sa3), axis_size=2, graph=True) + def f(): + m = nnx.Linear(3, 3, rngs=nnx.Rngs(0)) + return m, m, m + + m1, m2, m3 = f() + + assert m1 is m2 and m2 is m3 + self.assertEqual(m1.kernel.shape, (2, 3, 3)) + def test_equivalent_state_sharding_mapping(self): m = nnx.Linear(4, 4, rngs=nnx.Rngs(0)) @@ -4148,29 +6186,55 @@ def test_equivalent_state_sharding_mapping(self): sa1 = nnx.StateSharding({...: sharding}) sa2 = nnx.StateSharding({nnx.Param: sharding}) - @nnx.jit(in_shardings=(sharding, sa1, sa2)) + @nnx.jit(in_shardings=(sharding, sa1, sa2), graph=True, graph_updates=True) def f(m1, m2, m3): pass f(m, m, m) - def test_captured_module_in_return_error(self): + assert m.kernel.sharding == sharding + + def test_equivalent_state_sharding_mapping_functional(self): + + mesh = jax.sharding.Mesh(jax.devices(), ('mp',)) + m = nnx.eval_shape(lambda: nnx.Linear(4, 4, rngs=nnx.Rngs(0))) + + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('mp') + ) + + sa1, sa2, sa3 = nnx.prefix((m, m, m), {...: sharding}, graph=True) + + @nnx.jit(out_shardings=(sa1, sa2, sa3), graph=True, graph_updates=False) + def f(): + m = nnx.Linear(4, 4, rngs=nnx.Rngs(0)) + return m, m, m + + with jax.set_mesh(mesh): + m1, m2, m3 = f() + + assert m1 is m2 and m2 is m3 + assert m1.kernel.sharding == sharding + + @parameterized.parameters(True, False) + def test_captured_module_in_return_error(self, graph_updates): class Foo(nnx.Module): def __init__(self): - self.a = jnp.zeros((4, 4)) + self.a = nnx.Variable(jnp.arange(4)) m = Foo() - @nnx.vmap(in_axes=0, out_axes=0) + @nnx.vmap(in_axes=0, out_axes=0, graph=True, graph_updates=graph_updates) def f(x): - return x, m + return m - with self.assertRaisesRegex( - errors.TraceContextError, - r'Trying to extract graph node from different trace level.*Foo', - ): - x = jnp.zeros((4,)) - f(x) + if graph_updates: + error_regex = 'Cannot extract graph node from different trace level' + else: + error_regex = 'Cannot return captured Variable' + + with self.assertRaisesRegex(ValueError, error_regex): + f(jnp.zeros((4,))) def test_vmap_and_cond_passthrough(self): class Broadcast(nnx.Variable[nnx.A]): ... @@ -4185,7 +6249,11 @@ def __init__(self): env = Env() - @nnx.vmap(in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),)) + @nnx.vmap( + in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),), + graph=True, + graph_updates=True, + ) def f(env: Env): self.assertEqual(env.step.shape, ()) @@ -4196,12 +6264,56 @@ def no_nothing(env: Env): pass is_even = env.index % 2 == 0 - nnx.cond(is_even, increment, no_nothing, env) + nnx.cond( + is_even, increment, no_nothing, env, graph=True, graph_updates=True + ) f(env) np.testing.assert_array_equal(env.step[...], [1, 0, 1, 0, 1, 0, 1, 0]) + @parameterized.parameters(True, False) + def test_vmap_and_cond_passthrough_functional_success(self, graph): + class Broadcast(nnx.Variable[nnx.A]): + ... + + class Vectorized(nnx.Variable[nnx.A]): + ... + + class Env(nnx.Module): + + def __init__(self): + self.broadcast = Broadcast(jnp.array(1)) + self.index = Vectorized(jnp.arange(8)) + self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) + + env = Env() + + env_axes = nnx.prefix(env, {Broadcast: None, Vectorized: 0}, graph=graph) + + @nnx.vmap(in_axes=(env_axes,), graph=graph, graph_updates=False) + def f(env): + self.assertEqual(env.step.shape, ()) + + def increment(env: Env): + env.step[...] += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond( + is_even, + increment, + no_nothing, + env, + graph=graph, + graph_updates=False, + ) + + f(env) + np.testing.assert_array_equal(env.step[...], [1, 0, 1, 0, 1, 0, 1, 0]) + def test_vmap_and_cond_passthrough_error(self): class Broadcast(nnx.Variable[nnx.A]): ... @@ -4215,7 +6327,11 @@ def __init__(self): env = Env() - @nnx.vmap(in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),)) + @nnx.vmap( + in_axes=(nnx.StateAxes({Broadcast: None, Vectorized: 0}),), + graph=True, + graph_updates=True, + ) def f(env: Env): self.assertEqual(env.step.shape, ()) @@ -4227,7 +6343,9 @@ def no_nothing(env: Env): pass is_even = env.index % 2 == 0 - nnx.cond(is_even, increment, no_nothing, env) + nnx.cond( + is_even, increment, no_nothing, env, graph=True, graph_updates=True + ) with self.assertRaisesRegex( ValueError, @@ -4236,7 +6354,61 @@ def no_nothing(env: Env): ): f(env) - def test_example(self): + @parameterized.parameters(True, False) + def test_vmap_and_cond_passthrough_functional(self, graph): + class Broadcast(nnx.Variable[nnx.A]): + ... + + class Vectorized(nnx.Variable[nnx.A]): + ... + + class Env(nnx.Module): + + def __init__(self): + self.broadcast = Broadcast(jnp.array(1)) + self.index = Vectorized(jnp.arange(8)) + self.step = Vectorized(jnp.zeros((8,), jnp.uint32)) + + env = Env() + + model_axes = nnx.prefix(env, {Broadcast: None, Vectorized: 0}, graph=graph) + + @nnx.vmap( + in_axes=(model_axes,), + out_axes=model_axes, + graph=graph, + graph_updates=False, + ) + def f(env): + self.assertEqual(env.step.shape, ()) + + def increment(env: Env): + env.step[...] += 1 + env.broadcast[...] += 1 + + def no_nothing(env: Env): + pass + + is_even = env.index % 2 == 0 + nnx.cond( + is_even, + increment, + no_nothing, + env, + graph=graph, + graph_updates=False, + ) + # Returning the module outputs the unmodified graph state (aliasing the input) + # which triggers the ValueError when graph_updates=False + return env + + with self.assertRaisesRegex(ValueError, r'Duplicate Broadcast'): + f(env) + + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_example(self, graph, graph_updates): class Model(nnx.Module): def __init__(self, din, dout, *, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) @@ -4246,7 +6418,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): def __call__(self, x): return nnx.relu(self.dropout(self.bn(self.linear(x)))) - @nnx.vmap(in_axes=0, out_axes=0) + @nnx.vmap(in_axes=0, out_axes=0, graph=graph, graph_updates=graph_updates) def initialize_ensemble(key): rngs = nnx.Rngs(key) return Model(2, 3, rngs=rngs) @@ -4256,7 +6428,7 @@ def initialize_ensemble(key): self.assertEqual(ensemble.linear.kernel.shape, (5, 2, 3)) - @nnx.vmap(in_axes=(0, None), out_axes=0) + @nnx.vmap(in_axes=(0, None), out_axes=0, graph=graph, graph_updates=graph_updates) def forward(model, x): return model(x) @@ -4264,14 +6436,17 @@ def forward(model, x): y = forward(ensemble, x) self.assertEqual(y.shape, (5, 4, 3)) - def test_example_with_vectorization(self): + @parameterized.parameters( + (True, True), (True, False), (False, False), + ) + def test_example_with_vectorization(self, graph, graph_updates): class LinearEnsemble(nnx.Module): def __init__(self, num, rngs): self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) model = LinearEnsemble(5, rngs=nnx.Rngs(0)) - - @nnx.vmap(in_axes=(0, None), out_axes=0) + + @nnx.vmap(in_axes=(0, None), out_axes=0, graph=graph, graph_updates=graph_updates) def forward(model, x): self.assertEqual(model.w.shape, (2, 3)) return jnp.dot(x, model.w) @@ -4281,12 +6456,13 @@ def forward(model, x): self.assertEqual(y.shape, (5, 4, 3)) - def test_metadata(self): - @nnx.vmap( - in_axes=(None,), - out_axes=0, - axis_size=5, - transform_metadata={nnx.spmd.PARTITION_NAME: 'c'}, + def test_metadata_graph_updates(self): + @nnx.compat.vmap( + in_axes=(None,), + out_axes=0, + axis_size=5, + transform_metadata={nnx.spmd.PARTITION_NAME: 'c'}, + graph_updates=True, ) def create_block(rngs: nnx.Rngs): return nnx.Linear( @@ -4298,7 +6474,73 @@ def create_block(rngs: nnx.Rngs): ), ) - + mesh = jax.make_mesh( + (1, 1, 1), + ('a', 'b', 'c'), + axis_types=(jax.sharding.AxisType.Auto,) * len(('a', 'b', 'c')), + ) + with jax.set_mesh(mesh): + m = create_block(nnx.Rngs(0)) + self.assertEqual(m.kernel.shape, (5, 16, 32)) + self.assertEqual(m.kernel.out_sharding, ('c', 'a', 'b')) + + @parameterized.parameters(True, False) + def test_metadata_graph_updates_functional(self, graph): + @nnx.vmap( + in_axes=(None,), + out_axes=0, + axis_size=5, + graph=graph, + graph_updates=False, + ) + @nnx.transform_metadata( + in_axes=(None,), + out_axes=0, + partition='c', + graph=graph, + ) + def create_block(rngs: nnx.Rngs): + return nnx.Linear( + 16, + 32, + rngs=rngs, + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), ('a', 'b') + ), + ) + + mesh = jax.make_mesh( + (1, 1, 1), + ('a', 'b', 'c'), + axis_types=(jax.sharding.AxisType.Auto,) * len(('a', 'b', 'c')), + ) + with jax.set_mesh(mesh): + m = create_block(nnx.Rngs(0)) + self.assertEqual(m.kernel.shape, (5, 16, 32)) + self.assertEqual(m.kernel.out_sharding, ('c', 'a', 'b')) + + def test_metadata_transform_metadata(self): + @nnx.vmap( + in_axes=(None,), + out_axes=0, + axis_size=5, + graph_updates=False, + ) + @nnx.transform_metadata( + in_axes=(None,), + out_axes=0, + partition='c', + ) + def create_block(rngs: nnx.Rngs): + return nnx.Linear( + 16, + 32, + rngs=rngs, + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), ('a', 'b') + ), + ) + mesh = jax.make_mesh((1, 1, 1), ('a', 'b', 'c'), axis_types=(jax.sharding.AxisType.Auto,) * len(('a', 'b', 'c'))) with jax.set_mesh(mesh): m = create_block(nnx.Rngs(0)) @@ -4330,7 +6572,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): self.assertEqual(state_axes.map_prefix(('bn', 'var'), None), 0) self.assertEqual(state_axes.map_prefix(('bn', 'bias'), None), None) - @nnx.vmap(out_axes=state_axes, axis_size=5) + @nnx.vmap(out_axes=state_axes, axis_size=5, graph=True, graph_updates=True) def create_block(): return Model(2, 3, rngs=nnx.Rngs(0)) @@ -4343,11 +6585,51 @@ def create_block(): self.assertEqual(model.bn.var.shape, (5, 3)) self.assertEqual(model.bn.bias.shape, (3,)) + @parameterized.parameters(True, False) + def test_state_axes_from_state_functional(self, graph): + class Model(nnx.Module): + + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.bn = nnx.BatchNorm(dout, rngs=rngs) + + filter_0 = lambda path, var: ( + (path[0] == 'linear' and path[1] == 'kernel') + or (path[0] == 'bn' and path[1] in ('scale', 'var')) + ) + filter_1 = lambda path, var: ( + (path[0] == 'linear' and path[1] == 'bias') + or (path[0] == 'bn' and path[1] == 'mean') + ) + + abs_model = nnx.eval_shape(lambda: Model(2, 3, rngs=nnx.Rngs(0))) + model_axes = nnx.prefix( + abs_model, {filter_0: 0, filter_1: 1, ...: None}, graph=graph + ) + + @nnx.vmap( + out_axes=model_axes, + axis_size=5, + graph=graph, + graph_updates=False, + ) + def create_block_functional(): + return Model(2, 3, rngs=nnx.Rngs(0)) + + model = create_block_functional() + + self.assertEqual(model.linear.kernel.shape, (5, 2, 3)) + self.assertEqual(model.linear.bias.shape, (3, 5)) + self.assertEqual(model.bn.scale.shape, (5, 3)) + self.assertEqual(model.bn.mean.shape, (3, 5)) + self.assertEqual(model.bn.var.shape, (5, 3)) + self.assertEqual(model.bn.bias.shape, (3,)) - def test_vmap_inconsistent_aliasing(self): + @parameterized.parameters(True, False) + def test_vmap_inconsistent_aliasing(self, graph_updates): v = nnx.Param(jnp.arange(3.0)) - @nnx.vmap(in_axes=(0, None), graph=True, graph_updates=False) + @nnx.vmap(in_axes=(0, None), graph=True, graph_updates=graph_updates) def f(v_mapped, v_broadcast): return v_mapped[...] + v_broadcast[...] @@ -4371,8 +6653,63 @@ def __call__(self, x: jax.Array) -> jax.Array: state_axes = nnx.StateAxes({(nnx.Param, nnx.RngState): 0, ...: None}) @nnx.split_rngs(splits=1) - @nnx.pmap(in_axes=(state_axes,), out_axes=state_axes, axis_size=1, - graph=True) + @nnx.pmap( + in_axes=(state_axes,), + out_axes=state_axes, + axis_size=1, + graph=True, + graph_updates=True, + ) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + rngs = nnx.Rngs(0) + module = create_block(rngs) + initial_key = module.dropout.rngs.key[...] + + assert module.dropout.rngs.count[0] == 0 + assert module.linear.kernel.shape == (1, 3, 10) + assert module.linear.bias.shape == (1, 10) + + x = jnp.ones((1, 1, 3)) + + @nnx.pmap( + in_axes=(state_axes, 0), axis_size=1, graph=True, graph_updates=True + ) + def forward_block(module, x): + return module(x) + + y = forward_block(module, x) + + assert y.shape == (1, 1, 10) + assert module.dropout.rngs.count[0] == 1 + assert module.dropout.rngs.key[...] == initial_key + + y2 = forward_block(module, x) + + assert not jnp.allclose(y, y2) + + @parameterized.parameters(True, False) + def test_basic_single_functional(self, graph): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 10, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.elu(x) + x = self.dropout(x) + return x + + abs_block = nnx.eval_shape(lambda: Block(nnx.Rngs(0))) + state_axes = nnx.prefix(abs_block, {(nnx.Param, nnx.RngState): 0, ...: None}, graph=graph) + + abs_rngs = nnx.eval_shape(lambda: nnx.Rngs(0)) + rngs_axes = nnx.prefix(abs_rngs, {nnx.RngState: 0, ...: None}, graph=graph) + + @nnx.with_rngs(split=1, graph=graph) + @nnx.pmap(in_axes=(rngs_axes,), out_axes=state_axes, axis_size=1, graph=graph, graph_updates=False) def create_block(rngs: nnx.Rngs): return Block(rngs) @@ -4386,7 +6723,7 @@ def create_block(rngs: nnx.Rngs): x = jnp.ones((1, 1, 3)) - @nnx.pmap(in_axes=(state_axes, 0), axis_size=1, graph=True) + @nnx.pmap(in_axes=(state_axes, 0), axis_size=1, graph=graph, graph_updates=False) def forward_block(module, x): return module(x) @@ -4440,7 +6777,8 @@ def forward_block(module: Block, x): # dropout is working! assert not jnp.allclose(y, y2) - def test_replicate_single(self): + @parameterized.parameters(True, False) + def test_replicate_single_functional(self, graph): din = 3 dout = 10 @@ -4452,19 +6790,25 @@ def __init__(self, rngs: nnx.Rngs): def __call__(self, x: jax.Array) -> jax.Array: return self.dropout(nnx.relu(self.linear(x))) - def create_block(rngs: nnx.Rngs): - return Block(rngs) - - state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None}) + state_axes = nnx.prefix( + nnx.eval_shape(lambda: Block(nnx.Rngs(0))), + {nnx.RngState: 0, ...: None}, + graph=graph, + ) @nnx.split_rngs(splits=1) - @partial(nnx.pmap, in_axes=(state_axes, 0), out_axes=0, axis_size=1, - graph=True) + @nnx.pmap( + in_axes=(state_axes, 0), + out_axes=0, + axis_size=1, + graph=graph, + graph_updates=False, + ) def forward_block(module: Block, x): return module(x) rngs = nnx.Rngs(0) - module = create_block(rngs) + module = Block(rngs) initial_key = module.dropout.rngs.key[...] assert module.dropout.rngs.count[...] == 0 @@ -4487,7 +6831,9 @@ def forward_block(module: Block, x): class TestCond(parameterized.TestCase): - def test_basic(self): + + @parameterized.parameters(True, False) + def test_basic(self, graph_updates: bool): class TimeStep(tp.NamedTuple): step: nnx.Variable[jax.Array] reward: nnx.Variable[jax.Array] @@ -4504,18 +6850,20 @@ class Foo(nnx.Pytree): def update(self): def reward_2(self: Foo): - self.timestep = TimeStep( - step=nnx.Variable(self.timestep.step + 1), - reward=nnx.Variable(jnp.array(2.0)), - ) + self.timestep.step[...] += 1 + self.timestep.reward[...] = 2.0 def reward_0(self: Foo): - self.timestep = TimeStep( - step=nnx.Variable(self.timestep.step + 1), - reward=nnx.Variable(jnp.array(0.0)), - ) - - nnx.cond(self.timestep.step % 2 == 0, reward_2, reward_0, self) + self.timestep.step[...] += 1 + self.timestep.reward[...] = 0.0 + + nnx.cond( + self.timestep.step % 2 == 0, + reward_2, + reward_0, + self, + graph_updates=graph_updates, + ) foo = Foo(timestep=TimeStep.zero()) foo.update() @@ -4524,12 +6872,6 @@ def reward_0(self: Foo): foo.update() self.assertEqual(foo.timestep.step[...], 2) self.assertEqual(foo.timestep.reward[...], 0.0) - foo.update() - self.assertEqual(foo.timestep.step[...], 3) - self.assertEqual(foo.timestep.reward[...], 2.0) - foo.update() - self.assertEqual(foo.timestep.step[...], 4) - self.assertEqual(foo.timestep.reward[...], 0.0) @parameterized.parameters( (True, False), @@ -4620,7 +6962,8 @@ def update_b(a, b): self.assertEqual(a[...], 1) self.assertEqual(b[...], 10) - def test_cond_shared_references(self): + @parameterized.parameters(True, False) + def test_cond_shared_references(self, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): @@ -4636,10 +6979,10 @@ def true_fn(m): def false_fn(m): m.b[...] += 2 - nnx.cond(True, true_fn, false_fn, m, graph=True, graph_updates=False) + nnx.cond(True, true_fn, false_fn, m, graph=True, graph_updates=graph_updates) np.testing.assert_allclose(m.a[...], 1) np.testing.assert_allclose(m.b[...], 1) - nnx.cond(False, true_fn, false_fn, m, graph=True, graph_updates=False) + nnx.cond(False, true_fn, false_fn, m, graph=True, graph_updates=graph_updates) np.testing.assert_allclose(m.a[...], 3) np.testing.assert_allclose(m.b[...], 3) @@ -4716,7 +7059,8 @@ def add_100(x): graph=graph, graph_updates=graph_updates) self.assertEqual(x[...], 111) - def test_switch_shared_references(self): + @parameterized.parameters(True, False) + def test_switch_shared_references(self, graph_updates): @dataclasses.dataclass class Foo(nnx.Module): a: nnx.Variable @@ -4731,11 +7075,11 @@ def add_a(m): def add_b(m): m.b[...] += 10 - nnx.switch(0, (add_a, add_b), m, graph=True, graph_updates=False) + nnx.switch(0, (add_a, add_b), m, graph=True, graph_updates=graph_updates) np.testing.assert_allclose(m.a[...], 1) np.testing.assert_allclose(m.b[...], 1) - nnx.switch(1, (add_a, add_b), m, graph=True, graph_updates=False) + nnx.switch(1, (add_a, add_b), m, graph=True, graph_updates=graph_updates) np.testing.assert_allclose(m.a[...], 11) np.testing.assert_allclose(m.b[...], 11) @@ -4799,12 +7143,13 @@ def fwd_fn(input): graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal(y, x * 8) - def test_shared_module(self): + @parameterized.parameters(True, False) + def test_shared_module(self, graph_updates): m1 = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) m2 = nnx.Linear(10, 10, use_bias=False, rngs=nnx.Rngs(0)) m2.kernel = m1.kernel module = nnx.Sequential(m1, m2) - self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params + self.assertLen(jax.tree.leaves(nnx.compat.state(module)), 2) # only m1 params def fwd_fn(input): m, x, c = input @@ -4814,8 +7159,9 @@ def fwd_fn(input): x = 1e1 * jax.random.normal(jax.random.key(0), (10,)) _, y, _ = nnx.while_loop( - lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0)) - self.assertLen(jax.tree.leaves(nnx.state(module)), 2) # only m1 params + lambda input: input[-1] > 0, fwd_fn, (module, x, 2.0), + graph=True, graph_updates=graph_updates) + self.assertLen(jax.tree.leaves(nnx.compat.state(module)), 2) # only m1 params np.testing.assert_array_equal( m1.kernel[...], jnp.zeros((10, 10)), @@ -4885,7 +7231,8 @@ def fwd_fn(input): lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0), graph=graph, graph_updates=graph_updates) - def test_repeated_object(self): + @parameterized.parameters(True, False) + def test_repeated_object(self, graph_updates): m = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) def body_fn(val): @@ -4896,18 +7243,40 @@ def body_fn(val): lambda val: val[0] < 2, body_fn, (0, m, m), + graph=True, graph_updates=graph_updates, ) - def test_immut_fori_loop(self): + @parameterized.parameters( + (True, True), + (True, False), + (False, False), + ) + def test_immut_fori_loop(self, graph: bool, graph_updates: bool): def immut_fn(i, carry): - g_accum = carry - grads = jax.tree.map(jnp.ones_like, g_accum) - g_accum = jax.tree.map(lambda gm, g: gm + g, g_accum, grads) - return g_accum - model = nnx.Linear(10, 10, rngs=nnx.Rngs(0), use_bias=False) - g_accum = jax.tree.map(jnp.zeros_like, nnx.state(model)) - nnx.fori_loop(0, 2, immut_fn, g_accum) + def update_fn(p, c): + c[...] += 1.0 + return c + + nnx.map(update_fn, carry, graph=graph) + return carry + + model = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + zeros = nnx.map( + lambda p, v: nnx.Variable(jnp.zeros_like(v)), model, graph=graph + ) + res = nnx.fori_loop( + 0, 2, immut_fn, zeros, graph=graph, graph_updates=graph_updates + ) + + def assert_zeros(path, c): + np.testing.assert_array_equal(c[...], jnp.full(c.shape, 2.0)) + return c + + nnx.map(assert_zeros, res, graph=graph) + + self.assertIs(zeros.kernel, res.kernel) + self.assertIs(zeros.bias, res.bias) @parameterized.parameters( (True, True), (True, False), (False, False), @@ -4939,7 +7308,8 @@ def fwd_fn(i, input): graph=graph, graph_updates=graph_updates) np.testing.assert_array_equal(y, x * 2 * 3) - def test_fori_loop_with_sharing(self): + @parameterized.parameters(True, False) + def test_fori_loop_with_sharing(self, graph_updates): class A(nnx.Pytree): def __init__(self): self.params = nnx.Param(jnp.zeros((10,), dtype=int)) @@ -4962,9 +7332,9 @@ def increment(_, d: D) -> D: d.a.params[...] += 1 return d - @nnx.jit + @nnx.jit(graph=True, graph_updates=graph_updates) def rollout(d: D): - nnx.fori_loop(0, 10, increment, d) + nnx.fori_loop(0, 10, increment, d, graph=True, graph_updates=graph_updates) d = D() rollout(d) @@ -5515,222 +7885,5 @@ def forward_block(rng_state, rest_state, x): assert not jnp.allclose(y, y2) -class TestPureJaxFancyScan(absltest.TestCase): - - def test_carry_and_scan(self): - def cumsum(carry, x): - carry = carry + x - return carry, carry - - final_carry, ys = pure_jax_fancy_scan( - cumsum, jnp.array(0.0), jnp.arange(5.0), - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry, 10.0) - np.testing.assert_allclose(ys, jnp.array([0., 1., 3., 6., 10.])) - - def test_carry_only_output(self): - def sum_fn(carry, x): - return carry + x - - result = pure_jax_fancy_scan( - sum_fn, jnp.array(0.0), jnp.arange(5.0), - in_axes=(nnx.Carry, 0), out_axes=nnx.Carry, - ) - np.testing.assert_allclose(result, 10.0) - - def test_broadcast_args(self): - def scale_cumsum(carry, scale, x): - carry = carry + x * scale - return carry, carry - - final_carry, _ = pure_jax_fancy_scan( - scale_cumsum, jnp.array(0.0), jnp.array(2.0), jnp.arange(5.0), - in_axes=(nnx.Carry, None, 0), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry, 20.0) - - def test_pytree_carry(self): - def dict_scan(carry, x): - carry = {'a': carry['a'] + x['a'], 'b': carry['b'] + x['b']} - return carry, carry - - xs = {'a': jnp.arange(3.0), 'b': jnp.ones(3)} - init = {'a': jnp.array(0.0), 'b': jnp.array(0.0)} - final_carry, _ = pure_jax_fancy_scan( - dict_scan, init, xs, - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry['a'], 3.0) - np.testing.assert_allclose(final_carry['b'], 3.0) - - def test_no_carry_all_scanned(self): - def double(x): - return x * 2 - - ys = pure_jax_fancy_scan(double, jnp.arange(5.0), in_axes=0, out_axes=0) - np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) - - def test_reverse(self): - def cumsum(carry, x): - carry = carry + x - return carry, carry - - final_carry, _ = pure_jax_fancy_scan( - cumsum, jnp.array(0.0), jnp.arange(5.0), - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), reverse=True, - ) - np.testing.assert_allclose(final_carry, 10.0) - - def test_pytree_prefix_in_axes(self): - def fn(carry, x): - carry = carry + x['a'] + x['b'] - return carry, carry - - xs = {'a': jnp.arange(3.0), 'b': jnp.array(1.0)} - final_carry, _ = pure_jax_fancy_scan( - fn, jnp.array(0.0), xs, - in_axes=(nnx.Carry, {'a': 0, 'b': None}), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry, 6.0) - - def test_nested_carry_rejected(self): - with self.assertRaises(ValueError): - pure_jax_fancy_scan( - lambda x: x, - {'a': jnp.array(1.0)}, - in_axes=({'a': nnx.Carry},), out_axes=nnx.Carry, - ) - - def test_broadcast_out_axes_rejected(self): - with self.assertRaises(ValueError): - pure_jax_fancy_scan( - lambda c, x: (c, x), - jnp.array(0.0), jnp.arange(3.0), - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, None), - ) - - def test_none_broadcast_input(self): - def fn(carry, _unused, x): - carry = carry + x - return carry, carry - - final_carry, _ = pure_jax_fancy_scan( - fn, jnp.array(0.0), None, jnp.arange(3.0), - in_axes=(nnx.Carry, None, 0), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry, 3.0) - - def test_none_nested_in_arg(self): - def fn(carry, x): - carry = carry + x['a'] - return carry, carry - - xs = {'a': jnp.arange(3.0), 'b': None} - final_carry, _ = pure_jax_fancy_scan( - fn, jnp.array(0.0), xs, - in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry, 3.0) - - def test_nested_carry_in_out_axes_rejected(self): - with self.assertRaises(ValueError): - pure_jax_fancy_scan( - lambda c, x: (c, x), - jnp.array(0.0), jnp.arange(3.0), - in_axes=(nnx.Carry, 0), out_axes=({'a': nnx.Carry},), - ) - - def test_carry_in_in_axes_only_rejected(self): - with self.assertRaises(ValueError): - pure_jax_fancy_scan( - lambda c, x: (c + x,), - jnp.array(0.0), jnp.arange(3.0), - in_axes=(nnx.Carry, 0), out_axes=(0,), - ) - - def test_carry_in_out_axes_only_rejected(self): - with self.assertRaises(ValueError): - pure_jax_fancy_scan( - lambda x: x, - jnp.arange(3.0), - in_axes=(0,), out_axes=nnx.Carry, - ) - - def test_non_tuple_carry_only(self): - def f(carry): - return carry + 1.0 - - result = pure_jax_fancy_scan( - f, jnp.array(0.0), - in_axes=nnx.Carry, out_axes=nnx.Carry, length=5, - ) - np.testing.assert_allclose(result, 5.0) - - def test_non_tuple_scan_only(self): - def f(x): - return x * 2 - - ys = pure_jax_fancy_scan( - f, jnp.arange(5.0), - in_axes=0, out_axes=0, - ) - np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) - - def test_scan_axis_1(self): - def cumsum(carry, x): - carry = carry + x - return carry, carry - - x = jnp.arange(10.0).reshape((2, 5)) - final_carry, ys = pure_jax_fancy_scan( - cumsum, jnp.zeros(2), x, - in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 1), - ) - np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) - expected_ys = jnp.array([ - [0., 1., 3., 6., 10.], - [5., 11., 18., 26., 35.] - ]) - np.testing.assert_allclose(ys, expected_ys) - - def test_scan_axis_negative_1(self): - def cumsum(carry, x): - carry = carry + x - return carry, carry - - x = jnp.arange(10.0).reshape((2, 5)) - final_carry, ys = pure_jax_fancy_scan( - cumsum, jnp.zeros(2), x, - in_axes=(nnx.Carry, -1), out_axes=(nnx.Carry, -1), - ) - np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) - expected_ys = jnp.array([ - [0., 1., 3., 6., 10.], - [5., 11., 18., 26., 35.] - ]) - np.testing.assert_allclose(ys, expected_ys) - - def test_scan_different_in_out_axes(self): - def cumsum(carry, x): - carry = carry + x - return carry, carry - - x = jnp.arange(10.0).reshape((2, 5)) - final_carry, ys = pure_jax_fancy_scan( - cumsum, jnp.zeros(2), x, - in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 0), - ) - np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) - expected_ys = jnp.array([ - [0., 5.], - [1., 11.], - [3., 18.], - [6., 26.], - [10., 35.] - ]) - np.testing.assert_allclose(ys, expected_ys) - - if __name__ == '__main__': absltest.main()