Source code for flytekitplugins.kftensorflow.task

"""
This Plugin adds the capability of running distributed tensorflow training to Flyte using backend plugins, natively on
Kubernetes. It leverages `TF Job <https://github.com/kubeflow/tf-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union

from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task
from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit.configuration import SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.extend import TaskPlugins


@dataclass
class RestartPolicy(Enum):
    """
    RestartPolicy describes how the replicas should be restarted
    """

    ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS
    FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE
    NEVER = kubeflow_common.RESTART_POLICY_NEVER


@dataclass
class CleanPodPolicy(Enum):
    """
    CleanPodPolicy describes how to deal with pods when the job is finished.
    """

    NONE = kubeflow_common.CLEANPOD_POLICY_NONE
    ALL = kubeflow_common.CLEANPOD_POLICY_ALL
    RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING


@dataclass
class RunPolicy:
    """
    RunPolicy describes a set of policies to apply to the execution of a Kubeflow job.

    Args:
        clean_pod_policy: The policy for cleaning up pods after the job completes. Defaults to None.
        ttl_seconds_after_finished: The time-to-live (TTL) in seconds for cleaning up finished jobs.
        active_deadline_seconds: The duration (in seconds) since startTime during which the job can remain
            active before it is terminated. Must be a positive integer. This setting applies only to pods
            where restartPolicy is OnFailure or Always.
        backoff_limit: The number of retries before marking this job as failed.
    """

    clean_pod_policy: CleanPodPolicy = None
    ttl_seconds_after_finished: Optional[int] = None
    active_deadline_seconds: Optional[int] = None
    backoff_limit: Optional[int] = None


@dataclass
class Chief:
    image: Optional[str] = None
    requests: Optional[Resources] = None
    limits: Optional[Resources] = None
    replicas: Optional[int] = None
    restart_policy: Optional[RestartPolicy] = None


@dataclass
class PS:
    image: Optional[str] = None
    requests: Optional[Resources] = None
    limits: Optional[Resources] = None
    replicas: Optional[int] = None
    restart_policy: Optional[RestartPolicy] = None


@dataclass
class Worker:
    image: Optional[str] = None
    requests: Optional[Resources] = None
    limits: Optional[Resources] = None
    replicas: Optional[int] = None
    restart_policy: Optional[RestartPolicy] = None


@dataclass
class Evaluator:
    image: Optional[str] = None
    requests: Optional[Resources] = None
    limits: Optional[Resources] = None
    replicas: int = 0
    restart_policy: Optional[RestartPolicy] = None


[docs] @dataclass class TfJob: """ Configuration for an executable `TensorFlow Job <https://github.com/kubeflow/tf-operator>`_. Use this to run distributed TensorFlow training on Kubernetes. Args: chief: Configuration for the chief replica group. ps: Configuration for the parameter server (PS) replica group. worker: Configuration for the worker replica group. evaluator: Configuration for the evaluator replica group. run_policy: Configuration for the run policy. num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead. num_chief_replicas: [DEPRECATED] This argument is deprecated. Use `chief.replicas` instead. """ chief: Chief = field(default_factory=lambda: Chief()) ps: PS = field(default_factory=lambda: PS()) worker: Worker = field(default_factory=lambda: Worker()) evaluator: Evaluator = field(default_factory=lambda: Evaluator()) run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) # Support v0 config for backwards compatibility num_workers: Optional[int] = None num_ps_replicas: Optional[int] = None num_chief_replicas: Optional[int] = None num_evaluator_replicas: Optional[int] = None
class TensorflowFunctionTask(PythonFunctionTask[TfJob]): """ Plugin that submits a TFJob (see https://github.com/kubeflow/tf-operator) defined by the code within the _task_function to k8s cluster. """ _TF_JOB_TASK_TYPE = "tensorflow" def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): if task_config.num_workers and task_config.worker.replicas: raise ValueError( "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." ) if task_config.num_workers is None and task_config.worker.replicas is None: raise ValueError( "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." ) if task_config.num_chief_replicas and task_config.chief.replicas: raise ValueError( "Cannot specify both `num_chief_replicas` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." ) if task_config.num_chief_replicas is None and task_config.chief.replicas is None: raise ValueError( "Must specify either `num_chief_replicas` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." ) if task_config.num_ps_replicas and task_config.ps.replicas: raise ValueError( "Cannot specify both `num_ps_replicas` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." ) if task_config.num_ps_replicas is None and task_config.ps.replicas is None: raise ValueError( "Must specify either `num_ps_replicas` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." ) if task_config.num_evaluator_replicas and task_config.evaluator.replicas > 0: raise ValueError( "Cannot specify both `num_evaluator_replicas` and `evaluator.replicas`. Please use `evaluator.replicas` as `num_evaluator_replicas` is depreacated." ) super().__init__( task_type=self._TF_JOB_TASK_TYPE, task_config=task_config, task_function=task_function, task_type_version=1, **kwargs, ) def _convert_replica_spec( self, replica_config: Union[Chief, PS, Worker, Evaluator] ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( replicas=replica_config.replicas, image=replica_config.image, resources=resources.to_flyte_idl() if resources else None, restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, ) def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: return kubeflow_common.RunPolicy( clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy.value else None, ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, active_deadline_seconds=run_policy.active_deadline_seconds, backoff_limit=run_policy.backoff_limit, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: chief = self._convert_replica_spec(self.task_config.chief) if self.task_config.num_chief_replicas: chief.replicas = self.task_config.num_chief_replicas worker = self._convert_replica_spec(self.task_config.worker) if self.task_config.num_workers: worker.replicas = self.task_config.num_workers ps = self._convert_replica_spec(self.task_config.ps) if self.task_config.num_ps_replicas: ps.replicas = self.task_config.num_ps_replicas evaluator = self._convert_replica_spec(self.task_config.evaluator) if self.task_config.num_evaluator_replicas: evaluator.replicas = self.task_config.num_evaluator_replicas run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None training_task = tensorflow_task.DistributedTensorflowTrainingTask( chief_replicas=chief, worker_replicas=worker, ps_replicas=ps, evaluator_replicas=evaluator, run_policy=run_policy, ) return MessageToDict(training_task) # Register the Tensorflow Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(TfJob, TensorflowFunctionTask)