Machine Learning#
Flyte can handle a full spectrum of machine learning workloads, from training small models to gpu-accelerated deep learning and hyperparameter optimization.
Getting the Data#
In this simple example, we train a binary classification model on the
wine dataset
thatโs available through the scikit-learn
package:
import pandas as pd
from flytekit import Resources, task, workflow
from sklearn.datasets import load_wine
from sklearn.linear_model import LogisticRegression
import flytekit.extras.sklearn
@task(requests=Resources(mem="500Mi"))
def get_data() -> pd.DataFrame:
"""Get the wine dataset."""
return load_wine(as_frame=True).frame
Define a Training Workflow#
Then, we define process_data
and train_model
tasks along with a
training_workflow
to put all the pieces together for a model-training
pipeline.
@task
def process_data(data: pd.DataFrame) -> pd.DataFrame:
"""Simplify the task from a 3-class to a binary classification problem."""
return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1))
@task
def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression:
"""Train a model on the wine dataset."""
features = data.drop("target", axis="columns")
target = data["target"]
return LogisticRegression(max_iter=5000, **hyperparameters).fit(features, target)
@workflow
def training_workflow(hyperparameters: dict) -> LogisticRegression:
"""Put all of the steps together into a single workflow."""
data = get_data()
processed_data = process_data(data=data)
return train_model(
data=processed_data,
hyperparameters=hyperparameters,
)
Important
Even though you can use a dict
type to represent the modelโs hyperparameters,
we recommend using dataclasses to define a custom
Hyperparameter
Python object that provides more type information to the Flyte
compiler. For example, Flyte uses this type information to auto-generate
type-safe launch forms on the Flyte UI. Learn more in the
Extending Flyte guide.
Computing Predictions#
Executing this workflow locally, we can call the model.predict
method to make
sure we can use our newly trained model to make predictions based on some
feature matrix.
model = training_workflow(hyperparameters={"C": 0.01})
X, _ = load_wine(as_frame=True, return_X_y=True)
model.predict(X.sample(10, random_state=41))
array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1])
Extending your ML Workloads#
There are many ways to extend your workloads:
๐ Vertical Scaling |
Use the |
๐บ Horizontal Scaling |
With constructs like |
๐ง Specialized Tuning Libraries |
Use the Ray Integration and leverage tools like Ray Tune for hyperparameter optimization, all orchestrated by Flyte as ephemerally-provisioned Ray clusters. |
๐ฆ Ephemeral Cluster Resources |
Use the MPI Operator, Sagemaker, Kubeflow Tensorflow, Kubeflow Pytorch and more to do distributed training. |
๐ Experiment Tracking |
Auto-capture training logs with the |
โฉ Inference Acceleration |
Serialize your models in ONNX format using the ONNX plugin, which supports ScikitLearn, TensorFlow, and PyTorch. |
Learn More
See the Tutorials for more machine learning examples.