Source code for flytekitplugins.onnxpytorch.schema

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union

import torch
from dataclasses_json import DataClassJsonMixin
from torch.onnx import OperatorExportTypes, TrainingMode
from typing_extensions import Annotated, get_args, get_origin

from flytekit import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file import ONNXFile

[docs]@dataclass class PyTorch2ONNXConfig(DataClassJsonMixin): """ PyTorch2ONNXConfig is the config used during the pytorch to ONNX conversion. Args: args: The input to the model. export_params: Whether to export all the parameters. verbose: Whether to print description of the ONNX model. training: Whether to export the model in training mode or inference mode. opset_version: The ONNX version to export the model to. input_names: Names to assign to the input nodes of the graph. output_names: Names to assign to the output nodes of the graph. operator_export_type: How to export the ops. do_constant_folding: Whether to apply constant folding for optimization. dynamic_axes: Specify axes of tensors as dynamic. keep_initializers_as_inputs: Whether to add the initializers as inputs to the graph. custom_opsets: A dictionary of opset doman name and version. export_modules_as_functions: Whether to export modules as functions. """ args: Union[Tuple, torch.Tensor] export_params: bool = True verbose: bool = False training: TrainingMode = TrainingMode.EVAL opset_version: int = 9 input_names: List[str] = field(default_factory=list) output_names: List[str] = field(default_factory=list) operator_export_type: Optional[OperatorExportTypes] = None do_constant_folding: bool = False dynamic_axes: Union[Dict[str, Dict[int, str]], Dict[str, List[int]]] = field(default_factory=dict) keep_initializers_as_inputs: Optional[bool] = None custom_opsets: Dict[str, int] = field(default_factory=dict) export_modules_as_functions: Union[bool, set[Type]] = False
[docs]@dataclass class PyTorch2ONNX(DataClassJsonMixin): model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction] = field(default=None)
def extract_config(t: Type[PyTorch2ONNX]) -> Tuple[Type[PyTorch2ONNX], PyTorch2ONNXConfig]: config = None if get_origin(t) is Annotated: base_type, config = get_args(t) if isinstance(config, PyTorch2ONNXConfig): return base_type, config else: raise TypeTransformerFailedError(f"{t}'s config isn't of type PyTorch2ONNXConfig") return t, config def to_onnx(ctx, model, config): local_path = ctx.file_access.get_random_local_path() torch.onnx.export( model, **config, f=local_path, ) return local_path class PyTorch2ONNXTransformer(TypeTransformer[PyTorch2ONNX]): ONNX_FORMAT = "onnx" def __init__(self): super().__init__(name="PyTorch ONNX", t=PyTorch2ONNX) def get_literal_type(self, t: Type[PyTorch2ONNX]) -> LiteralType: return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) def to_literal( self, ctx: FlyteContext, python_val: PyTorch2ONNX, python_type: Type[PyTorch2ONNX], expected: LiteralType, ) -> Literal: python_type, config = extract_config(python_type) if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) remote_path = ctx.file_access.put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") return Literal( scalar=Scalar( blob=Blob( uri=remote_path, metadata=BlobMetadata( type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) ), ) ) ) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ONNXFile], ) -> ONNXFile: if not (lv.scalar.blob.uri and lv.scalar.blob.metadata.format == self.ONNX_FORMAT): raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") return ONNXFile(path=lv.scalar.blob.uri) def guess_python_type(self, literal_type: LiteralType) -> Type[PyTorch2ONNX]: if ( literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.ONNX_FORMAT ): return PyTorch2ONNX raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") TypeEngine.register(PyTorch2ONNXTransformer())