Source code for lightning.fabric.plugins.io.checkpoint_io
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from torch.nn import Module
from torch.optim import Optimizer
from lightning.fabric.utilities.types import _PATH
[docs]class CheckpointIO(ABC):
"""Interface to save/load checkpoints as they are saved through the ``Strategy``.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may
require particular handling depending on the plugin.
In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it
to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``.
.. note::
For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not
modifiable.
"""
[docs] @abstractmethod
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
path: write-target path
storage_options: Optional parameters when saving the model/training states.
"""
[docs] @abstractmethod
def load_checkpoint(
self,
path: _PATH,
*,
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
map_location: Optional[Any] = None,
weights_only: Optional[bool] = None,
) -> dict[str, Any]:
"""Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
Args:
path: Path to checkpoint
state: Optional dict to load the checkpoint into.
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
recommend using ``weights_only=True``. For more information, please refer to the
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
Returns:
- A dictionary containing checkpoint contents that still need to be
applied by the caller, or
- ``None`` if the checkpoint was fully restored in-place into ``state``.
"""
@property
def requires_cpu_collectives(self) -> bool:
return False
@property
def _restore_after_setup(self) -> bool:
"""Whether checkpoint restoration should be delayed until after the Strategy setup phase.
Some checkpoint implementations require the distributed environment, device placement,
or wrapped modules to be fully initialized before loading state. When this returns
``True``, the Trainer/Strategy will restore the checkpoint only after setup has completed.
This is primarily used by distributed checkpointing backends that depend on collective
communication during load.
"""
return False
@property
def requires_state_on_load(self) -> bool:
"""Whether the ``state`` argument of ``load_checkpoint`` is required for loading the checkpoint.
If ``True``, the Trainer will always pass a state dict containing the current model and optimizer
to the ``load_checkpoint`` method. This is for plugins that need to do in-place restoration of the
checkpoint into the provided state objects instead of returning a new checkpoint dict.
"""
return self._restore_after_setup
[docs] @abstractmethod
def remove_checkpoint(self, path: _PATH) -> None:
"""Remove checkpoint file from the filesystem.
Args:
path: Path to checkpoint
"""
[docs] def teardown(self) -> None:
"""This method is called to teardown the process."""