Run Distributed TensorFlow Training#

When you need to scale up model training using TensorFlow, you can utilize the Strategy to distribute your training across multiple devices. Several strategies are available within this API, all of which can be employed as needed. In this example, we employ the MirroredStrategy to train an MNIST model using a CNN.

The MirroredStrategy enables synchronous distributed training across multiple GPUs on a single machine. For a deeper understanding of distributed training with TensorFlow, refer to the distributed training with TensorFlow in the TensorFlow documentation.

To begin, load the libraries.

import os
from dataclasses import dataclass
from typing import NamedTuple, Tuple

from dataclasses_json import dataclass_json
from flytekit import ImageSpec, Resources, task, workflow
from import FlyteDirectory

Create an ImageSpec to encompass all the dependencies needed for the TensorFlow task.

custom_image = ImageSpec(
    packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],


Replace with a container registry you’ve access to publish to. To upload the image to the local registry in the demo cluster, indicate the registry as localhost:30000.

The following imports are required to configure the TensorFlow cluster in Flyte. You can load them on demand.

if custom_image.is_container():
    import tensorflow as tf
    import tensorflow_datasets as tfds
    from flytekitplugins.kftensorflow import PS, Chief, TfJob, Worker

You can activate GPU support by either using the base image that includes the necessary GPU dependencies or by initializing the CUDA parameters within the ImageSpec.

For this example, we define the MODEL_FILE_PATH variable to indicate the storage location for the model file.

MODEL_FILE_PATH = "saved_model/"

We initialize a data class to store the hyperparameters.

class Hyperparameters(object):
    batch_size_per_replica: int = 64
    buffer_size: int = 10000
    epochs: int = 10

We use the MNIST dataset to train our model.

def load_data(
    hyperparameters: Hyperparameters,
) -> Tuple[,, tf.distribute.Strategy]:
    datasets, _ = tfds.load(name="mnist", with_info=True, as_supervised=True)
    mnist_train, mnist_test = datasets["train"], datasets["test"]

    strategy = tf.distribute.MirroredStrategy()
    print("Number of devices: {}".format(strategy.num_replicas_in_sync))

    # strategy.num_replicas_in_sync returns the number of replicas; helpful to utilize the extra compute power by increasing the batch size
    BATCH_SIZE = hyperparameters.batch_size_per_replica * strategy.num_replicas_in_sync

    def scale(image, label):
        image = tf.cast(image, tf.float32)
        image /= 255

        return image, label

    # Fetch train and evaluation datasets
    train_dataset =
    eval_dataset =

    return train_dataset, eval_dataset, strategy

We create and compile a model in the context of Strategy.scope.

def get_compiled_model(strategy: tf.distribute.Strategy) -> tf.keras.Model:
    with strategy.scope():
        model = tf.keras.Sequential(
                tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
                tf.keras.layers.Dense(64, activation="relu"),


    return model

We define a function for decaying the learning rate.

def decay(epoch: int):
    if epoch < 3:
        return 1e-3
    elif epoch >= 3 and epoch < 7:
        return 1e-4
        return 1e-5

We define the train_model function to initiate model training with three callbacks:

  • TensorBoard to log the training metrics

  • ModelCheckpoint to save the model after every epoch

  • LearningRateScheduler to decay the learning rate

def train_model(
    model: tf.keras.Model,
    hyperparameters: Hyperparameters,
) -> Tuple[tf.keras.Model, str]:
    # Define the checkpoint directory to store checkpoints
    checkpoint_dir = "./training_checkpoints"

    # Define the name of the checkpoint files
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

    # Define a callback for printing the learning rate at the end of each epoch
    class PrintLR(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            print("\nLearning rate for epoch {} is {}".format(epoch + 1,

    # Put all the callbacks together
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True),

    # Train the model, epochs=hyperparameters.epochs, callbacks=callbacks)

    # Save the model, save_format="tf")

    return model, checkpoint_dir

We define the test_model function to evaluate loss and accuracy on the test dataset.

def test_model(model: tf.keras.Model, checkpoint_dir: str, eval_dataset: -> Tuple[float, float]:

    eval_loss, eval_acc = model.evaluate(eval_dataset)

    return eval_loss, eval_acc

To create a TensorFlow task, add TfJob config to the Flyte task.

training_outputs = NamedTuple("TrainingOutputs", accuracy=float, loss=float, model_state=FlyteDirectory)

if os.getenv("SANDBOX") != "":
    resources = Resources(gpu="0", mem="1000Mi", storage="500Mi", ephemeral_storage="500Mi")
    resources = Resources(gpu="2", mem="10Gi", storage="10Gi", ephemeral_storage="500Mi")

    task_config=TfJob(worker=Worker(replicas=1), ps=PS(replicas=1), chief=Chief(replicas=1)),
def mnist_tensorflow_job(hyperparameters: Hyperparameters) -> training_outputs:
    train_dataset, eval_dataset, strategy = load_data(hyperparameters=hyperparameters)
    model = get_compiled_model(strategy=strategy)
    model, checkpoint_dir = train_model(model=model, train_dataset=train_dataset, hyperparameters=hyperparameters)
    eval_loss, eval_accuracy = test_model(model=model, checkpoint_dir=checkpoint_dir, eval_dataset=eval_dataset)
    return training_outputs(accuracy=eval_accuracy, loss=eval_loss, model_state=MODEL_FILE_PATH)

The task is initiated using TFJob with specific values configured:

  • num_workers: specifies the number of worker replicas to be launched in the cluster for this job

  • num_ps_replicas: determines the count of parameter server replicas to utilize

  • num_chief_replicas: defines the number of chief replicas to be employed

For our example, with MirroredStrategy leveraging an all-reduce algorithm to communicate variable updates across devices, the parameter num_ps_replicas does not hold significance.


If you’re interested in exploring the diverse TensorFlow strategies available for distributed training, you can find comprehensive information in the types of strategies section of the TensorFlow documentation.

Lastly, define a workflow to invoke the tasks.

def mnist_tensorflow_workflow(
    hyperparameters: Hyperparameters = Hyperparameters(batch_size_per_replica=64),
) -> training_outputs:
    return mnist_tensorflow_job(hyperparameters=hyperparameters)

You can run the code locally.

if __name__ == "__main__":


In the context of distributed training, it’s important to acknowledge that return values from various workers could potentially vary. If you need to regulate which worker’s return value gets passed on to subsequent tasks in the workflow, you have the option to raise an IgnoreOutputs exception for all remaining ranks.