Decorating workflows

Tags: Intermediate

The behavior of workflows can be modified in a light-weight fashion by using the built-in wraps() decorator pattern, similar to using decorators to customize task behavior. However, unlike in the case of tasks, we need to do a little extra work to make sure that the DAG underlying the workflow executes tasks in the correct order.

Setup-teardown pattern

The main use case of decorating @workflow-decorated functions is to establish a setup-teardown pattern to execute task before and after your main workflow logic. This is useful when integrating with other external services like wandb or clearml, which enable you to track metrics of model training runs.

Note

To clone and run the example code on this page, see the Flytesnacks repo.

To begin, import the necessary libraries.

advanced_composition/decorating_workflows.py
from functools import partial, wraps
from unittest.mock import MagicMock

import flytekit
from flytekit import FlyteContextManager, task, workflow
from flytekit.core.node_creation import create_node

Let’s define the tasks we need for setup and teardown. In this example, we use the unittest.mock.MagicMock class to create a fake external service that we want to initialize at the beginning of our workflow and finish at the end.

advanced_composition/decorating_workflows.py
external_service = MagicMock()


@task
def setup():
    print("initializing external service")
    external_service.initialize(id=flytekit.current_context().execution_id)


@task
def teardown():
    print("finish external service")
    external_service.complete(id=flytekit.current_context().execution_id)

As you can see, you can even use Flytekit’s current context to access the execution_id of the current workflow if you need to link Flyte with the external service so that you reference the same unique identifier in both the external service and Flyte.

Workflow decorator

We create a decorator that we want to use to wrap our workflow function.

advanced_composition/decorating_workflows.py
def setup_teardown(fn=None, *, before, after):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        # get the current flyte context to obtain access to the compilation state of the workflow DAG.
        ctx = FlyteContextManager.current_context()

        # defines before node
        before_node = create_node(before)
        # ctx.compilation_state.nodes == [before_node]

        # under the hood, flytekit compiler defines and threads
        # together nodes within the `my_workflow` function body
        outputs = fn(*args, **kwargs)
        # ctx.compilation_state.nodes == [before_node, *nodes_created_by_fn]

        # defines the after node
        after_node = create_node(after)
        # ctx.compilation_state.nodes == [before_node, *nodes_created_by_fn, after_node]

        # compile the workflow correctly by making sure `before_node`
        # runs before the first workflow node and `after_node`
        # runs after the last workflow node.
        if ctx.compilation_state is not None:
            # ctx.compilation_state.nodes is a list of nodes defined in the
            # order of execution above
            workflow_node0 = ctx.compilation_state.nodes[1]
            workflow_node1 = ctx.compilation_state.nodes[-2]
            before_node >> workflow_node0
            workflow_node1 >> after_node
        return outputs

    if fn is None:
        return partial(setup_teardown, before=before, after=after)

    return wrapper

There are a few key pieces to note in the setup_teardown decorator above:

  1. It takes a before and after argument, both of which need to be @task-decorated functions. These tasks will run before and after the main workflow function body.

  2. The create_node function to create nodes associated with the before and after tasks.

  3. When fn is called, under the hood Flytekit creates all the nodes associated with the workflow function body

  4. The code within the if ctx.compilation_state is not None: conditional is executed at compile time, which is where we extract the first and last nodes associated with the workflow function body at index 1 and -2.

  5. The >> right shift operator ensures that before_node executes before the first node and after_node executes after the last node of the main workflow function body.

Defining the DAG

We define two tasks that will constitute the workflow.

advanced_composition/decorating_workflows.py
@task
def t1(x: float) -> float:
    return x - 1


@task
def t2(x: float) -> float:
    return x**2

And then create our decorated workflow:

advanced_composition/decorating_workflows.py
@workflow
@setup_teardown(before=setup, after=teardown)
def decorating_workflow(x: float) -> float:
    return t2(x=t1(x=x))


# Run the example locally
if __name__ == "__main__":
    print(decorating_workflow(x=10.0))

Run the example on the Flyte cluster

To run the provided workflow on the Flyte cluster, use the following command:

pyflyte run --remote \
  https://raw.githubusercontent.com/flyteorg/flytesnacks/master/examples/advanced_composition/advanced_composition/decorating_workflows.py \
  decorating_workflow --x 10.0

To define workflows imperatively, refer to this example, and to learn more about how to extend Flyte at a deeper level, for example creating custom types, custom tasks or backend plugins, see Extending Flyte.