diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index b1e4c1404..bd98c3603 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -47,6 +47,8 @@ "from flax import nnx\n", "\n", "# Ignore this if you are already running on a TPU or GPU\n", + "nnx.set_graph_mode(False)\n", + "nnx.set_graph_updates(False)\n", "if not jax._src.xla_bridge.backends_are_initialized():\n", " jax.config.update('jax_num_cpu_devices', 8)\n", "print(f'You have 8 “fake” JAX devices now: {jax.devices()}')" @@ -85,7 +87,6 @@ "metadata": {}, "outputs": [], "source": [ - "nnx.use_eager_sharding(True)\n", "assert nnx.using_eager_sharding()" ] }, @@ -99,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "2d849e2e", "metadata": {}, "outputs": [], @@ -118,12 +119,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "67bbd440", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Param( # 16 (64 B)\n", + " value=Array([[1., 1., 1., 1.],\n", + " [1., 1., 1., 1.],\n", + " [1., 1., 1., 1.],\n", + " [1., 1., 1., 1.]], dtype=float32),\n", + " out_sharding=(None, 'model'),\n", + " mesh=Mesh(axis_sizes=(2, 4), axis_names=('data', 'model'), axis_types=(Explicit, Explicit))\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)" + "nnx.Param(jnp.ones((4, 4)), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh)" ] }, { @@ -141,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -210,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -242,7 +261,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -319,21 +338,20 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class DotReluDot(nnx.Module):\n", " def __init__(self, depth: int, rngs: nnx.Rngs):\n", - " init_fn = nnx.initializers.lecun_normal()\n", " self.dot1 = nnx.Linear(\n", " depth, depth,\n", - " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n", + " kernel_metadata={'out_sharding': (None, 'model')},\n", " use_bias=False, # or use `bias_init` to give it annotation too\n", " rngs=rngs)\n", " self.w2 = nnx.Param(\n", - " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", - " sharding=('model', None),\n", + " rngs.params.lecun_normal()((depth, depth)), # RNG key and shape for W2 creation\n", + " out_sharding=('model', None),\n", " )\n", "\n", " def __call__(self, x: jax.Array):\n", @@ -347,7 +365,8 @@ " def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs):\n", " # Annotate the additional axis with sharding=None, meaning it will be\n", " # replicated across all devices.\n", - " @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None})\n", + " @nnx.vmap\n", + " @nnx.transform_metadata(partition=None)\n", " def create_sublayers(r):\n", " return DotReluDot(depth, r)\n", " self.layers = create_sublayers(rngs.fork(split=num_layers))\n", @@ -368,18 +387,100 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 19, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_21221/1601373863.py:28: DeprecationWarning: The 'split' argument of 'fork' is deprecated; use the 'split' method instead.\n", + " self.layers = create_sublayers(rngs.fork(split=num_layers))\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "1.251457\n", - "0.8495563\n", - "0.6590716\n", - "0.5399748\n", - "0.39150265\n" + "\u001b[38;2;79;201;177mMultiDotReluDot\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 4,194,304 (16.8 MB)\u001b[0m\n", + " \u001b[38;2;156;220;254mlayers\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mDotReluDot\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 4,194,304 (16.8 MB)\u001b[0m\n", + " \u001b[38;2;156;220;254mdot1\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 2,097,152 (8.4 MB)\u001b[0m\n", + " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m,\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 2,097,152 (8.4 MB)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m1024\u001b[0m, \u001b[38;2;182;207;169m1024\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mout_sharding\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m, \u001b[38;2;86;156;214mNone\u001b[0m, \u001b[38;2;207;144;120m'model'\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mdot_general\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m,\n", + " \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m,\n", + " \u001b[38;2;156;220;254min_features\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m1024\u001b[0m,\n", + " \u001b[38;2;156;220;254mout_features\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m1024\u001b[0m,\n", + " \u001b[38;2;156;220;254mparam_dtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mfloat32\u001b[0m,\n", + " \u001b[38;2;156;220;254mprecision\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m,\n", + " \u001b[38;2;156;220;254mpreferred_element_type\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mNone\u001b[0m,\n", + " \u001b[38;2;156;220;254mpromote_dtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m,\n", + " \u001b[38;2;156;220;254muse_bias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mFalse\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mw2\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 2,097,152 (8.4 MB)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m1024\u001b[0m, \u001b[38;2;182;207;169m1024\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mout_sharding\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;207;144;120m'model'\u001b[0m, \u001b[38;2;86;156;214mNone\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Sharding spec ('model',) implies that array axis 0 is partitioned 4 times, but does not evenly divide the dimension size 2. Got shape: (2, 1024, 1024) and sharding NamedSharding(mesh=AbstractMesh('data': 2, 'model': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None), spec=PartitionSpec('model', None, None))", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mValueError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[19]\u001b[39m\u001b[32m, line 19\u001b[39m\n\u001b[32m 17\u001b[39m model = MultiDotReluDot(\u001b[32m1024\u001b[39m, \u001b[32m2\u001b[39m, rngs=nnx.Rngs(\u001b[32m0\u001b[39m))\n\u001b[32m 18\u001b[39m \u001b[38;5;28mprint\u001b[39m(model)\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m optimizer = \u001b[43mnnx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mOptimizer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptax\u001b[49m\u001b[43m.\u001b[49m\u001b[43madam\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m1e-3\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwrt\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnnx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mParam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[38;5;66;03m# The loop\u001b[39;00m\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m5\u001b[39m):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/pytreelib.py:420\u001b[39m, in \u001b[36mPytreeMeta.__call__\u001b[39m\u001b[34m(cls, *args, **kwargs)\u001b[39m\n\u001b[32m 419\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, *args: Any, **kwargs: Any) -> Any:\n\u001b[32m--> \u001b[39m\u001b[32m420\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_graph_node_meta_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/pytreelib.py:431\u001b[39m, in \u001b[36m_graph_node_meta_call\u001b[39m\u001b[34m(cls, *args, **kwargs)\u001b[39m\n\u001b[32m 429\u001b[39m \u001b[38;5;28mobject\u001b[39m.\u001b[34m__setattr__\u001b[39m(node, \u001b[33m'\u001b[39m\u001b[33m_pytree__state\u001b[39m\u001b[33m'\u001b[39m, PytreeState())\n\u001b[32m 430\u001b[39m \u001b[38;5;28mobject\u001b[39m.\u001b[34m__setattr__\u001b[39m(node, \u001b[33m'\u001b[39m\u001b[33m_pytree__nodes\u001b[39m\u001b[33m'\u001b[39m, \u001b[38;5;28mcls\u001b[39m._pytree__nodes)\n\u001b[32m--> \u001b[39m\u001b[32m431\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_pytree_meta_construct\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 432\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m._pytree__is_pytree:\n\u001b[32m 433\u001b[39m missing: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mbool\u001b[39m] = {}\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/pytreelib.py:423\u001b[39m, in \u001b[36mPytreeMeta._pytree_meta_construct\u001b[39m\u001b[34m(cls, self, *args, **kwargs)\u001b[39m\n\u001b[32m 422\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_pytree_meta_construct\u001b[39m(\u001b[38;5;28mcls\u001b[39m, \u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m423\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[34;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/training/optimizer.py:88\u001b[39m, in \u001b[36m_check_wrt_arg_passed.._check_wrt_wrapper\u001b[39m\u001b[34m(wrt, *args, **kwargs)\u001b[39m\n\u001b[32m 83\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(wrt, _Missing):\n\u001b[32m 84\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 85\u001b[39m \u001b[33m'\u001b[39m\u001b[33mMissing required argument `wrt`. As of Flax 0.11.0 the `wrt` argument is required, \u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 86\u001b[39m \u001b[33m'\u001b[39m\u001b[33mif you want to keep the previous use nnx.ModelAndOptimizer instead of nnx.Optimizer.\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 87\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m88\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwrt\u001b[49m\u001b[43m=\u001b[49m\u001b[43mwrt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/training/optimizer.py:154\u001b[39m, in \u001b[36mOptimizer.__init__\u001b[39m\u001b[34m(self, model, tx, wrt)\u001b[39m\n\u001b[32m 151\u001b[39m \u001b[38;5;28mself\u001b[39m.step = OptState(jnp.array(\u001b[32m0\u001b[39m, dtype=jnp.uint32))\n\u001b[32m 152\u001b[39m \u001b[38;5;28mself\u001b[39m.tx = tx\n\u001b[32m 153\u001b[39m \u001b[38;5;28mself\u001b[39m.opt_state = nnx.data(\n\u001b[32m--> \u001b[39m\u001b[32m154\u001b[39m \u001b[43mto_opt_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtx\u001b[49m\u001b[43m.\u001b[49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnnx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwrt\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 155\u001b[39m )\n\u001b[32m 156\u001b[39m \u001b[38;5;28mself\u001b[39m.wrt = wrt\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/training/optimizer.py:57\u001b[39m, in \u001b[36mto_opt_state\u001b[39m\u001b[34m(tree)\u001b[39m\n\u001b[32m 54\u001b[39m opt_state = OptArray(x)\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m opt_state\n\u001b[32m---> \u001b[39m\u001b[32m57\u001b[39m tree = \u001b[43mjax\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtree\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 58\u001b[39m \u001b[43m \u001b[49m\u001b[43m_to_opt_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 59\u001b[39m \u001b[43m \u001b[49m\u001b[43mtree\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 60\u001b[39m \u001b[43m \u001b[49m\u001b[43mis_leaf\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mVariable\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 61\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 62\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m tree\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/tree.py:155\u001b[39m, in \u001b[36mmap\u001b[39m\u001b[34m(f, tree, is_leaf, *rest)\u001b[39m\n\u001b[32m 115\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mmap\u001b[39m(f: Callable[..., Any],\n\u001b[32m 116\u001b[39m tree: Any,\n\u001b[32m 117\u001b[39m *rest: Any,\n\u001b[32m 118\u001b[39m is_leaf: Callable[[Any], \u001b[38;5;28mbool\u001b[39m] | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mNone\u001b[39;00m) -> Any:\n\u001b[32m 119\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Maps a multi-input function over pytree args to produce a new pytree.\u001b[39;00m\n\u001b[32m 120\u001b[39m \n\u001b[32m 121\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 153\u001b[39m \u001b[33;03m - :func:`jax.tree.reduce`\u001b[39;00m\n\u001b[32m 154\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m155\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtree_util\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtree_map\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtree\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43mrest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_leaf\u001b[49m\u001b[43m=\u001b[49m\u001b[43mis_leaf\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py:369\u001b[39m, in \u001b[36mtree_map\u001b[39m\u001b[34m(f, tree, is_leaf, *rest)\u001b[39m\n\u001b[32m 367\u001b[39m leaves, treedef = tree_flatten(tree, is_leaf)\n\u001b[32m 368\u001b[39m all_leaves = [leaves] + [treedef.flatten_up_to(r) \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m rest]\n\u001b[32m--> \u001b[39m\u001b[32m369\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtreedef\u001b[49m\u001b[43m.\u001b[49m\u001b[43munflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mxs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mxs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mzip\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mall_leaves\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/tree_util.py:369\u001b[39m, in \u001b[36m\u001b[39m\u001b[34m(.0)\u001b[39m\n\u001b[32m 367\u001b[39m leaves, treedef = tree_flatten(tree, is_leaf)\n\u001b[32m 368\u001b[39m all_leaves = [leaves] + [treedef.flatten_up_to(r) \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m rest]\n\u001b[32m--> \u001b[39m\u001b[32m369\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m treedef.unflatten(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mxs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m xs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(*all_leaves))\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/training/optimizer.py:52\u001b[39m, in \u001b[36mto_opt_state.._to_opt_state\u001b[39m\u001b[34m(x)\u001b[39m\n\u001b[32m 50\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_to_opt_state\u001b[39m(x):\n\u001b[32m 51\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(x, Variable):\n\u001b[32m---> \u001b[39m\u001b[32m52\u001b[39m opt_state = \u001b[43mOptVariable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_metadata\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[32m 53\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 54\u001b[39m opt_state = OptArray(x)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/variablelib.py:1165\u001b[39m, in \u001b[36mVariableMeta.__call__\u001b[39m\u001b[34m(cls, *args, **kwargs)\u001b[39m\n\u001b[32m 1164\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, *args, **kwargs):\n\u001b[32m-> \u001b[39m\u001b[32m1165\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_variable_meta_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/variablelib.py:1168\u001b[39m, in \u001b[36mVariableMeta._variable_meta_call\u001b[39m\u001b[34m(cls, *args, **kwargs)\u001b[39m\n\u001b[32m 1167\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_variable_meta_call\u001b[39m(\u001b[38;5;28mcls\u001b[39m, *args, **kwargs):\n\u001b[32m-> \u001b[39m\u001b[32m1168\u001b[39m variable = \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1169\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m variable.hijax:\n\u001b[32m 1170\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m _new_hijax_from_variable(variable)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/nnx/variablelib.py:1352\u001b[39m, in \u001b[36mVariable.__init__\u001b[39m\u001b[34m(self, value, hijax, ref, eager_sharding, **metadata)\u001b[39m\n\u001b[32m 1350\u001b[39m \u001b[38;5;66;03m# shard the _value if applicable\u001b[39;00m\n\u001b[32m 1351\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m eager_sharding \u001b[38;5;129;01mand\u001b[39;00m \u001b[33m'\u001b[39m\u001b[33mout_sharding\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m metadata:\n\u001b[32m-> \u001b[39m\u001b[32m1352\u001b[39m value = \u001b[43mcore_spmd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mshard_value\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1353\u001b[39m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1354\u001b[39m \u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mout_sharding\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1355\u001b[39m \u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43msharding_rules\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1356\u001b[39m \u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mmesh\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1357\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1358\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m ref:\n\u001b[32m 1359\u001b[39m value = jax.new_ref(value) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/core/spmd.py:61\u001b[39m, in \u001b[36mshard_value\u001b[39m\u001b[34m(value, sharding, sharding_rules, mesh)\u001b[39m\n\u001b[32m 56\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 57\u001b[39m \u001b[33m'\u001b[39m\u001b[33mAn auto mesh context or metadata is required if creating a variable\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 58\u001b[39m \u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33m with annotation \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msharding\u001b[38;5;132;01m=}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 59\u001b[39m \u001b[33m'\u001b[39m\u001b[33mFor more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 60\u001b[39m pspec = get_pspec(sharding, sharding_rules)\n\u001b[32m---> \u001b[39m\u001b[32m61\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_apply_sharding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mNamedSharding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpspec\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmesh\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/flax/core/spmd.py:37\u001b[39m, in \u001b[36m_apply_sharding\u001b[39m\u001b[34m(value, sharding, mesh)\u001b[39m\n\u001b[32m 35\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_apply_sharding\u001b[39m(value, sharding, mesh):\n\u001b[32m 36\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m mesh.are_all_axes_explicit:\n\u001b[32m---> \u001b[39m\u001b[32m37\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjax\u001b[49m\u001b[43m.\u001b[49m\u001b[43msharding\u001b[49m\u001b[43m.\u001b[49m\u001b[43mreshard\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msharding\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 38\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m mesh.are_all_axes_auto:\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m jax.lax.with_sharding_constraint(value, sharding)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:2256\u001b[39m, in \u001b[36mreshard\u001b[39m\u001b[34m(xs, out_shardings)\u001b[39m\n\u001b[32m 2252\u001b[39m ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim)) \u001b[38;5;66;03m# pytype: disable=attribute-error\u001b[39;00m\n\u001b[32m 2253\u001b[39m cmesh = (s.mesh \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(s, NamedSharding) \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[32m 2254\u001b[39m \u001b[38;5;28misinstance\u001b[39m(s.mesh, mesh_lib.Mesh))\n\u001b[32m 2255\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m-> \u001b[39m\u001b[32m2256\u001b[39m out_flat.append(\u001b[43mreshard_p\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdst_sharding\u001b[49m\u001b[43m=\u001b[49m\u001b[43mds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconcrete_mesh\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcmesh\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 2257\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m tree_unflatten(treedef, out_flat)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:632\u001b[39m, in \u001b[36mPrimitive.bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 630\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **params):\n\u001b[32m 631\u001b[39m args = args \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.skip_canonicalization \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(canonicalize_value, args)\n\u001b[32m--> \u001b[39m\u001b[32m632\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_true_bind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:648\u001b[39m, in \u001b[36mPrimitive._true_bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 646\u001b[39m trace_ctx.set_trace(eval_trace)\n\u001b[32m 647\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprev_trace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 649\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 650\u001b[39m trace_ctx.set_trace(prev_trace)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:660\u001b[39m, in \u001b[36mPrimitive.bind_with_trace\u001b[39m\u001b[34m(self, trace, args, params)\u001b[39m\n\u001b[32m 658\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m set_current_trace(trace):\n\u001b[32m 659\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.to_lojax(*args, **params) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m660\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrace\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 661\u001b[39m trace.process_primitive(\u001b[38;5;28mself\u001b[39m, args, params) \u001b[38;5;66;03m# may raise lojax error\u001b[39;00m\n\u001b[32m 662\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mcouldn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt apply typeof to args: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:1205\u001b[39m, in \u001b[36mEvalTrace.process_primitive\u001b[39m\u001b[34m(self, primitive, args, params)\u001b[39m\n\u001b[32m 1203\u001b[39m args = \u001b[38;5;28mmap\u001b[39m(full_lower, args)\n\u001b[32m 1204\u001b[39m check_eval_args(args)\n\u001b[32m-> \u001b[39m\u001b[32m1205\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:2273\u001b[39m, in \u001b[36m_reshard_impl\u001b[39m\u001b[34m(x, dst_sharding, concrete_mesh)\u001b[39m\n\u001b[32m 2270\u001b[39m thunk = \u001b[38;5;28;01mlambda\u001b[39;00m: dispatch.apply_primitive(\n\u001b[32m 2271\u001b[39m reshard_p, x, dst_sharding=dst_sharding, concrete_mesh=concrete_mesh)\n\u001b[32m 2272\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m concrete_mesh \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2273\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mthunk\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2274\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 2275\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m sharding_impls.set_mesh(concrete_mesh):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:2270\u001b[39m, in \u001b[36m_reshard_impl..\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 2269\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_reshard_impl\u001b[39m(x, *, dst_sharding, concrete_mesh):\n\u001b[32m-> \u001b[39m\u001b[32m2270\u001b[39m thunk = \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[43mdispatch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mapply_primitive\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2271\u001b[39m \u001b[43m \u001b[49m\u001b[43mreshard_p\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdst_sharding\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdst_sharding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconcrete_mesh\u001b[49m\u001b[43m=\u001b[49m\u001b[43mconcrete_mesh\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2272\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m concrete_mesh \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 2273\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m thunk()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:91\u001b[39m, in \u001b[36mapply_primitive\u001b[39m\u001b[34m(prim, *args, **params)\u001b[39m\n\u001b[32m 89\u001b[39m prev = config.disable_jit.swap_local(\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 90\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m91\u001b[39m outs = \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 92\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 93\u001b[39m config.disable_jit.set_local(prev)\n", + " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:257\u001b[39m, in \u001b[36m_cpp_pjit..cache_miss\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 254\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.no_tracing.value:\n\u001b[32m 255\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mre-tracing function \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjit_info.fun_sourceinfo\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 256\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m`jit`, but \u001b[39m\u001b[33m'\u001b[39m\u001b[33mno_tracing\u001b[39m\u001b[33m'\u001b[39m\u001b[33m is set\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m257\u001b[39m p, args_flat = \u001b[43m_trace_for_jit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjit_info\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 258\u001b[39m (outs, out_flat, out_tree, args_flat, jaxpr,\n\u001b[32m 259\u001b[39m executable, pgle_profiler, const_args) = _run_python_pjit(\n\u001b[32m 260\u001b[39m p, args_flat, fun, jit_info, args, kwargs)\n\u001b[32m 262\u001b[39m maybe_fastpath_data = _get_fastpath_data(\n\u001b[32m 263\u001b[39m executable, out_tree, args_flat, out_flat, jaxpr.effects, jaxpr.consts,\n\u001b[32m 264\u001b[39m pgle_profiler, const_args)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:544\u001b[39m, in \u001b[36m_trace_for_jit\u001b[39m\u001b[34m(fun, ji, args, kwargs)\u001b[39m\n\u001b[32m 542\u001b[39m jaxpr, out_avals = pe.trace_to_jaxpr(fun, in_type, dbg, qdd_token)\n\u001b[32m 543\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m544\u001b[39m jaxpr, out_avals = \u001b[43mpe\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrace_to_jaxpr\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdbg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mqdd_token\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 546\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.debug_key_reuse.value:\n\u001b[32m 547\u001b[39m \u001b[38;5;66;03m# Import here to avoid circular imports\u001b[39;00m\n\u001b[32m 548\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mjax\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mexperimental\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mkey_reuse\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01m_core\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m check_key_reuse_jaxpr \u001b[38;5;66;03m# pytype: disable=import-error\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2412\u001b[39m, in \u001b[36mtrace_to_jaxpr\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m 2410\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m core.set_current_trace(trace):\n\u001b[32m 2411\u001b[39m args, kwargs = in_tracers.unflatten()\n\u001b[32m-> \u001b[39m\u001b[32m2412\u001b[39m ans_pytree = \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2413\u001b[39m debug_info = debug_info.set_result_paths(ans_pytree)\n\u001b[32m 2414\u001b[39m ans = FlatTree.flatten(ans_pytree)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:106\u001b[39m, in \u001b[36mxla_primitive_callable..prim_fun\u001b[39m\u001b[34m(*args)\u001b[39m\n\u001b[32m 104\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mprim_fun\u001b[39m(*args):\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m config.eager_constant_folding(\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m106\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprim\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:632\u001b[39m, in \u001b[36mPrimitive.bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 630\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **params):\n\u001b[32m 631\u001b[39m args = args \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.skip_canonicalization \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(canonicalize_value, args)\n\u001b[32m--> \u001b[39m\u001b[32m632\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_true_bind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:648\u001b[39m, in \u001b[36mPrimitive._true_bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 646\u001b[39m trace_ctx.set_trace(eval_trace)\n\u001b[32m 647\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprev_trace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 649\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 650\u001b[39m trace_ctx.set_trace(prev_trace)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:660\u001b[39m, in \u001b[36mPrimitive.bind_with_trace\u001b[39m\u001b[34m(self, trace, args, params)\u001b[39m\n\u001b[32m 658\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m set_current_trace(trace):\n\u001b[32m 659\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.to_lojax(*args, **params) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m660\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrace\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 661\u001b[39m trace.process_primitive(\u001b[38;5;28mself\u001b[39m, args, params) \u001b[38;5;66;03m# may raise lojax error\u001b[39;00m\n\u001b[32m 662\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mcouldn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt apply typeof to args: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2031\u001b[39m, in \u001b[36mDynamicJaxprTrace.process_primitive\u001b[39m\u001b[34m(self, primitive, tracers, params)\u001b[39m\n\u001b[32m 2028\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m primitive \u001b[38;5;129;01min\u001b[39;00m custom_staging_rules:\n\u001b[32m 2029\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m custom_staging_rules[primitive](\u001b[38;5;28mself\u001b[39m, source_info, *jaxpr_tracers,\n\u001b[32m 2030\u001b[39m **params)\n\u001b[32m-> \u001b[39m\u001b[32m2031\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mdefault_process_primitive\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2032\u001b[39m \u001b[43m \u001b[49m\u001b[43mprimitive\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjaxpr_tracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msource_info\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2049\u001b[39m, in \u001b[36mDynamicJaxprTrace.default_process_primitive\u001b[39m\u001b[34m(self, primitive, tracers, params, source_info)\u001b[39m\n\u001b[32m 2047\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 2048\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2049\u001b[39m out_avals, effs = \u001b[43m_cached_abstract_eval\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprimitive\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43maval_qdds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2050\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 2051\u001b[39m \u001b[38;5;66;03m# TODO(phawkins): remove this 3 months after the release of JAX v0.7.\u001b[39;00m\n\u001b[32m 2052\u001b[39m _verify_params_are_hashable(primitive, params)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/util.py:466\u001b[39m, in \u001b[36mmulti_weakref_lru_cache..wrapper\u001b[39m\u001b[34m(*orig_args, **orig_kwargs)\u001b[39m\n\u001b[32m 464\u001b[39m nr_weakrefs = \u001b[38;5;28mlen\u001b[39m(acc_weakrefs)\n\u001b[32m 465\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m nr_weakrefs == \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m466\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcached_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_multi_weakref_placeholder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 467\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43morig_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43morig_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 468\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m nr_weakrefs == \u001b[32m1\u001b[39m:\n\u001b[32m 469\u001b[39m \u001b[38;5;66;03m# Put the single weakref first, and skip the MultiWeakRefCacheKey\u001b[39;00m\n\u001b[32m 470\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m cached_call(acc_weakrefs[\u001b[32m0\u001b[39m],\n\u001b[32m 471\u001b[39m *args, **kwargs)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/util.py:450\u001b[39m, in \u001b[36mmulti_weakref_lru_cache..cache_miss\u001b[39m\u001b[34m(key, *args, **kwargs)\u001b[39m\n\u001b[32m 447\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m: \u001b[38;5;66;03m# had 1 weakref, we had put it first as the `key`\u001b[39;00m\n\u001b[32m 448\u001b[39m orig_args, orig_kwargs = sentinel_to_referrents(\n\u001b[32m 449\u001b[39m (args, kwargs), \u001b[38;5;28miter\u001b[39m([weakref.ref(key)]), \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m450\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43morig_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43morig_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1860\u001b[39m, in \u001b[36m_cached_abstract_eval\u001b[39m\u001b[34m(primitive, *aval_qdds, **params)\u001b[39m\n\u001b[32m 1858\u001b[39m \u001b[38;5;129m@multi_weakref_lru_cache\u001b[39m\n\u001b[32m 1859\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_cached_abstract_eval\u001b[39m(primitive: core.Primitive, *aval_qdds, **params):\n\u001b[32m-> \u001b[39m\u001b[32m1860\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[43m.\u001b[49m\u001b[43mabstract_eval\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43maval_qdds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:702\u001b[39m, in \u001b[36m_effect_free_abstract_eval..abstract_eval_\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 701\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mabstract_eval_\u001b[39m(*args, **kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m702\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mabstract_eval\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m, no_effects\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:2266\u001b[39m, in \u001b[36m_reshard_abstract_eval\u001b[39m\u001b[34m(aval, dst_sharding, concrete_mesh)\u001b[39m\n\u001b[32m 2264\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m aval.sharding == dst_sharding:\n\u001b[32m 2265\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m aval\n\u001b[32m-> \u001b[39m\u001b[32m2266\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43maval\u001b[49m\u001b[43m.\u001b[49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43msharding\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdst_sharding\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:2217\u001b[39m, in \u001b[36mShapedArray.update\u001b[39m\u001b[34m(self, shape, dtype, weak_type, **kwargs)\u001b[39m\n\u001b[32m 2215\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m'\u001b[39m\u001b[33mmemory_space\u001b[39m\u001b[33m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m kwargs:\n\u001b[32m 2216\u001b[39m kwargs[\u001b[33m'\u001b[39m\u001b[33mmemory_space\u001b[39m\u001b[33m'\u001b[39m] = \u001b[38;5;28mself\u001b[39m.memory_space\n\u001b[32m-> \u001b[39m\u001b[32m2217\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mShapedArray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:2193\u001b[39m, in \u001b[36mShapedArray.__init__\u001b[39m\u001b[34m(self, shape, dtype, weak_type, sharding, vma, memory_space)\u001b[39m\n\u001b[32m 2191\u001b[39m \u001b[38;5;28mself\u001b[39m.weak_type = weak_type\n\u001b[32m 2192\u001b[39m \u001b[38;5;66;03m# The ShapedArray.sharding.memory_kind is always None; use memory_space.\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m2193\u001b[39m \u001b[38;5;28mself\u001b[39m.sharding = \u001b[43mget_sharding\u001b[49m\u001b[43m(\u001b[49m\u001b[43msharding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2194\u001b[39m \u001b[38;5;66;03m# short for varying_manual_axes. See docs at\u001b[39;00m\n\u001b[32m 2195\u001b[39m \u001b[38;5;66;03m# https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true\u001b[39;00m\n\u001b[32m 2196\u001b[39m \u001b[38;5;28mself\u001b[39m.vma = get_vma(vma, \u001b[38;5;28mself\u001b[39m.sharding)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/util.py:301\u001b[39m, in \u001b[36mcache..wrap..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 299\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.check_tracer_leaks.value:\n\u001b[32m 300\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m f(*args, **kwargs)\n\u001b[32m--> \u001b[39m\u001b[32m301\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcached\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrace_context\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/util.py:295\u001b[39m, in \u001b[36mcache..wrap..cached\u001b[39m\u001b[34m(_, *args, **kwargs)\u001b[39m\n\u001b[32m 293\u001b[39m \u001b[38;5;129m@functools\u001b[39m.lru_cache(max_size)\n\u001b[32m 294\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcached\u001b[39m(_, *args, **kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m295\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:2139\u001b[39m, in \u001b[36mget_sharding\u001b[39m\u001b[34m(sharding, shape)\u001b[39m\n\u001b[32m 2136\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out_s.mesh, mesh_lib.AbstractMesh):\n\u001b[32m 2137\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mMesh of an aval must be an AbstractMesh. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 2138\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGot \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mout_s.mesh\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(out_s.mesh)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m-> \u001b[39m\u001b[32m2139\u001b[39m \u001b[43m_check_divisibility\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_s\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2140\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m out_s.memory_kind \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 2141\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m out_s\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:2111\u001b[39m, in \u001b[36m_check_divisibility\u001b[39m\u001b[34m(sharding, shape)\u001b[39m\n\u001b[32m 2109\u001b[39m _, remainder = \u001b[38;5;28mdivmod\u001b[39m(sh, size)\n\u001b[32m 2110\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m remainder != \u001b[32m0\u001b[39m:\n\u001b[32m-> \u001b[39m\u001b[32m2111\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 2112\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mSharding spec \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mspec\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m implies that array axis \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdim\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m is partitioned\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 2113\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msize\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m times, but does not evenly divide the dimension size \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msh\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 2114\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m Got shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m and sharding \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msharding\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[31mValueError\u001b[39m: Sharding spec ('model',) implies that array axis 0 is partitioned 4 times, but does not evenly divide the dimension size 2. Got shape: (2, 1024, 1024) and sharding NamedSharding(mesh=AbstractMesh('data': 2, 'model': 4, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None), spec=PartitionSpec('model', None, None))" ] } ], @@ -392,7 +493,7 @@ "\n", " loss, grads = jax.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads)\n", - " return model, loss\n", + " return model, optimizer, loss\n", "\n", "\n", "with jax.set_mesh(auto_mesh):\n", @@ -401,11 +502,12 @@ " label = jax.device_put(rngs.normal((8, 1024)), P('data', None))\n", " # Model and optimizer\n", " model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0))\n", + " print(model)\n", " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", "\n", " # The loop\n", " for i in range(5):\n", - " model, loss = train_step(model, optimizer, input, label)\n", + " model, optimizer, loss = train_step(model, optimizer, input, label)\n", " print(loss) # Model (over-)fitting to the labels quickly." ] }, @@ -420,7 +522,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -455,15 +557,18 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "PartitionSpec(None, None, 'model')\n", - "(2, 1024, 1024)\n" + "ename": "NameError", + "evalue": "name 'model' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01morbax\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mcheckpoint\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mocp\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;66;03m# Save the sharded state.\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m sharded_state = nnx.state(\u001b[43mmodel\u001b[49m)\n\u001b[32m 5\u001b[39m path = ocp.test_utils.erase_and_create_empty(\u001b[33m'\u001b[39m\u001b[33m/tmp/my-checkpoints/\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 6\u001b[39m checkpointer = ocp.StandardCheckpointer()\n", + "\u001b[31mNameError\u001b[39m: name 'model' is not defined" ] } ], @@ -477,8 +582,9 @@ "checkpointer.save(path / 'checkpoint_name', sharded_state)\n", "\n", "# Load a sharded state from the checkpoint.\n", - "graphdef, abs_state = nnx.get_abstract_model(\n", + "abs_model = nnx.eval_shape(\n", " lambda: MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)), auto_mesh)\n", + "graphdef, abs_state = nnx.split(abs_model)\n", "restored_state = checkpointer.restore(path / 'checkpoint_name',\n", " target=abs_state)\n", "restored_model = nnx.merge(graphdef, abs_state)\n", @@ -500,7 +606,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -551,7 +657,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -683,7 +789,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index 90a7fade5..f7fecc07f 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -33,6 +33,8 @@ import flax from flax import nnx # Ignore this if you are already running on a TPU or GPU +nnx.set_graph_mode(False) +nnx.set_graph_updates(False) if not jax._src.xla_bridge.backends_are_initialized(): jax.config.update('jax_num_cpu_devices', 8) print(f'You have 8 “fake” JAX devices now: {jax.devices()}') @@ -50,7 +52,6 @@ auto_mesh = jax.make_mesh((2, 4), ('data', 'model')) > Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function. ```{code-cell} ipython3 -nnx.use_eager_sharding(True) assert nnx.using_eager_sharding() ``` @@ -64,7 +65,7 @@ with nnx.use_eager_sharding(False): You can also enable eager sharding on a per-variable basis by passing `eager_sharding=False` during variable initialization. The mesh can also be passed this way. ```{code-cell} ipython3 -nnx.Param(jnp.ones(4,4), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh) +nnx.Param(jnp.ones((4, 4)), out_sharding=(None, 'model'), eager_sharding=True, mesh=auto_mesh) ``` ## Shard a single-array model @@ -141,15 +142,14 @@ Make note of the following: ```{code-cell} ipython3 class DotReluDot(nnx.Module): def __init__(self, depth: int, rngs: nnx.Rngs): - init_fn = nnx.initializers.lecun_normal() self.dot1 = nnx.Linear( depth, depth, - kernel_init=nnx.with_partitioning(init_fn, (None, 'model')), + kernel_metadata={'out_sharding': (None, 'model')}, use_bias=False, # or use `bias_init` to give it annotation too rngs=rngs) self.w2 = nnx.Param( - init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation - sharding=('model', None), + rngs.params.lecun_normal()((depth, depth)), # RNG key and shape for W2 creation + out_sharding=('model', None), ) def __call__(self, x: jax.Array): @@ -163,7 +163,8 @@ class MultiDotReluDot(nnx.Module): def __init__(self, depth: int, num_layers: int, rngs: nnx.Rngs): # Annotate the additional axis with sharding=None, meaning it will be # replicated across all devices. - @nnx.vmap(transform_metadata={nnx.PARTITION_NAME: None}) + @nnx.vmap + @nnx.transform_metadata(partition=None) def create_sublayers(r): return DotReluDot(depth, r) self.layers = create_sublayers(rngs.fork(split=num_layers)) @@ -186,7 +187,7 @@ def train_step(model, optimizer, x, y): loss, grads = jax.value_and_grad(loss_fn)(model) optimizer.update(model, grads) - return model, loss + return model, optimizer, loss with jax.set_mesh(auto_mesh): @@ -195,11 +196,12 @@ with jax.set_mesh(auto_mesh): label = jax.device_put(rngs.normal((8, 1024)), P('data', None)) # Model and optimizer model = MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)) + print(model) optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) # The loop for i in range(5): - model, loss = train_step(model, optimizer, input, label) + model, optimizer, loss = train_step(model, optimizer, input, label) print(loss) # Model (over-)fitting to the labels quickly. ``` @@ -234,8 +236,9 @@ checkpointer = ocp.StandardCheckpointer() checkpointer.save(path / 'checkpoint_name', sharded_state) # Load a sharded state from the checkpoint. -graphdef, abs_state = nnx.get_abstract_model( +abs_model = nnx.eval_shape( lambda: MultiDotReluDot(1024, 2, rngs=nnx.Rngs(0)), auto_mesh) +graphdef, abs_state = nnx.split(abs_model) restored_state = checkpointer.restore(path / 'checkpoint_name', target=abs_state) restored_model = nnx.merge(graphdef, abs_state) diff --git a/docs_nnx/nnx_basics_tree.ipynb b/docs_nnx/nnx_basics_tree.ipynb new file mode 100644 index 000000000..ee63ff443 --- /dev/null +++ b/docs_nnx/nnx_basics_tree.ipynb @@ -0,0 +1,675 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NNX Basics\n", + "\n", + "NNX is a Neural Networks library for JAX. NNX provides the tools to structure modeling code as [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) so it can work with transforms, `jax.tree.*` utilities, and all standard JAX APIs. This guide covers the core concepts you need to get started." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "nnx.graphlib.set_graph_mode(False)\n", + "nnx.graphlib.set_graph_updates(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "NNX's main build blocks are:\n", + "\n", + "- **`nnx.Pytree`**: Base class for pytree-compatible objects. Defines the tree structure of your model.\n", + "- **`nnx.Variable`**: Wraps array data and tracks mutable state. Subclasses like `nnx.Param` categorize different kinds of state.\n", + "- **State APIs** (`nnx.{state, split, merge, update}`): Extract, partition, reconstruct, and apply state updates.\n", + "- **NNX Transforms** (`nnx.{jit, grad, scan, ...}`): Thin wrappers over JAX transforms that automate state propagation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pytrees and Variables\n", + "\n", + "`nnx.Pytree` and `nnx.Variable` are two orthogonal systems. **Pytrees** define the structure of your model as a JAX-compatible tree. **Variables** wrap array data and enable expressing state updates via in-place mutation. \n", + "\n", + "`Pytree`s are python objects that define its tree structure dynamically through its attributes, these are split into two categories: **Static attributes** (e.g. `int`, `str`) are embedded in the tree structure definition and are not traced by JAX. **Data attributes** (e.g. `nnx.Variable`, `jax.Array`) are the leaves of the tree and are traced by JAX. For more details see the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html).\n", + "\n", + "Here's a typical layer definition:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class Count(nnx.Variable): pass # custom Variable types\n", + "\n", + "class Linear(nnx.Pytree):\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din, self.dout = din, dout # static attributes\n", + " self.w = nnx.Param(rngs.uniform((din, dout))) # data attribute\n", + " self.count = Count(jnp.array(0)) # data attribute\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " self.count[...] += 1 # inplace state updates\n", + " return x @ self.w # Variable are Array-like\n", + "\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "nnx.display(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Note:** Most user code uses `nnx.Module`, which is a subclass of `nnx.Pytree` with additional features such as sopport for metric reporting.\n", + "\n", + "As we can see above, Variables are array-like; they support arithmetic operators, indexing, and can be used directly in JAX expressions. You can update their value in-place using `variable[...] = new_value`. Since NNX Pytrees are standard JAX pytrees, you can use `jax.tree.*` functions directly on them:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5), model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "\n", + "model.w sum: 4.1854\n", + "doubled.w sum: 8.3709\n", + "\n", + "Pytree leaves:\n", + ".count.value: Array(1, dtype=int32, weak_type=True)\n", + ".w.value: Array([[0.8423141 , 0.18237865, 0.2271781 , 0.12072563, 0.19181347],\n", + " [0.722015 , 0.7654456 , 0.15254045, 0.9517063 , 0.02931046]], dtype=float32)\n" + ] + } + ], + "source": [ + "x = jnp.ones((3, 2))\n", + "y = model(x)\n", + "print(f'{y.shape = }, {model.count[...] = }')\n", + "\n", + "# jax.tree.map works directly on NNX Pytrees\n", + "doubled_model = jax.tree.map(lambda x: x * 2, model)\n", + "print(f'\\nmodel.w sum: {model.w.sum():.4f}')\n", + "print(f'doubled.w sum: {doubled_model.w.sum():.4f}')\n", + "\n", + "# jax.tree.leaves_with_path shows the full tree structure\n", + "print('\\nPytree leaves:')\n", + "for path, value in jax.tree.leaves_with_path(model):\n", + " print(f'{jax.tree_util.keystr(path)}: {value!r}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here `jax.tree.map` was first used create a new model with each leaf Array doubled, and then `jax.tree.flatten_with_path` was used to show how JAX sees the tree structure. Notice that because Variables are also JAX pytrees containing a single element (their inner value) we see `value` as part of the leaf path." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rngs\n", + "`nnx.Rngs` simplify managing [JAX PRNG state](https://jax.readthedocs.io/en/latest/random-numbers.html). It is itself an `nnx.Pytree` that stores a seed `key` and an incrementing `counter` in `Variable`s internally. By calling it, `Rngs` can produce new PRNG keys:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key1 = Array((), dtype=key) overlaying:\n", + "[1797259609 2579123966]\n", + "key2 = Array((), dtype=key) overlaying:\n", + "[ 928981903 3453687069]\n", + "arr = Array([[ 1.2956359 , 1.3550105 , -0.40960556],\n", + " [-0.77188545, 0.38094172, 0.01888919]], dtype=float32)\n", + "\u001b[38;2;79;201;177mRngs\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # RngState: 2 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mdefault\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngStream\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # RngState: 2 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m,\n", + " \u001b[38;2;156;220;254mkey\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", + " [0 0],\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mcount\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(3, dtype=uint32),\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(0) # seeded with 0\n", + "\n", + "key1 = rngs() # get a raw key\n", + "key2 = rngs() # different key (counter incremented)\n", + "arr = rngs.normal((2, 3)) # draw samples directly\n", + "\n", + "print(f'{key1 = }')\n", + "print(f'{key2 = }')\n", + "print(f'{arr = }')\n", + "print(rngs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we've seen so far, `Rngs` conveniently exposes every `jax.random.*` distribution as a method (e.g. `rngs.uniform(...)`, `rngs.normal(...)`) without requiring the `key` argument and returning different random values every time they are called, this highly simplifies the user experience. In general `Rngs` can hold multiple keys and counters in structures called `RngStream`s, above we see that the `default` stream is being used. For more information check out the [Randomness guide](https://flax.readthedocs.io/en/latest/guides/randomness.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Nested Modules\n", + "\n", + "Pytree subclasses compose naturally, you can assign one as an attribute of another to build nested models. The example below builds a simple `MLP` from two `Linear` layers:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class MLP(nnx.Pytree):\n", + " def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din, self.dmid, self.dout = din, dmid, dout # static attributes\n", + " self.linear1 = Linear(din, dmid, rngs=rngs) # data attribute\n", + " self.linear2 = Linear(dmid, dout, rngs=rngs) # data attribute\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " x = nnx.relu(self.linear1(x))\n", + " return self.linear2(x)\n", + "\n", + "mlp = MLP(2, 16, 5, rngs=nnx.Rngs(0))\n", + "y = mlp(jnp.ones((3, 2)))\n", + "print(f'{y.shape = }')\n", + "\n", + "nnx.display(mlp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because the entire model is a single pytree, all the `jax.tree.*` functions, JAX transforms, and NNX state APIs work on the full nested structure at once. For more info check out the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## JAX Transforms\n", + "\n", + "NNX models can be passed directly to JAX transforms like `jax.jit`. However, JAX transforms create pure functions, meaning that they won't propagate side effects such as Variable state updates back to the caller:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit\n", + "def forward(model, x): # pure function\n", + " y = model(x)\n", + " return y\n", + "\n", + "y = forward(model, x)\n", + "\n", + "print(model.count[...]) # no state update" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here `count` was not updated because inside `jax.jit` new Variable copies are created so any updates inside will not be reflected outside. To propagate updates we can use two NNX helpers. `nnx.state(obj, *filters)` extracts the current state of all Variables in `obj` as a nested `State` dict; you can pass **filters** to select specific Variable types, for example `nnx.state(model, Count)` extracts only `Count` Variables (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for details). `nnx.update(obj, state)` writes a `State` back into the corresponding Variables of `obj`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit\n", + "def forward(model, x):\n", + " y = model(x)\n", + " return y, nnx.state(model, Count) # propagate state\n", + "\n", + "y, updates = forward(model, x)\n", + "nnx.update(model, updates) # apply state updates\n", + "\n", + "print(model.count[...]) # updated successfully" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we could've also chosen to return the entire `model` and replace its reference outside, however the use `nnx.state/update` is preferred as NNX promotes preserving existing Variable references." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training step with JAX transforms\n", + "\n", + "For a full training step we also need to differentiate with respect to some parameters while keeping the rest non-differentiable. `nnx.split` and `nnx.merge` let us partition and reconstruct the model. `nnx.split(obj, *filters)` returns a structure definition (`GraphDef`) followed by one `State` group per filter, where the catch-all filter `...` matches everything not yet matched by a previous filter (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for the full filter language). `nnx.merge(graphdef, *states)` reconstructs a copy of the object from its definition and state groups. We will use these to select the differentiable parameters when passing them to `jax.grad`.\n", + "\n", + "The example below shows a complete training step using raw JAX transforms. `nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and stores its internal state as Variables, providing a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@jax.jit\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params, nondiff):\n", + " nondiff = nnx.clone(nondiff) # refresh trace state\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss, nnx.state(model, Count) # propagate state\n", + "\n", + " grads, updates = jax.grad(loss_fn, has_aux=True)(params, nondiff)\n", + " nnx.update(model, updates)\n", + " optimizer.update(model, grads)\n", + "\n", + " return nnx.state((model, optimizer))\n", + "\n", + "updates = train_step(model, optimizer, x, y)\n", + "nnx.update((model, optimizer), updates)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few things to note. The state of the `model` and `optimizer` is extracted at once by packing them in a tuple (or any pytree), and `nnx.update` accepts the same structure. By default `jax.grad` differentiates with respect to the first positional argument only, `params` in our case. Finally, `nnx.clone` is needed because `jax.grad` passes non differentiable inputs (here `nondiff`) directly without tracing them, so we must manually clone them to refresh the trace state of their Variables - preventing tracer leakage. Omitting `nnx.clone` raises an error." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NNX Transforms\n", + "\n", + "NNX transforms (`nnx.jit`, `nnx.grad`, ...) are thin wrappers over JAX transforms that provide the exact same APIs. Their main feature is **automatic state propagation**: the state of all input Variables is tracked and updated automatically behind the scenes. This removes the need for the `nnx.state/update` boilerplate and the use of `nnx.clone`:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@nnx.jit # automatic state propagation\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params, nondiff):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss\n", + "\n", + " grads = nnx.grad(loss_fn)(params, nondiff)\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that `train_step` doesn't need to return anthing as `nnx.jit` propagates all Variable updates (model parameters, optimizer state, counts) automatically." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Graph Mode\n", + "\n", + "Certain programs are easier to express by sharing references between objets on different parts of a structure, however this is not compatible with JAX's pytree model. If we create a simple model that shares a reference to the same Variable in two different attributes, NNX transforms and most other APIs will raise an error as sharing can result in inconsistencies:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error: Variable at [0][0].b was already seen at [0][0].a. tree-mode jit does not support shared Variable references.\n" + ] + } + ], + "source": [ + "@nnx.dataclass\n", + "class Foo(nnx.Module):\n", + " a: nnx.Param\n", + " b: nnx.Param\n", + "\n", + "p = nnx.Param(jnp.array(1.0))\n", + "model = Foo(p, p) # shared Param\n", + "\n", + "@nnx.jit\n", + "def forward(model, x):\n", + " model.a[...] += 1.0\n", + " return model.a * x + model.b\n", + "\n", + "try:\n", + " forward(model, jnp.array(1.0))\n", + "except ValueError as e:\n", + " print(f'Error: {e}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, at the cost of some python overhead, `graph=True` can be passed to NNX APIs to enable **graph mode**. In graph mode, general graphs structures are allowed as long as they Variables are transformed consistently. We can fix the above example by enabling graph mode in `nnx.jit`:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y = 6.0, model.a[...] = 3.0, model.b[...] = 3.0\n" + ] + } + ], + "source": [ + "@nnx.jit(graph=True)\n", + "def forward(model, x):\n", + " model.a[...] += 1.0\n", + " return model.a * x + model.b\n", + "\n", + "y = forward(model, jnp.array(1.0))\n", + "\n", + "print(f'{y = !s}, {model.a[...] = !s}, {model.b[...] = !s}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hijax (experimental)\n", + "\n", + "JAX's experimental **Hijax** API allows custom mutable types whose state updates propagate automatically through JAX transforms. When enabled via `nnx.var_default(hijax=True)`, plain JAX transforms like `jax.jit` handle state propagation of `Variable`s without any manual `nnx.state` / `nnx.update` calls. As a bonus, in hijax mode Variables can also be passed as captures, further simplifying the loss function:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Count: 1 (4 B), Param: 10 (40 B), Total: 11 (44 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mdin\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m,\n", + " \u001b[38;2;156;220;254mdout\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m5\u001b[0m,\n", + " \u001b[38;2;156;220;254mw\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 10 (40 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m5\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mcount\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mCount\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "with nnx.var_defaults(hijax=True): # enables Hijax Variables\n", + " x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + " model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "print(model) # display Hijax Variables\n", + "\n", + "@jax.jit # automatic state propagation\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss\n", + "\n", + " grads = jax.grad(loss_fn)(nnx.vars_as(params, hijax=False)) # disable hijax for param grads\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a temporary limitation, `jax.grad` does not yet handle mutable Hijax types. We work around this by converting `params` to regular Variables via `nnx.vars_as(params, hijax=False)` before passing them to `grad`. Hijax can also be enabled on a per-Variable basis by passing `hijax=True` to the constructor:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "v[...] = 1\n", + "v[...] = 2\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(1), hijax=True)\n", + "\n", + "@jax.jit\n", + "def inc(v):\n", + " v[...] += 1\n", + "\n", + "print(f'{v[...] = !s}')\n", + "inc(v)\n", + "print(f'{v[...] = !s}')" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs_nnx/nnx_basics_tree.md b/docs_nnx/nnx_basics_tree.md new file mode 100644 index 000000000..2bdc6f577 --- /dev/null +++ b/docs_nnx/nnx_basics_tree.md @@ -0,0 +1,319 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# NNX Basics + +NNX is a Neural Networks library for JAX. NNX provides the tools to structure modeling code as [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) so it can work with transforms, `jax.tree.*` utilities, and all standard JAX APIs. This guide covers the core concepts you need to get started. + +```{code-cell} ipython3 +from flax import nnx +import jax +import jax.numpy as jnp + +nnx.graphlib.set_graph_mode(False) +nnx.graphlib.set_graph_updates(False) +``` + +NNX's main build blocks are: + +- **`nnx.Pytree`**: Base class for pytree-compatible objects. Defines the tree structure of your model. +- **`nnx.Variable`**: Wraps array data and tracks mutable state. Subclasses like `nnx.Param` categorize different kinds of state. +- **State APIs** (`nnx.{state, split, merge, update}`): Extract, partition, reconstruct, and apply state updates. +- **NNX Transforms** (`nnx.{jit, grad, scan, ...}`): Thin wrappers over JAX transforms that automate state propagation. + ++++ + +## Pytrees and Variables + +`nnx.Pytree` and `nnx.Variable` are two orthogonal systems. **Pytrees** define the structure of your model as a JAX-compatible tree. **Variables** wrap array data and enable expressing state updates via in-place mutation. + +`Pytree`s are python objects that define its tree structure dynamically through its attributes, these are split into two categories: **Static attributes** (e.g. `int`, `str`) are embedded in the tree structure definition and are not traced by JAX. **Data attributes** (e.g. `nnx.Variable`, `jax.Array`) are the leaves of the tree and are traced by JAX. For more details see the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html). + +Here's a typical layer definition: + +```{code-cell} ipython3 +class Count(nnx.Variable): pass # custom Variable types + +class Linear(nnx.Pytree): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.din, self.dout = din, dout # static attributes + self.w = nnx.Param(rngs.uniform((din, dout))) # data attribute + self.count = Count(jnp.array(0)) # data attribute + + def __call__(self, x: jax.Array): + self.count[...] += 1 # inplace state updates + return x @ self.w # Variable are Array-like + +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +nnx.display(model) +``` + +> **Note:** Most user code uses `nnx.Module`, which is a subclass of `nnx.Pytree` with additional features such as sopport for metric reporting. + +As we can see above, Variables are array-like; they support arithmetic operators, indexing, and can be used directly in JAX expressions. You can update their value in-place using `variable[...] = new_value`. Since NNX Pytrees are standard JAX pytrees, you can use `jax.tree.*` functions directly on them: + +```{code-cell} ipython3 +x = jnp.ones((3, 2)) +y = model(x) +print(f'{y.shape = }, {model.count[...] = }') + +# jax.tree.map works directly on NNX Pytrees +doubled_model = jax.tree.map(lambda x: x * 2, model) +print(f'\nmodel.w sum: {model.w.sum():.4f}') +print(f'doubled.w sum: {doubled_model.w.sum():.4f}') + +# jax.tree.leaves_with_path shows the full tree structure +print('\nPytree leaves:') +for path, value in jax.tree.leaves_with_path(model): + print(f'{jax.tree_util.keystr(path)}: {value!r}') +``` + +Here `jax.tree.map` was first used create a new model with each leaf Array doubled, and then `jax.tree.flatten_with_path` was used to show how JAX sees the tree structure. Notice that because Variables are also JAX pytrees containing a single element (their inner value) we see `value` as part of the leaf path. + ++++ + +## Rngs +`nnx.Rngs` simplify managing [JAX PRNG state](https://jax.readthedocs.io/en/latest/random-numbers.html). It is itself an `nnx.Pytree` that stores a seed `key` and an incrementing `counter` in `Variable`s internally. By calling it, `Rngs` can produce new PRNG keys: + +```{code-cell} ipython3 +rngs = nnx.Rngs(0) # seeded with 0 + +key1 = rngs() # get a raw key +key2 = rngs() # different key (counter incremented) +arr = rngs.normal((2, 3)) # draw samples directly + +print(f'{key1 = }') +print(f'{key2 = }') +print(f'{arr = }') +print(rngs) +``` + +As we've seen so far, `Rngs` conveniently exposes every `jax.random.*` distribution as a method (e.g. `rngs.uniform(...)`, `rngs.normal(...)`) without requiring the `key` argument and returning different random values every time they are called, this highly simplifies the user experience. In general `Rngs` can hold multiple keys and counters in structures called `RngStream`s, above we see that the `default` stream is being used. For more information check out the [Randomness guide](https://flax.readthedocs.io/en/latest/guides/randomness.html). + ++++ + +## Nested Modules + +Pytree subclasses compose naturally, you can assign one as an attribute of another to build nested models. The example below builds a simple `MLP` from two `Linear` layers: + +```{code-cell} ipython3 +class MLP(nnx.Pytree): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.din, self.dmid, self.dout = din, dmid, dout # static attributes + self.linear1 = Linear(din, dmid, rngs=rngs) # data attribute + self.linear2 = Linear(dmid, dout, rngs=rngs) # data attribute + + def __call__(self, x: jax.Array): + x = nnx.relu(self.linear1(x)) + return self.linear2(x) + +mlp = MLP(2, 16, 5, rngs=nnx.Rngs(0)) +y = mlp(jnp.ones((3, 2))) +print(f'{y.shape = }') + +nnx.display(mlp) +``` + +Because the entire model is a single pytree, all the `jax.tree.*` functions, JAX transforms, and NNX state APIs work on the full nested structure at once. For more info check out the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html). + ++++ + +## JAX Transforms + +NNX models can be passed directly to JAX transforms like `jax.jit`. However, JAX transforms create pure functions, meaning that they won't propagate side effects such as Variable state updates back to the caller: + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +@jax.jit +def forward(model, x): # pure function + y = model(x) + return y + +y = forward(model, x) + +print(model.count[...]) # no state update +``` + +Here `count` was not updated because inside `jax.jit` new Variable copies are created so any updates inside will not be reflected outside. To propagate updates we can use two NNX helpers. `nnx.state(obj, *filters)` extracts the current state of all Variables in `obj` as a nested `State` dict; you can pass **filters** to select specific Variable types, for example `nnx.state(model, Count)` extracts only `Count` Variables (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for details). `nnx.update(obj, state)` writes a `State` back into the corresponding Variables of `obj`. + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +@jax.jit +def forward(model, x): + y = model(x) + return y, nnx.state(model, Count) # propagate state + +y, updates = forward(model, x) +nnx.update(model, updates) # apply state updates + +print(model.count[...]) # updated successfully +``` + +In this example we could've also chosen to return the entire `model` and replace its reference outside, however the use `nnx.state/update` is preferred as NNX promotes preserving existing Variable references. + ++++ + +### Training step with JAX transforms + +For a full training step we also need to differentiate with respect to some parameters while keeping the rest non-differentiable. `nnx.split` and `nnx.merge` let us partition and reconstruct the model. `nnx.split(obj, *filters)` returns a structure definition (`GraphDef`) followed by one `State` group per filter, where the catch-all filter `...` matches everything not yet matched by a previous filter (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for the full filter language). `nnx.merge(graphdef, *states)` reconstructs a copy of the object from its definition and state groups. We will use these to select the differentiable parameters when passing them to `jax.grad`. + +The example below shows a complete training step using raw JAX transforms. `nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and stores its internal state as Variables, providing a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters: + +```{code-cell} ipython3 +import optax + +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@jax.jit +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + nondiff = nnx.clone(nondiff) # refresh trace state + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss, nnx.state(model, Count) # propagate state + + grads, updates = jax.grad(loss_fn, has_aux=True)(params, nondiff) + nnx.update(model, updates) + optimizer.update(model, grads) + + return nnx.state((model, optimizer)) + +updates = train_step(model, optimizer, x, y) +nnx.update((model, optimizer), updates) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +A few things to note. The state of the `model` and `optimizer` is extracted at once by packing them in a tuple (or any pytree), and `nnx.update` accepts the same structure. By default `jax.grad` differentiates with respect to the first positional argument only, `params` in our case. Finally, `nnx.clone` is needed because `jax.grad` passes non differentiable inputs (here `nondiff`) directly without tracing them, so we must manually clone them to refresh the trace state of their Variables - preventing tracer leakage. Omitting `nnx.clone` raises an error. + ++++ + +## NNX Transforms + +NNX transforms (`nnx.jit`, `nnx.grad`, ...) are thin wrappers over JAX transforms that provide the exact same APIs. Their main feature is **automatic state propagation**: the state of all input Variables is tracked and updated automatically behind the scenes. This removes the need for the `nnx.state/update` boilerplate and the use of `nnx.clone`: + +```{code-cell} ipython3 +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@nnx.jit # automatic state propagation +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss + + grads = nnx.grad(loss_fn)(params, nondiff) + optimizer.update(model, grads) + +train_step(model, optimizer, x, y) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +Notice that `train_step` doesn't need to return anthing as `nnx.jit` propagates all Variable updates (model parameters, optimizer state, counts) automatically. + ++++ + +## Graph Mode + +Certain programs are easier to express by sharing references between objets on different parts of a structure, however this is not compatible with JAX's pytree model. If we create a simple model that shares a reference to the same Variable in two different attributes, NNX transforms and most other APIs will raise an error as sharing can result in inconsistencies: + +```{code-cell} ipython3 +@nnx.dataclass +class Foo(nnx.Module): + a: nnx.Param + b: nnx.Param + +p = nnx.Param(jnp.array(1.0)) +model = Foo(p, p) # shared Param + +@nnx.jit +def forward(model, x): + model.a[...] += 1.0 + return model.a * x + model.b + +try: + forward(model, jnp.array(1.0)) +except ValueError as e: + print(f'Error: {e}') +``` + +However, at the cost of some python overhead, `graph=True` can be passed to NNX APIs to enable **graph mode**. In graph mode, general graphs structures are allowed as long as they Variables are transformed consistently. We can fix the above example by enabling graph mode in `nnx.jit`: + +```{code-cell} ipython3 +@nnx.jit(graph=True) +def forward(model, x): + model.a[...] += 1.0 + return model.a * x + model.b + +y = forward(model, jnp.array(1.0)) + +print(f'{y = !s}, {model.a[...] = !s}, {model.b[...] = !s}') +``` + +## Hijax (experimental) + +JAX's experimental **Hijax** API allows custom mutable types whose state updates propagate automatically through JAX transforms. When enabled via `nnx.var_default(hijax=True)`, plain JAX transforms like `jax.jit` handle state propagation of `Variable`s without any manual `nnx.state` / `nnx.update` calls. As a bonus, in hijax mode Variables can also be passed as captures, further simplifying the loss function: + +```{code-cell} ipython3 +with nnx.var_defaults(hijax=True): # enables Hijax Variables + x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) + model = Linear(2, 5, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +print(model) # display Hijax Variables + +@jax.jit # automatic state propagation +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params): + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss + + grads = jax.grad(loss_fn)(nnx.vars_as(params, hijax=False)) # disable hijax for param grads + optimizer.update(model, grads) + +train_step(model, optimizer, x, y) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +As a temporary limitation, `jax.grad` does not yet handle mutable Hijax types. We work around this by converting `params` to regular Variables via `nnx.vars_as(params, hijax=False)` before passing them to `grad`. Hijax can also be enabled on a per-Variable basis by passing `hijax=True` to the constructor: + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(1), hijax=True) + +@jax.jit +def inc(v): + v[...] += 1 + +print(f'{v[...] = !s}') +inc(v) +print(f'{v[...] = !s}') +``` diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4004b02c0..37e2247af 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -25,7 +25,7 @@ from flax.nnx.pytreelib import Pytree from flax.nnx.variablelib import Variable -M = tp.TypeVar('M', bound=nnx.Module) +M = tp.TypeVar('M') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) class OptState(Variable):