Skip to content
Open
23 changes: 18 additions & 5 deletions ignite/base/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def attach(self, engine: "Engine", *args: Any, **kwargs: Any) -> None:


class Serializable:
_state_dict_all_req_keys: tuple = ()
_state_dict_one_of_opt_keys: tuple = ()
_state_dict_all_req_keys: tuple[str, ...] = ()
_state_dict_one_of_opt_keys: tuple[tuple[str, ...], ...] = ()
Comment thread
TahaZahid05 marked this conversation as resolved.

def state_dict(self) -> OrderedDict:
raise NotImplementedError
Expand All @@ -43,6 +43,19 @@ def load_state_dict(self, state_dict: Mapping) -> None:
raise ValueError(
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")

opt_groups = self._state_dict_one_of_opt_keys
if len(opt_groups) > 0 and isinstance(opt_groups[0], str):
opt_groups = (opt_groups,)

# Handle groups of one-of optional keys
for one_of_opt_keys in opt_groups:
if len(one_of_opt_keys) == 0:
raise ValueError(
f"Empty group found in '{self.__class__.__name__}._state_dict_one_of_opt_keys'. "
"Each group must contain at least one state attribute key."
)
opts = [k in state_dict for k in one_of_opt_keys]
num_present = sum(opts)
if num_present != 1:
raise ValueError(f"state_dict should contain exactly one of '{one_of_opt_keys}' keys")
184 changes: 138 additions & 46 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def compute_mean_std(engine, batch):

"""

_state_dict_all_req_keys = ("epoch_length", "max_epochs")
_state_dict_one_of_opt_keys = ("iteration", "epoch")
_state_dict_all_req_keys: tuple[str, ...] = ("epoch_length", "iteration")
_state_dict_one_of_opt_keys: tuple[tuple[str, ...], ...] = (("iteration", "epoch"), ("max_epochs", "max_iters"))

# Flag to disable engine._internal_run as generator feature for BC
interrupt_resume_enabled = True
Expand Down Expand Up @@ -682,8 +682,9 @@ def state_dict_user_keys(self) -> list:
return self._state_dict_user_keys

def state_dict(self) -> OrderedDict:
"""Returns a dictionary containing engine's state: "epoch_length", "max_epochs" and "iteration" and
other state values defined by `engine.state_dict_user_keys`
"""Returns a dictionary containing engine's state: ``"epoch_length"``, ``"iteration"``,
one of ``"max_epochs"`` or ``"max_iters"``, and other state values defined by
``engine.state_dict_user_keys``.

.. code-block:: python

Expand All @@ -707,15 +708,23 @@ def save_engine(_):
OrderedDict:
a dictionary containing engine's state

.. versionchanged:: 0.5.5
Added support for serializing ``max_iters``.

Comment thread
vfdev-5 marked this conversation as resolved.
"""
keys: tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
keys: tuple[str, ...] = self._state_dict_all_req_keys
# Include either max_epochs or max_iters based on which was originally set
if self.state.max_iters is not None:
keys += ("max_iters",)
else:
keys += ("max_epochs",)
Comment thread
TahaZahid05 marked this conversation as resolved.
keys += tuple(self._state_dict_user_keys)
return OrderedDict([(k, getattr(self.state, k)) for k in keys])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Setups engine from `state_dict`.

State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`.
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.

Expand All @@ -726,15 +735,20 @@ def load_state_dict(self, state_dict: Mapping) -> None:

.. code-block:: python

# Restore from the 4rd epoch
# Restore from the 4th epoch
state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
# or 500th iteration
# state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)}
# or with max_iters
# state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)}

trainer = Engine(...)
trainer.load_state_dict(state_dict)
trainer.run(data)

.. versionchanged:: 0.5.5
Added support for restoring from a state dict containing ``max_iters`` instead of ``max_epochs``.

Comment thread
vfdev-5 marked this conversation as resolved.
"""
super().load_state_dict(state_dict)

Expand All @@ -743,17 +757,15 @@ def load_state_dict(self, state_dict: Mapping) -> None:
raise ValueError(
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
self.state.max_epochs = state_dict["max_epochs"]
self.state.epoch_length = state_dict["epoch_length"]
for k in self._state_dict_user_keys:
setattr(self.state, k, state_dict[k])
self.state.epoch_length = state_dict["epoch_length"]

if "iteration" in state_dict:
self.state.iteration = state_dict["iteration"]
self.state.epoch = 0
if self.state.epoch_length is not None:
self.state.epoch = self.state.iteration // self.state.epoch_length
elif "epoch" in state_dict:
else: # epoch is in state_dict
self.state.epoch = state_dict["epoch"]
if self.state.epoch_length is None:
raise ValueError(
Expand All @@ -762,6 +774,22 @@ def load_state_dict(self, state_dict: Mapping) -> None:
)
self.state.iteration = self.state.epoch_length * self.state.epoch

if "max_epochs" in state_dict:
self.state.max_iters = None
max_epochs = state_dict.get("max_epochs")
if max_epochs is None:
self.state.max_epochs = None
else:
self._check_and_set_max_epochs(max_epochs)

elif "max_iters" in state_dict:
self.state.max_epochs = None
max_iters = state_dict.get("max_iters")
if max_iters is None:
self.state.max_iters = None
else:
self._check_and_set_max_iters(max_iters)

@staticmethod
def _is_done(state: State) -> bool:
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
Expand All @@ -773,6 +801,31 @@ def _is_done(state: State) -> bool:
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
return is_done_iters or is_done_count or is_done_epochs

def _check_and_set_termination_param(
self, name: str, value: int | None, progress_name: str, progress_value: int
) -> None:
"""Validate and set the passed parameter (max_epochs or max_iters)."""
if value is not None:
if value < 1:
raise ValueError(f"Argument {name} is invalid. Please, set a correct {name} positive value")

if value < progress_value:
raise ValueError(
f"Argument {name} should be greater than or equal to the start "
f"{progress_name} defined in the state: {value} vs {progress_value}. "
f"Please, set engine.state.{name} = None "
"before calling engine.run() in order to restart the training from the beginning."
)
setattr(self.state, name, value)

def _check_and_set_max_epochs(self, max_epochs: int | None = None) -> None:
"""Validate and set max_epochs with proper checks."""
self._check_and_set_termination_param("max_epochs", max_epochs, "epoch", self.state.epoch)

def _check_and_set_max_iters(self, max_iters: int | None = None) -> None:
"""Validate and set max_iters with proper checks."""
self._check_and_set_termination_param("max_iters", max_iters, "iteration", self.state.iteration)

def set_data(self, data: Iterable | DataLoader) -> None:
"""Method to set data. After calling the method the next batch passed to `processing_function` is
from newly provided data. Please, note that epoch length is not modified.
Expand Down Expand Up @@ -871,45 +924,58 @@ def switch_batch(engine):
if data is not None and not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

if max_epochs is not None and max_iters is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive. Please provide only max_epochs or max_iters."
)

# Check for mode switching during resume
if max_iters is not None and self.state.max_epochs is not None:
raise ValueError(
"To switch from max_epochs to max_iters mode during resume, "
"you must first reset by setting 'engine.state.max_epochs = None' before calling run()."
)
if max_epochs is not None and self.state.max_iters is not None:
raise ValueError(
"To switch from max_iters to max_epochs mode during resume, "
"you must first reset by setting 'engine.state.max_iters = None' before calling run()."
)

if self.state.max_epochs is not None:
# Check and apply overridden parameters
if max_epochs is not None:
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be greater than or equal to the start "
f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. "
"Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
self._check_and_set_max_epochs(max_epochs)

if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
if self.state.max_iters is not None:
self._check_and_set_max_iters(max_iters)
Comment thread
TahaZahid05 marked this conversation as resolved.

# Check if we need to create new state or resume
# Create new state if:
# 1. No termination params set (first run), OR
# 2. Training is done AND generator is None
should_create_new_state = (self.state.max_epochs is None and self.state.max_iters is None) or (
self._is_done(self.state) and self._internal_run_generator is None
)

if should_create_new_state:
# Create new state
if epoch_length is None:
if data is None:
raise ValueError("epoch_length should be provided if data is None")
if data is None and epoch_length is None and self.state.epoch_length is None:
raise ValueError("epoch_length should be provided if data is None")

epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
# Set epoch_length for new state
if epoch_length is None:
# Try to get from data first, then fall back to existing state
if data is not None:
epoch_length = self._get_data_length(data)
if epoch_length is None:
epoch_length = self.state.epoch_length
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

if max_iters is None:
if max_epochs is None:
max_epochs = 1
else:
if max_epochs is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters."
)
if epoch_length is not None:
max_epochs = math.ceil(max_iters / epoch_length)
# If max_iters is provided, we stay in max_iters mode
max_epochs = None

self.state.iteration = 0
self.state.epoch = 0
Expand All @@ -918,12 +984,38 @@ def switch_batch(engine):
self.state.epoch_length = epoch_length
# Reset generator if previously used
self._internal_run_generator = None
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.")

if self.state.max_iters is not None:
self.logger.info(f"Engine run starting with max_iters={self.state.max_iters}.")
else:
self.logger.info(f"Engine run starting with max_epochs={self.state.max_epochs}.")

else:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)
if self.state.epoch_length is not None:
if epoch_length is not None and epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
else:
if epoch_length is None and data is not None:
epoch_length = self._get_data_length(data)
if epoch_length is not None:
if epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
self.state.epoch_length = epoch_length

if self.state.max_iters is not None:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_iters} iterations"
)
else:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)

Comment thread
TahaZahid05 marked this conversation as resolved.
if self.state.epoch_length is None and data is None:
raise ValueError("epoch_length should be provided if data is None")

Expand Down
Loading
Loading