Source code for flytekitplugins.ray.task

import base64
import json
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import yaml
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from google.protobuf.json_format import MessageToDict

from flytekit import lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins

ray = lazy_module("ray")


[docs] @dataclass class HeadNodeConfig: ray_start_params: typing.Optional[typing.Dict[str, str]] = None
[docs] @dataclass class WorkerNodeConfig: group_name: str replicas: int min_replicas: typing.Optional[int] = None max_replicas: typing.Optional[int] = None ray_start_params: typing.Optional[typing.Dict[str, str]] = None
[docs] @dataclass class RayJobConfig: worker_node_config: typing.List[WorkerNodeConfig] head_node_config: typing.Optional[HeadNodeConfig] = None enable_autoscaling: bool = False runtime_env: typing.Optional[dict] = None address: typing.Optional[str] = None shutdown_after_job_finishes: bool = False ttl_seconds_after_finished: typing.Optional[int] = None
class RayFunctionTask(PythonFunctionTask): """ Actual Plugin that transforms the local python code for execution within Ray job. """ _RAY_TASK_TYPE = "ray" def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs): super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs) self._task_config = task_config def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: ray.init(address=self._task_config.address) return user_params def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: cfg = self._task_config # Deprecated: runtime_env is removed KubeRay >= 1.1.0. It is replaced by runtime_env_yaml runtime_env = base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode() if cfg.runtime_env else None runtime_env_yaml = yaml.dump(cfg.runtime_env) if cfg.runtime_env else None ray_job = RayJob( ray_cluster=RayCluster( head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None, worker_group_spec=[ WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params) for c in cfg.worker_node_config ], enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False, ), runtime_env=runtime_env, runtime_env_yaml=runtime_env_yaml, ttl_seconds_after_finished=cfg.ttl_seconds_after_finished, shutdown_after_job_finishes=cfg.shutdown_after_job_finishes, ) return MessageToDict(ray_job.to_flyte_idl()) # Inject the Ray plugin into flytekits dynamic plugin loading system TaskPlugins.register_pythontask_plugin(RayJobConfig, RayFunctionTask)