AWS Sagemaker Pytorch#

Tags: Integration, MachineLearning, AWS, Advanced

This plugin shows an example of using Sagemaker custom training, with Pytorch distributed training.

Installation#

To use the Flytekit AWS Sagemaker plugin, simply run the following:

pip install flytekitplugins-awssagemaker

Creating a Dockerfile for Sagemaker Custom Training [Required]#

The dockerfile for Sagemaker custom training is similar to any regular dockerfile, except for the difference in using the Nvidia cuda base to use GPU’s.

Note

If using CPU for training, then the special dockerfile is NOT REQUIRED. If GPU or TPUs are required, the dockerfile differs only in the driver setup. The following dockerfile is enabled for GPU accelerated training using CUDA. The checked in version of docker file uses python:3.8-slim-buster for faster CI, but you can use the Dockerfile pasted below which uses cuda base. Additionally, the requirements.in uses the cpu version of pytorch. Remove the + cpu for torch and torchvision in requirements.in and make all requirements as shown below:

make -C integrations/aws/sagemaker_pytorch requirements
 1FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel
 2LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks
 3
 4WORKDIR /root
 5ENV LANG C.UTF-8
 6ENV LC_ALL C.UTF-8
 7ENV PYTHONPATH /root
 8
 9# Install the AWS cli separately to prevent issues with boto being written over
10RUN pip install awscli
11
12ENV VENV /opt/venv
13# Virtual environment
14RUN python3 -m venv ${VENV}
15ENV PATH="${VENV}/bin:$PATH"
16
17# Install Python dependencies
18COPY sagemaker_pytorch/requirements.txt /root/.
19RUN pip install -r /root/requirements.txt
20
21# Setup Sagemaker entrypoints
22ENV SAGEMAKER_PROGRAM /opt/venv/bin/flytekit_sagemaker_runner.py
23
24# Copy the makefile targets to expose on the container. This makes it easier to register.
25COPY in_container.mk /root/Makefile
26COPY sagemaker_pytorch/sandbox.config /root
27
28# Copy the actual code
29COPY sagemaker_pytorch/ /root/sagemaker_pytorch
30
31# This tag is supplied by the build script and will be used to determine the version
32# when registering tasks, workflows, and launch plans
33ARG tag
34ENV FLYTE_INTERNAL_IMAGE $tag