diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index 09a00ba0f..f693a9214 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -52,7 +52,8 @@ def build_shaped_array(x, batch_dim: bool = False) -> core.ShapedArray: shape=shape, dtype=jnp.result_type(x), sharding=sharding, - **{k: getattr(x, k) for k in ["weak_type", "vma"] if hasattr(x, k)}, + **{k: getattr(x, k) for k in ["weak_type", "manual_type"] + if hasattr(x, k)}, )