diff --git a/ignite/base/mixins.py b/ignite/base/mixins.py index 03b477dc3a6c..661fe566c6e9 100644 --- a/ignite/base/mixins.py +++ b/ignite/base/mixins.py @@ -29,7 +29,7 @@ def attach(self, engine: "Engine", *args: Any, **kwargs: Any) -> None: class Serializable: _state_dict_all_req_keys: tuple = () - _state_dict_one_of_opt_keys: tuple = () + _state_dict_one_of_opt_keys: tuple[tuple[str, ...], ...] = () def state_dict(self) -> OrderedDict: raise NotImplementedError @@ -43,6 +43,17 @@ def load_state_dict(self, state_dict: Mapping) -> None: raise ValueError( f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" ) - opts = [k in state_dict for k in self._state_dict_one_of_opt_keys] - if len(opts) > 0 and ((not any(opts)) or (all(opts))): - raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys") + + # Handle groups of one-of optional keys + for one_of_opt_keys in self._state_dict_one_of_opt_keys: + if len(one_of_opt_keys) == 0: + raise ValueError( + f"Empty group found in '{self.__class__.__name__}._state_dict_one_of_opt_keys'. " + "Each group must contain at least one state attribute key." + ) + opts = [k in state_dict for k in one_of_opt_keys] + num_present = sum(opts) + if num_present == 0: + raise ValueError(f"state_dict should contain at least one of '{one_of_opt_keys}' keys") + if num_present > 1: + raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys") diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 5f97d54de4a6..9cd055f7791a 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -128,8 +128,8 @@ def compute_mean_std(engine, batch): """ - _state_dict_all_req_keys = ("epoch_length", "max_epochs") - _state_dict_one_of_opt_keys = ("iteration", "epoch") + _state_dict_all_req_keys = ("epoch_length",) + _state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters")) # Flag to disable engine._internal_run as generator feature for BC interrupt_resume_enabled = True @@ -707,15 +707,25 @@ def save_engine(_): OrderedDict: a dictionary containing engine's state + .. versionchanged:: 0.5.5 + Added support for serializing ``max_iters``. + """ - keys: tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) + keys: tuple[str, ...] = self._state_dict_all_req_keys + # We add iteration by default to get exact measure of progress + keys += ("iteration",) + # Include either max_epochs or max_iters based on which was originally set + if self.state.max_iters is not None: + keys += ("max_iters",) + else: + keys += ("max_epochs",) keys += tuple(self._state_dict_user_keys) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) def load_state_dict(self, state_dict: Mapping) -> None: """Setups engine from `state_dict`. - State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`. + State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary. Iteration and epoch values are 0-based: the first iteration or epoch is zero. @@ -726,15 +736,20 @@ def load_state_dict(self, state_dict: Mapping) -> None: .. code-block:: python - # Restore from the 4rd epoch + # Restore from the 4th epoch state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)} # or 500th iteration # state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)} + # or with max_iters + # state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)} trainer = Engine(...) trainer.load_state_dict(state_dict) trainer.run(data) + .. versionchanged:: 0.5.5 + Added support for restoring from a state dict containing ``max_iters`` instead of ``max_epochs``. + """ super().load_state_dict(state_dict) @@ -743,17 +758,15 @@ def load_state_dict(self, state_dict: Mapping) -> None: raise ValueError( f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" ) - self.state.max_epochs = state_dict["max_epochs"] - self.state.epoch_length = state_dict["epoch_length"] - for k in self._state_dict_user_keys: setattr(self.state, k, state_dict[k]) + self.state.epoch_length = state_dict["epoch_length"] if "iteration" in state_dict: self.state.iteration = state_dict["iteration"] self.state.epoch = 0 if self.state.epoch_length is not None: self.state.epoch = self.state.iteration // self.state.epoch_length - elif "epoch" in state_dict: + else: # epoch is in state_dict self.state.epoch = state_dict["epoch"] if self.state.epoch_length is None: raise ValueError( @@ -762,6 +775,36 @@ def load_state_dict(self, state_dict: Mapping) -> None: ) self.state.iteration = self.state.epoch_length * self.state.epoch + # Set max_epochs or max_iters with validation + max_epochs_value = state_dict.get("max_epochs", None) + max_iters_value = state_dict.get("max_iters", None) + + # Validate max_epochs if present + if max_epochs_value is not None: + if max_epochs_value < 1: + raise ValueError("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value") + if max_epochs_value < self.state.epoch: + raise ValueError( + "max_epochs in state_dict should be larger than or equal to the current epoch " + f"defined in the state: {max_epochs_value} vs {self.state.epoch}. " + ) + self.state.max_epochs = max_epochs_value + else: + self.state.max_epochs = None + + # Validate max_iters if present + if max_iters_value is not None: + if max_iters_value < 1: + raise ValueError("max_iters in state_dict is invalid. Please, set a correct max_iters positive value") + if max_iters_value < self.state.iteration: + raise ValueError( + "max_iters in state_dict should be larger than or equal to the current iteration " + f"defined in the state: {max_iters_value} vs {self.state.iteration}. " + ) + self.state.max_iters = max_iters_value + else: + self.state.max_iters = None + @staticmethod def _is_done(state: State) -> bool: is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters @@ -773,6 +816,31 @@ def _is_done(state: State) -> bool: is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs return is_done_iters or is_done_count or is_done_epochs + def _check_and_set_termination_param( + self, name: str, value: int | None, progress_name: str, progress_value: int + ) -> None: + """Validate and set the passed parameter (max_epochs or max_iters).""" + if value is not None: + if value < 1: + raise ValueError(f"Argument {name} is invalid. Please, set a correct {name} positive value") + + if getattr(self.state, name) is not None and value < progress_value: + raise ValueError( + f"Argument {name} should be greater than or equal to the start " + f"{progress_name} defined in the state: {value} vs {progress_value}. " + f"Please, set engine.state.{name} = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + setattr(self.state, name, value) + + def _check_and_set_max_epochs(self, max_epochs: int | None = None) -> None: + """Validate and set max_epochs with proper checks.""" + self._check_and_set_termination_param("max_epochs", max_epochs, "epoch", self.state.epoch) + + def _check_and_set_max_iters(self, max_iters: int | None = None) -> None: + """Validate and set max_iters with proper checks.""" + self._check_and_set_termination_param("max_iters", max_iters, "iteration", self.state.iteration) + def set_data(self, data: Iterable | DataLoader) -> None: """Method to set data. After calling the method the next batch passed to `processing_function` is from newly provided data. Please, note that epoch length is not modified. @@ -871,43 +939,44 @@ def switch_batch(engine): if data is not None and not isinstance(data, Iterable): raise TypeError("Argument data should be iterable") + if max_epochs is not None and max_iters is not None: + raise ValueError( + "Arguments max_iters and max_epochs are mutually exclusive.Please provide only max_epochs or max_iters." + ) + if self.state.max_epochs is not None: - # Check and apply overridden parameters - if max_epochs is not None: - if max_epochs < self.state.epoch: - raise ValueError( - "Argument max_epochs should be greater than or equal to the start " - f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. " - "Please, set engine.state.max_epochs = None " - "before calling engine.run() in order to restart the training from the beginning." - ) - self.state.max_epochs = max_epochs - if epoch_length is not None: - if epoch_length != self.state.epoch_length: - raise ValueError( - "Argument epoch_length should be same as in the state, " - f"but given {epoch_length} vs {self.state.epoch_length}" - ) + self._check_and_set_max_epochs(max_epochs) + + if self.state.max_iters is not None: + self._check_and_set_max_iters(max_iters) - if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None): + # Check if we need to create new state or resume + # Create new state if: + # 1. No termination params set (first run), OR + # 2. Training is done AND generator is None + should_create_new_state = (self.state.max_epochs is None and self.state.max_iters is None) or ( + self._is_done(self.state) and self._internal_run_generator is None + ) + + if should_create_new_state: # Create new state - if epoch_length is None: - if data is None: - raise ValueError("epoch_length should be provided if data is None") + if data is None and epoch_length is None and self.state.epoch_length is None: + raise ValueError("epoch_length should be provided if data is None") - epoch_length = self._get_data_length(data) - if epoch_length is not None and epoch_length < 1: - raise ValueError("Input data has zero size. Please provide non-empty data") + # Set epoch_length for new state + if epoch_length is None: + # Try to get from data first, then fall back to existing state + if data is not None: + epoch_length = self._get_data_length(data) + if epoch_length is None and self.state.epoch_length is not None: + epoch_length = self.state.epoch_length + if epoch_length is not None and epoch_length < 1: + raise ValueError("Input data has zero size. Please provide non-empty data") if max_iters is None: if max_epochs is None: max_epochs = 1 else: - if max_epochs is not None: - raise ValueError( - "Arguments max_iters and max_epochs are mutually exclusive." - "Please provide only max_epochs or max_iters." - ) if epoch_length is not None: max_epochs = math.ceil(max_iters / epoch_length) @@ -918,18 +987,41 @@ def switch_batch(engine): self.state.epoch_length = epoch_length # Reset generator if previously used self._internal_run_generator = None - self.logger.info(f"Engine run starting with max_epochs={max_epochs}.") + + if self.state.max_epochs is not None: + self.logger.info(f"Engine run starting with max_epochs={self.state.max_epochs}.") + else: + self.logger.info(f"Engine run starting with max_iters={self.state.max_iters}.") else: - self.logger.info( - f"Engine run resuming from iteration {self.state.iteration}, " - f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" - ) + if self.state.epoch_length is not None: + if epoch_length is not None and epoch_length != self.state.epoch_length: + raise ValueError( + "Argument epoch_length should be same as in the state, " + f"but given {epoch_length} vs {self.state.epoch_length}" + ) + else: + if epoch_length is None and data is not None: + epoch_length = self._get_data_length(data) + if epoch_length is not None: + if epoch_length < 1: + raise ValueError("Input data has zero size. Please provide non-empty data") + self.state.epoch_length = epoch_length + + if self.state.max_epochs is not None: + self.logger.info( + f"Engine run resuming from iteration {self.state.iteration}, " + f"epoch {self.state.epoch} until {self.state.max_epochs} epochs" + ) + else: + self.logger.info( + f"Engine run resuming from iteration {self.state.iteration}, " + f"epoch {self.state.epoch} until {self.state.max_iters} iterations" + ) + if self.state.epoch_length is None and data is None: raise ValueError("epoch_length should be provided if data is None") if self.should_terminate: - # If engine was terminated and now is resuming from terminated state - # we need to initialize iter_counter as 0 self._init_iter = 0 if self._dataloader_iter is None: diff --git a/tests/ignite/base/test_mixins.py b/tests/ignite/base/test_mixins.py index 0f3a39811fbb..21cbbbc60f16 100644 --- a/tests/ignite/base/test_mixins.py +++ b/tests/ignite/base/test_mixins.py @@ -1,5 +1,7 @@ import pytest +from collections import OrderedDict + from ignite.base import Serializable @@ -12,3 +14,262 @@ def test_state_dict(): def test_load_state_dict(): s = Serializable() s.load_state_dict({}) + + +class ExampleSerializable(Serializable): + _state_dict_all_req_keys = ("a", "b") + _state_dict_one_of_opt_keys = (("c", "d"), ("e", "f")) + + def __init__(self): + super().__init__() + self.data = {} + + def state_dict(self): + return {"a": 1, "b": 2, "c": 3, "e": 5} + + +class EngineStyleSerializable(Serializable): + """Serializable that mimics Engine's key structure.""" + + _state_dict_all_req_keys = ("epoch_length",) + _state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters")) + + def __init__(self): + super().__init__() + self.data = {} + + def state_dict(self): + result = OrderedDict() + for key in self._state_dict_all_req_keys: + if key in self.data: + result[key] = self.data[key] + + # Add user keys + for key in self._state_dict_user_keys: + if key in self.data: + result[key] = self.data[key] + + return result + + +def test_load_state_dict_validation(): + """Test the updated load_state_dict validation.""" + s = ExampleSerializable() + + # Test type check + with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): + s.load_state_dict("not a dict") + + # Test missing required keys + with pytest.raises(ValueError, match=r"Required state attribute 'a' is absent"): + s.load_state_dict({}) + + with pytest.raises(ValueError, match=r"Required state attribute 'b' is absent"): + s.load_state_dict({"a": 1}) + + # Test one-of optional keys - missing all + with pytest.raises(ValueError, match=r"should contain at least one of"): + s.load_state_dict({"a": 1, "b": 2}) + + # Test one-of optional keys - having all from one group + with pytest.raises(ValueError, match=r"should contain only one of '\('c', 'd'\)'"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + + # Test one-of optional keys - having all from another group + with pytest.raises(ValueError, match=r"should contain only one of '\('e', 'f'\)'"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "f": 6}) + + # Valid state dict + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5}) + print("Valid state dict loaded successfully") + + +@pytest.mark.parametrize( + "all_req, opt_groups, state_dict, expected_error_match", + [ + # Case 1: Pure empty - should pass + (("required",), (), {"required": "value"}, None), + # Case 2: Nested empty group - should raise descriptive ValueError + ( + ("required",), + ((),), + {"required": "value"}, + r"Empty group found in 'ParametrizedSerializable._state_dict_one_of_opt_keys'", + ), + # Case 3: Mixed empty and filled - should raise descriptive ValueError + ( + ("required",), + ((), ("opt1", "opt2")), + {"required": "value", "opt1": "val"}, + r"Empty group found in 'ParametrizedSerializable._state_dict_one_of_opt_keys'", + ), + # Case 4: Standard one-of group - missing keys + ( + ("required",), + (("opt1", "opt2"),), + {"required": "value"}, + r"should contain at least one of '\('opt1', 'opt2'\)'", + ), + # Case 5: Standard one-of group - too many keys + ( + ("required",), + (("opt1", "opt2"),), + {"required": "value", "opt1": "v1", "opt2": "v2"}, + r"should contain only one of '\('opt1', 'opt2'\)'", + ), + # Case 6: Standard one-of group - valid + (("required",), (("opt1", "opt2"),), {"required": "value", "opt1": "v1"}, None), + ], +) +def test_optional_groups_logic(all_req, opt_groups, state_dict, expected_error_match): + """Test various combinations of optional groups using parametrization.""" + + class ParametrizedSerializable(Serializable): + _state_dict_all_req_keys = all_req + _state_dict_one_of_opt_keys = opt_groups + + def state_dict(self): + return {} + + s = ParametrizedSerializable() + if expected_error_match: + # for negative cases where we expect case to fail + with pytest.raises(ValueError, match=expected_error_match): + s.load_state_dict(state_dict) + else: + # for positive cases + s.load_state_dict(state_dict) + + +def test_engine_style_validation(): + """Test validation that mimics Engine usage.""" + s = EngineStyleSerializable() + + # Valid: iteration + max_iters + s.load_state_dict({"epoch_length": 100, "iteration": 150, "max_iters": 500}) + + # Valid: epoch + max_epochs + s2 = EngineStyleSerializable() + s2.load_state_dict({"epoch_length": 100, "epoch": 3, "max_epochs": 10}) + + # Invalid: both iteration and epoch + s3 = EngineStyleSerializable() + with pytest.raises(ValueError, match="should contain only one of.*iteration.*epoch"): + s3.load_state_dict({"epoch_length": 100, "iteration": 150, "epoch": 3, "max_epochs": 10}) + + # Invalid: both max_epochs and max_iters + s4 = EngineStyleSerializable() + with pytest.raises(ValueError, match="should contain only one of.*max_epochs.*max_iters"): + s4.load_state_dict({"epoch_length": 100, "iteration": 150, "max_epochs": 10, "max_iters": 500}) + + +def test_single_option_group(): + """Test group with single option.""" + + class SingleOptionSerializable(Serializable): + _state_dict_all_req_keys = ("base",) + _state_dict_one_of_opt_keys = (("single",),) + + def state_dict(self): + return {} + + s = SingleOptionSerializable() + + # Should require the single option + with pytest.raises(ValueError, match="should contain at least one of"): + s.load_state_dict({"base": "value"}) + + # Should pass with single option + s.load_state_dict({"base": "value", "single": "option"}) + + +def test_inheritance_overrides(): + """Test that subclasses can override validation rules.""" + + class BaseSerializable(Serializable): + _state_dict_all_req_keys = ("base_req",) + _state_dict_one_of_opt_keys = (("base_opt1", "base_opt2"),) + + def state_dict(self): + return {} + + class DerivedSerializable(BaseSerializable): + _state_dict_all_req_keys = ("derived_req1", "derived_req2") + _state_dict_one_of_opt_keys = (("derived_opt1", "derived_opt2"),) + + # Base class uses its own rules + base = BaseSerializable() + base.load_state_dict({"base_req": "value", "base_opt1": "opt"}) + + # Derived class uses overridden rules + derived = DerivedSerializable() + with pytest.raises(ValueError, match="Required state attribute.*derived_req1"): + derived.load_state_dict({"base_req": "value", "base_opt1": "opt"}) + + # Valid for derived class + derived.load_state_dict({"derived_req1": "d1", "derived_req2": "d2", "derived_opt2": "opt"}) + + +def test_complex_grouped_keys(): + """Test grouped optional keys.""" + s = EngineStyleSerializable() + + # Valid with all requirements + s.load_state_dict({"epoch_length": 100, "iteration": 250, "max_iters": 500}) + + # Missing from a group should fail + s2 = EngineStyleSerializable() + with pytest.raises(ValueError, match="should contain at least one of.*iteration.*epoch"): + s2.load_state_dict({"epoch_length": 100, "max_iters": 500}) + + +def test_backwards_compatibility(): + """Test that old style validation still works.""" + + class OldStyleSerializable(Serializable): + _state_dict_all_req_keys = ("req1", "req2") + # No _state_dict_one_of_opt_keys defined - should default to empty + + def state_dict(self): + return {} + + s = OldStyleSerializable() + + # Should work with just required keys + s.load_state_dict({"req1": "r1", "req2": "r2"}) + + # Should fail without required keys + with pytest.raises(ValueError, match="Required state attribute"): + s.load_state_dict({"req1": "r1"}) + + +def test_complex_scenario(): + """Test complex scenario with multiple groups and user keys.""" + + class ComplexSerializable(Serializable): + _state_dict_all_req_keys = ("base1", "base2") + _state_dict_one_of_opt_keys = ( + ("pos1", "pos2", "pos3"), + ("term1", "term2"), + ("opt1", "opt2", "opt3", "opt4"), + ) + + def state_dict(self): + return {} + + s = ComplexSerializable() + # Valid complex state + s.load_state_dict( + { + "base1": "b1", + "base2": "b2", + "pos2": "position", + "term1": "termination", + "opt3": "option", + } + ) + + # Missing from one group should fail + s2 = ComplexSerializable() + with pytest.raises(ValueError, match="should contain at least one of.*term1.*term2"): + s2.load_state_dict({"base1": "b1", "base2": "b2", "pos1": "pos", "opt4": "opt"}) diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index 4ccfb7ea7720..db4a81029032 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -18,7 +18,9 @@ def test_state_dict(): def _test(state): engine.state = state sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + # Total keys = required keys + 1 from each optional group (e.g. iteration & max_epochs) + expected_len = len(engine._state_dict_all_req_keys) + len(engine._state_dict_one_of_opt_keys) + assert isinstance(sd, Mapping) and len(sd) == expected_len assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length assert sd["max_epochs"] == engine.state.max_epochs @@ -35,9 +37,10 @@ def test_state_dict_with_user_keys(): def _test(state): engine.state = state sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + len( - engine.state_dict_user_keys - ) + # Total keys = required keys + 1 from each optional group + user keys + num_opt_keys = len(engine._state_dict_one_of_opt_keys) + expected_len = len(engine._state_dict_all_req_keys) + num_opt_keys + len(engine.state_dict_user_keys) + assert isinstance(sd, Mapping) and len(sd) == expected_len assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length assert sd["max_epochs"] == engine.state.max_epochs @@ -52,31 +55,49 @@ def test_state_dict_integration(): data = range(100) engine.run(data, max_epochs=10) sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 + # Total keys = required keys + 1 from each optional group (e.g. iteration & max_epochs) + expected_len = len(engine._state_dict_all_req_keys) + len(engine._state_dict_one_of_opt_keys) + assert isinstance(sd, Mapping) and len(sd) == expected_len assert sd["iteration"] == engine.state.iteration == 10 * 100 assert sd["epoch_length"] == engine.state.epoch_length == 100 assert sd["max_epochs"] == engine.state.max_epochs == 10 -def test_load_state_dict_asserts(): +@pytest.mark.parametrize( + "state_dict, error_type, match", + [ + ("not a dict", TypeError, r"Argument state_dict should be a dictionary"), + ({}, ValueError, r"Required state attribute 'epoch_length' is absent"), + ({"epoch_length": 100}, ValueError, r"state_dict should contain at least one of"), + ( + {"epoch_length": 100, "iteration": 10, "epoch": 1, "max_epochs": 5}, + ValueError, + r"should contain only one of '\('iteration', 'epoch'\)'", + ), + ( + {"epoch_length": 100, "iteration": 10, "max_epochs": 5, "max_iters": 500}, + ValueError, + r"should contain only one of '\('max_epochs', 'max_iters'\)'", + ), + ({"epoch": 5, "max_epochs": 3, "epoch_length": 10}, ValueError, r"larger than or equal to the current epoch"), + ( + {"iteration": 100, "max_iters": 50, "epoch_length": 10}, + ValueError, + r"larger than or equal to the current iteration", + ), + ( + {"iteration": 12, "epoch_length": 120, "max_epochs": 100}, + ValueError, + r"Required user state attribute 'alpha'", + ), + ], +) +def test_load_state_dict_errors(state_dict, error_type, match): engine = Engine(lambda e, b: 1) - - with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): - engine.load_state_dict("123") - - with pytest.raises(ValueError, match=r"is absent in provided state_dict"): - engine.load_state_dict({}) - - with pytest.raises(ValueError, match=r"state_dict should contain only one of"): - engine.load_state_dict({"max_epochs": 100, "epoch_length": 120}) - - with pytest.raises(ValueError, match=r"state_dict should contain only one of"): - engine.load_state_dict({"max_epochs": 100, "epoch_length": 120, "iteration": 12, "epoch": 123}) - - engine = Engine(lambda e, b: 1) - engine.state_dict_user_keys.append("alpha") - with pytest.raises(ValueError, match=r"Required user state attribute"): - engine.load_state_dict({"max_epochs": 100, "epoch_length": 120, "iteration": 12}) + if "alpha" in str(match): + engine.state_dict_user_keys.append("alpha") + with pytest.raises(error_type, match=match): + engine.load_state_dict(state_dict) engine = Engine(lambda e, b: 1) with pytest.raises(ValueError, match=r"If epoch is provided in the state dict, epoch_length should not be None"): @@ -117,6 +138,12 @@ def _test(sd): _test({"max_epochs": 100, "epoch_length": 120, "iteration": 123, "alpha": 0.1, "beta": "abc"}) + # Test missing user key + engine2 = Engine(lambda e, b: 1) + engine2.state_dict_user_keys.append("alpha") + with pytest.raises(ValueError, match="Required user state attribute 'alpha' is absent in provided state_dict"): + engine2.load_state_dict({"max_epochs": 100, "epoch_length": 120, "iteration": 123}) + def test_load_state_dict_integration(): engine = Engine(lambda e, b: 1) @@ -274,3 +301,418 @@ def test_restart_training(): state = engine.run(data, max_epochs=2) state.max_epochs = None engine.run(data, max_epochs=2) + + +@pytest.mark.parametrize( + "termination_param, value, expected_iters", + [ + ("max_epochs", 5, 500), + ("max_iters", 250, 250), + ], +) +def test_state_dict_termination_variants(termination_param, value, expected_iters): + """Test state_dict with different termination parameters.""" + engine = Engine(lambda e, b: 1) + data = range(100) + engine.run(data, **{termination_param: value}) + + sd = engine.state_dict() + assert "iteration" in sd + assert "epoch_length" in sd + assert termination_param in sd + other_param = "max_iters" if termination_param == "max_epochs" else "max_epochs" + assert other_param not in sd + assert sd[termination_param] == value + assert sd["epoch_length"] == 100 + assert sd["iteration"] == expected_iters + + +@pytest.mark.parametrize( + "state_dict, expected_state", + [ + ( + {"epoch": 2, "max_epochs": 5, "epoch_length": 100}, + {"epoch": 2, "max_epochs": 5, "epoch_length": 100, "iteration": 200}, + ), + ( + {"iteration": 150, "max_iters": 250, "epoch_length": 100}, + {"iteration": 150, "max_iters": 250, "epoch_length": 100, "epoch": 1}, + ), + ( + {"iteration": 150, "max_epochs": 3, "epoch_length": 100}, + {"iteration": 150, "max_epochs": 3, "epoch_length": 100, "epoch": 1}, + ), + ( + {"epoch": 2, "max_iters": 500, "epoch_length": 100}, + {"epoch": 2, "max_iters": 500, "epoch_length": 100, "iteration": 200}, + ), + ], +) +def test_load_state_dict_termination_variants(state_dict, expected_state): + """Test load_state_dict with different combinations of progress and termination params.""" + engine = Engine(lambda e, b: 1) + engine.load_state_dict(state_dict) + + for attr, expected_value in expected_state.items(): + assert getattr(engine.state, attr) == expected_value + + +def test_save_and_load_with_max_iters(): + """Test saving and loading engine state with max_iters.""" + # Create and run engine with max_iters + engine1 = Engine(lambda e, b: b) + data = list(range(20)) + engine1.run(data, max_iters=50, epoch_length=10) + + # Save state + state_dict = engine1.state_dict() + assert state_dict["iteration"] == 50 + assert state_dict["max_iters"] == 50 + assert state_dict["epoch_length"] == 10 + assert "max_epochs" not in state_dict + + # Load state in new engine + engine2 = Engine(lambda e, b: b) + engine2.load_state_dict(state_dict) + + assert engine2.state.iteration == 50 + assert engine2.state.max_iters == 50 + assert engine2.state.epoch_length == 10 + assert engine2.state.epoch == 5 # 50 // 10 + + +def test_resume_with_max_iters(): + """Test resuming engine run with max_iters using early termination.""" + counter = [0] + + def update_fn(engine, batch): + counter[0] += 1 + return batch + + engine = Engine(update_fn) + data = list(range(10)) + + # Set up early termination at iteration 15 + @engine.on(Events.ITERATION_COMPLETED(once=15)) + def stop_early(engine): + engine.terminate() + + # Run with max_iters=25 but terminate early at 15 + engine.run(data, max_iters=25, epoch_length=10) + assert counter[0] == 15 + assert engine.state.iteration == 15 + assert engine.state.max_iters == 25 # Still has the original max_iters + + # Save and reload state + state_dict = engine.state_dict() + counter[0] = 0 # Reset counter + + engine2 = Engine(update_fn) + engine2.load_state_dict(state_dict) + + # Resume running - should continue from iteration 15 to 25 + engine2.run(data) + assert counter[0] == 10 # 25 - 15 + assert engine2.state.iteration == 25 + + +def test_mutually_exclusive_max_epochs_max_iters(): + """Test that max_epochs and max_iters are mutually exclusive.""" + engine = Engine(lambda e, b: 1) + data = range(10) + + with pytest.raises(ValueError, match="mutually exclusive"): + engine.run(data, max_epochs=5, max_iters=50) + + +def test_unknown_epoch_length_with_max_iters(): + """Test handling unknown epoch_length with max_iters.""" + counter = [0] + + def update_fn(engine, batch): + counter[0] += 1 + return batch + + def data_iter(): + for i in range(15): + yield i + + engine = Engine(update_fn) + + # Run with unknown epoch length and max_iters that completes before first epoch ends + engine.run(data_iter(), max_iters=10) + assert counter[0] == 10 + assert engine.state.iteration == 10 + # epoch_length remains None since we stopped before completing an epoch + assert engine.state.epoch_length is None + + # State dict should have max_iters + sd = engine.state_dict() + assert "max_iters" in sd + assert sd["max_iters"] == 10 + + # Test case where we complete a full epoch + engine2 = Engine(update_fn) + counter[0] = 0 + engine2.run(data_iter(), max_iters=20) + assert counter[0] == 15 # Iterator exhausted after 15 + assert engine2.state.iteration == 15 + # epoch_length should be determined when iterator is exhausted + assert engine2.state.epoch_length == 15 + + +def test_engine_attributes(): + """Test basic engine attributes and state.""" + engine = Engine(lambda e, b: 1) + + # Test basic attributes exist + assert hasattr(engine, "state") + assert hasattr(engine, "logger") + assert hasattr(engine, "state_dict_user_keys") + + # Test initial state + assert engine.state.iteration == 0 + assert engine.state.epoch == 0 + assert engine.state.max_epochs is None + assert engine.state.max_iters is None + assert engine.state.epoch_length is None + + +@pytest.mark.parametrize( + "param_name, current_val, low_val, high_val", + [ + ("max_epochs", 3, 2, 5), + ("max_iters", 30, 25, 40), + ], +) +def test_helper_methods(param_name, current_val, low_val, high_val): + """Test the helper validation methods.""" + engine = Engine(lambda e, b: 1) + data = range(10) + + # Initialize engine state + engine.run(data, **{param_name: current_val}) + + helper_method = getattr(engine, f"_check_and_set_{param_name}") + + # Test too low value + with pytest.raises(ValueError, match="greater than or equal to the start"): + helper_method(low_val) + + # Test valid higher value + helper_method(high_val) + assert getattr(engine.state, param_name) == high_val + + +def test_backward_compatibility(): + """Test backward compatibility with old state dicts.""" + engine = Engine(lambda e, b: 1) + + # Old state dict format (with max_epochs) + old_state_dict = {"iteration": 200, "max_epochs": 5, "epoch_length": 100} + + engine.load_state_dict(old_state_dict) + assert engine.state.iteration == 200 + assert engine.state.max_epochs == 5 + assert engine.state.epoch_length == 100 + assert engine.state.epoch == 2 # 200 // 100 + + +def test_user_keys_with_max_iters(): + """Test user-defined keys work with max_iters.""" + engine = Engine(lambda e, b: b) + data = list(range(10)) + + # Add user keys + engine.state_dict_user_keys.append("custom_value") + engine.state_dict_user_keys.append("another_value") + + @engine.on(Events.STARTED) + def init_custom_values(engine): + engine.state.custom_value = 42 + engine.state.another_value = "test" + + engine.run(data, max_iters=5) + + # Check state dict contains user keys + sd = engine.state_dict() + assert "custom_value" in sd + assert "another_value" in sd + assert sd["custom_value"] == 42 + assert sd["another_value"] == "test" + assert "max_iters" in sd + assert "max_epochs" not in sd + + # Load into new engine + engine2 = Engine(lambda e, b: b) + engine2.state_dict_user_keys.append("custom_value") + engine2.state_dict_user_keys.append("another_value") + + engine2.load_state_dict(sd) + assert engine2.state.custom_value == 42 + assert engine2.state.another_value == "test" + assert engine2.state.max_iters == 5 + + +def test_is_done_method_with_max_iters(): + """Test the _is_done static method with max_iters.""" + # Test with max_iters + state = State() + state.iteration = 100 + state.max_iters = 100 + state.epoch_length = 25 + state.epoch = 4 + state.max_epochs = None + + assert Engine._is_done(state) is True + + state.iteration = 99 + assert Engine._is_done(state) is False + + state.iteration = 101 + assert Engine._is_done(state) is True + + # Test with both set (shouldn't happen but test logic) + state.iteration = 50 + state.max_iters = 100 + state.max_epochs = 3 + state.epoch = 2 + state.epoch_length = 25 + assert Engine._is_done(state) is False + + state.iteration = 100 + assert Engine._is_done(state) is True + + state.iteration = 75 + state.epoch = 3 + assert Engine._is_done(state) is True + + +def test_none_data_with_max_iters(): + """Test running with None data and max_iters.""" + counter = [0] + + def update_fn(engine, batch): + assert batch is None + counter[0] += 1 + return 1 + + engine = Engine(update_fn) + + # Should work with None data if epoch_length provided + engine.run(data=None, max_iters=30, epoch_length=10) + + assert counter[0] == 30 + assert engine.state.iteration == 30 + assert engine.state.max_iters == 30 + assert engine.state.epoch_length == 10 + assert engine.state.epoch == 3 # ceil(30/10) = 3 + + +def test_epoch_calculation_with_max_iters(): + """Test epoch calculation when using max_iters.""" + engine = Engine(lambda e, b: b) + data = list(range(25)) + + # Run with max_iters that doesn't divide evenly + engine.run(data, max_iters=60) + + assert engine.state.iteration == 60 + assert engine.state.max_iters == 60 + assert engine.state.epoch_length == 25 + assert engine.state.epoch == 3 # ceil(60/25) = 3 + + # Save and verify state dict + sd = engine.state_dict() + assert sd["iteration"] == 60 + assert sd["max_iters"] == 60 + assert sd["epoch_length"] == 25 + + +def test_resume_with_higher_max_iters(): + """Test loading state and running with higher max_iters value.""" + counter = [0] + + def update_fn(engine, batch): + counter[0] += 1 + return batch + + # First engine: run until 15 then save state + engine1 = Engine(update_fn) + data = list(range(10)) + + # Use early termination to simulate partial run + @engine1.on(Events.ITERATION_COMPLETED(once=15)) + def stop_early(engine): + engine.terminate() + + engine1.run(data, max_iters=20) + assert counter[0] == 15 + assert engine1.state.iteration == 15 + assert engine1.state.max_iters == 20 + + # Save state + sd = engine1.state_dict() + counter[0] = 0 + + # Second engine: load state and increase max_iters + engine2 = Engine(update_fn) + engine2.load_state_dict(sd) + + # Directly set higher max_iters and run + engine2.run(data, max_iters=25) + assert counter[0] == 10 # 25 - 15 + assert engine2.state.iteration == 25 + assert engine2.state.max_iters == 25 + + # Final state dict + final_sd = engine2.state_dict() + assert final_sd["iteration"] == 25 + assert final_sd["max_iters"] == 25 + + +def test_checkpoint_with_max_iters(): + import tempfile + import os + from ignite.handlers import Checkpoint, DiskSaver + from ignite.engine import Engine, Events + + with tempfile.TemporaryDirectory() as tmpdir: + data = list(range(10)) + + def update_fn(engine, batch): + return 1 + + engine1 = Engine(update_fn) + + # Save after 15 iterations (mid 2nd epoch) + to_save = {"engine": engine1} + handler = Checkpoint(to_save, DiskSaver(tmpdir, require_empty=False), n_saved=1) + engine1.add_event_handler(Events.ITERATION_COMPLETED(once=15), handler) + + @engine1.on(Events.ITERATION_COMPLETED(once=15)) + def stop_early(): + engine1.terminate() + + engine1.run(data, max_iters=25) + + assert engine1.state.iteration == 15 + assert engine1.state.max_iters == 25 + + # Reload checkpoint + engine2 = Engine(update_fn) + checkpoint_path = os.path.join(tmpdir, os.listdir(tmpdir)[0]) + import torch + + checkpoint = torch.load(checkpoint_path) + Checkpoint.load_objects(to_load={"engine": engine2}, checkpoint=checkpoint) + + assert engine2.state.iteration == 15 + assert engine2.state.max_iters == 25 + assert engine2.state.epoch_length == 10 + + # Resume + engine2.run(data, max_iters=25) + assert engine2.state.iteration == 25 + assert engine2.state.epoch == 3 + assert getattr(engine2.state, "max_epochs", None) is None