ScikitLearn Example#

In this example, we will see how to convert a scikitlearn model to an ONNX model.

First import the necessary libraries.

from typing import List, NamedTuple

import numpy
import onnxruntime as rt
import pandas as pd
from flytekit import task, workflow
from flytekit.types.file import ONNXFile
from flytekitplugins.onnxscikitlearn import ScikitLearn2ONNX, ScikitLearn2ONNXConfig
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from typing_extensions import Annotated

Define a NamedTuple to hold the output schema. Note the annotation on the model field. This is a special annotation that tells Flytekit that this parameter is to be converted to an ONNX model with the given metadata.

TrainOutput = NamedTuple(
                    initial_types=[("float_input", FloatTensorType([None, 4]))],
        ("test", pd.DataFrame),

Define a train task that will train a scikitlearn model and return the model and test data.

def train() -> TrainOutput:
    iris = load_iris(as_frame=True)
    X, y =,
    X_train, X_test, y_train, _ = train_test_split(X, y)
    model = RandomForestClassifier(), y_train)

    return TrainOutput(test=X_test, model=ScikitLearn2ONNX(model))

Define a predict task that will use the model to predict the labels for the test data.

def predict(
    model: ONNXFile,
    X_test: pd.DataFrame,
) -> List[int]:
    sess = rt.InferenceSession(
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name
    pred_onx =
        [label_name], {input_name: X_test.to_numpy(dtype=numpy.float32)}
    return pred_onx.tolist()

Lastly define a workflow to run the above tasks.

def wf() -> List[int]:
    train_output = train()
    return predict(model=train_output.model, X_test=train_output.test)

Run the workflow locally.

if __name__ == "__main__":
    print(f"Predictions: {wf()}")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery