Dataclass

Tags: Basic

When you’ve multiple values that you want to send across Flyte entities, you can use a dataclass.

Flytekit uses the Mashumaro library to serialize and deserialize dataclasses.

Important

If you’re using Flytekit version below v1.10, you’ll need to decorate with @dataclass_json using from dataclass_json import dataclass_json instead of inheriting from Mashumaro’s DataClassJSONMixin.

If you’re using Flytekit version >= v1.11.1, you don’t need to decorate with @dataclass_json or inherit from Mashumaro’s DataClassJSONMixin.

Note

To clone and run the example code on this page, see the Flytesnacks repo.

To begin, import the necessary dependencies:

data_types_and_io/dataclass.py
import os
import tempfile
from dataclasses import dataclass

import pandas as pd
from flytekit import task, workflow
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset
from mashumaro.mixins.json import DataClassJSONMixin

Python types

We define a dataclass with int, str and dict as the data types.

data_types_and_io/dataclass.py
@dataclass
class Datum(DataClassJSONMixin):
    x: int
    y: str
    z: dict[int, str]

You can send a dataclass between different tasks written in various languages, and input it through the Flyte console as raw JSON.

Note

All variables in a data class should be annotated with their type. Failure to do should will result in an error.

Once declared, a dataclass can be returned as an output or accepted as an input.

data_types_and_io/dataclass.py
@task
def stringify(s: int) -> Datum:
    """
    A dataclass return will be treated as a single complex JSON return.
    """
    return Datum(x=s, y=str(s), z={s: str(s)})


@task
def add(x: Datum, y: Datum) -> Datum:
    """
    Flytekit automatically converts the provided JSON into a data class.
    If the structures don't match, it triggers a runtime failure.
    """
    x.z.update(y.z)
    return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)

Flyte types

We also define a data class that accepts StructuredDataset, FlyteFile and FlyteDirectory.

data_types_and_io/dataclass.py
@dataclass
class FlyteTypes(DataClassJSONMixin):
    dataframe: StructuredDataset
    file: FlyteFile
    directory: FlyteDirectory


@task
def upload_data() -> FlyteTypes:
    """
    Flytekit will upload FlyteFile, FlyteDirectory and StructuredDataset to the blob store,
    such as GCP or S3.
    """
    # 1. StructuredDataset
    df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

    # 2. FlyteDirectory
    temp_dir = tempfile.mkdtemp(prefix="flyte-")
    df.to_parquet(temp_dir + "/df.parquet")

    # 3. FlyteFile
    file_path = tempfile.NamedTemporaryFile(delete=False)
    file_path.write(b"Hello, World!")

    fs = FlyteTypes(
        dataframe=StructuredDataset(dataframe=df),
        file=FlyteFile(file_path.name),
        directory=FlyteDirectory(temp_dir),
    )
    return fs


@task
def download_data(res: FlyteTypes):
    assert pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}).equals(res.dataframe.open(pd.DataFrame).all())
    f = open(res.file, "r")
    assert f.read() == "Hello, World!"
    assert os.listdir(res.directory) == ["df.parquet"]

A data class supports the usage of data associated with Python types, data classes, flyte file, flyte directory and structured dataset.

We define a workflow that calls the tasks created above.

data_types_and_io/dataclass.py
@workflow
def dataclass_wf(x: int, y: int) -> (Datum, FlyteTypes):
    o1 = add(x=stringify(s=x), y=stringify(s=y))
    o2 = upload_data()
    download_data(res=o2)
    return o1, o2

You can run the workflow locally as follows:

data_types_and_io/dataclass.py
if __name__ == "__main__":
    dataclass_wf(x=10, y=20)

To trigger a task that accepts a dataclass as an input with pyflyte run, you can provide a JSON file as an input:

pyflyte run \
  https://raw.githubusercontent.com/flyteorg/flytesnacks/master/examples/data_types_and_io/data_types_and_io/dataclass.py \
  add --x dataclass_input.json --y dataclass_input.json