Use PyTorch Lightning to Train an MNIST Autoencoder

This notebook demonstrates how to use Pytorch Lightning with Flyte’s Elastic task config, which is exposed by the flytekitplugins-kfpytorch plugin.

First, we import all of the relevant packages.

import os

import lightning as L
from flytekit import ImageSpec, PodTemplate, Resources, task, workflow
from flytekit.extras.accelerators import T4
from flytekit.types.directory import FlyteDirectory
from flytekitplugins.kfpytorch.task import Elastic
from kubernetes.client.models import (
    V1Container,
    V1EmptyDirVolumeSource,
    V1PodSpec,
    V1Volume,
    V1VolumeMount,
)
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

Image and Pod Template Configuration

For this task, we’re going to use a custom image that has all of the necessary dependencies installed.

custom_image = ImageSpec(
    packages=[
        "torch",
        "torchvision",
        "flytekitplugins-kfpytorch",
        "kubernetes",
        "lightning",
    ],
    # use the cuda and python_version arguments to build a CUDA image
    # cuda="12.1.0"
    # python_version="3.10"
    registry="ghcr.io/flyteorg",
)

Important

Replace ghcr.io/flyteorg with a container registry you’ve access to publish to. To upload the image to the local registry in the demo cluster, indicate the registry as localhost:30000.

Note

You can activate GPU support by either using the base image that includes the necessary GPU dependencies or by specifying the cuda parameter in the ImageSpec, for example:

custom_image = ImageSpec(
    packages=[...],
    cuda="12.1.0",
    python_version="3.10",
    ...
)

We’re also going to define a custom pod template that mounts a shared memory volume to /dev/shm. This is necessary for distributed data parallel (DDP) training so that state can be shared across workers.

container = V1Container(name=custom_image.name, volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")])
volume = V1Volume(name="dshm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))
custom_pod_template = PodTemplate(
    primary_container_name=custom_image.name,
    pod_spec=V1PodSpec(containers=[container], volumes=[volume]),
)

Define a LightningModule

Then we create a pytorch lightning module, which defines an autoencoder that will learn how to create compressed embeddings of MNIST images.

class MNISTAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Define a LightningDataModule

Then we define a pytorch lightning data module, which defines how to prepare and setup the training data.

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, root_dir, batch_size=64, dataloader_num_workers=0):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.dataloader_num_workers = dataloader_num_workers

    def prepare_data(self):
        MNIST(self.root_dir, train=True, download=True)

    def setup(self, stage=None):
        self.dataset = MNIST(
            self.root_dir,
            train=True,
            download=False,
            transform=ToTensor(),
        )

    def train_dataloader(self):
        persistent_workers = self.dataloader_num_workers > 0
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            num_workers=self.dataloader_num_workers,
            persistent_workers=persistent_workers,
            pin_memory=True,
            shuffle=True,
        )

Creating the pytorch Elastic task

With the model architecture defined, we now create a Flyte task that assumes a world size of 16: 2 nodes with 8 devices each. We also set the max_restarts to 3 so that the task can be retried up to 3 times in case it fails for whatever reason, and we set rdzv_configs to have a generous timeout so that the head and worker nodes have enought time to connect to each other.

This task will output a FlyteDirectory, which will contain the model checkpoint that will result from training.

NUM_NODES = 2
NUM_DEVICES = 8


@task(
    container_image=custom_image,
    task_config=Elastic(
        nnodes=NUM_NODES,
        nproc_per_node=NUM_DEVICES,
        rdzv_configs={"timeout": 36000, "join_timeout": 36000},
        max_restarts=3,
    ),
    accelerator=T4,
    requests=Resources(mem="32Gi", cpu="48", gpu="8", ephemeral_storage="100Gi"),
    pod_template=custom_pod_template,
)
def train_model(dataloader_num_workers: int) -> FlyteDirectory:
    """Train an autoencoder model on the MNIST."""

    encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
    decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
    autoencoder = MNISTAutoEncoder(encoder, decoder)

    root_dir = os.getcwd()
    data = MNISTDataModule(
        root_dir,
        batch_size=4,
        dataloader_num_workers=dataloader_num_workers,
    )

    model_dir = os.path.join(root_dir, "model")
    trainer = L.Trainer(
        default_root_dir=model_dir,
        max_epochs=3,
        num_nodes=NUM_NODES,
        devices=NUM_DEVICES,
        accelerator="gpu",
        strategy="ddp",
        precision="16-mixed",
    )
    trainer.fit(model=autoencoder, datamodule=data)
    return FlyteDirectory(path=str(model_dir))

Finally, we wrap it all up in a workflow.

@workflow
def train_workflow(dataloader_num_workers: int = 1) -> FlyteDirectory:
    return train_model(dataloader_num_workers=dataloader_num_workers)