Source code for flytekit.core.container_task

import typing
from enum import Enum
from typing import Any, Dict, List, Optional, OrderedDict, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.context_manager import FlyteContext
from flytekit.core.interface import Interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"


[docs]class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast majority of tasks - the typical ``@task`` decorated tasks for instance all run a container. An example of something that doesn't run a container would be something like the Athena SQL task. """ class MetadataFormat(Enum): JSON = _task_model.DataLoadingConfig.LITERALMAP_FORMAT_JSON YAML = _task_model.DataLoadingConfig.LITERALMAP_FORMAT_YAML PROTO = _task_model.DataLoadingConfig.LITERALMAP_FORMAT_PROTO class IOStrategy(Enum): DOWNLOAD_EAGER = _task_model.IOStrategy.DOWNLOAD_MODE_EAGER DOWNLOAD_STREAM = _task_model.IOStrategy.DOWNLOAD_MODE_STREAM DO_NOT_DOWNLOAD = _task_model.IOStrategy.DOWNLOAD_MODE_NO_DOWNLOAD UPLOAD_EAGER = _task_model.IOStrategy.UPLOAD_MODE_EAGER UPLOAD_ON_EXIT = _task_model.IOStrategy.UPLOAD_MODE_ON_EXIT DO_NOT_UPLOAD = _task_model.IOStrategy.UPLOAD_MODE_NO_UPLOAD def __init__( self, name: str, image: typing.Union[str, ImageSpec], command: List[str], inputs: Optional[OrderedDict[str, Type]] = None, metadata: Optional[TaskMetadata] = None, arguments: Optional[List[str]] = None, outputs: Optional[Dict[str, Type]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, input_data_dir: Optional[str] = None, output_data_dir: Optional[str] = None, metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, **kwargs, ): sec_ctx = None if secret_requests: for s in secret_requests: if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) # pod_template_name overwrites the metadata.pod_template_name metadata = metadata or TaskMetadata() metadata.pod_template_name = pod_template_name super().__init__( task_type="raw-container", name=name, interface=Interface(inputs, outputs), metadata=metadata, task_config=None, security_ctx=sec_ctx, **kwargs, ) self._image = image self._cmd = command self._args = arguments self._input_data_dir = input_data_dir self._output_data_dir = output_data_dir self._md_format = metadata_format self._io_strategy = io_strategy self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) self.pod_template = pod_template @property def resources(self) -> ResourceSpec: return self._resources
[docs] def local_execute(self, ctx: FlyteContext, **kwargs) -> Any: raise RuntimeError("ContainerTask is not supported in local executions.")
[docs] def get_container(self, settings: SerializationSettings) -> _task_model.Container: # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container if self.pod_template is not None: return None return self._get_container(settings)
def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: return _task_model.DataLoadingConfig( input_path=self._input_data_dir, output_path=self._output_data_dir, format=self._md_format.value, enabled=True, io_strategy=self._io_strategy.value if self._io_strategy else None, ) def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env if isinstance(self._image, ImageSpec): if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: self._image.source_root = settings.source_root return _get_container_definition( image=get_registerable_container_image(self._image, settings.image_config), command=self._cmd, args=self._args, data_loading_config=self._get_data_loading_config(), environment=env, ephemeral_storage_request=self.resources.requests.ephemeral_storage, cpu_request=self.resources.requests.cpu, gpu_request=self.resources.requests.gpu, memory_request=self.resources.requests.mem, ephemeral_storage_limit=self.resources.limits.ephemeral_storage, cpu_limit=self.resources.limits.cpu, gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, )
[docs] def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None return _task_model.K8sPod( pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings), settings), metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, ), data_config=self._get_data_loading_config(), )
[docs] def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: if self.pod_template is None: return {} return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}