diff --git a/docs_nnx/guides/view.ipynb b/docs_nnx/guides/view.ipynb index c864f9e56..367f22373 100644 --- a/docs_nnx/guides/view.ipynb +++ b/docs_nnx/guides/view.ipynb @@ -6,7 +6,9 @@ "metadata": {}, "source": [ "# Model Views\n", - "This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:" + "This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.view`, the canonical NNX View that sets module modes.\n", + "\n", + "For other views, NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original." ] }, { @@ -320,11 +322,73 @@ }, { "cell_type": "markdown", - "id": "1acbcc09", + "id": "74740224", "metadata": {}, "source": [ "The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n", "\n", + "## Using `with_vars`\n", + "\n", + "{func}`nnx.with_vars ` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `view` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n", + "\n", + "The flags it controls are:\n", + "\n", + "- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n", + "- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n", + "- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n", + "\n", + "The `only` argument accepts a {doc}`Filter ` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34f4d760", + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "class SimpleModel(nnx.Module):\n", + " def __init__(self, rngs):\n", + " self.linear = nnx.Linear(2, 3, rngs=rngs)\n", + "\n", + "model = SimpleModel(nnx.Rngs(0))\n", + "\n", + "# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n", + "ref_model = nnx.with_vars(model, ref=True)\n", + "ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n", + "\n", + "# The original model's kernel is unchanged; ref_model has doubled values\n", + "assert model.linear.kernel is not ref_model.linear.kernel" + ] + }, + { + "cell_type": "markdown", + "id": "16057b6d", + "metadata": {}, + "source": [ + "Use the `only` filter to convert only a subset of Variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11fa1e81", + "metadata": {}, + "outputs": [], + "source": [ + "# only convert Param variables, leave BatchStat variables unchanged\n", + "ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)" + ] + }, + { + "cell_type": "markdown", + "id": "1acbcc09", + "metadata": {}, + "source": [ "## Using `with_attributes`\n", "\n", "If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged." @@ -412,7 +476,33 @@ "id": "bf521e45", "metadata": {}, "source": [ - "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`." + "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n", + "\n", + "## Other NNX views\n", + "\n", + "Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n", + "\n", + "- {func}`nnx.as_pure ` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n", + "\n", + " ```python\n", + " _, state = nnx.split(model)\n", + " pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n", + " ```\n", + "\n", + "- {func}`nnx.as_abstract ` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n", + "\n", + " ```python\n", + " with jax.set_mesh(mesh):\n", + " abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n", + " abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n", + " ```\n", + "\n", + "- {func}`nnx.with_rngs ` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n", + "\n", + " ```python\n", + " # Split params stream into 4 keys (one per vmap replica); fork the rest\n", + " vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n", + " ```" ] } ], diff --git a/docs_nnx/guides/view.md b/docs_nnx/guides/view.md index d2fb3f8d4..ad886814b 100644 --- a/docs_nnx/guides/view.md +++ b/docs_nnx/guides/view.md @@ -9,7 +9,9 @@ jupytext: --- # Model Views -This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example: +This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.view`, the canonical NNX View that sets module modes. + +For other views, NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original. ```{code-cell} from flax import nnx @@ -216,6 +218,44 @@ print(nnx.view_info(model)) The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree. +## Using `with_vars` + +{func}`nnx.with_vars ` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `view` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX. + +The flags it controls are: + +- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state. +- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step. +- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform. + +The `only` argument accepts a {doc}`Filter ` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original). + +```{code-cell} +from flax import nnx +import jax +import jax.numpy as jnp + +class SimpleModel(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + +model = SimpleModel(nnx.Rngs(0)) + +# ref=True: expose Variable values as JAX refs so jax.tree.map can update them +ref_model = nnx.with_vars(model, ref=True) +ref_model = jax.tree.map(lambda x: x * 2, ref_model) + +# The original model's kernel is unchanged; ref_model has doubled values +assert model.linear.kernel is not ref_model.linear.kernel +``` + +Use the `only` filter to convert only a subset of Variables: + +```{code-cell} +# only convert Param variables, leave BatchStat variables unchanged +ref_params = nnx.with_vars(model, ref=True, only=nnx.Param) +``` + ## Using `with_attributes` If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes ` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged. @@ -280,3 +320,29 @@ print(noisy_model)s ``` Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`. + +## Other NNX views + +Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees: + +- {func}`nnx.as_pure ` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed. + + ```python + _, state = nnx.split(model) + pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain + ``` + +- {func}`nnx.as_abstract ` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes. + + ```python + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0))) + abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars + ``` + +- {func}`nnx.with_rngs ` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys. + + ```python + # Split params stream into 4 keys (one per vmap replica); fork the rest + vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...) + ```