Source code for lightning.fabric.plugins.io.checkpoint_io

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Optional, Union

from torch.nn import Module
from torch.optim import Optimizer

from lightning.fabric.utilities.types import _PATH


[docs]class CheckpointIO(ABC): """Interface to save/load checkpoints as they are saved through the ``Strategy``. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may require particular handling depending on the plugin. In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``. .. note:: For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not modifiable. """
[docs] @abstractmethod def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state path: write-target path storage_options: Optional parameters when saving the model/training states. """
[docs] @abstractmethod def load_checkpoint( self, path: _PATH, *, state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, map_location: Optional[Any] = None, weights_only: Optional[bool] = None, ) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint state: Optional dict to load the checkpoint into. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain ``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we recommend using ``weights_only=True``. For more information, please refer to the `PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_. Returns: - A dictionary containing checkpoint contents that still need to be applied by the caller, or - ``None`` if the checkpoint was fully restored in-place into ``state``. """
@property def requires_cpu_collectives(self) -> bool: return False @property def _restore_after_setup(self) -> bool: """Whether checkpoint restoration should be delayed until after the Strategy setup phase. Some checkpoint implementations require the distributed environment, device placement, or wrapped modules to be fully initialized before loading state. When this returns ``True``, the Trainer/Strategy will restore the checkpoint only after setup has completed. This is primarily used by distributed checkpointing backends that depend on collective communication during load. """ return False @property def requires_state_on_load(self) -> bool: """Whether the ``state`` argument of ``load_checkpoint`` is required for loading the checkpoint. If ``True``, the Trainer will always pass a state dict containing the current model and optimizer to the ``load_checkpoint`` method. This is for plugins that need to do in-place restoration of the checkpoint into the provided state objects instead of returning a new checkpoint dict. """ return self._restore_after_setup
[docs] @abstractmethod def remove_checkpoint(self, path: _PATH) -> None: """Remove checkpoint file from the filesystem. Args: path: Path to checkpoint """
[docs] def teardown(self) -> None: """This method is called to teardown the process."""