Rename SDS(vma: frozenset = ...) to SDS(manual_type: jax.sharding.ManualAxisType = ...)
#5351
+2
−1