Machine Learning#

Flyte can handle a full spectrum of machine learning workloads, from training small models to gpu-accelerated deep learning and hyperparameter optimization.

Getting the Data#

In this simple example, we train a binary classification model on the wine dataset that’s available through the scikit-learn package:

import pandas as pd
from flytekit import Resources, task, workflow
from sklearn.datasets import load_wine
from sklearn.linear_model import LogisticRegression

import flytekit.extras.sklearn

def get_data() -> pd.DataFrame:
    """Get the wine dataset."""
    return load_wine(as_frame=True).frame
2023-03-23 20:37:31.074725: I tensorflow/core/platform/] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-23 20:37:31.258677: E tensorflow/stream_executor/cuda/] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Define a Training Workflow#

Then, we define process_data and train_model tasks along with a training_workflow to put all the pieces together for a model-training pipeline.

def process_data(data: pd.DataFrame) -> pd.DataFrame:
    """Simplify the task from a 3-class to a binary classification problem."""
    return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1))

def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression:
    """Train a model on the wine dataset."""
    features = data.drop("target", axis="columns")
    target = data["target"]
    return LogisticRegression(max_iter=5000, **hyperparameters).fit(features, target)

def training_workflow(hyperparameters: dict) -> LogisticRegression:
    """Put all of the steps together into a single workflow."""
    data = get_data()
    processed_data = process_data(data=data)
    return train_model(


Even though you can use a dict type to represent the model’s hyperparameters, we recommend using dataclasses to define a custom Hyperparameter Python object that provides more type information to the Flyte compiler. For example, Flyte uses this type information to auto-generate type-safe launch forms on the Flyte UI. Learn more in the Extending Flyte guide.

Computing Predictions#

Executing this workflow locally, we can call the model.predict method to make sure we can use our newly trained model to make predictions based on some feature matrix.

model = training_workflow(hyperparameters={"C": 0.01})
X, _ = load_wine(as_frame=True, return_X_y=True)
model.predict(X.sample(10, random_state=41))
array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1])

Extending your ML Workloads#

There are many ways to extend your workloads:

🏔 Vertical Scaling

Use the Resources task keyword argument to request additional CPUs, GPUs, and/or memory.

🗺 Horizontal Scaling

With constructs like dynamic() workflows and map_task()s, implement gridsearch, random search, and even bayesian optimization.

🔧 Specialized Tuning Libraries

Use the Ray Integration and leverage tools like Ray Tune for hyperparameter optimization, all orchestrated by Flyte as ephemerally-provisioned Ray clusters.

📦 Ephemeral Cluster Resources

Use the MPI Operator, Sagemaker, Kubeflow Tensorflow, Kubeflow Pytorch and more to do distributed training.

🔎 Experiment Tracking

Auto-capture training logs with the mlflow_autolog() decorator, which can be viewed as Flyte Decks with @task(disable_decks=False).

⏩ Inference Acceleration

Serialize your models in ONNX format using the ONNX plugin, which supports ScikitLearn, TensorFlow, and PyTorch.

Learn More

See the Tutorials for more machine learning examples.