Source code for flytekitplugins.onnxscikitlearn.schema

from __future__ import annotations

import inspect
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

from dataclasses_json import DataClassJsonMixin
from typing_extensions import Annotated, get_args, get_origin

from flytekit import FlyteContext, lazy_module
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.file import ONNXFile

sklearn = lazy_module("sklearn")
skl2onnx = lazy_module("skl2onnx")
skl2onnx_data_types = lazy_module("skl2onnx.common.data_types")


[docs] @dataclass class ScikitLearn2ONNXConfig(DataClassJsonMixin): """ ScikitLearn2ONNXConfig is the config used during the scikitlearn to ONNX conversion. Args: initial_types: The types of the inputs to the model. name: The name of the graph in the produced ONNX model. doc_string: A string attached onto the produced ONNX model. target_opset: The ONNX opset number. custom_conversion_functions: A dictionary for specifying the user customized conversion function. custom_shape_calculators: A dictionary for specifying the user customized shape calculator. custom_parsers: Parsers determine which outputs are expected for which particular task. options: Specific options given to converters. intermediate: If True, the function returns the converted model and the instance of Topology used, else, it returns the converted model. naming: Change the way intermediates are named. white_op: White list of ONNX nodes allowed while converting a pipeline. black_op: Black list of ONNX nodes disallowed while converting a pipeline. verbose: Display progress while converting a model. final_types: Used to overwrite the type (if type is not None) and the name of every output. """ initial_types: List[Tuple[str, Type]] name: Optional[str] = None doc_string: str = "" target_opset: Optional[int] = None custom_conversion_functions: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) custom_shape_calculators: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) custom_parsers: Dict[Callable[..., Any], Callable[..., None]] = field(default_factory=dict) options: Dict[Any, Any] = field(default_factory=dict) intermediate: bool = False naming: Optional[Union[str, Callable[..., Any]]] = None white_op: Optional[Set[str]] = None black_op: Optional[Set[str]] = None verbose: int = 0 final_types: Optional[List[Tuple[str, Type]]] = None def __post_init__(self): validate_initial_types = [ True for item in self.initial_types if item in inspect.getmembers(skl2onnx_data_types) ] if not all(validate_initial_types): raise ValueError("All types in initial_types must be in skl2onnx.common.data_types") if self.final_types: validate_final_types = [ True for item in self.final_types if item in inspect.getmembers(skl2onnx_data_types) ] if not all(validate_final_types): raise ValueError("All types in final_types must be in skl2onnx.common.data_types")
[docs] @dataclass class ScikitLearn2ONNX(DataClassJsonMixin): model: sklearn.base.BaseEstimator = field(default=None)
def extract_config(t: Type[ScikitLearn2ONNX]) -> Tuple[Type[ScikitLearn2ONNX], ScikitLearn2ONNXConfig]: config = None if get_origin(t) is Annotated: base_type, config = get_args(t) if isinstance(config, ScikitLearn2ONNXConfig): return base_type, config else: raise TypeTransformerFailedError(f"{t}'s config isn't of type ScikitLearn2ONNXConfig") return t, config def to_onnx(ctx, model, config): local_path = ctx.file_access.get_random_local_path() onx = skl2onnx.convert_sklearn(model, **config) with open(local_path, "wb") as f: f.write(onx.SerializeToString()) return local_path class ScikitLearn2ONNXTransformer(TypeTransformer[ScikitLearn2ONNX]): ONNX_FORMAT = "onnx" def __init__(self): super().__init__(name="ScikitLearn ONNX", t=ScikitLearn2ONNX) def get_literal_type(self, t: Type[ScikitLearn2ONNX]) -> LiteralType: return LiteralType(blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)) def to_literal( self, ctx: FlyteContext, python_val: ScikitLearn2ONNX, python_type: Type[ScikitLearn2ONNX], expected: LiteralType, ) -> Literal: python_type, config = extract_config(python_type) if config: local_path = to_onnx(ctx, python_val.model, config.__dict__.copy()) remote_path = ctx.file_access.put_raw_data(local_path) else: raise TypeTransformerFailedError(f"{python_type}'s config is None") return Literal( scalar=Scalar( blob=Blob( uri=remote_path, metadata=BlobMetadata( type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE) ), ) ) ) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ONNXFile], ) -> ONNXFile: if not lv.scalar.blob.uri: raise TypeTransformerFailedError(f"ONNX format isn't of the expected type {expected_python_type}") return ONNXFile(path=lv.scalar.blob.uri) def guess_python_type(self, literal_type: LiteralType) -> Type[ScikitLearn2ONNX]: if ( literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.ONNX_FORMAT ): return ScikitLearn2ONNX raise TypeTransformerFailedError(f"Transformer {self} cannot reverse {literal_type}") TypeEngine.register(ScikitLearn2ONNXTransformer())