Machine learning¶
Flyte can handle a full spectrum of machine learning workloads, from training small models to gpu-accelerated deep learning and hyperparameter optimization.
Getting the data¶
In this simple example, we train a binary classification model on the
wine dataset
that is available through the scikit-learn
package:
import pandas as pd
from flytekit import Resources, task, workflow
from sklearn.datasets import load_wine
from sklearn.linear_model import LogisticRegression
import flytekit.extras.sklearn
@task(requests=Resources(mem="500Mi"))
def get_data() -> pd.DataFrame:
"""Get the wine dataset."""
return load_wine(as_frame=True).frame
Define a training workflow¶
Then, we define process_data
and train_model
tasks along with a
training_workflow
to put all the pieces together for a model-training
pipeline.
@task
def process_data(data: pd.DataFrame) -> pd.DataFrame:
"""Simplify the task from a 3-class to a binary classification problem."""
return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1))
@task
def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression:
"""Train a model on the wine dataset."""
features = data.drop("target", axis="columns")
target = data["target"]
return LogisticRegression(max_iter=5000, **hyperparameters).fit(features, target)
@workflow
def training_workflow(hyperparameters: dict) -> LogisticRegression:
"""Put all of the steps together into a single workflow."""
data = get_data()
processed_data = process_data(data=data)
return train_model(
data=processed_data,
hyperparameters=hyperparameters,
)
Important
Even though you can use a dict
type to represent the model’s hyperparameters,
we recommend using dataclasses to define a custom
Hyperparameter
Python object that provides more type information to the Flyte
compiler. For example, Flyte uses this type information to auto-generate
type-safe launch forms on the Flyte UI. Learn more in the
Extending Flyte guide.
Computing predictions¶
Executing this workflow locally, we can call the model.predict
method to make
sure we can use our newly trained model to make predictions based on some
feature matrix.
model = training_workflow(hyperparameters={"C": 0.01})
X, _ = load_wine(as_frame=True, return_X_y=True)
model.predict(X.sample(10, random_state=41))
Extending your ML workloads¶
There are many ways to extend your workloads:
🏔 Vertical Scaling |
Use the |
🗺 Horizontal Scaling |
With constructs like |
🔧 Specialized Tuning Libraries |
Use the Ray Integration and leverage tools like Ray Tune for hyperparameter optimization, all orchestrated by Flyte as ephemerally-provisioned Ray clusters. |
📦 Ephemeral Cluster Resources |
Use the MPI Operator, Kubeflow Tensorflow, Kubeflow Pytorch and more to do distributed training. |
🔎 Experiment Tracking |
Auto-capture training logs with the |
⏩ Inference Acceleration |
Serialize your models in ONNX format using the ONNX plugin, which supports ScikitLearn, TensorFlow, and PyTorch. |
Learn more
See the Tutorials for more machine learning examples.