-
-
Notifications
You must be signed in to change notification settings - Fork 695
feat(early_stopping): rename min_delta→threshold, min_delta_mode→thre… #3619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ramyars466
wants to merge
12
commits into
pytorch:master
Choose a base branch
from
ramyars466:fix/early-stopping-param-rename
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
82ea602
feat(early_stopping): rename min_delta→threshold, min_delta_mode→thre…
rashpana a08c5bd
fix(early_stopping): ensure deprecated args override before validatio…
rashpana e22856c
feat(early_stopping): rename min_delta→threshold with backward compat…
rashpana c48adc1
docs(early_stopping): document parameter renaming in versionchanged
rashpana 71726fa
test(early_stopping): fully replace deprecated params with new API
rashpana d789996
fix(early_stopping): update error messages to new parameter names
rashpana 79024f3
fix imports after merge conflict
rashpana 2a8d6df
remove deprecated attributes as requested by reviewer
rashpana aea191a
fix: add property aliases for deprecated attributes
rashpana 1a5393e
fix: add property aliases for deprecated attributes
rashpana 17500d4
fix: add backward compatibility with deprecation warnings
rashpana ceb4b02
test: add legacy API tests and complete backward compatibility
rashpana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||
| from collections import OrderedDict | ||||||||
| from typing import Any, Callable, cast, Mapping, Literal | ||||||||
| from typing import Any, Callable, cast, Mapping, Literal, Optional | ||||||||
| import warnings | ||||||||
|
|
||||||||
| from ignite.base import Serializable, ResettableHandler | ||||||||
| from ignite.engine import Engine, Events | ||||||||
|
|
@@ -17,26 +18,28 @@ class EarlyStopping(Serializable, ResettableHandler): | |||||||
| object, and return a score `float`. An improvement is considered if the score is higher (for ``mode='max'``) | ||||||||
| or lower (for ``mode='min'``). | ||||||||
| trainer: Trainer engine to stop the run if no improvement. | ||||||||
| min_delta: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum | ||||||||
| increase; for ``mode='min'``, it's a minimum decrease. An improvement is only considered if the change | ||||||||
| exceeds the threshold determined by `min_delta` and `min_delta_mode`. | ||||||||
| cumulative_delta: If True, `min_delta` defines the change since the last `patience` reset, otherwise, | ||||||||
| it defines the change after the last event. Default value is False. | ||||||||
| min_delta_mode: Determines whether `min_delta` is an absolute change or a relative change. | ||||||||
| threshold: A minimum change in the score to qualify as an improvement. For ``mode='max'``, it's a minimum | ||||||||
| increase; for ``mode='min'``, it's a minimum decrease. Default is 0.0. | ||||||||
| threshold_mode: Determines whether `threshold` is an absolute change or a relative change. | ||||||||
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| - In 'abs' mode: | ||||||||
|
|
||||||||
| - For ``mode='max'``: improvement if score > best_score + min_delta | ||||||||
| - For ``mode='min'``: improvement if score < best_score - min_delta | ||||||||
| - For ``mode='max'``: improvement if score > best_score + threshold | ||||||||
| - For ``mode='min'``: improvement if score < best_score - threshold | ||||||||
|
|
||||||||
| - In 'rel' mode: | ||||||||
|
|
||||||||
| - For ``mode='max'``: improvement if score > best_score * (1 + min_delta) | ||||||||
| - For ``mode='min'``: improvement if score < best_score * (1 - min_delta) | ||||||||
| - For ``mode='max'``: improvement if score > best_score * (1 + threshold) | ||||||||
| - For ``mode='min'``: improvement if score < best_score * (1 - threshold) | ||||||||
|
|
||||||||
| Possible values are "abs" and "rel". Default value is "abs". | ||||||||
| cumulative: If True, `threshold` defines the change since the last `patience` reset, otherwise, | ||||||||
| it defines the change after the last event. Default value is False. | ||||||||
| mode: Whether to maximize ('max') or minimize ('min') the score. Default is 'max'. | ||||||||
|
|
||||||||
| # Deprecated args for backward compatibility (will be removed in future) | ||||||||
| min_delta: *Deprecated: use `threshold` instead* | ||||||||
| min_delta_mode: *Deprecated: use `threshold_mode` instead* | ||||||||
| cumulative_delta: *Deprecated: use `cumulative` instead* | ||||||||
|
|
||||||||
| Examples: | ||||||||
| .. code-block:: python | ||||||||
|
|
||||||||
|
|
@@ -53,8 +56,7 @@ def score_function(engine): | |||||||
|
|
||||||||
| .. versionchanged:: 0.6.0 | ||||||||
| Added `mode` parameter to support minimization in addition to maximization. | ||||||||
| Added `min_delta_mode` parameter to support both absolute and relative improvements. | ||||||||
|
|
||||||||
| Added `threshold`/`threshold_mode` parameters to support both absolute and relative improvements. | ||||||||
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
| """ | ||||||||
|
|
||||||||
| _state_dict_all_req_keys = ( | ||||||||
|
|
@@ -67,38 +69,72 @@ def __init__( | |||||||
| patience: int, | ||||||||
| score_function: Callable, | ||||||||
| trainer: Engine, | ||||||||
| min_delta: float = 0.0, | ||||||||
| cumulative_delta: bool = False, | ||||||||
| min_delta_mode: Literal["abs", "rel"] = "abs", | ||||||||
| threshold: float = 0.0, | ||||||||
| threshold_mode: Literal["abs", "rel"] = "abs", | ||||||||
| cumulative: bool = False, | ||||||||
| mode: Literal["min", "max"] = "max", | ||||||||
| # Deprecated args for BC | ||||||||
| min_delta: Optional[float] = None, | ||||||||
| min_delta_mode: Optional[Literal["abs", "rel"]] = None, | ||||||||
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
| cumulative_delta: Optional[bool] = None, | ||||||||
| ): | ||||||||
| if not callable(score_function): | ||||||||
| raise TypeError("Argument score_function should be a function.") | ||||||||
|
|
||||||||
| if patience < 1: | ||||||||
| raise ValueError("Argument patience should be positive integer.") | ||||||||
|
|
||||||||
| if min_delta < 0.0: | ||||||||
| raise ValueError("Argument min_delta should not be a negative number.") | ||||||||
|
|
||||||||
| if not isinstance(trainer, Engine): | ||||||||
| raise TypeError("Argument trainer should be an instance of Engine.") | ||||||||
|
|
||||||||
| if min_delta_mode not in ("abs", "rel"): | ||||||||
| # Backward compatibility for deprecated args | ||||||||
| if min_delta is not None: | ||||||||
| warnings.warn( | ||||||||
| "'min_delta' is deprecated and will be removed in a future version. " "Please use 'threshold' instead.", | ||||||||
| DeprecationWarning, | ||||||||
| stacklevel=2, | ||||||||
| ) | ||||||||
| threshold = min_delta | ||||||||
|
|
||||||||
| if min_delta_mode is not None: | ||||||||
| warnings.warn( | ||||||||
| "'min_delta_mode' is deprecated and will be removed in a future version. " | ||||||||
| "Please use 'threshold_mode' instead.", | ||||||||
| DeprecationWarning, | ||||||||
| stacklevel=2, | ||||||||
| ) | ||||||||
| threshold_mode = min_delta_mode | ||||||||
|
|
||||||||
| if cumulative_delta is not None: | ||||||||
| warnings.warn( | ||||||||
| "'cumulative_delta' is deprecated and will be removed in a future version. " | ||||||||
| "Please use 'cumulative' instead.", | ||||||||
| DeprecationWarning, | ||||||||
| stacklevel=2, | ||||||||
| ) | ||||||||
| cumulative = cumulative_delta | ||||||||
|
|
||||||||
| if threshold < 0.0: | ||||||||
| raise ValueError("Argument min_delta should not be a negative number.") | ||||||||
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| if threshold_mode not in ("abs", "rel"): | ||||||||
| raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.") | ||||||||
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| if mode not in ("min", "max"): | ||||||||
| raise ValueError("Argument mode should be either 'min' or 'max'.") | ||||||||
|
|
||||||||
| self.score_function = score_function | ||||||||
| self.patience = patience | ||||||||
| self.min_delta = min_delta | ||||||||
| self.cumulative_delta = cumulative_delta | ||||||||
| self.threshold = threshold | ||||||||
| self.threshold_mode = threshold_mode | ||||||||
| self.cumulative = cumulative | ||||||||
ramyars466 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| self.trainer = trainer | ||||||||
| self.counter = 0 | ||||||||
| self.best_score: float | None = None | ||||||||
| self.logger = setup_logger(__name__ + "." + self.__class__.__name__) | ||||||||
| self.min_delta_mode = min_delta_mode | ||||||||
| self.min_delta = threshold | ||||||||
| self.min_delta_mode = threshold_mode | ||||||||
| self.cumulative_delta = cumulative | ||||||||
|
||||||||
| self.min_delta = threshold | |
| self.min_delta_mode = threshold_mode | |
| self.cumulative_delta = cumulative |
Collaborator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ramyars466 please address this comment
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
ramyars466 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.