Decorating Workflows#

The behavior of workflows can be modified in a light-weight fashion by using the built-in wraps() decorator pattern, similar to using decorators to customize task behavior. However, unlike in the case of tasks, we need to do a little extra work to make sure that the DAG underlying the workflow executes tasks in the correct order.

Setup-Teardown Pattern#

The main use case of decorating @workflow-decorated functions is when you want to establish a setup-teardown pattern that executes some task before and after your main workflow logic. This is useful when integrating with other external services like wandb or clearml, which enable you to track metrics of model training runs.

from functools import partial, wraps

import flytekit
from flytekit import task, workflow, FlyteContextManager
from flytekit.core.node_creation import create_node

First, let’s define the tasks that we want for setup and teardown. In this example, we’ll use the unittest.mock.MagicMock class to create a fake external service that we want to initialize at the beginning of our workflow and finish at the end.

from unittest.mock import MagicMock

external_service = MagicMock()

def setup():
    print("initializing external service")

def teardown():
    print("finish external service")

As you can see, you can even use Flytekit’s current context to access the execution_id of the current workflow if you need to link Flyte with the external service so that you reference the same unique identifier in both the external service and Flyte.

Workflow Decorator#

Next we create the decorator that we’ll use to wrap our workflow function.

def setup_teardown(fn=None, *, before, after):
    def wrapper(*args, **kwargs):
        # get the current flyte context to obtain access to the compilation state of the workflow DAG.
        ctx = FlyteContextManager.current_context()

        # defines before node
        before_node = create_node(before)
        # ctx.compilation_state.nodes == [before_node]

        # under the hood, flytekit compiler defines and threads
        # together nodes within the `my_workflow` function body
        outputs = fn(*args, **kwargs)
        # ctx.compilation_state.nodes == [before_node, *nodes_created_by_fn]

        # defines the after node
        after_node = create_node(after)
        # ctx.compilation_state.nodes == [before_node, *nodes_created_by_fn, after_node]

        # compile the workflow correctly by making sure `before_node`
        # runs before the first workflow node and `after_node`
        # runs after the last workflow node.
        if ctx.compilation_state is not None:
            # ctx.compilation_state.nodes is a list of nodes defined in the
            # order of execution above
            workflow_node0 = ctx.compilation_state.nodes[1]
            workflow_node1 = ctx.compilation_state.nodes[-2]
            before_node >> workflow_node0
            workflow_node1 >> after_node
        return outputs

    if fn is None:
        return partial(setup_teardown, before=before, after=after)

    return wrapper

There are a few key pieces to note in the setup_teardown decorator above:

  1. It takes a before and after argument, both of which need to be @task-decorated functions. These tasks will run before and after the main workflow function body.

  2. The create_node function to create nodes associated with the before and after tasks.

  3. When fn is called, under the hood Flytekit creates all the nodes associated with the workflow function body

  4. The code within the if ctx.compilation_state is not None: conditional is executed at compile time, which is where we extract the first and last nodes associated with the workflow function body at index 1 and -2.

  5. Finally, we use the >> right shift operator to ensure that before_node executes before the first node and after_node executes after the last node of the main workflow function body.

Defining the DAG#

Now let’s define two tasks that will constitute the workflow

def t1(x: float) -> float:
    return x - 1

def t2(x: float) -> float:
    return x ** 2

And then create our decorated workflow:

@setup_teardown(before=setup, after=teardown)
def wf(x: float) -> float:
    return t2(x=t1(x=x))

if __name__ == "__main__":

In this example, you learned how to modify the behavior of a workflow by defining a setup_teardown decorator that can be applied to any workflow in your project. This is useful when integrating with other external services like wandb or clearml, which enable you to track metrics of model training runs.

If you want to define workflows imperatively, check out :ref:` this example <>`, and to learn more about how to extend Flyte at a deeper level, for example creating custom types, custom tasks, or backend plugins, see Extending Flyte.

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery