Source code for flytekit.extras.pytorch.checkpoint

import pathlib
import typing
from dataclasses import asdict, dataclass, fields, is_dataclass
from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union

import torch
from dataclasses_json import DataClassJsonMixin
from typing_extensions import Protocol

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


class IsDataclass(Protocol):
    __dataclass_fields__: Dict
    __dataclass_params__: Dict
    __post_init__: Optional[Callable]


[docs] @dataclass class PyTorchCheckpoint(DataClassJsonMixin): """ This class is helpful to save a checkpoint. """ module: Optional[torch.nn.Module] = None hyperparameters: Optional[Union[Dict[str, Any], NamedTuple, IsDataclass]] = None optimizer: Optional[torch.optim.Optimizer] = None def __post_init__(self): if not ( isinstance(self.hyperparameters, dict) or (is_dataclass(self.hyperparameters) and not isinstance(self.hyperparameters, type)) or (isinstance(self.hyperparameters, tuple) and hasattr(self.hyperparameters, "_fields")) or (self.hyperparameters is None) ): raise TypeTransformerFailedError( f"hyperparameters must be a dict, dataclass, or NamedTuple. Got {type(self.hyperparameters)}" ) if not (self.module or self.hyperparameters or self.optimizer): raise TypeTransformerFailedError("Must have at least one of module, hyperparameters, or optimizer")
[docs] class PyTorchCheckpointTransformer(TypeTransformer[PyTorchCheckpoint]): """ TypeTransformer that supports serializing and deserializing checkpoint. """ PYTORCH_CHECKPOINT_FORMAT = "PyTorchCheckpoint" def __init__(self): super().__init__(name="PyTorch Checkpoint", t=PyTorchCheckpoint)
[docs] def get_literal_type(self, t: Type[PyTorchCheckpoint]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) )
[docs] def to_literal( self, ctx: FlyteContext, python_val: PyTorchCheckpoint, python_type: Type[PyTorchCheckpoint], expected: LiteralType, ) -> Literal: meta = BlobMetadata( type=_core_types.BlobType( format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ) local_path = ctx.file_access.get_random_local_path() + ".pt" pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) to_save = {} for field in fields(python_val): value = getattr(python_val, field.name) if value and field.name in ["module", "optimizer"]: to_save[field.name + "_state_dict"] = getattr(value, "state_dict")() elif value and field.name == "hyperparameters": if isinstance(value, dict): to_save.update(value) elif isinstance(value, tuple): to_save.update(value._asdict()) elif is_dataclass(value): to_save.update(asdict(value)) if not to_save: raise TypeTransformerFailedError(f"Cannot save empty {python_val}") # save checkpoint to a file torch.save(to_save, local_path) remote_path = ctx.file_access.put_raw_data(local_path) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
[docs] def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PyTorchCheckpoint] ) -> PyTorchCheckpoint: try: uri = lv.scalar.blob.uri except AttributeError: TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() ctx.file_access.get_data(uri, local_path, is_multipart=False) # cpu <-> gpu conversion if torch.cuda.is_available(): map_location = "cuda:0" else: map_location = torch.device("cpu") # load checkpoint from a file return typing.cast(PyTorchCheckpoint, torch.load(local_path, map_location=map_location))
[docs] def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorchCheckpoint]: if ( literal_type.blob is not None and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.PYTORCH_CHECKPOINT_FORMAT ): return PyTorchCheckpoint raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
TypeEngine.register(PyTorchCheckpointTransformer())