Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 93 additions & 3 deletions docs_nnx/guides/view.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
"metadata": {},
"source": [
"# Model Views\n",
"This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:"
"This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.view`, the canonical NNX View that sets module modes.\n",
"\n",
"For other views, NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original."
]
},
{
Expand Down Expand Up @@ -320,11 +322,73 @@
},
{
"cell_type": "markdown",
"id": "1acbcc09",
"id": "74740224",
"metadata": {},
"source": [
"The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n",
"\n",
"## Using `with_vars`\n",
"\n",
"{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `view` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n",
"\n",
"The flags it controls are:\n",
"\n",
"- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n",
"- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n",
"- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n",
"\n",
"The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34f4d760",
"metadata": {},
"outputs": [],
"source": [
"from flax import nnx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"class SimpleModel(nnx.Module):\n",
" def __init__(self, rngs):\n",
" self.linear = nnx.Linear(2, 3, rngs=rngs)\n",
"\n",
"model = SimpleModel(nnx.Rngs(0))\n",
"\n",
"# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n",
"ref_model = nnx.with_vars(model, ref=True)\n",
"ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n",
"\n",
"# The original model's kernel is unchanged; ref_model has doubled values\n",
"assert model.linear.kernel is not ref_model.linear.kernel"
]
},
{
"cell_type": "markdown",
"id": "16057b6d",
"metadata": {},
"source": [
"Use the `only` filter to convert only a subset of Variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11fa1e81",
"metadata": {},
"outputs": [],
"source": [
"# only convert Param variables, leave BatchStat variables unchanged\n",
"ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)"
]
},
{
"cell_type": "markdown",
"id": "1acbcc09",
"metadata": {},
"source": [
"## Using `with_attributes`\n",
"\n",
"If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged."
Expand Down Expand Up @@ -412,7 +476,33 @@
"id": "bf521e45",
"metadata": {},
"source": [
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`."
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n",
"\n",
"## Other NNX views\n",
"\n",
"Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n",
"\n",
"- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n",
"\n",
" ```python\n",
" _, state = nnx.split(model)\n",
" pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n",
" ```\n",
"\n",
"- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n",
"\n",
" ```python\n",
" with jax.set_mesh(mesh):\n",
" abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n",
" abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n",
" ```\n",
"\n",
"- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n",
"\n",
" ```python\n",
" # Split params stream into 4 keys (one per vmap replica); fork the rest\n",
" vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n",
" ```"
]
}
],
Expand Down
68 changes: 67 additions & 1 deletion docs_nnx/guides/view.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ jupytext:
---

# Model Views
This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:
This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.view`, the canonical NNX View that sets module modes.

For other views, NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original.

```{code-cell}
from flax import nnx
Expand Down Expand Up @@ -216,6 +218,44 @@ print(nnx.view_info(model))

The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.

## Using `with_vars`

{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `view` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.

The flags it controls are:

- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.
- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.
- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.

The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original).

```{code-cell}
from flax import nnx
import jax
import jax.numpy as jnp

class SimpleModel(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)

model = SimpleModel(nnx.Rngs(0))

# ref=True: expose Variable values as JAX refs so jax.tree.map can update them
ref_model = nnx.with_vars(model, ref=True)
ref_model = jax.tree.map(lambda x: x * 2, ref_model)

# The original model's kernel is unchanged; ref_model has doubled values
assert model.linear.kernel is not ref_model.linear.kernel
```

Use the `only` filter to convert only a subset of Variables:

```{code-cell}
# only convert Param variables, leave BatchStat variables unchanged
ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)
```

## Using `with_attributes`

If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged.
Expand Down Expand Up @@ -280,3 +320,29 @@ print(noisy_model)s
```

Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.

## Other NNX views

Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:

- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.

```python
_, state = nnx.split(model)
pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain
```

- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.

```python
with jax.set_mesh(mesh):
abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))
abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars
```

- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.

```python
# Split params stream into 4 keys (one per vmap replica); fork the rest
vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)
```
Loading