Note
Go to the end to download the full example code
Distributed Pytorch on Sagemaker#
This example is adapted from this sagemaker example:
It shows how distributed training can be completely performed on the user side with minimal changes using Flyte.
Note
Flytekit will be adding further simplifications to make writing a distributed training algorithm even simpler, but this example basically provides the full details.
import logging
import os
import typing
from dataclasses import dataclass
import flytekit
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim
from dataclasses_json import dataclass_json
from flytekit import task, workflow
from flytekit.types.file import PythonPickledFile
from flytekitplugins.awssagemaker import (
AlgorithmName,
AlgorithmSpecification,
InputContentType,
InputMode,
SagemakerTrainingJobConfig,
TrainingJobResourceConfig,
)
from torchvision import datasets, transforms
@dataclass_json
@dataclass
class Hyperparameters(object):
"""
Args:
batch_size: input batch size for training (default: 64)
test_batch_size: input batch size for testing (default: 1000)
epochs: number of epochs to train (default: 10)
learning_rate: learning rate (default: 0.01)
sgd_momentum: SGD momentum (default: 0.5)
seed: random seed (default: 1)
log_interval: how many batches to wait before logging training status
dir: directory where summary logs are stored
"""
backend: str = dist.Backend.GLOO
sgd_momentum: float = 0.5
seed: int = 1
log_interval: int = 10
batch_size: int = 64
test_batch_size: int = 1000
epochs: int = 10
learning_rate: float = 0.01
@dataclass
class TrainingArgs(Hyperparameters):
"""
These are training arguments that contain additional metadata beyond the hyper parameters useful especially in
distributed training:
"""
hosts: typing.List[int] = None
current_host: int = 0
num_gpus: int = 0
data_dir: str = "/tmp"
model_dir: str = "/tmp"
def is_distributed(self) -> bool:
return len(self.hosts) > 1 and self.backend is not None
# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = functional.relu(functional.max_pool2d(self.conv1(x), 2))
x = functional.relu(functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = functional.relu(self.fc1(x))
x = functional.dropout(x, training=self.training)
x = self.fc2(x)
return functional.log_softmax(x, dim=1)
def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs):
logging.info("Get train data loader")
dataset = datasets.MNIST(
training_dir,
train=True,
download=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
logging.info("Dataset is downloaded. Creating a train_sampler")
train_sampler = (
torch.utils.data.distributed.DistributedSampler(dataset)
if is_distributed
else None
)
logging.info("Train_sampler is successfully created. Creating a DataLoader")
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=train_sampler is None,
sampler=train_sampler,
**kwargs,
)
def _get_test_data_loader(test_batch_size, training_dir, **kwargs):
logging.info("Get test data loader")
return torch.utils.data.DataLoader(
datasets.MNIST(
training_dir,
train=False,
download=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=test_batch_size,
shuffle=True,
**kwargs,
)
def _average_gradients(model):
# Gradient averaging.
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
param.grad.data /= size
def configure_model(model, is_distributed, gpu):
if is_distributed:
# multi-machine multi-gpu case
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[gpu], output_device=gpu
)
else:
# single-machine multi-gpu case or single-machine or multi-machine cpu case
model = torch.nn.DataParallel(model)
return model
The Actual Trainer
def train(gpu: int, args: TrainingArgs):
logging.basicConfig(level="INFO")
is_distributed = args.is_distributed()
logging.warning("Distributed training - {}".format(is_distributed))
use_cuda = args.num_gpus > 0
logging.warning("Number of gpus available - {}".format(args.num_gpus))
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
device = torch.device("cuda" if use_cuda else "cpu")
rank = 0
if is_distributed:
# Initialize the distributed environment
world_size = len(args.hosts) * args.num_gpus
os.environ["WORLD_SIZE"] = str(world_size)
rank = args.hosts.index(args.current_host) * args.num_gpus + gpu
os.environ["RANK"] = str(rank)
dist.init_process_group(
backend=args.backend, init_method="env://", rank=rank, world_size=world_size
)
logging.info(
"Initialized the distributed environment: '{}' backend on {} nodes. ".format(
args.backend, dist.get_world_size()
)
+ "Current host rank is {}. Number of gpus: {}".format(
dist.get_rank(), args.num_gpus
)
)
torch.cuda.set_device(gpu)
# set the seed for generating random numbers
torch.manual_seed(args.seed)
if use_cuda:
torch.cuda.manual_seed(args.seed)
train_loader = _get_train_data_loader(
args.batch_size, args.data_dir, is_distributed, **kwargs
)
test_loader = _get_test_data_loader(args.test_batch_size, args.data_dir, **kwargs)
logging.info(
"Processes {}/{} ({:.0f}%) of train data".format(
len(train_loader.sampler),
len(train_loader.dataset),
100.0 * len(train_loader.sampler) / len(train_loader.dataset),
)
)
logging.info(
"Processes {}/{} ({:.0f}%) of test data".format(
len(test_loader.sampler),
len(test_loader.dataset),
100.0 * len(test_loader.sampler) / len(test_loader.dataset),
)
)
model = Net().to(device)
model = configure_model(model, is_distributed, gpu)
optimizer = optim.SGD(
model.parameters(), lr=args.learning_rate, momentum=args.sgd_momentum
)
logging.info(
"[rank {}|local-rank {}] Totally {} epochs".format(rank, gpu, args.epochs + 1)
)
for epoch in range(1, args.epochs + 1):
model.train()
for batch_idx, (data, target) in enumerate(train_loader, 1):
if use_cuda:
data, target = (
data.cuda(non_blocking=True),
target.cuda(non_blocking=True),
)
optimizer.zero_grad()
output = model(data)
loss = functional.nll_loss(output, target)
loss.backward()
if is_distributed and not use_cuda:
# average gradients manually for multi-machine cpu case only
_average_gradients(model)
optimizer.step()
if batch_idx % args.log_interval == 0:
if not is_distributed or (is_distributed and rank == 0):
logging.info(
"[rank {}|local-rank {}] Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format(
rank,
gpu,
epoch,
batch_idx * len(data),
len(train_loader.sampler),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
test(model, test_loader, device)
if not is_distributed or (is_distributed and rank == 0):
save_model(model, args.model_dir)
Lets test the trained model
def test(model, test_loader, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
if device.type == "cuda":
data, target = (
data.cuda(non_blocking=True),
target.cuda(non_blocking=True),
)
output = model(data)
test_loss += functional.nll_loss(
output, target, size_average=False
).item() # sum up batch loss
pred = output.max(1, keepdim=True)[
1
] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
logging.info(
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)
def model_fn(model_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.nn.DataParallel(Net())
with open(os.path.join(model_dir, "model.pth"), "rb") as f:
model.load_state_dict(torch.load(f))
return model.to(device)
Save the model to a local path
def save_model(model, model_dir) -> PythonPickledFile:
logging.info("Saving the model.")
path = os.path.join(model_dir, "model.pth")
# recommended way from http://pytorch.org/docs/master/notes/serialization.html
torch.save(model.cpu().state_dict(), path)
print(f"Model saved to {path}")
return path
def download_training_data(training_dir):
logging.info("Downloading train data")
datasets.MNIST(
training_dir,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
def download_test_data(training_dir):
logging.info("Downloading test data")
datasets.MNIST(
training_dir,
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
# https://github.com/aws/amazon-sagemaker-examples/blob/89831fcf99ea3110f52794db0f6433a4013a5bca/sagemaker-python-sdk/pytorch_mnist/mnist.py
@task(
task_config=SagemakerTrainingJobConfig(
algorithm_specification=AlgorithmSpecification(
input_mode=InputMode.FILE,
algorithm_name=AlgorithmName.CUSTOM,
algorithm_version="",
input_content_type=InputContentType.TEXT_CSV,
),
training_job_resource_config=TrainingJobResourceConfig(
instance_type="ml.p3.8xlarge",
instance_count=2,
volume_size_in_gb=25,
),
),
cache_version="1.0",
cache=True,
)
def mnist_pytorch_job(hp: Hyperparameters) -> PythonPickledFile:
# pytorch's save() function does not create a path if the path specified does not exist
# therefore we must pass an existing path
ctx = flytekit.current_context()
data_dir = os.path.join(ctx.working_directory, "data")
model_dir = os.path.join(ctx.working_directory, "model")
os.makedirs(data_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
args = TrainingArgs(
hosts=ctx.distributed_training_context.hosts,
current_host=ctx.distributed_training_context.current_host,
num_gpus=torch.cuda.device_count(),
batch_size=hp.batch_size,
test_batch_size=hp.test_batch_size,
epochs=hp.epochs,
learning_rate=hp.learning_rate,
sgd_momentum=hp.sgd_momentum,
seed=hp.seed,
log_interval=hp.log_interval,
backend=hp.backend,
data_dir=data_dir,
model_dir=model_dir,
)
# Data shouldn't be downloaded by the functions called in mp.spawn due to race conditions
# These can be replaced by Flyte's blob type inputs. Note that the data here are assumed
# to be accessible via a local path:
download_training_data(args.data_dir)
download_test_data(args.data_dir)
if len(args.hosts) > 1:
# Config MASTER_ADDR and MASTER_PORT for PyTorch Distributed Training
os.environ["MASTER_ADDR"] = args.hosts[0]
os.environ["MASTER_PORT"] = "29500"
os.environ[
"NCCL_SOCKET_IFNAME"
] = ctx.distributed_training_context.network_interface_name
os.environ["NCCL_DEBUG"] = "INFO"
# The function is called as fn(i, *args), where i is the process index and args is the passed
# through tuple of arguments.
# https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
mp.spawn(train, nprocs=args.num_gpus, args=(args,))
else:
# Config for Multi GPU with a single instance training
if args.num_gpus > 1:
gpu_devices = ",".join([str(gpu_id) for gpu_id in range(args.num_gpus)])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
train(-1, args)
pth = os.path.join(model_dir, "model.pth")
print(f"Returning model @ {pth}")
return pth
Create a Pipeline#
Now the training and the plotting can be put together into a pipeline, where the training is performed first followed by the plotting of the accuracy. Data is passed between them and the workflow itself outputs the image and the serialize model:
@workflow
def pytorch_training_wf(hp: Hyperparameters) -> PythonPickledFile:
return mnist_pytorch_job(hp=hp)
if __name__ == "__main__":
model = pytorch_training_wf(hp=Hyperparameters(epochs=2, batch_size=128))
print(f"Model: {model}")
Total running time of the script: ( 0 minutes 0.000 seconds)