Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class EarlyStopping(Serializable, ResettableHandler):

Possible values are "abs" and "rel". Default value is "abs".
mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'.
min_evals: Minimum number of evaluations before early stopping can trigger. The handler will
not stop training until it has been called at least ``min_evals`` times, regardless of
the score. Default value is 0 (no warmup period).

Examples:
.. code-block:: python
Expand All @@ -52,15 +55,26 @@ def score_function(engine):
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)

# With min_evals: don't allow early stopping for the first 5 evaluations
handler = EarlyStopping(
patience=10,
score_function=score_function,
trainer=trainer,
min_evals=5,
)
evaluator.add_event_handler(Events.COMPLETED, handler)

.. versionchanged:: 0.6.0
Added `mode` parameter to support minimization in addition to maximization.
Added `min_delta_mode` parameter to support both absolute and relative improvements.
Added `min_evals` parameter to support a warmup period before early stopping.

"""

_state_dict_all_req_keys = (
"counter",
"best_score",
"_eval_counter",
)

def __init__(
Expand All @@ -72,6 +86,7 @@ def __init__(
cumulative_delta: bool = False,
min_delta_mode: Literal["abs", "rel"] = "abs",
mode: Literal["min", "max"] = "max",
min_evals: int = 0,
):
if not callable(score_function):
raise TypeError("Argument score_function should be a function.")
Expand All @@ -91,6 +106,9 @@ def __init__(
if mode not in ("min", "max"):
raise ValueError("Argument mode should be either 'min' or 'max'.")

if not isinstance(min_evals, int) or min_evals < 0:
raise ValueError("Argument min_evals should be a non-negative integer.")

self.score_function = score_function
self.patience = patience
self.min_delta = min_delta
Expand All @@ -101,12 +119,22 @@ def __init__(
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
self.min_delta_mode = min_delta_mode
self.mode = mode
self.min_evals = min_evals
self._eval_counter = 0

def __call__(self, engine: Engine) -> None:
self._eval_counter += 1
score = self.score_function(engine)

if self.best_score is None:
self.best_score = score

if self._eval_counter <= self.min_evals:
# Warmup period: track best score but don't enforce patience
if self.mode == "max":
self.best_score = max(score, self.best_score)
else:
self.best_score = min(score, self.best_score)
return

min_delta = -self.min_delta if self.mode == "min" else self.min_delta
Expand All @@ -130,12 +158,13 @@ def __call__(self, engine: Engine) -> None:
self.counter = 0

def reset(self) -> None:
"""Reset the early stopping state, including the counter and best score.
"""Reset the early stopping state, including the counter, eval counter and best score.

.. versionadded:: 0.6.0
"""
self.counter = 0
self.best_score = None
self._eval_counter = 0

def attach( # type: ignore[override]
self,
Expand Down Expand Up @@ -169,17 +198,24 @@ def attach( # type: ignore[override]
target_reset_engine.add_event_handler(reset_event, self.reset)

def state_dict(self) -> "OrderedDict[str, float]":
"""Method returns state dict with ``counter`` and ``best_score``.
"""Method returns state dict with ``counter``, ``best_score`` and ``_eval_counter``.
Can be used to save internal state of the class.
"""
return OrderedDict([("counter", self.counter), ("best_score", cast(float, self.best_score))])
return OrderedDict(
[
("counter", self.counter),
("best_score", cast(float, self.best_score)),
("_eval_counter", self._eval_counter),
]
)

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replace internal state of the class with provided state dict data.

Args:
state_dict: a dict with "counter" and "best_score" keys/values.
state_dict: a dict with "counter", "best_score" and "_eval_counter" keys/values.
"""
super().load_state_dict(state_dict)
self.counter = state_dict["counter"]
self.best_score = state_dict["best_score"]
self._eval_counter = state_dict.get("_eval_counter", 0)
165 changes: 165 additions & 0 deletions tests/ignite/handlers/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,168 @@ def test_early_stopping_attach_cross_engine():

assert trainer.has_event_handler(handler.reset, Events.STARTED)
assert not evaluator.has_event_handler(handler.reset, Events.STARTED)


def test_args_validation_min_evals():
trainer = Engine(do_nothing_update_fn)

with pytest.raises(ValueError, match=r"Argument min_evals should be a non-negative integer."):
EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=trainer, min_evals=-1)

with pytest.raises(ValueError, match=r"Argument min_evals should be a non-negative integer."):
EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=trainer, min_evals=1.5)


def test_min_evals_no_early_stopping_during_warmup():
"""Handler must not terminate training during the min_evals warmup period."""
# Scores that would normally trigger early stopping after 2 calls
scores = iter([1.0, 0.5, 0.4])

trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=3)

assert not trainer.should_terminate
h(None) # eval 1 of 3 warmup — no patience check
assert not trainer.should_terminate
h(None) # eval 2 of 3 warmup — no patience check
assert not trainer.should_terminate
h(None) # eval 3 of 3 warmup — no patience check (still in warmup)
assert not trainer.should_terminate


def test_min_evals_early_stopping_after_warmup():
"""Handler should enforce patience only after min_evals warmup period."""
# scores: warmup evals first, then declining scores that trigger stopping
scores = iter([1.0, 1.1, 1.2, 0.5, 0.4])

trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=3)

h(None) # warmup 1
h(None) # warmup 2
h(None) # warmup 3 — end of warmup, best_score=1.2
assert not trainer.should_terminate

h(None) # score=0.5, no improvement, counter=1
assert not trainer.should_terminate
h(None) # score=0.4, no improvement, counter=2 → terminate
assert trainer.should_terminate


def test_min_evals_zero_is_default_behaviour():
"""min_evals=0 (default) should behave identically to not specifying min_evals."""
scores_a = iter([1.0, 0.8, 0.7])
scores_b = iter([1.0, 0.8, 0.7])

trainer_a = Engine(do_nothing_update_fn)
trainer_b = Engine(do_nothing_update_fn)

h_default = EarlyStopping(patience=2, score_function=lambda _: next(scores_a), trainer=trainer_a)
h_zero = EarlyStopping(patience=2, score_function=lambda _: next(scores_b), trainer=trainer_b, min_evals=0)

for _ in range(3):
h_default(None)
h_zero(None)

assert trainer_a.should_terminate == trainer_b.should_terminate


def test_min_evals_tracks_best_score_during_warmup():
"""During warmup, handler should still track the best score."""
scores = iter([0.5, 1.5, 1.0, 0.9, 0.8])

trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=2)

h(None) # warmup 1, best_score=0.5
h(None) # warmup 2, best_score=1.5 (updated to max)
assert h.best_score == 1.5

h(None) # score=1.0, no improvement vs 1.5, counter=1
assert not trainer.should_terminate
h(None) # score=0.9, counter=2 → terminate
assert trainer.should_terminate


def test_min_evals_reset_clears_eval_counter():
"""reset() should clear _eval_counter so warmup restarts."""
scores = iter([1.0, 0.5, 0.4, 1.0, 0.3, 0.2])

trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=lambda _: next(scores), trainer=trainer, min_evals=2)

h(None) # warmup 1
h(None) # warmup 2 — end of warmup
assert h._eval_counter == 2

h.reset()
assert h._eval_counter == 0
assert h.best_score is None
assert h.counter == 0

# After reset, warmup should restart
h(None) # warmup 1 again
assert not trainer.should_terminate
h(None) # warmup 2 again — still in warmup
assert not trainer.should_terminate


def test_min_evals_state_dict():
"""state_dict and load_state_dict should preserve _eval_counter."""
scores = iter([1.0, 0.8, 0.6, 0.4, 0.2])

trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=3, score_function=lambda _: next(scores), trainer=trainer, min_evals=2)

h(None) # warmup 1
h(None) # warmup 2

state = h.state_dict()
assert "_eval_counter" in state
assert state["_eval_counter"] == 2

# Restore to a new handler
h2 = EarlyStopping(patience=3, score_function=lambda _: next(scores), trainer=trainer, min_evals=2)
h2.load_state_dict(state)
assert h2._eval_counter == 2
assert h2.best_score == h.best_score

# Should now enforce patience (warmup is over)
h2(None) # score=0.6, no improvement, counter=1
assert not trainer.should_terminate
h2(None) # score=0.4, counter=2
assert not trainer.should_terminate
h2(None) # score=0.2, counter=3 → terminate
assert trainer.should_terminate


def test_with_engine_min_evals():
"""Integration test: min_evals delays early stopping inside a full engine run."""
# 10 epochs: warmup for 3, then declining scores trigger stop at epoch 7
scores = iter([1.0, 1.1, 1.2, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3])

trainer = Engine(do_nothing_update_fn)
evaluator = Engine(do_nothing_update_fn)

early_stopping = EarlyStopping(
patience=2,
score_function=lambda _: next(scores),
trainer=trainer,
min_evals=3,
)

epoch_count = [0]

@trainer.on(Events.EPOCH_COMPLETED)
def evaluation(engine):
evaluator.run([0])
epoch_count[0] += 1

evaluator.add_event_handler(Events.COMPLETED, early_stopping)
trainer.run([0], max_epochs=10)

# Warmup: epochs 1-3 (no patience enforced)
# Epoch 4: score=0.9 < best=1.2, counter=1
# Epoch 5: score=0.8, counter=2 → stop
assert epoch_count[0] == 5
assert trainer.state.epoch == 5