diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 070b2fa77..833fbb9eb 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -433,6 +433,8 @@ def __getitem__(self, name: str): return self._get_stream(name, KeyError) def __getattr__(self, name: str): + if name.startswith('__') and name.endswith('__'): + raise AttributeError(name) return self._get_stream(name, AttributeError) def __call__(self): diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 1652b6231..6f680796f 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -421,5 +421,15 @@ def test_with_rngs_broadcast(self, graph): self.assertEqual(new_rngs_mapped['dropout'].key.shape, ()) + def test_dunder_raises_attribute_error(self): + rngs = nnx.Rngs(default=42) + with self.assertRaises(AttributeError): + _ = rngs.__pydantic_serializer__ + with self.assertRaises(AttributeError): + _ = rngs.__totally_fake__ + with self.assertRaises(AttributeError): + _ = rngs.__json__ + + if __name__ == '__main__': absltest.main()