diff --git a/ignite/handlers/early_stopping.py b/ignite/handlers/early_stopping.py index 2e7ab19edcec..701269cb3472 100644 --- a/ignite/handlers/early_stopping.py +++ b/ignite/handlers/early_stopping.py @@ -37,6 +37,9 @@ class EarlyStopping(Serializable, ResettableHandler): Possible values are "abs" and "rel". Default value is "abs". mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'. + min_evals: Minimum number of evaluations before early stopping can trigger. The handler will + not stop training until it has been called at least ``min_evals`` times, regardless of + the score. Default value is 0 (no warmup period). Examples: .. code-block:: python @@ -52,15 +55,26 @@ def score_function(engine): # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). evaluator.add_event_handler(Events.COMPLETED, handler) + # With min_evals: don't allow early stopping for the first 5 evaluations + handler = EarlyStopping( + patience=10, + score_function=score_function, + trainer=trainer, + min_evals=5, + ) + evaluator.add_event_handler(Events.COMPLETED, handler) + .. versionchanged:: 0.6.0 Added `mode` parameter to support minimization in addition to maximization. Added `min_delta_mode` parameter to support both absolute and relative improvements. + Added `min_evals` parameter to support a warmup period before early stopping. """ _state_dict_all_req_keys = ( "counter", "best_score", + "_eval_counter", ) def __init__( @@ -72,6 +86,7 @@ def __init__( cumulative_delta: bool = False, min_delta_mode: Literal["abs", "rel"] = "abs", mode: Literal["min", "max"] = "max", + min_evals: int = 0, ): if not callable(score_function): raise TypeError("Argument score_function should be a function.") @@ -91,6 +106,9 @@ def __init__( if mode not in ("min", "max"): raise ValueError("Argument mode should be either 'min' or 'max'.") + if not isinstance(min_evals, int) or min_evals < 0: + raise ValueError("Argument min_evals should be a non-negative integer.") + self.score_function = score_function self.patience = patience self.min_delta = min_delta @@ -101,12 +119,22 @@ def __init__( self.logger = setup_logger(__name__ + "." + self.__class__.__name__) self.min_delta_mode = min_delta_mode self.mode = mode + self.min_evals = min_evals + self._eval_counter = 0 def __call__(self, engine: Engine) -> None: + self._eval_counter += 1 score = self.score_function(engine) if self.best_score is None: self.best_score = score + + if self._eval_counter <= self.min_evals: + # Warmup period: track best score but don't enforce patience + if self.mode == "max": + self.best_score = max(score, self.best_score) + else: + self.best_score = min(score, self.best_score) return min_delta = -self.min_delta if self.mode == "min" else self.min_delta @@ -130,12 +158,13 @@ def __call__(self, engine: Engine) -> None: self.counter = 0 def reset(self) -> None: - """Reset the early stopping state, including the counter and best score. + """Reset the early stopping state, including the counter, eval counter and best score. .. versionadded:: 0.6.0 """ self.counter = 0 self.best_score = None + self._eval_counter = 0 def attach( # type: ignore[override] self, @@ -169,17 +198,24 @@ def attach( # type: ignore[override] target_reset_engine.add_event_handler(reset_event, self.reset) def state_dict(self) -> "OrderedDict[str, float]": - """Method returns state dict with ``counter`` and ``best_score``. + """Method returns state dict with ``counter``, ``best_score`` and ``_eval_counter``. Can be used to save internal state of the class. """ - return OrderedDict([("counter", self.counter), ("best_score", cast(float, self.best_score))]) + return OrderedDict( + [ + ("counter", self.counter), + ("best_score", cast(float, self.best_score)), + ("_eval_counter", self._eval_counter), + ] + ) def load_state_dict(self, state_dict: Mapping) -> None: """Method replace internal state of the class with provided state dict data. Args: - state_dict: a dict with "counter" and "best_score" keys/values. + state_dict: a dict with "counter", "best_score" and "_eval_counter" keys/values. """ super().load_state_dict(state_dict) self.counter = state_dict["counter"] self.best_score = state_dict["best_score"] + self._eval_counter = state_dict.get("_eval_counter", 0) diff --git a/tests/ignite/handlers/test_early_stopping.py b/tests/ignite/handlers/test_early_stopping.py index b43fa75d7ad6..fa4dab663d4c 100644 --- a/tests/ignite/handlers/test_early_stopping.py +++ b/tests/ignite/handlers/test_early_stopping.py @@ -582,3 +582,168 @@ def test_early_stopping_attach_cross_engine(): assert trainer.has_event_handler(handler.reset, Events.STARTED) assert not evaluator.has_event_handler(handler.reset, Events.STARTED) + + +def test_args_validation_min_evals(): + trainer = Engine(do_nothing_update_fn) + + with pytest.raises(ValueError, match=r"Argument min_evals should be a non-negative integer."): + EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=trainer, min_evals=-1) + + with pytest.raises(ValueError, match=r"Argument min_evals should be a non-negative integer."): + EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=trainer, min_evals=1.5) + + +def test_min_evals_no_early_stopping_during_warmup(): + """Handler must not terminate training during the min_evals warmup period.""" + # Scores that would normally trigger early stopping after 2 calls + scores = iter([1.0, 0.5, 0.4]) + + trainer = Engine(do_nothing_update_fn) + h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=3) + + assert not trainer.should_terminate + h(None) # eval 1 of 3 warmup — no patience check + assert not trainer.should_terminate + h(None) # eval 2 of 3 warmup — no patience check + assert not trainer.should_terminate + h(None) # eval 3 of 3 warmup — no patience check (still in warmup) + assert not trainer.should_terminate + + +def test_min_evals_early_stopping_after_warmup(): + """Handler should enforce patience only after min_evals warmup period.""" + # scores: warmup evals first, then declining scores that trigger stopping + scores = iter([1.0, 1.1, 1.2, 0.5, 0.4]) + + trainer = Engine(do_nothing_update_fn) + h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=3) + + h(None) # warmup 1 + h(None) # warmup 2 + h(None) # warmup 3 — end of warmup, best_score=1.2 + assert not trainer.should_terminate + + h(None) # score=0.5, no improvement, counter=1 + assert not trainer.should_terminate + h(None) # score=0.4, no improvement, counter=2 → terminate + assert trainer.should_terminate + + +def test_min_evals_zero_is_default_behaviour(): + """min_evals=0 (default) should behave identically to not specifying min_evals.""" + scores_a = iter([1.0, 0.8, 0.7]) + scores_b = iter([1.0, 0.8, 0.7]) + + trainer_a = Engine(do_nothing_update_fn) + trainer_b = Engine(do_nothing_update_fn) + + h_default = EarlyStopping(patience=2, score_function=lambda _: next(scores_a), trainer=trainer_a) + h_zero = EarlyStopping(patience=2, score_function=lambda _: next(scores_b), trainer=trainer_b, min_evals=0) + + for _ in range(3): + h_default(None) + h_zero(None) + + assert trainer_a.should_terminate == trainer_b.should_terminate + + +def test_min_evals_tracks_best_score_during_warmup(): + """During warmup, handler should still track the best score.""" + scores = iter([0.5, 1.5, 1.0, 0.9, 0.8]) + + trainer = Engine(do_nothing_update_fn) + h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=2) + + h(None) # warmup 1, best_score=0.5 + h(None) # warmup 2, best_score=1.5 (updated to max) + assert h.best_score == 1.5 + + h(None) # score=1.0, no improvement vs 1.5, counter=1 + assert not trainer.should_terminate + h(None) # score=0.9, counter=2 → terminate + assert trainer.should_terminate + + +def test_min_evals_reset_clears_eval_counter(): + """reset() should clear _eval_counter so warmup restarts.""" + scores = iter([1.0, 0.5, 0.4, 1.0, 0.3, 0.2]) + + trainer = Engine(do_nothing_update_fn) + h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=2) + + h(None) # warmup 1 + h(None) # warmup 2 — end of warmup + assert h._eval_counter == 2 + + h.reset() + assert h._eval_counter == 0 + assert h.best_score is None + assert h.counter == 0 + + # After reset, warmup should restart + h(None) # warmup 1 again + assert not trainer.should_terminate + h(None) # warmup 2 again — still in warmup + assert not trainer.should_terminate + + +def test_min_evals_state_dict(): + """state_dict and load_state_dict should preserve _eval_counter.""" + scores = iter([1.0, 0.8, 0.6, 0.4, 0.2]) + + trainer = Engine(do_nothing_update_fn) + h = EarlyStopping(patience=3, score_function=lambda _: next(scores), trainer=trainer, min_evals=2) + + h(None) # warmup 1 + h(None) # warmup 2 + + state = h.state_dict() + assert "_eval_counter" in state + assert state["_eval_counter"] == 2 + + # Restore to a new handler + h2 = EarlyStopping(patience=3, score_function=lambda _: next(scores), trainer=trainer, min_evals=2) + h2.load_state_dict(state) + assert h2._eval_counter == 2 + assert h2.best_score == h.best_score + + # Should now enforce patience (warmup is over) + h2(None) # score=0.6, no improvement, counter=1 + assert not trainer.should_terminate + h2(None) # score=0.4, counter=2 + assert not trainer.should_terminate + h2(None) # score=0.2, counter=3 → terminate + assert trainer.should_terminate + + +def test_with_engine_min_evals(): + """Integration test: min_evals delays early stopping inside a full engine run.""" + # 10 epochs: warmup for 3, then declining scores trigger stop at epoch 7 + scores = iter([1.0, 1.1, 1.2, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]) + + trainer = Engine(do_nothing_update_fn) + evaluator = Engine(do_nothing_update_fn) + + early_stopping = EarlyStopping( + patience=2, + score_function=lambda _: next(scores), + trainer=trainer, + min_evals=3, + ) + + epoch_count = [0] + + @trainer.on(Events.EPOCH_COMPLETED) + def evaluation(engine): + evaluator.run([0]) + epoch_count[0] += 1 + + evaluator.add_event_handler(Events.COMPLETED, early_stopping) + trainer.run([0], max_epochs=10) + + # Warmup: epochs 1-3 (no patience enforced) + # Epoch 4: score=0.9 < best=1.2, counter=1 + # Epoch 5: score=0.8, counter=2 → stop + assert epoch_count[0] == 5 + assert trainer.state.epoch == 5