"""
Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with
a reference task as well as run-time parameters that limit execution concurrency and failure tolerations.
"""
import functools
import hashlib
import logging
import math
import os
import typing
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set
from flytekit.configuration import SerializationSettings
from flytekit.core import tracker
from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin
from flytekit.core.constants import CONTAINER_ARRAY_TASK
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.tracker import TrackedInstance
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
from flytekit.loggers import logger
from flytekit.models.array_job import ArrayJob
from flytekit.models.interface import Variable
from flytekit.models.task import Container, K8sPod, Sql
from flytekit.tools.module_loader import load_object_from_module
class MapPythonTask(PythonTask):
"""
A MapPythonTask defines a :py:class:`flytekit.PythonTask` which specifies how to run
an inner :py:class:`flytekit.PythonFunctionTask` across a range of inputs in parallel.
"""
def __init__(
self,
python_function_task: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
**kwargs,
):
"""
Wrapper that creates a MapPythonTask
:param python_function_task: This argument is implicitly passed and represents the repeatable function
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given
batch size
:param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete
successfully before terminating this task and marking it successful
:param bound_inputs: List[str] specifies a list of variable names within the interface of python_function_task,
that are already bound and should not be considered as list inputs, but scalar values. This is mostly
useful at runtime and is passed in by MapTaskResolver. This field is not required when a `partial` method
is specified. The bound_vars will be auto-deduced from the `partial.keywords`.
"""
self._partial = None
if isinstance(python_function_task, functools.partial):
# TODO: We should be able to support partial tasks with lists as inputs
for arg in python_function_task.keywords.values():
if isinstance(arg, list):
raise ValueError("Map tasks do not support partial tasks with lists as inputs. ")
self._partial = python_function_task
actual_task = self._partial.func
else:
actual_task = python_function_task
if not isinstance(actual_task, PythonFunctionTask):
if isinstance(actual_task, PythonInstanceTask):
pass
else:
raise ValueError("Map tasks can only compose of PythonFuncton and PythonInstanceTasks currently")
n_outputs = len(actual_task.python_interface.outputs.keys())
if n_outputs > 1:
raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs")
self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set()
if self._partial:
self._bound_inputs = set(self._partial.keywords.keys())
# Transform the interface to List[Optional[T]] in case `min_success_ratio` is set
output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
actual_task.python_interface, self._bound_inputs, output_as_list_of_optionals
)
self._run_task: typing.Union[PythonFunctionTask, PythonInstanceTask] = actual_task # type: ignore
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(typing.cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
h = hashlib.md5(sorted_bounded_inputs.encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"
self._cmd_prefix: typing.Optional[typing.List[str]] = None
self._max_concurrency: typing.Optional[int] = concurrency
self._min_success_ratio: typing.Optional[float] = min_success_ratio
self._array_task_interface = actual_task.python_interface
if "metadata" not in kwargs and actual_task.metadata:
kwargs["metadata"] = actual_task.metadata
if "security_ctx" not in kwargs and actual_task.security_context:
kwargs["security_ctx"] = actual_task.security_context
super().__init__(
name=name,
interface=collection_interface,
task_type=CONTAINER_ARRAY_TASK,
task_config=None,
task_type_version=1,
**kwargs,
)
@property
def bound_inputs(self) -> Set[str]:
return self._bound_inputs
def get_command(self, settings: SerializationSettings) -> List[str]:
"""
TODO ADD bound variables to the resolver. Maybe we need a different resolver?
"""
mt = MapTaskResolver()
container_args = [
"pyflyte-map-execute",
"--inputs",
"{{.input}}",
"--output-prefix",
"{{.outputPrefix}}",
"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--checkpoint-path",
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
mt.name(),
"--",
*mt.loader_args(settings, self),
]
if self._cmd_prefix:
return self._cmd_prefix + container_args
return container_args
def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]):
self._cmd_prefix = cmd
@contextmanager
def prepare_target(self):
"""
TODO: why do we do this?
Alters the underlying run_task command to modify it for map task execution and then resets it after.
"""
self._run_task.set_command_fn(self.get_command)
try:
yield
finally:
self._run_task.reset_command_fn()
def get_container(self, settings: SerializationSettings) -> Container:
with self.prepare_target():
return self._run_task.get_container(settings)
def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod:
with self.prepare_target():
return self._run_task.get_k8s_pod(settings)
def get_sql(self, settings: SerializationSettings) -> Sql:
with self.prepare_target():
return self._run_task.get_sql(settings)
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return ArrayJob(parallelism=self._max_concurrency, min_success_ratio=self._min_success_ratio).to_dict()
def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
return self._run_task.get_config(settings)
@property
def run_task(self) -> typing.Union[PythonFunctionTask, PythonInstanceTask]:
return self._run_task
def __call__(self, *args, **kwargs):
"""
This call method modifies the kwargs and adds kwargs from partial.
This is mostly done in the local_execute and compilation only.
At runtime, the map_task is created with all the inputs filled in. to support this, we have modified
the map_task interface in the constructor.
"""
if self._partial:
"""If partial exists, then mix-in all partial values"""
kwargs = {**self._partial.keywords, **kwargs}
return super().__call__(*args, **kwargs)
def execute(self, **kwargs) -> Any:
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
return self._execute_map_task(ctx, **kwargs)
return self._raw_execute(**kwargs)
@staticmethod
def _compute_array_job_index() -> int:
"""
Computes the absolute index of the current array job. This is determined by summing the compute-environment-specific
environment variable and the offset (if one's set). The offset will be set and used when the user request that the
job runs in a number of slots less than the size of the input.
"""
return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", "0")) + int(
os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "0"), "0")
)
@property
def _outputs_interface(self) -> Dict[Any, Variable]:
"""
We override this method from PythonTask because the dispatch_execute method uses this
interface to construct outputs. Each instance of an container_array task will however produce outputs
according to the underlying run_task interface and the array plugin handler will actually create a collection
from these individual outputs as the final output value.
"""
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self.interface.outputs
return self._run_task.interface.outputs
def get_type_for_output_var(self, k: str, v: Any) -> type:
"""
We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this
interface to construct outputs. Each instance of an container_array task will however produce outputs
according to the underlying run_task interface and the array plugin handler will actually create a collection
from these individual outputs as the final output value.
"""
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self._python_interface.outputs[k]
return self._run_task._python_interface.outputs[k]
def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any:
"""
This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the
Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but
only produce a single output based on the map task (array task) instance. The array plugin handler will actually
create a collection from these individual outputs as the final map task output value.
"""
task_index = self._compute_array_job_index()
map_task_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
map_task_inputs[k] = v[task_index]
else:
map_task_inputs[k] = v
return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs)
def _raw_execute(self, **kwargs) -> Any:
"""
This is called during locally run executions. Unlike array task execution on the Flyte platform, _raw_execute
produces the full output collection.
"""
outputs_expected = True
if not self.interface.outputs:
outputs_expected = False
outputs = []
mapped_tasks_count = 0
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_tasks_count = len(v)
break
failed_count = 0
min_successes = mapped_tasks_count
if self._min_success_ratio:
min_successes = math.ceil(min_successes * self._min_success_ratio)
for i in range(mapped_tasks_count):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
single_instance_inputs[k] = kwargs[k][i]
else:
single_instance_inputs[k] = kwargs[k]
try:
o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
if outputs_expected:
outputs.append(o)
except Exception as exc:
outputs.append(None)
failed_count += 1
if mapped_tasks_count - failed_count < min_successes:
logger.error("The number of successful tasks is lower than the minimum ratio")
raise exc
return outputs
[docs]def map_task(
task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: int = 0,
min_success_ratio: float = 1.0,
**kwargs,
):
"""
Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of
any individual :py:class:`flytekit.PythonFunctionTask`.
Invoke a map task with arguments using the :py:class:`list` version of the expected input.
Usage:
.. literalinclude:: ../../../tests/flytekit/unit/core/test_map_task.py
:start-after: # test_map_task_start
:end-before: # test_map_task_end
:language: python
:dedent: 4
At run time, the underlying map task will be run for every value in the input collection. Attributes
such as :py:class:`flytekit.TaskMetadata` and ``with_overrides`` are applied to individual instances
of the mapped task.
**Map Task Plugins**
There are two plugins to run maptasks that ship as part of flyteplugins:
1. K8s Array
2. `AWS batch <https://docs.flyte.org/en/latest/deployment/plugin_setup/aws/batch.html>`_
Enabling a plugin is controlled in the plugin configuration at `values-sandbox.yaml <https://github.com/flyteorg/flyte/blob/10cee9f139824512b6c5be1667d321bdbc8835fa/charts/flyte/values-sandbox.yaml#L152-L162>`_.
**K8s Array**
By default, the map task uses the ``K8s Array`` plugin. It executes array tasks by launching a pod for every instance in the array. It’s simple to use, has a straightforward implementation, and works out of the box.
**AWS batch**
Learn more about ``AWS batch`` setup configuration `here <https://docs.flyte.org/en/latest/deployment/plugin_setup/aws/batch.html#deployment-plugin-setup-aws-array>`_.
A custom plugin can also be implemented to handle the task type.
:param task_function: This argument is implicitly passed and represents the repeatable function
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If left unspecified, this means unbounded concurrency.
:param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete
successfully before terminating this task and marking it successful.
"""
return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs)
class MapTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Special resolver that is used for MapTasks.
This exists because it is possible that MapTasks are created using nested "partial" subtasks.
When a maptask is created its interface is interpolated from the interface of the subtask - the interpolation,
simply converts every input into a list/collection input.
For example:
interface -> (i: int, j: str) -> str => map_task interface -> (i: List[int], j: List[str]) -> List[str]
But in cases in which `j` is bound to a fixed value by using `functools.partial` we need a way to ensure that
the interface is not simply interpolated, but only the unbound inputs are interpolated.
.. code-block:: python
def foo((i: int, j: str) -> str:
...
mt = map_task(functools.partial(foo, j=10))
print(mt.interface)
output:
(i: List[int], j: str) -> List[str]
But, at runtime this information is lost. To reconstruct this, we use MapTaskResolver that records the "bound vars"
and then at runtime reconstructs the interface with this knowledge
"""
def name(self) -> str:
return "MapTaskResolver"
@timeit("Load map task")
def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPythonTask:
"""
Loader args should be of the form
vars "var1,var2,.." resolver "resolver" [resolver_args]
"""
_, bound_vars, _, resolver, *resolver_args = loader_args
logging.info(f"MapTask found task resolver {resolver} and arguments {resolver_args}")
resolver_obj = load_object_from_module(resolver)
# Use the resolver to load the actual task object
_task_def = resolver_obj.load_task(loader_args=resolver_args)
bound_inputs = set(bound_vars.split(","))
return MapPythonTask(python_function_task=_task_def, max_concurrency=max_concurrency, bound_inputs=bound_inputs)
def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore
return [
"vars",
f'{",".join(sorted(t.bound_inputs))}',
"resolver",
t.run_task.task_resolver.location,
*t.run_task.task_resolver.loader_args(settings, t.run_task),
]
def get_all_tasks(self) -> List[Task]:
raise NotImplementedError("MapTask resolver cannot return every instance of the map task")