diff --git a/flax/configurations.py b/flax/configurations.py index 183b6c5ba..ad28db4d4 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -75,19 +75,20 @@ def __repr__(self): return f'Config({values_repr}\n)' @contextmanager - def temp_flip_flag(self, var_name: str, var_value: bool): + def temp_flip_flag(self, var_name: str, var_value: bool, prefix='flax'): """Context manager to temporarily flip feature flags for test functions. Args: - var_name: the config variable name (without the 'flax_' prefix) + var_name: the config variable name (without its prefix like 'flax_') var_value: the boolean value to set var_name to temporarily + prefix: the prefix of the config variable name (default: 'flax') """ - old_value = getattr(self, f'flax_{var_name}') + old_value = getattr(self, prefix + '_' + var_name) try: - self.update(f'flax_{var_name}', var_value) + self.update(prefix + '_' + var_name, var_value) yield finally: - self.update(f'flax_{var_name}', old_value) + self.update(prefix + '_' + var_name, old_value) config = Config() @@ -307,4 +308,4 @@ def static_int_env(varname: str, default: int | None) -> int | None: name='nnx_graph_updates', default=True, help='Whether graph-mode uses dynamic (True) or simple (False) graph traversal.', -) \ No newline at end of file +)