Machine Learning#

Flyte can handle a full spectrum of machine learning workloads, from training small models to gpu-accelerated deep learning and hyperparameter optimization.

Getting the Data#

In this simple example, we train a binary classification model on the wine dataset thatโ€™s available through the scikit-learn package:

import pandas as pd
from flytekit import Resources, task, workflow
from sklearn.datasets import load_wine
from sklearn.linear_model import LogisticRegression

import flytekit.extras.sklearn


@task(requests=Resources(mem="500Mi"))
def get_data() -> pd.DataFrame:
    """Get the wine dataset."""
    return load_wine(as_frame=True).frame

Define a Training Workflow#

Then, we define process_data and train_model tasks along with a training_workflow to put all the pieces together for a model-training pipeline.

@task
def process_data(data: pd.DataFrame) -> pd.DataFrame:
    """Simplify the task from a 3-class to a binary classification problem."""
    return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1))


@task
def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression:
    """Train a model on the wine dataset."""
    features = data.drop("target", axis="columns")
    target = data["target"]
    return LogisticRegression(max_iter=5000, **hyperparameters).fit(features, target)


@workflow
def training_workflow(hyperparameters: dict) -> LogisticRegression:
    """Put all of the steps together into a single workflow."""
    data = get_data()
    processed_data = process_data(data=data)
    return train_model(
        data=processed_data,
        hyperparameters=hyperparameters,
    )

Important

Even though you can use a dict type to represent the modelโ€™s hyperparameters, we recommend using dataclasses to define a custom Hyperparameter Python object that provides more type information to the Flyte compiler. For example, Flyte uses this type information to auto-generate type-safe launch forms on the Flyte UI. Learn more in the Extending Flyte guide.

Computing Predictions#

Executing this workflow locally, we can call the model.predict method to make sure we can use our newly trained model to make predictions based on some feature matrix.

model = training_workflow(hyperparameters={"C": 0.01})
X, _ = load_wine(as_frame=True, return_X_y=True)
model.predict(X.sample(10, random_state=41))
Getting /tmp/flyte-mo_j6hoa/raw/a9a27735e442054d6f44d389928f612a/407b0e101d5aea050d4849f1e2e3f608.joblib to /tmp/flyte-mo_j6hoa/sandbox/local_flytekit/1cfe21517753a70bb7d4b8cf7b9cca4d
array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1])

Extending your ML Workloads#

There are many ways to extend your workloads:

๐Ÿ” Vertical Scaling

Use the Resources task keyword argument to request additional CPUs, GPUs, and/or memory.

๐Ÿ—บ Horizontal Scaling

With constructs like dynamic() workflows and map_task()s, implement gridsearch, random search, and even bayesian optimization.

๐Ÿ”ง Specialized Tuning Libraries

Use the Ray Integration and leverage tools like Ray Tune for hyperparameter optimization, all orchestrated by Flyte as ephemerally-provisioned Ray clusters.

๐Ÿ“ฆ Ephemeral Cluster Resources

Use the MPI Operator, Sagemaker, Kubeflow Tensorflow, Kubeflow Pytorch and more to do distributed training.

๐Ÿ”Ž Experiment Tracking

Auto-capture training logs with the mlflow_autolog() decorator, which can be viewed as Flyte Decks with @task(disable_decks=False).

โฉ Inference Acceleration

Serialize your models in ONNX format using the ONNX plugin, which supports ScikitLearn, TensorFlow, and PyTorch.

Learn More

See the Tutorials for more machine learning examples.