Use PyTorch Lightning to Train an MNIST Autoencoder#
This notebook demonstrates how to use Pytorch Lightning with Flyte’s Elastic
task config, which is exposed by the flytekitplugins-kfpytorch
plugin.
First, we import all of the relevant packages.
import os
import lightning as L
from flytekit import ImageSpec, Resources, task, workflow
from flytekit.extras.accelerators import T4
from flytekit.types.directory import FlyteDirectory
from flytekitplugins.kfpytorch.task import Elastic
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
Image and Pod Template Configuration#
For this task, we’re going to use a custom image that has all of the necessary dependencies installed.
custom_image = ImageSpec(
packages=[
"torch",
"torchvision",
"flytekitplugins-kfpytorch",
"kubernetes",
"lightning",
],
# use the cuda and python_version arguments to build a CUDA image
# cuda="12.1.0"
# python_version="3.10"
registry="ghcr.io/flyteorg",
)
Important
Replace ghcr.io/flyteorg
with a container registry you’ve access to publish to.
To upload the image to the local registry in the demo cluster, indicate the
registry as localhost:30000
.
Note
You can activate GPU support by either using the base image that includes
the necessary GPU dependencies or by specifying the cuda
parameter in
the ImageSpec
, for example:
custom_image = ImageSpec(
packages=[...],
cuda="12.1.0",
python_version="3.10",
...
)
Define a LightningModule
#
Then we create a pytorch lightning module, which defines an autoencoder that will learn how to create compressed embeddings of MNIST images.
class MNISTAutoEncoder(L.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
Define a LightningDataModule
#
Then we define a pytorch lightning data module, which defines how to prepare and setup the training data.
class MNISTDataModule(L.LightningDataModule):
def __init__(self, root_dir, batch_size=64, dataloader_num_workers=0):
super().__init__()
self.root_dir = root_dir
self.batch_size = batch_size
self.dataloader_num_workers = dataloader_num_workers
def prepare_data(self):
MNIST(self.root_dir, train=True, download=True)
def setup(self, stage=None):
self.dataset = MNIST(
self.root_dir,
train=True,
download=False,
transform=ToTensor(),
)
def train_dataloader(self):
persistent_workers = self.dataloader_num_workers > 0
return DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.dataloader_num_workers,
persistent_workers=persistent_workers,
pin_memory=True,
shuffle=True,
)
Creating the pytorch Elastic
task#
With the model architecture defined, we now create a Flyte task that assumes
a world size of 16: 2 nodes with 8 devices each. We also set the max_restarts
to 3
so that the task can be retried up to 3 times in case it fails for
whatever reason, and we set rdzv_configs
to have a generous timeout so that
the head and worker nodes have enought time to connect to each other.
This task will output a FlyteDirectory, which will contain the model checkpoint that will result from training.
NUM_NODES = 2
NUM_DEVICES = 8
@task(
container_image=custom_image,
task_config=Elastic(
nnodes=NUM_NODES,
nproc_per_node=NUM_DEVICES,
rdzv_configs={"timeout": 36000, "join_timeout": 36000},
max_restarts=3,
),
accelerator=T4,
requests=Resources(mem="32Gi", cpu="48", gpu="8", ephemeral_storage="100Gi"),
)
def train_model(dataloader_num_workers: int) -> FlyteDirectory:
"""Train an autoencoder model on the MNIST."""
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
autoencoder = MNISTAutoEncoder(encoder, decoder)
root_dir = os.getcwd()
data = MNISTDataModule(
root_dir,
batch_size=4,
dataloader_num_workers=dataloader_num_workers,
)
model_dir = os.path.join(root_dir, "model")
trainer = L.Trainer(
default_root_dir=model_dir,
max_epochs=3,
num_nodes=NUM_NODES,
devices=NUM_DEVICES,
accelerator="gpu",
strategy="ddp",
precision="16-mixed",
)
trainer.fit(model=autoencoder, datamodule=data)
return FlyteDirectory(path=str(model_dir))
Finally, we wrap it all up in a workflow.
@workflow
def train_workflow(dataloader_num_workers: int = 1) -> FlyteDirectory:
return train_model(dataloader_num_workers=dataloader_num_workers)