PyTorch Example

In this example, we will see how to convert a pytorch model to an ONNX model.

First import the necessary libraries.

from pathlib import Path

import flytekit
import numpy as np
import onnxruntime
import requests
import torch.nn.init as init
import torch.onnx
import torch.utils.model_zoo as model_zoo
import torchvision.transforms as transforms
from flytekit import task, workflow
from flytekit.types.file import JPEGImageFile, ONNXFile
from flytekitplugins.onnxpytorch import PyTorch2ONNX, PyTorch2ONNXConfig
from PIL import Image
from torch import nn
from typing_extensions import Annotated

Define a conv super resolution model.

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
        init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
        init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
        init.orthogonal_(self.conv4.weight)

Define a train task to train the model. Note the annotated output put. This is a special annotation that tells Flytekit that this parameter is to be converted to an ONNX model with the given config.

@task
def train() -> (
    Annotated[
        PyTorch2ONNX,
        PyTorch2ONNXConfig(
            args=torch.randn(1, 1, 224, 224, requires_grad=True),
            export_params=True,  # store the trained parameter weights inside
            opset_version=10,  # the ONNX version to export the model to
            do_constant_folding=True,  # whether to execute constant folding for optimization
            input_names=["input"],  # the model's input names
            output_names=["output"],  # the model's output names
            dynamic_axes={
                "input": {0: "batch_size"},
                "output": {0: "batch_size"},
            },  # variable length axes
        ),
    ]
):
    # Create the super-resolution model by using the above model definition.
    torch_model = SuperResolutionNet(upscale_factor=3)

    # Load pretrained model weights
    model_url = "https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth"

    # Initialize model with the pretrained weights
    map_location = lambda storage, loc: storage  # noqa: E731
    if torch.cuda.is_available():
        map_location = None
    torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

    return PyTorch2ONNX(model=torch_model)

Define an onnx_predict task to generate a super resolution image from the model, given an input image.

@task
def onnx_predict(model_file: ONNXFile) -> JPEGImageFile:
    ort_session = onnxruntime.InferenceSession(model_file.download())

    img = Image.open(
        requests.get(
            "https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/cat.jpg",
            stream=True,
        ).raw
    )

    img = transforms.Resize([224, 224])(img)

    img_ycbcr = img.convert("YCbCr")
    img_y, img_cb, img_cr = img_ycbcr.split()

    to_tensor = transforms.ToTensor()
    img_y = to_tensor(img_y)
    img_y.unsqueeze_(0)

    # compute ONNX Runtime output prediction
    ort_inputs = {
        ort_session.get_inputs()[0].name: img_y.detach().cpu().numpy() if img_y.requires_grad else img_y.cpu().numpy()
    }
    ort_outs = ort_session.run(None, ort_inputs)
    img_out_y = ort_outs[0]

    img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")

    # get the output image follow post-processing step from PyTorch implementation
    final_img = Image.merge(
        "YCbCr",
        [
            img_out_y,
            img_cb.resize(img_out_y.size, Image.BICUBIC),
            img_cr.resize(img_out_y.size, Image.BICUBIC),
        ],
    ).convert("RGB")

    img_path = Path(flytekit.current_context().working_directory) / "cat_superres_with_ort.jpg"
    final_img.save(img_path)

    return JPEGImageFile(path=str(img_path))

Define a workflow to run the above tasks.

@workflow
def wf() -> JPEGImageFile:
    model = train()
    return onnx_predict(model_file=model)

Run the workflow locally.

if __name__ == "__main__":
    print(f"Prediction: {wf()}")