from __future__ import annotations
import importlib
import re
from abc import ABC
from typing import Callable, Dict, List, Optional, TypeVar, Union
from flyteidl.core import tasks_pb2
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.security import Secret, SecurityContext
T = TypeVar("T")
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
[docs]
class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
"""
A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the
container and the container information to be automatically captured.
This base will auto configure the image and image version to be used for all its derivatives.
If you are looking to extend, you might prefer to use ``PythonFunctionTask`` or ``PythonInstanceTask``
"""
def __init__(
self,
name: str,
task_config: T,
task_type="python-task",
container_image: Optional[Union[str, ImageSpec]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
environment: Optional[Dict[str, str]] = None,
task_resolver: Optional[TaskResolverMixin] = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
**kwargs,
):
"""
:param name: unique name for the task, usually the function's module and name.
:param task_config: Configuration object for Task. Should be a unique type for that specific Task.
:param task_type: String task type to be associated with this Task
:param container_image: String FQN for the image.
:param requests: custom resource request settings.
:param limits: custom resource limit settings.
:param environment: Environment variables you want the task to have when run.
:param task_resolver: Custom resolver - will pick up the default resolver if empty, or the resolver set
in the compilation context if one is set.
:param List[Secret] secret_requests: Secrets that are requested by this container execution. These secrets will
be mounted based on the configuration in the Secret and available through
the SecretManager using the name of the secret as the group
Ideally the secret keys should also be semi-descriptive.
The key values will be available from runtime, if the backend is configured
to provide secrets and if secrets are available in the configured secrets store.
Possible options for secret stores are
- `Vault <https://www.vaultproject.io/>`__
- `Confidant <https://lyft.github.io/confidant/>`__
- `Kube secrets <https://kubernetes.io/docs/concepts/configuration/secret/>`__
- `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""
sec_ctx = None
if secret_requests:
for s in secret_requests:
if not isinstance(s, Secret):
raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}")
sec_ctx = SecurityContext(secrets=secret_requests)
# pod_template_name overwrites the metadata.pod_template_name
kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata()
kwargs["metadata"].pod_template_name = pod_template_name
super().__init__(
task_type=task_type,
name=name,
task_config=task_config,
security_ctx=sec_ctx,
**kwargs,
)
self._container_image = container_image
# TODO(katrogan): Implement resource overrides
self._resources = ResourceSpec(
requests=requests if requests else Resources(), limits=limits if limits else Resources()
)
self._environment = environment or {}
compilation_state = FlyteContextManager.current_context().compilation_state
if compilation_state and compilation_state.task_resolver:
if task_resolver:
logger.info(
f"Not using the passed in task resolver {task_resolver} because one found in compilation context"
)
self._task_resolver = compilation_state.task_resolver
if self._task_resolver.task_name(self) is not None:
self._name = self._task_resolver.task_name(self) or ""
else:
self._task_resolver = task_resolver or default_task_resolver
self._get_command_fn = self.get_default_command
self.pod_template = pod_template
self.accelerator = accelerator
@property
def task_resolver(self) -> TaskResolverMixin:
return self._task_resolver
@property
def container_image(self) -> Optional[Union[str, ImageSpec]]:
return self._container_image
@property
def resources(self) -> ResourceSpec:
return self._resources
def get_default_command(self, settings: SerializationSettings) -> List[str]:
"""
Returns the default pyflyte-execute command used to run this on hosted Flyte platforms.
"""
container_args = [
"pyflyte-execute",
"--inputs",
"{{.input}}",
"--output-prefix",
"{{.outputPrefix}}",
"--raw-output-data-prefix",
"{{.rawOutputDataPrefix}}",
"--checkpoint-path",
"{{.checkpointOutputPrefix}}",
"--prev-checkpoint",
"{{.prevCheckpointPrefix}}",
"--resolver",
self.task_resolver.location,
"--",
*self.task_resolver.loader_args(settings, self),
]
return container_args
def set_command_fn(self, get_command_fn: Optional[Callable[[SerializationSettings], List[str]]] = None):
"""
By default, the task will run on the Flyte platform using the pyflyte-execute command.
However, it can be useful to update the command with which the task is serialized for specific cases like
running map tasks ("pyflyte-map-execute") or for fast-executed tasks.
"""
self._get_command_fn = get_command_fn # type: ignore
def reset_command_fn(self):
"""
Resets the command which should be used in the container definition of this task to the default arguments.
This is useful when the command line is overridden at serialization time.
"""
self._get_command_fn = self.get_default_command
def get_command(self, settings: SerializationSettings) -> List[str]:
"""
Returns the command which should be used in the container definition for the serialized version of this task
registered on a hosted Flyte platform.
"""
return self._get_command_fn(settings)
def get_image(self, settings: SerializationSettings) -> str:
if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled:
if isinstance(self.container_image, ImageSpec):
# Set the source root for the image spec if it's non-fast registration
self.container_image.source_root = settings.source_root
return get_registerable_container_image(self.container_image, settings.image_config)
def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container
if self.pod_template is not None:
return None
else:
return self._get_container(settings)
def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {}
for elem in (settings.env, self.environment):
if elem:
env.update(elem)
return _get_container_definition(
image=self.get_image(settings),
command=[],
args=self.get_command(settings=settings),
data_loading_config=None,
environment=env,
ephemeral_storage_request=self.resources.requests.ephemeral_storage,
cpu_request=self.resources.requests.cpu,
gpu_request=self.resources.requests.gpu,
memory_request=self.resources.requests.mem,
ephemeral_storage_limit=self.resources.limits.ephemeral_storage,
cpu_limit=self.resources.limits.cpu,
gpu_limit=self.resources.limits.gpu,
memory_limit=self.resources.limits.mem,
)
def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
return _task_model.K8sPod(
pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings), settings),
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
),
)
# need to call super in all its children tasks
def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]:
if self.pod_template is None:
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}
def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None
return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())
class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Please see the notes in the TaskResolverMixin as it describes this default behavior.
"""
def name(self) -> str:
return "DefaultTaskResolver"
@timeit("Load task")
def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
_, task_module, _, task_name, *_ = loader_args
task_module = importlib.import_module(name=task_module) # type: ignore
task_def = getattr(task_module, task_name)
return task_def
def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore
_, m, t, _ = extract_task_module(task)
return ["task-module", m, "task-name", t]
def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
raise Exception("should not be needed")
default_task_resolver = DefaultTaskResolver()
def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg: ImageConfig) -> str:
"""
Resolve the image to the real image name that should be used for registration.
1. If img is a ImageSpec, it will be built and the image name will be returned
2. If img is a placeholder string (e.g. {{.image.default.fqn}}:{{.image.default.version}}),
it will be resolved using the cfg and the image name will be returned
:param img: Configured image or image spec
:param cfg: Registration configuration
:return:
"""
if isinstance(img, ImageSpec):
ImageBuildEngine.build(img)
return img.image_name()
if img is not None and img != "":
matches = _IMAGE_REPLACE_REGEX.findall(img)
if matches is None or len(matches) == 0:
return img
for m in matches:
if len(m) < 3:
raise AssertionError(
"Image specification should be of the form <fqn>:<tag> OR <fqn>:{{.image.default.version}} OR "
f"{{.image.xyz.fqn}}:{{.image.xyz.version}} OR {{.image.xyz}} - Received {m}"
)
replace_group, name, attr = m
if name is None or name == "":
raise AssertionError(f"Image format is incorrect {m}")
img_cfg = cfg.find_image(name)
if img_cfg is None:
raise AssertionError(f"Image Config with name {name} not found in the configuration")
if attr == "version":
if img_cfg.version is not None:
img = img.replace(replace_group, img_cfg.version)
else:
img = img.replace(replace_group, cfg.default_image.version)
elif attr == "fqn":
img = img.replace(replace_group, img_cfg.fqn)
elif attr == "":
img = img.replace(replace_group, img_cfg.full)
else:
raise AssertionError(f"Only fqn and version are supported replacements, {attr} is not supported")
return img
if cfg.default_image is None:
raise ValueError("An image is required for PythonAutoContainer tasks")
return cfg.default_image.full
# Matches {{.image.<name>.<attr>}}. A name can be either 'default' indicating the default image passed during
# serialization or it can be a custom name for an image that must be defined in the config section Images. An attribute
# can be either 'fqn', 'version' or non-existent.
# fqn will access the fully qualified name of the image (e.g. registry/imagename:version -> registry/imagename)
# version will access the version part of the image (e.g. registry/imagename:version -> version)
# With empty attribute, it'll access the full image path (e.g. registry/imagename:version -> registry/imagename:version)
_IMAGE_REPLACE_REGEX = re.compile(r"({{\s*\.image[s]?(?:\.([a-zA-Z0-9_]+))(?:\.([a-zA-Z0-9_]+))?\s*}})", re.IGNORECASE)