"""
==============================
:mod:`flytekit.core.base_task`
==============================
.. currentmodule:: flytekit.core.base_task
.. autosummary::
:nosignatures:
:template: custom.rst
:toctree: generated/
kwtypes
PythonTask
Task
TaskResolverMixin
IgnoreOutputs
"""
import asyncio
import collections
import datetime
import inspect
import warnings
from abc import abstractmethod
from base64 import b64encode
from dataclasses import dataclass
from typing import (
Any,
Coroutine,
Dict,
Generic,
List,
Optional,
OrderedDict,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from flyteidl.core import artifact_id_pb2 as art_id
from flyteidl.core import tasks_pb2
from flytekit.configuration import LocalConfig, SerializationSettings
from flytekit.core.artifact_utils import (
idl_partitions_from_dict,
idl_time_partition_from_datetime,
)
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.promise import (
Promise,
VoidPromise,
create_and_link_node,
create_task_output,
extract_obj_name,
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.tracker import TrackedInstance
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.utils import timeit
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import task as _task_model
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.models.interface import Variable
from flytekit.models.security import SecurityContext
DYNAMIC_PARTITIONS = "_uap"
MODEL_CARD = "_ucm"
DATA_CARD = "_ucd"
UNSET_CARD = "_uc"
[docs]
def kwtypes(**kwargs) -> OrderedDict[str, Type]:
"""
This is a small helper function to convert the keyword arguments to an OrderedDict of types.
.. code-block:: python
kwtypes(a=int, b=str)
"""
d = collections.OrderedDict()
for k, v in kwargs.items():
d[k] = v
return d
class IgnoreOutputs(Exception):
"""
This exception should be used to indicate that the outputs generated by this can be safely ignored.
This is useful in case of distributed training or peer-to-peer parallel algorithms.
"""
pass
[docs]
class Task(object):
"""
The base of all Tasks in flytekit. This task is closest to the FlyteIDL TaskTemplate and captures information in
FlyteIDL specification and does not have python native interfaces associated. Refer to the derived classes for
examples of how to extend this class.
"""
def __init__(
self,
task_type: str,
name: str,
interface: _interface_models.TypedInterface,
metadata: Optional[TaskMetadata] = None,
task_type_version=0,
security_ctx: Optional[SecurityContext] = None,
docs: Optional[Documentation] = None,
**kwargs,
):
self._task_type = task_type
self._name = name
self._interface = interface
self._metadata = metadata if metadata else TaskMetadata()
self._task_type_version = task_type_version
self._security_ctx = security_ctx
self._docs = docs
FlyteEntities.entities.append(self)
@property
def interface(self) -> _interface_models.TypedInterface:
return self._interface
@property
def metadata(self) -> TaskMetadata:
return self._metadata
@property
def name(self) -> str:
return self._name
@property
def task_type(self) -> str:
return self._task_type
@property
def python_interface(self) -> Optional[Interface]:
return None
@property
def task_type_version(self) -> int:
return self._task_type_version
@property
def security_context(self) -> SecurityContext:
return self._security_ctx
@property
def docs(self) -> Documentation:
return self._docs
[docs]
def get_type_for_output_var(self, k: str, v: Any) -> type:
"""
Returns the python native type for the given output variable
# TODO we could use literal type to determine this
"""
return type(v)
[docs]
def local_execute(
self, ctx: FlyteContext, **kwargs
) -> Union[Tuple[Promise], Promise, VoidPromise, Coroutine, None]:
"""
This function is used only in the local execution path and is responsible for calling dispatch execute.
Use this function when calling a task with native values (or Promises containing Flyte literals derived from
Python native values).
"""
# Unwrap the kwargs values. After this, we essentially have a LiteralMap
# The reason why we need to do this is because the inputs during local execute can be of 2 types
# - Promises or native constants
# Promises as essentially inputs from previous task executions
# native constants are just bound to this specific task (default values for a task input)
# Also along with promises and constants, there could be dictionary or list of promises or constants
try:
kwargs = translate_inputs_to_literals(
ctx,
incoming_values=kwargs,
flyte_interface_types=self.interface.inputs,
native_types=self.get_input_types(), # type: ignore
)
except TypeTransformerFailedError as exc:
msg = f"Failed to convert inputs of task '{self.name}':\n {exc}"
logger.error(msg)
raise TypeError(msg) from exc
input_literal_map = _literal_models.LiteralMap(literals=kwargs)
# if metadata.cache is set, check memoized version
local_config = LocalConfig.auto()
if self.metadata.cache and local_config.cache_enabled:
# TODO: how to get a nice `native_inputs` here?
logger.info(
f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} "
f", inputs: {input_literal_map}, and ignore input vars: {self.metadata.cache_ignore_input_vars}"
)
if local_config.cache_overwrite:
outputs_literal_map = None
logger.info("Cache overwrite, task will be executed now")
else:
outputs_literal_map = LocalTaskCache.get(
self.name, self.metadata.cache_version, input_literal_map, self.metadata.cache_ignore_input_vars
)
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
else:
logger.info("Cache hit")
if outputs_literal_map is None:
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
# TODO: need `native_inputs`
LocalTaskCache.set(
self.name,
self.metadata.cache_version,
input_literal_map,
self.metadata.cache_ignore_input_vars,
outputs_literal_map,
)
logger.info(
f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} "
f", inputs: {input_literal_map}, and ignore input vars: {self.metadata.cache_ignore_input_vars}"
)
else:
# This code should mirror the call to `sandbox_execute` in the above cache case.
# Code is simpler with duplication and less metaprogramming, but introduces regressions
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
if inspect.iscoroutine(outputs_literal_map):
return outputs_literal_map
outputs_literals = outputs_literal_map.literals
# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
# location, otherwise we dont really need to right? The higher level execute could just handle literalMap
# After running, we again have to wrap the outputs, if any, back into Promise objects
output_names = list(self.interface.outputs.keys()) # type: ignore
if len(output_names) != len(outputs_literals):
# Length check, clean up exception
raise AssertionError(f"Length difference {len(output_names)} {len(outputs_literals)}")
# Tasks that don't return anything still return a VoidPromise
if len(output_names) == 0:
return VoidPromise(self.name)
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)
def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore
[docs]
def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")
[docs]
def get_container(self, settings: SerializationSettings) -> Optional[_task_model.Container]:
"""
Returns the container definition (if any) that is used to run the task on hosted Flyte.
"""
return None
[docs]
def get_k8s_pod(self, settings: SerializationSettings) -> Optional[_task_model.K8sPod]:
"""
Returns the kubernetes pod definition (if any) that is used to run the task on hosted Flyte.
"""
return None
[docs]
def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]:
"""
Returns the Sql definition (if any) that is used to run the task on hosted Flyte.
"""
return None
[docs]
def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
"""
Return additional plugin-specific custom data (if any) as a serializable dictionary.
"""
return None
[docs]
def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
"""
Returns the task config as a serializable dictionary. This task config consists of metadata about the custom
defined for this task.
"""
return None
[docs]
def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
return None
[docs]
def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_TASK_EXECUTION
[docs]
def sandbox_execute(
self,
ctx: FlyteContext,
input_literal_map: _literal_models.LiteralMap,
) -> _literal_models.LiteralMap:
"""
Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime.
"""
es = cast(ExecutionState, ctx.execution_state)
b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
return self.dispatch_execute(ctx, input_literal_map)
[docs]
@abstractmethod
def dispatch_execute(
self,
ctx: FlyteContext,
input_literal_map: _literal_models.LiteralMap,
) -> _literal_models.LiteralMap:
"""
This method translates Flyte's Type system based input values and invokes the actual call to the executor
This method is also invoked during runtime.
"""
pass
[docs]
@abstractmethod
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
to get some user space parameters. This also ensures that things like SparkSession are already correctly
setup before the type transformers are called
This should return either the same context of the mutated context
"""
pass
[docs]
@abstractmethod
def execute(self, **kwargs) -> Any:
"""
This method will be invoked to execute the task.
"""
pass
T = TypeVar("T")
[docs]
class PythonTask(TrackedInstance, Task, Generic[T]):
"""
Base Class for all Tasks with a Python native ``Interface``. This should be directly used for task types, that do
not have a python function to be executed. Otherwise refer to :py:class:`flytekit.PythonFunctionTask`.
"""
def __init__(
self,
task_type: str,
name: str,
task_config: Optional[T] = None,
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
disable_deck: Optional[bool] = None,
enable_deck: Optional[bool] = None,
**kwargs,
):
"""
Args:
task_type (str): defines a unique task-type for every new extension. If a backend plugin is required then
this has to be done in-concert with the backend plugin identifier
name (str): A unique name for the task instantiation. This is unique for every instance of task.
task_config (T): Configuration for the task. This is used to configure the specific plugin that handles this
task
interface (Optional[Interface]): A python native typed interface ``(inputs) -> outputs`` that declares the
signature of the task
environment (Optional[Dict[str, str]]): Any environment variables that should be supplied during the
execution of the task. Supplied as a dictionary of key/value pairs
disable_deck (bool): (deprecated) If true, this task will not output deck html file
enable_deck (bool): If true, this task will output deck html file
"""
super().__init__(
task_type=task_type,
name=name,
interface=transform_interface_to_typed_interface(interface, allow_partial_artifact_id_binding=True),
**kwargs,
)
self._python_interface = interface if interface else Interface()
self._environment = environment if environment else {}
self._task_config = task_config
if disable_deck is not None:
warnings.warn(
"disable_deck was deprecated in 1.10.0, please use enable_deck instead",
FutureWarning,
)
# Confirm that disable_deck and enable_deck do not contradict each other
if disable_deck is not None and enable_deck is not None:
raise ValueError("disable_deck and enable_deck cannot both be set at the same time")
if enable_deck is not None:
self._disable_deck = not enable_deck
elif disable_deck is not None:
self._disable_deck = disable_deck
else:
self._disable_deck = True
if self._python_interface.docstring:
if self.docs is None:
self._docs = Documentation(
short_description=self._python_interface.docstring.short_description,
long_description=Description(value=self._python_interface.docstring.long_description),
)
else:
if self._python_interface.docstring.short_description:
cast(
Documentation, self._docs
).short_description = self._python_interface.docstring.short_description
if self._python_interface.docstring.long_description:
cast(Documentation, self._docs).long_description = Description(
value=self._python_interface.docstring.long_description
)
# TODO lets call this interface and the other as flyte_interface?
@property
def python_interface(self) -> Interface:
"""
Returns this task's python interface.
"""
return self._python_interface
@property
def task_config(self) -> Optional[T]:
"""
Returns the user-specified task config which is used for plugin-specific handling of the task.
"""
return self._task_config
[docs]
def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for the specified output variable by name.
"""
return self._python_interface.outputs[k]
[docs]
def compile(self, ctx: FlyteContext, *args, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
"""
Generates a node that encapsulates this task in a workflow definition.
"""
return create_and_link_node(ctx, entity=self, **kwargs)
@property
def _outputs_interface(self) -> Dict[Any, Variable]:
return self.interface.outputs # type: ignore
def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
) -> Dict[str, Any]:
return TypeEngine.literal_map_to_kwargs(ctx, literal_map, self.python_interface.inputs)
def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
expected_output_names = list(self._outputs_interface.keys())
if len(expected_output_names) == 1:
# Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
# length one. That convention is used for naming outputs - and single-length-NamedTuples are
# particularly troublesome, but elegant handling of them is not a high priority
# Again, we're using the output_tuple_name as a proxy.
if self.python_interface.output_tuple_name and isinstance(native_outputs, tuple):
native_outputs_as_map = {expected_output_names[0]: native_outputs[0]}
else:
native_outputs_as_map = {expected_output_names[0]: native_outputs}
elif len(expected_output_names) == 0:
native_outputs_as_map = {}
else:
native_outputs_as_map = {expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)}
# We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
# built into the IDL that all the values of a literal map are of the same type.
with timeit("Translate the output to literals"):
literals = {}
omt = ctx.output_metadata_tracker
for i, (k, v) in enumerate(native_outputs_as_map.items()):
literal_type = self._outputs_interface[k].type
py_type = self.get_type_for_output_var(k, v)
if isinstance(v, tuple):
raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}")
try:
lit = TypeEngine.to_literal(ctx, v, py_type, literal_type)
literals[k] = lit
except Exception as e:
# only show the name of output key if it's user-defined (by default Flyte names these as "o<n>")
key = k if k != f"o{i}" else i
msg = f"Failed to convert outputs of task '{self.name}' at position {key}:\n {e}"
logger.error(msg)
raise TypeError(msg) from e
# Now check if there is any output metadata associated with this output variable and attach it to the
# literal
if omt is not None:
om = omt.get(v)
if om:
metadata = {}
if om.additional_items:
for ii in om.additional_items:
md_key, md_val = ii.serialize_to_string(ctx, k)
metadata[md_key] = md_val
logger.info(f"Adding {om.additional_items} additional metadata items {metadata} for {k}")
if om.dynamic_partitions or om.time_partition:
a = art_id.ArtifactID(
partitions=idl_partitions_from_dict(om.dynamic_partitions),
time_partition=idl_time_partition_from_datetime(
om.time_partition, om.artifact.time_partition_granularity
),
)
s = a.SerializeToString()
encoded = b64encode(s).decode("utf-8")
metadata[DYNAMIC_PARTITIONS] = encoded
if metadata:
lit.set_metadata(metadata)
return _literal_models.LiteralMap(literals=literals), native_outputs_as_map
def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_params):
if self._disable_deck is False:
from flytekit.deck.deck import Deck, _output_deck
INPUT = "Inputs"
OUTPUT = "Outputs"
input_deck = Deck(INPUT)
for k, v in native_inputs.items():
input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v)))
output_deck = Deck(OUTPUT)
for k, v in native_outputs_as_map.items():
output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v)))
if ctx.execution_state and ctx.execution_state.is_local_execution():
# When we run the workflow remotely, flytekit outputs decks at the end of _dispatch_execute
_output_deck(self.name.split(".")[-1], new_user_params)
async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params):
native_outputs = await native_outputs
native_outputs = self.post_execute(new_user_params, native_outputs)
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
return literals_map
[docs]
def dispatch_execute(
self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec, Coroutine]:
"""
This method translates Flyte's Type system based input values and invokes the actual call to the executor
This method is also invoked during runtime.
* ``VoidPromise`` is returned in the case when the task itself declares no outputs.
* ``Literal Map`` is returned when the task returns either one more outputs in the declaration. Individual outputs
may be none
* ``DynamicJobSpec`` is returned when a dynamic workflow is executed
"""
# Invoked before the task is executed
new_user_params = self.pre_execute(ctx.user_space_params)
# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(
cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params)
)
# type: ignore
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
# Translate the input literals to Python native
try:
native_inputs = self._literal_map_to_python_input(input_literal_map, exec_ctx)
except Exception as exc:
msg = f"Failed to convert inputs of task '{self.name}':\n {exc}"
logger.error(msg)
raise type(exc)(msg) from exc
# TODO: Logger should auto inject the current context information to indicate if the task is running within
# a workflow or a subworkflow etc
logger.info(f"Invoking {self.name} with inputs: {native_inputs}")
with timeit("Execute user level code"):
native_outputs = self.execute(**native_inputs)
if inspect.iscoroutine(native_outputs):
# If native outputs is a coroutine, then this is an eager workflow.
if exec_ctx.execution_state:
if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
# Just return task outputs as a coroutine if the eager workflow is being executed locally,
# outside of a workflow. This preserves the expectation that the eager workflow is an async
# function.
return native_outputs
elif exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
# If executed inside of a workflow being executed locally, then run the coroutine to get the
# actual results.
return asyncio.run(
self._async_execute(
native_inputs,
native_outputs,
ctx,
exec_ctx,
new_user_params,
)
)
return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params)
logger.debug("Task executed successfully in user level")
# Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
# bubbled up to be handled at the callee layer.
native_outputs = self.post_execute(new_user_params, native_outputs)
# Short circuit the translation to literal map because what's returned may be a dj spec (or an
# already-constructed LiteralMap if the dynamic task was a no-op), not python native values
# dynamic_execute returns a literal map in local execute so this also gets triggered.
if isinstance(
native_outputs,
(_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec),
):
return native_outputs
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
# After the execute has been successfully completed
return literals_map
[docs]
def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore
"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
to get some user space parameters. This also ensures that things like SparkSession are already correctly
setup before the type transformers are called
This should return either the same context of the mutated context
"""
return user_params
[docs]
@abstractmethod
def execute(self, **kwargs) -> Any:
"""
This method will be invoked to execute the task.
"""
pass
[docs]
def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any:
"""
Post execute is called after the execution has completed, with the user_params and can be used to clean-up,
or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op
Args:
rval is returned value from call to execute
user_params: are the modified user params as created during the pre_execute step
"""
return rval
@property
def environment(self) -> Dict[str, str]:
"""
Any environment variables that supplied during the execution of the task.
"""
return self._environment
@property
def disable_deck(self) -> bool:
"""
If true, this task will not output deck html file
"""
return self._disable_deck
[docs]
class TaskResolverMixin(object):
"""
Flytekit tasks interact with the Flyte platform very, very broadly in two steps. They need to be uploaded to Admin,
and then they are run by the user upon request (either as a single task execution or as part of a workflow). In any
case, at execution time, for most tasks (that is those that generate a container target) the container image
containing the task needs to be spun up again at which point the container needs to know which task it's supposed
to run and how to rehydrate the task object.
For example, the serialization of a simple task ::
# in repo_root/workflows/example.py
@task
def t1(...) -> ...: ...
might result in a container with arguments like ::
pyflyte-execute --inputs s3://path/inputs.pb --output-prefix s3://outputs/location \
--raw-output-data-prefix /tmp/data \
--resolver flytekit.core.python_auto_container.default_task_resolver \
-- \
task-module repo_root.workflows.example task-name t1
At serialization time, the container created for the task will start out automatically with the ``pyflyte-execute``
bit, along with the requisite input/output args and the offloaded data prefix. Appended to that will be two things,
#. the ``location`` of the task's task resolver, followed by two dashes, followed by
#. the arguments provided by calling the ``loader_args`` function below.
The ``default_task_resolver`` declared below knows that
* When ``loader_args`` is called on a task, to look up the module the task is in, and the name of the task (the
key of the task in the module, either the function name, or the variable it was assigned to).
* When ``load_task`` is called, it interprets the first part of the command as the module to call
``importlib.import_module`` on, and then looks for a key ``t1``.
This is just the default behavior. Users should feel free to implement their own resolvers.
"""
@property
@abstractmethod
def location(self) -> str:
pass
[docs]
@abstractmethod
def name(self) -> str:
pass
[docs]
@abstractmethod
def load_task(self, loader_args: List[str]) -> Task:
"""
Given the set of identifier keys, should return one Python Task or raise an error if not found
"""
pass
[docs]
@abstractmethod
def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]:
"""
Return a list of strings that can help identify the parameter Task
"""
pass
[docs]
@abstractmethod
def get_all_tasks(self) -> List[Task]:
"""
Future proof method. Just making it easy to access all tasks (Not required today as we auto register them)
"""
pass
[docs]
def task_name(self, t: Task) -> Optional[str]:
"""
Overridable function that can optionally return a custom name for a given task
"""
return None