"""
This Plugin adds the capability of running distributed MPI training to Flyte using backend plugins, natively on
Kubernetes. It leverages `MPI Job <https://github.com/kubeflow/mpi-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Union
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task
from google.protobuf.json_format import MessageToDict
from flytekit import PythonFunctionTask, Resources
from flytekit.configuration import SerializationSettings
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.extend import TaskPlugins
@dataclass
class RestartPolicy(Enum):
"""
RestartPolicy describes how the replicas should be restarted
"""
ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS
FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE
NEVER = kubeflow_common.RESTART_POLICY_NEVER
@dataclass
class CleanPodPolicy(Enum):
"""
CleanPodPolicy describes how to deal with pods when the job is finished.
"""
NONE = kubeflow_common.CLEANPOD_POLICY_NONE
ALL = kubeflow_common.CLEANPOD_POLICY_ALL
RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING
@dataclass
class RunPolicy:
"""
RunPolicy describes some policy to apply to the execution of a kubeflow job.
Args:
clean_pod_policy: Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None.
ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs.
active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job.
can remain active before it is terminated. Must be a positive integer. This setting applies only to pods.
where restartPolicy is OnFailure or Always.
backoff_limit (int): Number of retries before marking this job as failed.
"""
clean_pod_policy: CleanPodPolicy = None
ttl_seconds_after_finished: Optional[int] = None
active_deadline_seconds: Optional[int] = None
backoff_limit: Optional[int] = None
@dataclass
class Worker:
"""
Worker replica configuration. Worker command can be customized. If not specified, the worker will use
default command generated by the mpi operator.
"""
command: Optional[List[str]] = None
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = None
restart_policy: Optional[RestartPolicy] = None
@dataclass
class Launcher:
"""
Launcher replica configuration. Launcher command can be customized. If not specified, the launcher will use
the command specified in the task signature.
"""
command: Optional[List[str]] = None
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: Optional[int] = None
restart_policy: Optional[RestartPolicy] = None
[docs]
@dataclass
class MPIJob(object):
"""
Configuration for an executable `MPI Job <https://github.com/kubeflow/mpi-operator>`_. Use this
to run distributed training on k8s with MPI
Args:
launcher: Configuration for the launcher replica group.
worker: Configuration for the worker replica group.
run_policy: Configuration for the run policy.
slots: The number of slots per worker used in the hostfile.
num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated.
num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job
"""
launcher: Launcher = field(default_factory=lambda: Launcher())
worker: Worker = field(default_factory=lambda: Worker())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
slots: int = 1
# Support v0 config for backwards compatibility
num_launcher_replicas: Optional[int] = None
num_workers: Optional[int] = None
class MPIFunctionTask(PythonFunctionTask[MPIJob]):
"""
Plugin that submits a MPIJob (see https://github.com/kubeflow/mpi-operator)
defined by the code within the _task_function to k8s cluster.
"""
_MPI_JOB_TASK_TYPE = "mpi"
_MPI_BASE_COMMAND = [
"mpirun",
"--allow-run-as-root",
"-bind-to",
"none",
"-map-by",
"slot",
"-x",
"LD_LIBRARY_PATH",
"-x",
"PATH",
"-x",
"NCCL_DEBUG=INFO",
"-mca",
"pml",
"ob1",
"-mca",
"btl",
"^openib",
]
def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs):
if task_config.num_workers and task_config.worker.replicas:
raise ValueError(
"Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated."
)
if task_config.num_workers is None and task_config.worker.replicas is None:
raise ValueError(
"Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated."
)
if task_config.num_launcher_replicas and task_config.launcher.replicas:
raise ValueError(
"Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated."
)
if task_config.num_launcher_replicas is None and task_config.launcher.replicas is None:
raise ValueError(
"Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated."
)
super().__init__(
task_config=task_config,
task_function=task_function,
task_type=self._MPI_JOB_TASK_TYPE,
# task_type_version controls the version of the task template, do not change
task_type_version=1,
**kwargs,
)
def _convert_replica_spec(
self, replica_config: Union[Launcher, Worker]
) -> mpi_task.DistributedMPITrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return mpi_task.DistributedMPITrainingReplicaSpec(
command=replica_config.command,
replicas=replica_config.replicas,
image=replica_config.image,
resources=resources.to_flyte_idl() if resources else None,
restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None,
)
def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
return kubeflow_common.RunPolicy(
clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None,
ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished,
active_deadline_seconds=run_policy.active_deadline_seconds,
backoff_limit=run_policy.backoff_limit,
)
def _get_base_command(self, settings: SerializationSettings) -> List[str]:
return super().get_command(settings)
def get_command(self, settings: SerializationSettings) -> List[str]:
cmd = self._get_base_command(settings)
if self.task_config.num_workers:
num_workers = self.task_config.num_workers
else:
num_workers = self.task_config.worker.replicas
num_procs = num_workers * self.task_config.slots
mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", settings.entrypoint_settings.path] + cmd
# the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile
return mpi_cmd
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
worker = self._convert_replica_spec(self.task_config.worker)
if self.task_config.num_workers:
worker.replicas = self.task_config.num_workers
launcher = self._convert_replica_spec(self.task_config.launcher)
if self.task_config.num_launcher_replicas:
launcher.replicas = self.task_config.num_launcher_replicas
run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
mpi_job = mpi_task.DistributedMPITrainingTask(
worker_replicas=worker,
launcher_replicas=launcher,
slots=self.task_config.slots,
run_policy=run_policy,
)
return MessageToDict(mpi_job)
@dataclass
class HorovodJob(object):
"""
Configuration for an executable `Horovod Job using MPI operator<https://github.com/kubeflow/mpi-operator>`_. Use this
to run distributed training on k8s with MPI. For more info, check out Running Horovod<https://horovod.readthedocs.io/en/stable/summary_include.html#running-horovod>`_.
Args:
worker: Worker configuration for the job.
launcher: Launcher configuration for the job.
run_policy: Configuration for the run policy.
slots: Number of slots per worker used in hostfile (default: 1).
verbose: Optional flag indicating whether to enable verbose logging (default: False).
log_level: Optional string specifying the log level (default: "INFO").
discovery_script_path: Path to the discovery script used for host discovery (default: "/etc/mpi/discover_hosts.sh").
num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. Please use launcher.replicas instead.
num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job. Please use worker.replicas instead.
"""
worker: Worker = field(default_factory=lambda: Worker())
launcher: Launcher = field(default_factory=lambda: Launcher())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
slots: int = 1
verbose: Optional[bool] = False
log_level: Optional[str] = "INFO"
discovery_script_path: Optional[str] = "/etc/mpi/discover_hosts.sh"
# Support v0 config for backwards compatibility
num_launcher_replicas: Optional[int] = None
num_workers: Optional[int] = None
class HorovodFunctionTask(MPIFunctionTask):
"""
For more info, check out https://github.com/horovod/horovod
"""
# Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod.
def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs):
super().__init__(
task_config=task_config,
task_function=task_function,
**kwargs,
)
def get_command(self, settings: SerializationSettings) -> List[str]:
cmd = self._get_base_command(settings)
mpi_cmd = self._get_horovod_prefix() + cmd
return mpi_cmd
def _get_horovod_prefix(self) -> List[str]:
np = self.task_config.worker.replicas * self.task_config.slots
log_level = self.task_config.log_level
base_cmd = [
"horovodrun",
"-np",
f"{np}",
"--log-level",
f"{log_level}",
"--network-interface",
"eth0",
"--min-np",
f"{np}",
"--max-np",
f"{np}",
"--slots-per-host",
f"{self.task_config.slots}",
"--host-discovery-script",
self.task_config.discovery_script_path,
]
if self.task_config.verbose:
base_cmd.append("--verbose")
return base_cmd
# Register the MPI Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask)
TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask)