Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from typing import Any, Callable, cast, Mapping, Literal
from typing import Any, Callable, cast, Mapping, Literal, Optional
import warnings

from ignite.base import Serializable, ResettableHandler
from ignite.engine import Engine, Events
Expand All @@ -17,26 +18,28 @@ class EarlyStopping(Serializable, ResettableHandler):
object, and return a score `float`. An improvement is considered if the score is higher (for ``mode='max'``)
or lower (for ``mode='min'``).
trainer: Trainer engine to stop the run if no improvement.
min_delta: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum
increase; for ``mode='min'``, it's a minimum decrease. An improvement is only considered if the change
exceeds the threshold determined by `min_delta` and `min_delta_mode`.
cumulative_delta: If True, `min_delta` defines the change since the last `patience` reset, otherwise,
it defines the change after the last event. Default value is False.
min_delta_mode: Determines whether `min_delta` is an absolute change or a relative change.
threshold: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum
increase; for ``mode='min'``, it's a minimum decrease. Default is 0.0.
threshold_mode: Determines whether `threshold` is an absolute change or a relative change.

- In 'abs' mode:

- For ``mode='max'``: improvement if score > best_score + min_delta
- For ``mode='min'``: improvement if score < best_score - min_delta
- For ``mode='max'``: improvement if score > best_score + threshold
- For ``mode='min'``: improvement if score < best_score - threshold

- In 'rel' mode:

- For ``mode='max'``: improvement if score > best_score * (1 + min_delta)
- For ``mode='min'``: improvement if score < best_score * (1 - min_delta)
- For ``mode='max'``: improvement if score > best_score * (1 + threshold)
- For ``mode='min'``: improvement if score < best_score * (1 - threshold)

Possible values are "abs" and "rel". Default value is "abs".
cumulative: If True, `threshold` defines the change since the last `patience` reset, otherwise,
it defines the change after the last event. Default value is False.
mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'.

# Deprecated args for backward compatibility (will be removed in future)
min_delta: *Deprecated: use `threshold` instead*
min_delta_mode: *Deprecated: use `threshold_mode` instead*
cumulative_delta: *Deprecated: use `cumulative` instead*

Examples:
.. code-block:: python

Expand All @@ -53,8 +56,7 @@ def score_function(engine):

.. versionchanged:: 0.6.0
Added `mode` parameter to support minimization in addition to maximization.
Added `min_delta_mode` parameter to support both absolute and relative improvements.

Added `threshold`/`threshold_mode` parameters to support both absolute and relative improvements.
"""

_state_dict_all_req_keys = (
Expand All @@ -67,38 +69,72 @@ def __init__(
patience: int,
score_function: Callable,
trainer: Engine,
min_delta: float = 0.0,
cumulative_delta: bool = False,
min_delta_mode: Literal["abs", "rel"] = "abs",
threshold: float = 0.0,
threshold_mode: Literal["abs", "rel"] = "abs",
cumulative: bool = False,
mode: Literal["min", "max"] = "max",
# Deprecated args for BC
min_delta: Optional[float] = None,
min_delta_mode: Optional[Literal["abs", "rel"]] = None,
cumulative_delta: Optional[bool] = None,
):
if not callable(score_function):
raise TypeError("Argument score_function should be a function.")

if patience < 1:
raise ValueError("Argument patience should be positive integer.")

if min_delta < 0.0:
raise ValueError("Argument min_delta should not be a negative number.")

if not isinstance(trainer, Engine):
raise TypeError("Argument trainer should be an instance of Engine.")

if min_delta_mode not in ("abs", "rel"):
# Backward compatibility for deprecated args
if min_delta is not None:
warnings.warn(
"'min_delta' is deprecated and will be removed in a future version. " "Please use 'threshold' instead.",
DeprecationWarning,
stacklevel=2,
)
threshold = min_delta

if min_delta_mode is not None:
warnings.warn(
"'min_delta_mode' is deprecated and will be removed in a future version. "
"Please use 'threshold_mode' instead.",
DeprecationWarning,
stacklevel=2,
)
threshold_mode = min_delta_mode

if cumulative_delta is not None:
warnings.warn(
"'cumulative_delta' is deprecated and will be removed in a future version. "
"Please use 'cumulative' instead.",
DeprecationWarning,
stacklevel=2,
)
cumulative = cumulative_delta

if threshold < 0.0:
raise ValueError("Argument min_delta should not be a negative number.")

if threshold_mode not in ("abs", "rel"):
raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.")

if mode not in ("min", "max"):
raise ValueError("Argument mode should be either 'min' or 'max'.")

self.score_function = score_function
self.patience = patience
self.min_delta = min_delta
self.cumulative_delta = cumulative_delta
self.threshold = threshold
self.threshold_mode = threshold_mode
self.cumulative = cumulative
self.trainer = trainer
self.counter = 0
self.best_score: float | None = None
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
self.min_delta_mode = min_delta_mode
self.min_delta = threshold
self.min_delta_mode = threshold_mode
self.cumulative_delta = cumulative
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.min_delta = threshold
self.min_delta_mode = threshold_mode
self.cumulative_delta = cumulative

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ramyars466 please address this comment

self.mode = mode

def __call__(self, engine: Engine) -> None:
Expand All @@ -108,16 +144,16 @@ def __call__(self, engine: Engine) -> None:
self.best_score = score
return

min_delta = -self.min_delta if self.mode == "min" else self.min_delta
if self.min_delta_mode == "abs":
min_delta = -self.threshold if self.mode == "min" else self.threshold
if self.threshold_mode == "abs":
improvement_threshold = self.best_score + min_delta
else:
improvement_threshold = self.best_score * (1 + min_delta)

no_improvement = score <= improvement_threshold if self.mode == "max" else score >= improvement_threshold

if no_improvement:
if not self.cumulative_delta:
if not self.cumulative:
self.best_score = max(score, self.best_score) if self.mode == "max" else min(score, self.best_score)
self.counter += 1
self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
Expand Down