diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 5dde1343f..443e368da 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -188,34 +188,6 @@ class AttrPriority(enum.IntEnum): LOW = 100 -class PriorityStr(str): - _priority: AttrPriority - - def __new__(cls, priority: AttrPriority, value: str): - obj = super().__new__(cls, value) - obj._priority = priority - return obj - - def _check_and_get_priority(self, other) -> AttrPriority: - if not isinstance(other, (str, PriorityStr)): - raise NotImplementedError( - f'Cannot compare {type(self)} with {type(other)}' - ) - if isinstance(other, PriorityStr): - return other._priority - return AttrPriority.DEFAULT - - def __lt__(self, other) -> bool: - other_priority = self._check_and_get_priority(other) - if self._priority == other_priority: - return super().__lt__(other) - return self._priority < other_priority - - def __gt__(self, other) -> bool: - other_priority = self._check_and_get_priority(other) - if self._priority == other_priority: - return super().__gt__(other) - return self._priority > other_priority class ModuleBase: if tp.TYPE_CHECKING: @@ -241,7 +213,7 @@ def _getattr(self, name: str) -> tp.Any: return value def _setattr(self, name: str, value: tp.Any) -> None: - if self.scope is not None: + if getattr(self, 'scope', None) is not None: if name in vars(self) and isinstance( state := vars(self)[name], ModuleState ): @@ -254,11 +226,13 @@ def _setattr(self, name: str, value: tp.Any) -> None: def _graph_node_flatten(self): nodes = vars(self).copy() - keys = ( - PriorityStr(self.attr_priorities.get(k, AttrPriority.DEFAULT), k) - for k in nodes.keys() - ) - sorted_nodes = list((k, nodes[k]) for k in sorted(keys)) + def get_priority(k): + if k in ('scope', '_pytree__state', 'attr_priorities'): + return AttrPriority.HIGH + return self.attr_priorities.get(k, AttrPriority.DEFAULT) + + sorted_keys = sorted(nodes.keys(), key=lambda k: (get_priority(k), k)) + sorted_nodes = list((k, nodes[k]) for k in sorted_keys) return sorted_nodes, type(self) def set_attr_priority(self, name: str, value: AttrPriority):