MLflow Example

MLflow is a platform to streamline machine learning development, including tracking experiments, packaging code into reproducible runs, and sharing and deploying models.

Flyte provides an easy-to-use interface to log the task’s metrics and parameters to either Flyte Deck or MLflow server.

import mlflow.keras
import tensorflow as tf

Let’s first import the libraries.

from flytekit import ImageSpec, Resources, task, workflow
from flytekitplugins.mlflow import mlflow_autolog

custom_image = ImageSpec(registry="ghcr.io/flyteorg", packages=["flytekitplugins-mlflow", "tensorflow"])

Run a model training here and generate metrics and parameters. Add mlflow_autolog to the task, then flyte will automatically log the metric to the Flyte Deck.

@task(enable_deck=True, container_image=custom_image, requests=Resources(mem="3000Mi"))
@mlflow_autolog(framework=mlflow.keras)
def train_model(epochs: int):
    # Refer to https://www.tensorflow.org/tutorials/keras/classification
    fashion_mnist = tf.keras.datasets.fashion_mnist
    (train_images, train_labels), (_, _) = fashion_mnist.load_data()
    train_images = train_images / 255.0

    model = tf.keras.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dense(10),
        ]
    )

    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    model.fit(train_images, train_labels, epochs=epochs)
Model Metrics
Model Parameters

Finally, we put everything together into a workflow:

@workflow
def ml_pipeline(epochs: int):
    train_model(epochs=epochs)


if __name__ == "__main__":
    print(f"Running {__file__} main...")
    ml_pipeline(epochs=5)