Map Tasks#

A map task lets you run a pod task or a regular task over a list of inputs within a single workflow node. This means you can run thousands of instances of the task without creating a node for every instance, providing valuable performance gains!

Some use cases of map tasks include:

  • Several inputs must run through the same code logic

  • Multiple data batches need to be processed in parallel

  • Hyperparameter optimization

Let’s look at an example now!

First, import the libraries.

from typing import List

from flytekit import Resources, map_task, task, workflow

Next, define a task to use in the map task.

Note

A map task can only accept one input and produce one output.

@task
def a_mappable_task(a: int) -> str:
    inc = a + 2
    stringified = str(inc)
    return stringified

Also define a task to reduce the mapped output to a string.

@task
def coalesce(b: List[str]) -> str:
    coalesced = "".join(b)
    return coalesced

We send a_mappable_task to be repeated across a collection of inputs to the map_task() function. In the example, a of type List[int] is the input. The task a_mappable_task is run for each element in the list.

with_overrides is useful to set resources for individual map task.

@workflow
def my_map_workflow(a: List[int]) -> str:
    mapped_out = map_task(a_mappable_task)(a=a).with_overrides(
        requests=Resources(mem="300Mi"),
        limits=Resources(mem="500Mi"),
        retries=1,
    )
    coalesced = coalesce(b=mapped_out)
    return coalesced

Lastly, we can run the workflow locally!

if __name__ == "__main__":
    result = my_map_workflow(a=[1, 2, 3, 4, 5])
    print(f"{result}")

When defining a map task, avoid calling other tasks in it. Flyte can’t accurately register tasks that call other tasks. While Flyte will correctly execute a task that calls other tasks, it will not be able to give full performance advantages. This is especially true for map tasks.

In this example, the map task suboptimal_mappable_task would not give you the best performance.

@task
def upperhalf(a: int) -> int:
    return a / 2 + 1

@task
def suboptimal_mappable_task(a: int) -> str:
    inc = upperhalf(a=a)
    stringified = str(inc)
    return stringified

By default, the map task uses the K8s Array plugin. Map tasks can also run on alternate execution backends, such as AWS Batch, a provisioned service that can scale to great sizes.

Map a Task with Multiple Inputs#

You might need to map a task with multiple inputs.

For example, we have a task that takes 3 inputs.

@task
def multi_input_task(quantity: int, price: float, shipping: float) -> float:
    return quantity * price * shipping

But we only want to map this task with the quantity input while the other inputs stay the same. Since a map task accepts only one input, we can do this by creating a new task that prepares the map task’s inputs.

We start by putting the inputs in a Dataclass and dataclass_json. We also define our helper task to prepare the map task’s inputs.

from dataclasses import dataclass
from dataclasses_json import dataclass_json

@dataclass_json
@dataclass
class MapInput:
    quantity: float
    price: float
    shipping: float

@task
def prepare_map_inputs(list_q: List[float], p: float, s: float) -> List[MapInput]:
    return [MapInput(q, p, s) for q in list_q]

Then we refactor multi_input_task. Instead of 3 inputs, mappable_task has a single input.

@task
def mappable_task(input: MapInput) -> float:
    return input.quantity * input.price * input.shipping

Our workflow prepares a new list of inputs for the map task.

@workflow
def multiple_workflow(list_q: List[float], p: float, s: float) -> List[float]:
    prepared = prepare_map_inputs(list_q=list_q, p=p, s=s)
    return map_task(mappable_task)(input=prepared)

We can run our multi-input map task locally.

if __name__ == "__main__":
    result = multiple_workflow(list_q=[1.0, 2.0, 3.0, 4.0, 5.0], p=6.0, s=7.0)
    print(f"{result}")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery