Machine learning

Flyte can handle a full spectrum of machine learning workloads, from training small models to gpu-accelerated deep learning and hyperparameter optimization.

Getting the data

In this simple example, we train a binary classification model on the wine dataset that is available through the scikit-learn package:

import pandas as pd
from flytekit import Resources, task, workflow
from sklearn.datasets import load_wine
from sklearn.linear_model import LogisticRegression

import flytekit.extras.sklearn

def get_data() -> pd.DataFrame:
    """Get the wine dataset."""
    return load_wine(as_frame=True).frame

Define a training workflow

Then, we define process_data and train_model tasks along with a training_workflow to put all the pieces together for a model-training pipeline.

def process_data(data: pd.DataFrame) -> pd.DataFrame:
    """Simplify the task from a 3-class to a binary classification problem."""
    return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1))

def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression:
    """Train a model on the wine dataset."""
    features = data.drop("target", axis="columns")
    target = data["target"]
    return LogisticRegression(max_iter=5000, **hyperparameters).fit(features, target)

def training_workflow(hyperparameters: dict) -> LogisticRegression:
    """Put all of the steps together into a single workflow."""
    data = get_data()
    processed_data = process_data(data=data)
    return train_model(


Even though you can use a dict type to represent the model’s hyperparameters, we recommend using dataclasses to define a custom Hyperparameter Python object that provides more type information to the Flyte compiler. For example, Flyte uses this type information to auto-generate type-safe launch forms on the Flyte UI. Learn more in the Extending Flyte guide.

Computing predictions

Executing this workflow locally, we can call the model.predict method to make sure we can use our newly trained model to make predictions based on some feature matrix.

model = training_workflow(hyperparameters={"C": 0.01})
X, _ = load_wine(as_frame=True, return_X_y=True)
model.predict(X.sample(10, random_state=41))

Extending your ML workloads

There are many ways to extend your workloads:

🏔 Vertical Scaling

Use the Resources task keyword argument to request additional CPUs, GPUs, and/or memory.

🗺 Horizontal Scaling

With constructs like dynamic() workflows and map_task()s, implement gridsearch, random search, and even bayesian optimization.

🔧 Specialized Tuning Libraries

Use the Ray Integration and leverage tools like Ray Tune for hyperparameter optimization, all orchestrated by Flyte as ephemerally-provisioned Ray clusters.

📦 Ephemeral Cluster Resources

Use the MPI Operator, Kubeflow Tensorflow, Kubeflow Pytorch and more to do distributed training.

🔎 Experiment Tracking

Auto-capture training logs with the mlflow_autolog() decorator, which can be viewed as Flyte Decks with @task(disable_decks=False).

⏩ Inference Acceleration

Serialize your models in ONNX format using the ONNX plugin, which supports ScikitLearn, TensorFlow, and PyTorch.

Learn more

See the Tutorials for more machine learning examples.