Source code for flytekit.tools.translator

import sys
import typing
from collections import OrderedDict
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union

from flyteidl.admin import schedule_pb2

from flytekit import PythonFunctionTask, SourceCode
from flytekit.configuration import SerializationSettings
from flytekit.core import constants as _common_constants
from flytekit.core.array_node_map_task import ArrayNodeMapTask
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
from flytekit.core.container_task import ContainerTask
from flytekit.core.gate import Gate
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.legacy_map_task import MapPythonTask
from flytekit.core.node import Node
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
from flytekit.core.task import ReferenceTask
from flytekit.core.utils import ClassDecorator, _dnsify
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
from flytekit.models import common as _common_models
from flytekit.models import common as common_models
from flytekit.models import interface as interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import security
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.core import identifier as _identifier_model
from flytekit.models.core import workflow as _core_wf
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.workflow import ApproveCondition, GateNode, SignalCondition, SleepCondition, TaskNodeOverrides
from flytekit.models.core.workflow import ArrayNode as ArrayNodeModel
from flytekit.models.core.workflow import BranchNode as BranchNodeModel
from flytekit.models.task import TaskSpec, TaskTemplate

FlyteLocalEntity = Union[
    PythonTask,
    BranchNode,
    Node,
    LaunchPlan,
    WorkflowBase,
    ReferenceWorkflow,
    ReferenceTask,
    ReferenceLaunchPlan,
    ReferenceEntity,
]
FlyteControlPlaneEntity = Union[
    TaskSpec,
    _launch_plan_models.LaunchPlan,
    admin_workflow_models.WorkflowSpec,
    workflow_model.Node,
    BranchNodeModel,
    ArrayNodeModel,
]


[docs] @dataclass class Options(object): """ These are options that can be configured for a launchplan during registration or overridden during an execution. For instance two people may want to run the same workflow but have the offloaded data stored in two different buckets. Or you may want labels or annotations to be different. This object is used when launching an execution in a Flyte backend, and also when registering launch plans. Args: labels: Custom labels to be applied to the execution resource annotations: Custom annotations to be applied to the execution resource security_context: Indicates security context for permissions triggered with this launch plan raw_output_data_config: Optional location of offloaded data for things like S3, etc. remote prefix for storage location of the form ``s3://<bucket>/key...`` or ``gcs://...`` or ``file://...``. If not specified will use the platform configured default. This is where the data for offloaded types is stored. max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. notifications: List of notifications for this execution. disable_notifications: This should be set to true if all notifications are intended to be disabled for this execution. """ labels: typing.Optional[common_models.Labels] = None annotations: typing.Optional[common_models.Annotations] = None raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None security_context: typing.Optional[security.SecurityContext] = None max_parallelism: typing.Optional[int] = None notifications: typing.Optional[typing.List[common_models.Notification]] = None disable_notifications: typing.Optional[bool] = None overwrite_cache: typing.Optional[bool] = None
[docs] @classmethod def default_from( cls, k8s_service_account: typing.Optional[str] = None, raw_data_prefix: typing.Optional[str] = None ) -> "Options": return cls( security_context=security.SecurityContext(run_as=security.Identity(k8s_service_account=k8s_service_account)) if k8s_service_account else None, raw_output_data_config=common_models.RawOutputDataConfig(output_location_prefix=raw_data_prefix) if raw_data_prefix else None, )
def to_serializable_case( entity_mapping: OrderedDict, settings: SerializationSettings, c: _core_wf.IfBlock, options: Optional[Options] = None, ) -> _core_wf.IfBlock: if c is None: raise ValueError("Cannot convert none cases to registrable") then_node = get_serializable(entity_mapping, settings, c.then_node, options=options) return _core_wf.IfBlock(condition=c.condition, then_node=then_node) def to_serializable_cases( entity_mapping: OrderedDict, settings: SerializationSettings, cases: List[_core_wf.IfBlock], options: Optional[Options] = None, ) -> Optional[List[_core_wf.IfBlock]]: if cases is None: return None ret_cases = [] for c in cases: ret_cases.append(to_serializable_case(entity_mapping, settings, c, options)) return ret_cases def get_command_prefix_for_fast_execute(settings: SerializationSettings) -> List[str]: return [ "pyflyte-fast-execute", "--additional-distribution", settings.fast_serialization_settings.distribution_location if settings.fast_serialization_settings and settings.fast_serialization_settings.distribution_location else "{{ .remote_package_path }}", "--dest-dir", settings.fast_serialization_settings.destination_dir if settings.fast_serialization_settings and settings.fast_serialization_settings.destination_dir else "{{ .dest_dir }}", "--", ] def prefix_with_fast_execute(settings: SerializationSettings, cmd: typing.List[str]) -> List[str]: return get_command_prefix_for_fast_execute(settings) + cmd def _fast_serialize_command_fn( settings: SerializationSettings, task: PythonAutoContainerTask ) -> Callable[[SerializationSettings], List[str]]: """ This function is only applicable for Pod tasks. """ default_command = task.get_default_command(settings) def fn(settings: SerializationSettings) -> List[str]: return prefix_with_fast_execute(settings, default_command) return fn def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, options: Optional[Options] = None, ) -> TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, settings.domain, entity.name, settings.version, ) if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state # from the serialization context. This is passed through an environment variable, that is read from # during dynamic serialization settings = settings.with_serialized_context() if entity.node_dependency_hints is not None: for entity_hint in entity.node_dependency_hints: get_serializable(entity_mapping, settings, entity_hint, options) container = entity.get_container(settings) # This pod will be incorrect when doing fast serialize pod = entity.get_k8s_pod(settings) if settings.should_fast_serialize(): # This handles container tasks. if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask, ArrayNodeMapTask)): # For fast registration, we'll need to muck with the command, but on # ly for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class container._args = prefix_with_fast_execute(settings, container.args) # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. elif pod and not isinstance(entity, ContainerTask): if isinstance(entity, (MapPythonTask, ArrayNodeMapTask)): entity.set_command_prefix(get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) else: entity.set_command_fn(_fast_serialize_command_fn(settings, entity)) pod = entity.get_k8s_pod(settings) entity.reset_command_fn() entity_config = entity.get_config(settings) or {} extra_config = {} if hasattr(entity, "task_function") and isinstance(entity.task_function, ClassDecorator): extra_config = entity.task_function.get_extra_config() merged_config = {**entity_config, **extra_config} tt = TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), container=container, task_type_version=entity.task_type_version, security_context=entity.security_context, config=merged_config, k8s_pod=pod, sql=entity.get_sql(settings), extended_resources=entity.get_extended_resources(settings), ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() return TaskSpec(template=tt, docs=entity.docs) def get_serializable_workflow( entity_mapping: OrderedDict, settings: SerializationSettings, entity: WorkflowBase, options: Optional[Options] = None, ) -> admin_workflow_models.WorkflowSpec: # Serialize all nodes serialized_nodes = [] sub_wfs = [] for n in entity.nodes: # Ignore start nodes if n.id == _common_constants.GLOBAL_INPUT_NODE_ID: continue # Recursively serialize the node serialized_nodes.append(get_serializable(entity_mapping, settings, n, options)) # If the node is workflow Node or Branch node, we need to handle it specially, to extract all subworkflows, # so that they can be added to the workflow being serialized if isinstance(n.flyte_entity, WorkflowBase): # We are currently not supporting reference workflows since these will # require a network call to flyteadmin to populate the WorkflowTemplate # object if isinstance(n.flyte_entity, ReferenceWorkflow): raise Exception( "Reference sub-workflows are currently unsupported. Use reference launch plans instead." ) sub_wf_spec = get_serializable(entity_mapping, settings, n.flyte_entity, options) if not isinstance(sub_wf_spec, admin_workflow_models.WorkflowSpec): raise TypeError( f"Unexpected type for serialized form of workflow. Expected {admin_workflow_models.WorkflowSpec}, but got {type(sub_wf_spec)}" ) sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) from flytekit.remote import FlyteWorkflow if isinstance(n.flyte_entity, FlyteWorkflow): for swf in n.flyte_entity.flyte_sub_workflows: sub_wf = get_serializable(entity_mapping, settings, swf, options) sub_wfs.append(sub_wf.template) main_wf = get_serializable(entity_mapping, settings, n.flyte_entity, options) sub_wfs.append(main_wf.template) if isinstance(n.flyte_entity, BranchNode): if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block # See comment in get_serializable_branch_node also. Again this is a List[Node] even though it's supposed # to be a List[workflow_model.Node] leaf_nodes: List[Node] = filter( # noqa None, [ if_else.case.then_node, *([] if if_else.other is None else [x.then_node for x in if_else.other]), if_else.else_node, ], ) for leaf_node in leaf_nodes: if isinstance(leaf_node.flyte_entity, WorkflowBase): sub_wf_spec = get_serializable(entity_mapping, settings, leaf_node.flyte_entity, options) sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) elif isinstance(leaf_node.flyte_entity, FlyteWorkflow): get_serializable(entity_mapping, settings, leaf_node.flyte_entity, options) sub_wfs.append(leaf_node.flyte_entity) sub_wfs.extend([s for s in leaf_node.flyte_entity.sub_workflows.values()]) serialized_failure_node = None if entity.failure_node: serialized_failure_node = get_serializable(entity_mapping, settings, entity.failure_node, options) if isinstance(entity.failure_node.flyte_entity, WorkflowBase): sub_wf_spec = get_serializable(entity_mapping, settings, entity.failure_node.flyte_entity, options) sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) wf_id = _identifier_model.Identifier( resource_type=_identifier_model.ResourceType.WORKFLOW, project=settings.project, domain=settings.domain, name=entity.name, version=settings.version, ) wf_t = workflow_model.WorkflowTemplate( id=wf_id, metadata=entity.workflow_metadata.to_flyte_model(), metadata_defaults=entity.workflow_metadata_defaults.to_flyte_model(), interface=entity.interface, nodes=serialized_nodes, outputs=entity.output_bindings, failure_node=serialized_failure_node, ) return admin_workflow_models.WorkflowSpec( template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()), docs=entity.docs ) def get_serializable_launch_plan( entity_mapping: OrderedDict, settings: SerializationSettings, entity: LaunchPlan, recurse_downstream: bool = True, options: Optional[Options] = None, ) -> _launch_plan_models.LaunchPlan: """ :param entity_mapping: :param settings: :param entity: :param options: :param recurse_downstream: This boolean indicate is wf for the entity should also be recursed to :return: """ if recurse_downstream: wf_spec = get_serializable(entity_mapping, settings, entity.workflow, options) wf_id = wf_spec.template.id else: wf_id = _identifier_model.Identifier( resource_type=_identifier_model.ResourceType.WORKFLOW, project=settings.project, domain=settings.domain, name=entity.workflow.name, version=settings.version, ) if not options: options = Options() if options and options.raw_output_data_config: raw_prefix_config = options.raw_output_data_config else: raw_prefix_config = entity.raw_output_data_config or _common_models.RawOutputDataConfig("") if entity.trigger: lc = entity.trigger.to_flyte_idl(entity) if isinstance(lc, schedule_pb2.Schedule): raise ValueError("Please continue to use the schedule arg, the trigger arg is not implemented yet") else: lc = None lps = _launch_plan_models.LaunchPlanSpec( workflow_id=wf_id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( schedule=entity.schedule, notifications=options.notifications or entity.notifications, launch_conditions=lc, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, labels=options.labels or entity.labels or _common_models.Labels({}), annotations=options.annotations or entity.annotations or _common_models.Annotations({}), auth_role=None, raw_output_data_config=raw_prefix_config, max_parallelism=options.max_parallelism or entity.max_parallelism, security_context=options.security_context or entity.security_context, overwrite_cache=options.overwrite_cache or entity.overwrite_cache, ) lp_id = _identifier_model.Identifier( resource_type=_identifier_model.ResourceType.LAUNCH_PLAN, project=settings.project, domain=settings.domain, name=entity.name, version=settings.version, ) lp_model = _launch_plan_models.LaunchPlan( id=lp_id, spec=lps, closure=_launch_plan_models.LaunchPlanClosure( state=None, expected_inputs=interface_models.ParameterMap({}), expected_outputs=interface_models.VariableMap({}), ), ) return lp_model def get_serializable_node( entity_mapping: OrderedDict, settings: SerializationSettings, entity: Node, options: Optional[Options] = None, ) -> workflow_model.Node: if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") upstream_nodes = [ get_serializable(entity_mapping, settings, n, options=options) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID ] # Reference entities also inherit from the classes in the second if statement so address them first. if isinstance(entity.flyte_entity, ReferenceEntity): ref_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) ref_template = ref_spec.template node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], ) if ref_template.resource_type == _identifier_model.ResourceType.TASK: node_model._task_node = workflow_model.TaskNode(reference_id=ref_template.id) elif ref_template.resource_type == _identifier_model.ResourceType.WORKFLOW: node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref_template.id) elif ref_template.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref_template.id) else: raise Exception( f"Unexpected resource type for reference entity {entity.flyte_entity}: {ref_template.resource_type}" ) return node_model from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow if isinstance(entity.flyte_entity, ArrayNodeMapTask): node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], array_node=get_serializable_array_node(entity_mapping, settings, entity, options=options), ) # TODO: do I need this? # if entity._aliases: # node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, ), ), ) if entity._aliases: node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, WorkflowBase): wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id), ) elif isinstance(entity.flyte_entity, BranchNode): node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options), ) elif isinstance(entity.flyte_entity, LaunchPlan): lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) # Node's inputs should not contain the data which is fixed input node_input = [] for b in entity.bindings: if b.var not in entity.flyte_entity.fixed_inputs.literals: node_input.append(b) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) elif isinstance(entity.flyte_entity, Gate): if entity.flyte_entity.sleep_duration: gn = GateNode(sleep=SleepCondition(duration=entity.flyte_entity.sleep_duration)) elif entity.flyte_entity.input_type: output_name = list(entity.flyte_entity.python_interface.outputs.keys())[0] # should be o0 gn = GateNode( signal=SignalCondition( entity.flyte_entity.name, type=entity.flyte_entity.literal_type, output_variable_name=output_name ) ) else: gn = GateNode(approve=ApproveCondition(entity.flyte_entity.name)) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], gate_node=gn, ) elif isinstance(entity.flyte_entity, FlyteTask): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, ), ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) for sub_wf in entity.flyte_entity.flyte_sub_workflows: get_serializable(entity_mapping, settings, sub_wf, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.id), ) elif isinstance(entity.flyte_entity, FlyteLaunchPlan): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) # Node's inputs should not contain the data which is fixed input node_input = [] for b in entity.bindings: if b.var not in entity.flyte_entity.fixed_inputs.literals: node_input.append(b) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=entity.flyte_entity.id), ) else: raise Exception(f"Node contained non-serializable entity {entity._flyte_entity}") return node_model def get_serializable_array_node( entity_mapping: OrderedDict, settings: SerializationSettings, node: Node, options: Optional[Options] = None, ) -> ArrayNodeModel: # TODO Add support for other flyte entities entity = node.flyte_entity task_spec = get_serializable(entity_mapping, settings, entity, options) task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides( resources=node._resources, extended_resources=node._extended_resources, container_image=node._container_image, ), ) node = workflow_model.Node( id=entity.name, metadata=entity.construct_node_metadata(), inputs=node.bindings, upstream_node_ids=[], output_aliases=[], task_node=task_node, ) return ArrayNodeModel( node=node, parallelism=entity.concurrency, min_successes=entity.min_successes, min_success_ratio=entity.min_success_ratio, ) def get_serializable_branch_node( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, options: Optional[Options] = None, ) -> BranchNodeModel: # We have to iterate through the blocks to convert the nodes from the internal Node type to the Node model type. # This was done to avoid having to create our own IfElseBlock object (i.e. condition.py just uses the model # directly) even though the node there is of the wrong type (our type instead of the model type). # TODO this should be cleaned up instead of mutation, we probably should just create a new object first = to_serializable_case(entity_mapping, settings, entity._ifelse_block.case, options) other = to_serializable_cases(entity_mapping, settings, entity._ifelse_block.other, options) else_node_model = None if entity._ifelse_block.else_node: else_node_model = get_serializable(entity_mapping, settings, entity._ifelse_block.else_node, options=options) return BranchNodeModel( if_else=_core_wf.IfElseBlock( case=first, other=other, else_node=else_node_model, error=entity._ifelse_block.error ) ) def get_reference_spec( entity_mapping: OrderedDict, settings: SerializationSettings, entity: ReferenceEntity ) -> ReferenceSpec: template = ReferenceTemplate(entity.id, entity.reference.resource_type) return ReferenceSpec(template) def get_serializable_flyte_workflow( entity: "FlyteWorkflow", settings: SerializationSettings ) -> FlyteControlPlaneEntity: """ TODO replace with deep copy """ def _mutate_task_node(tn: workflow_model.TaskNode): tn.reference_id._project = settings.project tn.reference_id._domain = settings.domain def _mutate_branch_node_task_ids(bn: workflow_model.BranchNode): _mutate_node(bn.if_else.case.then_node) for c in bn.if_else.other: _mutate_node(c.then_node) if bn.if_else.else_node: _mutate_node(bn.if_else.else_node) def _mutate_workflow_node(wn: workflow_model.WorkflowNode): wn.sub_workflow_ref._project = settings.project wn.sub_workflow_ref._domain = settings.domain def _mutate_node(n: workflow_model.Node): if n.task_node: _mutate_task_node(n.task_node) elif n.branch_node: _mutate_branch_node_task_ids(n.branch_node) elif n.workflow_node: _mutate_workflow_node(n.workflow_node) for n in entity.flyte_nodes: _mutate_node(n) entity.id._project = settings.project entity.id._domain = settings.domain return entity def get_serializable_flyte_task(entity: "FlyteTask", settings: SerializationSettings) -> FlyteControlPlaneEntity: """ TODO replace with deep copy """ entity.id._project = settings.project entity.id._domain = settings.domain return entity
[docs] def get_serializable( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, options: Optional[Options] = None, ) -> FlyteControlPlaneEntity: """ The flytekit authoring code produces objects representing Flyte entities (tasks, workflows, etc.). In order to register these, they need to be converted into objects that Flyte Admin understands (the IDL objects basically, but this function currently translates to the layer above (e.g. SdkTask) - this will be changed to the IDL objects directly in the future). :param entity_mapping: This is an ordered dict that will be mutated in place. The reason this argument exists is because there is a natural ordering to the entities at registration time. That is, underlying tasks have to be registered before the workflows that use them. The recursive search done by this function and the functions above form a natural topological sort, finding the dependent entities and adding them to this parameter before the parent entity this function is called with. :param settings: used to pick up project/domain/name - to be deprecated. :param entity: The local flyte entity to try to convert (along with its dependencies) :param options: Optionally pass in a set of options that can be used to add additional metadata for Launchplans :return: The resulting control plane entity, in addition to being added to the mutable entity_mapping parameter is also returned. """ if entity in entity_mapping: return entity_mapping[entity] from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow if isinstance(entity, ReferenceEntity): cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): cp_entity = get_serializable_task(entity_mapping, settings, entity) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) elif isinstance(entity, Node): cp_entity = get_serializable_node(entity_mapping, settings, entity, options) elif isinstance(entity, LaunchPlan): cp_entity = get_serializable_launch_plan(entity_mapping, settings, entity, options=options) elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): if entity.should_register: if isinstance(entity, FlyteTask): cp_entity = get_serializable_flyte_task(entity, settings) else: if entity.should_register: # We only add the tasks if the should register flag is set. This is to avoid adding # unnecessary tasks to the registrable list. for t in entity.flyte_tasks: get_serializable(entity_mapping, settings, t, options) cp_entity = get_serializable_flyte_workflow(entity, settings) else: cp_entity = entity elif isinstance(entity, FlyteLaunchPlan): cp_entity = entity else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): # 1. Check if the size of long description exceeds 16KB # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity if entity.docs and entity.docs.long_description: if entity.docs.long_description.value: if sys.getsizeof(entity.docs.long_description.value) > 16 * 1024 * 1024: raise ValueError( "Long Description of the flyte entity exceeds the 16KB size limit. Please specify the uri in the long description instead." ) entity.docs.source_code = SourceCode(link=settings.git_repo) # This needs to be at the bottom not the top - i.e. dependent tasks get added before the workflow containing it entity_mapping[entity] = cp_entity return cp_entity
def gather_dependent_entities( serialized: OrderedDict, ) -> Tuple[ Dict[_identifier_model.Identifier, TaskTemplate], Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec], Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec], ]: """ The ``get_serializable`` function above takes in an ``OrderedDict`` that helps keep track of dependent entities. For example, when serializing a workflow, all its tasks are also serialized. The ordered dict will also contain serialized entities that aren't as useful though, like nodes and branches. This is just a small helper function that will pull out the serialized tasks, workflows, and launch plans. This function is primarily used for testing. :param serialized: This should be the filled in OrderedDict used in the get_serializable function above. :return: """ task_templates: Dict[_identifier_model.Identifier, TaskTemplate] = {} workflow_specs: Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec] = {} launch_plan_specs: Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec] = {} for cp_entity in serialized.values(): if isinstance(cp_entity, TaskSpec): task_templates[cp_entity.template.id] = cp_entity.template elif isinstance(cp_entity, _launch_plan_models.LaunchPlan): launch_plan_specs[cp_entity.id] = cp_entity.spec elif isinstance(cp_entity, admin_workflow_models.WorkflowSpec): workflow_specs[cp_entity.template.id] = cp_entity return task_templates, workflow_specs, launch_plan_specs