from __future__ import annotations
import collections
import copy
import inspect
import typing
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast
from flyteidl.core import artifact_id_pb2 as art_id
from typing_extensions import get_args, get_type_hints
from flytekit.core import context_manager
from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery
from flytekit.core.docstring import Docstring
from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING
from flytekit.core.type_engine import TypeEngine, UnionTransformer
from flytekit.exceptions.user import FlyteValidationException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models.literals import Literal, Scalar, Void
T = typing.TypeVar("T")
def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str:
if isinstance(v, tuple):
if v[1]:
return f"{k}: {v[0]}={v[1]}"
return f"{k}: {v[0]}"
return f"{k}: {v}"
def repr_type_signature(io: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]]) -> str:
"""
Converts an inputs and outputs to a type signature
"""
s = "("
i = 0
for k, v in io.items():
if i > 0:
s += ", "
s += repr_kv(k, v)
i = i + 1
return s + ")"
[docs]
class Interface(object):
"""
A Python native interface object, like inspect.signature but simpler.
"""
def __init__(
self,
inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None,
outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None,
output_tuple_name: Optional[str] = None,
docstring: Optional[Docstring] = None,
):
"""
:param outputs: Output variables and their types as a dictionary
:param inputs: Map of input name to either a tuple where the first element is the python type, and the second
value is the default, or just a single value which is the python type. The latter case is used by tasks
for which perhaps a default value does not make sense. For consistency, we turn it into a tuple.
:param output_tuple_name: This is used to store the name of a typing.NamedTuple when the task or workflow
returns one. This is also used as a proxy for better or for worse for the presence of a tuple return type,
primarily used when handling one-element NamedTuples.
:param docstring: Docstring of the annotated @task or @workflow from which the interface derives from.
"""
self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore
if inputs:
for k, v in inputs.items():
if type(v) is tuple and len(cast(Tuple, v)) > 1:
self._inputs[k] = v # type: ignore
else:
self._inputs[k] = (v, None) # type: ignore
self._outputs = outputs if outputs else {} # type: ignore
self._output_tuple_name = output_tuple_name
if outputs:
variables = [k for k in outputs.keys()]
# TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one.
class Output( # type: ignore
collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore
): # type: ignore
"""
This class can be used in two different places. For multivariate-return entities this class is used
to rewrap the outputs so that our with_overrides function can work.
For manual node creation, it's used during local execution as something that can be dereferenced.
See the create_node function for more information.
"""
def with_overrides(self, *args, **kwargs):
val = self.__getattribute__(self._fields[0])
val.with_overrides(*args, **kwargs)
return self
@property
def ref(self):
for var_name in variables:
if self.__getattribute__(var_name).ref:
return self.__getattribute__(var_name).ref
return None
def runs_before(self, *args, **kwargs):
"""
This is a placeholder and should do nothing. It is only here to enable local execution of workflows
where runs_before is manually called.
"""
def __rshift__(self, *args, **kwargs):
... # See runs_before
self._output_tuple_class = Output
self._docstring = docstring
@property
def output_tuple(self) -> Type[collections.namedtuple]: # type: ignore
return self._output_tuple_class
@property
def output_tuple_name(self) -> Optional[str]:
return self._output_tuple_name
@property
def inputs(self) -> Dict[str, type]:
r = {}
for k, v in self._inputs.items():
r[k] = v[0]
return r
@property
def output_names(self) -> Optional[List[str]]:
if self.outputs:
return [k for k in self.outputs.keys()]
return None
@property
def inputs_with_defaults(self) -> Dict[str, Tuple[Type, Any]]:
return cast(Dict[str, Tuple[Type, Any]], self._inputs)
@property
def default_inputs_as_kwargs(self) -> Dict[str, Any]:
return {k: v[1] for k, v in self._inputs.items() if v[1] is not None}
@property
def outputs(self) -> typing.Dict[str, type]:
return self._outputs # type: ignore
@property
def docstring(self) -> Optional[Docstring]:
return self._docstring
[docs]
def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface:
"""
This method allows addition of extra outputs are expected from a task specification
"""
if not extra_outputs:
return self
new_outputs = copy.copy(self._outputs)
for k, v in extra_outputs.items():
if k in new_outputs:
raise ValueError(f"Output {k} cannot be added as it already exists in the interface")
new_outputs[k] = v
return Interface(self._inputs, new_outputs)
def __str__(self):
return f"{repr_type_signature(self._inputs)} -> {repr_type_signature(self._outputs)}"
def __repr__(self):
return str(self)
def transform_inputs_to_parameters(
ctx: context_manager.FlyteContext, interface: Interface
) -> _interface_models.ParameterMap:
"""
Transforms the given interface (with inputs) to a Parameter Map with defaults set
:param ctx: context
:param interface: the interface object
"""
if interface is None or interface.inputs_with_defaults is None:
return _interface_models.ParameterMap({})
if interface.docstring is None:
inputs_vars = transform_variable_map(interface.inputs)
else:
inputs_vars = transform_variable_map(interface.inputs, interface.docstring.input_descriptions)
params = {}
inputs_with_def = interface.inputs_with_defaults
for k, v in inputs_vars.items():
val, _default = inputs_with_def[k]
if _default is None and UnionTransformer.is_optional_type(val):
literal = Literal(scalar=Scalar(none_type=Void()))
params[k] = _interface_models.Parameter(var=v, default=literal, required=False)
else:
if isinstance(_default, ArtifactQuery):
params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=_default.to_flyte_idl())
elif isinstance(_default, Artifact):
if not _default.version:
# If the artifact is not versioned, assume it's meant to be a query.
q = _default.query()
if q.bound:
params[k] = _interface_models.Parameter(var=v, required=False, artifact_query=q.to_flyte_idl())
else:
raise FlyteValidationException(f"Cannot use default query with artifact {_default.name}")
else:
# If it is versioned, assumed it's intentionally hard-coded
artifact_id = _default.concrete_artifact_id # may raise
params[k] = _interface_models.Parameter(var=v, required=False, artifact_id=artifact_id)
else:
required = _default is None
default_lv = None
if _default is not None:
default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
return _interface_models.ParameterMap(params)
def transform_interface_to_typed_interface(
interface: typing.Optional[Interface],
allow_partial_artifact_id_binding: bool = False,
) -> typing.Optional[_interface_models.TypedInterface]:
"""
Transform the given simple python native interface to FlyteIDL's interface
"""
if interface is None:
return None
if interface.docstring is None:
input_descriptions = output_descriptions = {}
else:
input_descriptions = interface.docstring.input_descriptions
output_descriptions = remap_shared_output_descriptions(
interface.docstring.output_descriptions, interface.outputs
)
inputs_map = transform_variable_map(interface.inputs, input_descriptions)
outputs_map = transform_variable_map(interface.outputs, output_descriptions)
verify_outputs_artifact_bindings(interface.inputs, outputs_map, allow_partial_artifact_id_binding)
return _interface_models.TypedInterface(inputs_map, outputs_map)
def verify_outputs_artifact_bindings(
inputs: Dict[str, type],
outputs: Dict[str, _interface_models.Variable],
allow_partial_artifact_id_binding: bool = False,
):
# collect Artifacts
for k, v in outputs.items():
# Iterate through output partition values if any and verify that if they're bound to an input, that that input
# actually exists in the interface.
if (
v.artifact_partial_id
and v.artifact_partial_id.HasField("partitions")
and v.artifact_partial_id.partitions.value
):
for pk, pv in v.artifact_partial_id.partitions.value.items():
if pv == DYNAMIC_INPUT_BINDING:
if not allow_partial_artifact_id_binding:
raise FlyteValidationException(
f"Binding a partition {pk}'s value dynamically is not allowed for workflows"
)
else:
continue
if pv.HasField("input_binding"):
input_name = pv.input_binding.var
if input_name not in inputs:
raise FlyteValidationException(
f"Output partition {k} is bound to input {input_name} which does not exist in the interface"
)
if v.artifact_partial_id.HasField("time_partition"):
if (
v.artifact_partial_id.time_partition.value == DYNAMIC_INPUT_BINDING
and not allow_partial_artifact_id_binding
):
raise FlyteValidationException(
"Binding a time partition's value dynamically is not allowed for workflows"
)
if v.artifact_partial_id.time_partition.value.HasField("input_binding"):
input_name = v.artifact_partial_id.time_partition.value.input_binding.var
if input_name not in inputs:
raise FlyteValidationException(
f"Output time partition is bound to input {input_name} which does not exist in the interface"
)
def transform_types_to_list_of_type(
m: Dict[str, type], bound_inputs: typing.Set[str], list_as_optional: bool = False
) -> Dict[str, type]:
"""
Converts unbound inputs into the equivalent (optional) collections. This is useful for array jobs / map style code.
It will create a collection of types even if any one these types is not a collection type.
"""
if m is None:
return {}
all_types_are_collection = True
for k, v in m.items():
if k in bound_inputs:
# Skip the inputs that are bound. If they are bound, it does not matter if they are collection or
# singletons
continue
v_type = type(v)
if v_type != typing.List and v_type != list:
all_types_are_collection = False
break
if all_types_are_collection:
return m
om = {}
for k, v in m.items():
if k in bound_inputs:
om[k] = v
else:
om[k] = typing.List[typing.Optional[v] if list_as_optional else v] # type: ignore
return om # type: ignore
def transform_interface_to_list_interface(
interface: Interface, bound_inputs: typing.Set[str], optional_outputs: bool = False
) -> Interface:
"""
Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map
like functions
:param interface: Interface to be upgraded to a list interface
:param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars.
"""
map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs)
map_outputs = transform_types_to_list_of_type(interface.outputs, set(), optional_outputs)
return Interface(inputs=map_inputs, outputs=map_outputs)
def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
For now the fancy object, maybe in the future a dumb object.
"""
type_hints = get_type_hints(fn, include_extras=True)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)
outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = v # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
# This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We
# would like to preserve that name in our custom collections.namedtuple.
custom_name = None
if hasattr(return_annotation, "__bases__"):
bases = return_annotation.__bases__
if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"):
if hasattr(return_annotation, "__name__") and return_annotation.__name__ != "":
custom_name = return_annotation.__name__
return Interface(inputs, outputs, output_tuple_name=custom_name, docstring=docstring)
def transform_variable_map(
variable_map: Dict[str, type],
descriptions: Optional[Dict[str, str]] = None,
) -> Dict[str, _interface_models.Variable]:
"""
Given a map of str (names of inputs for instance) to their Python native types, return a map of the name to a
Flyte Variable object with that type.
"""
res = OrderedDict()
descriptions = descriptions or {}
if variable_map:
for k, v in variable_map.items():
res[k] = transform_type(v, descriptions.get(k, k))
return res
def detect_artifact(
ts: typing.Tuple[typing.Any, ...],
) -> Optional[art_id.ArtifactID]:
"""
If the user wishes to control how Artifacts are created (i.e. naming them, etc.) this is where we pick it up and
store it in the interface.
"""
for t in ts:
if isinstance(t, Artifact):
id_spec = t()
return id_spec.to_partial_artifact_id()
elif isinstance(t, ArtifactIDSpecification):
artifact_id = t.to_partial_artifact_id()
return artifact_id
return None
def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable:
artifact_id = detect_artifact(get_args(x))
if artifact_id:
logger.debug(f"Found artifact id spec: {artifact_id}")
return _interface_models.Variable(
type=TypeEngine.to_literal_type(x), description=description, artifact_partial_id=artifact_id
)
def default_output_name(index: int = 0) -> str:
return f"o{index}"
def output_name_generator(length: int) -> Generator[str, None, None]:
for x in range(0, length):
yield default_output_name(x)
def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
"""
The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
# Option 1
nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
def t(a: int, b: str) -> nt1: ...
# Option 2
def t(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): ...
# Option 3
def t(a: int, b: str) -> typing.Tuple[int, str]: ...
# Option 4
def t(a: int, b: str) -> (int, str): ...
# Option 5
def t(a: int, b: str) -> str: ...
# Option 6
def t(a: int, b: str) -> None: ...
# Options 7/8
def t(a: int, b: str) -> List[int]: ...
def t(a: int, b: str) -> Dict[str, int]: ...
Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
"""
# Handle Option 6
# We can think about whether we should add a default output name with type None in the future.
if return_annotation in (None, type(None), inspect.Signature.empty):
return {}
# This statement results in true for typing.Namedtuple, single and void return types, so this
# handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore
# isinstance / issubclass does not work for Namedtuple.
# Options 1 and 2
bases = return_annotation.__bases__ # type: ignore
if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"):
logger.debug(f"Task returns named tuple {return_annotation}")
return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
# Handle option 3
logger.debug(f"Task returns unnamed typing.Tuple {return_annotation}")
if len(return_annotation.__args__) == 1: # type: ignore
raise FlyteValidationException(
"Tuples should be used to indicate multiple return values, found only one return variable."
)
return OrderedDict(
zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) # type: ignore
)
elif isinstance(return_annotation, tuple):
if len(return_annotation) == 1:
raise FlyteValidationException("Please don't use a tuple if you're just returning one thing.")
return OrderedDict(zip(list(output_name_generator(len(return_annotation))), return_annotation))
else:
# Handle all other single return types
logger.debug(f"Task returns unnamed native tuple {return_annotation}")
return {default_output_name(): cast(Type, return_annotation)}
def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]:
"""
Deals with mixed styles of return value descriptions used in docstrings. If the docstring contains a single entry of return value description, that output description is shared by each output variable.
:param output_descriptions: Dict of output variable names mapping to output description
:param outputs: Interface outputs
:return: Dict of output variable names mapping to shared output description
"""
# no need to remap
if len(output_descriptions) != 1:
return output_descriptions
_, shared_description = next(iter(output_descriptions.items()))
return {k: shared_description for k, _ in outputs.items()}