ScikitLearn Example

In this example, we will see how to convert a scikitlearn model to an ONNX model.

First import the necessary libraries.

from typing import List, NamedTuple

import numpy
import onnxruntime as rt
import pandas as pd
from flytekit import task, workflow
from flytekit.types.file import ONNXFile
from flytekitplugins.onnxscikitlearn import ScikitLearn2ONNX, ScikitLearn2ONNXConfig
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from typing_extensions import Annotated

Define a NamedTuple to hold the output schema. Note the annotation on the model field. This is a special annotation that tells Flytekit that this parameter is to be converted to an ONNX model with the given metadata.

TrainOutput = NamedTuple(
    "TrainOutput",
    [
        (
            "model",
            Annotated[
                ScikitLearn2ONNX,
                ScikitLearn2ONNXConfig(
                    initial_types=[("float_input", FloatTensorType([None, 4]))],
                    target_opset=12,
                ),
            ],
        ),
        ("test", pd.DataFrame),
    ],
)

Define a train task that will train a scikitlearn model and return the model and test data.

@task
def train() -> TrainOutput:
    iris = load_iris(as_frame=True)
    X, y = iris.data, iris.target
    X_train, X_test, y_train, _ = train_test_split(X, y)
    model = RandomForestClassifier()
    model.fit(X_train, y_train)

    return TrainOutput(test=X_test, model=ScikitLearn2ONNX(model))

Define a predict task that will use the model to predict the labels for the test data.

@task
def predict(
    model: ONNXFile,
    X_test: pd.DataFrame,
) -> List[int]:
    sess = rt.InferenceSession(model.download())
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name
    pred_onx = sess.run([label_name], {input_name: X_test.to_numpy(dtype=numpy.float32)})[0]
    return pred_onx.tolist()

Lastly define a workflow to run the above tasks.

@workflow
def wf() -> List[int]:
    train_output = train()
    return predict(model=train_output.model, X_test=train_output.test)

Run the workflow locally.

if __name__ == "__main__":
    print(f"Predictions: {wf()}")