Source code for flytekit.extras.tasks.shell

import datetime
import os
import platform
import string
import subprocess
import typing
from dataclasses import dataclass
from typing import List

import flytekit
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.core.task import TaskPlugins
from flytekit.exceptions.user import FlyteRecoverableException
from flytekit.loggers import logger
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile


@dataclass
class ProcessResult:
    """Stores a process return code, standard output and standard error.

    Args:
        returncode: int The sub-process return code
        output: str The sub-process standard output string
        error: str The sub-process standard error string
    """

    returncode: int
    output: str
    error: str


[docs] @dataclass class OutputLocation: """ Args: var: str The name of the output variable var_type: typing.Type The type of output variable location: os.PathLike The location where this output variable will be written to or a regex that accepts input vars and generates the path. Of the form ``"{{ .inputs.v }}.tmp.md"``. This example for a given input v, at path `/tmp/abc.csv` will resolve to `/tmp/abc.csv.tmp.md` """ var: str var_type: typing.Type location: typing.Union[os.PathLike, str]
def subproc_execute(command: typing.Union[List[str], str], **kwargs) -> ProcessResult: """ Execute a command and capture its stdout and stderr. Useful for executing shell commands from within a python task. Args: command (List[str]): The command to be executed as a list of strings. Returns: ProcessResult: Structure containing output of the command. Raises: Exception: If the command execution fails, this exception is raised with details about the command, return code, and stderr output. Exception: If the executable is not found, this exception is raised with guidance on specifying a container image in the task definition when using custom dependencies. """ defaults = { "stdout": subprocess.PIPE, "stderr": subprocess.PIPE, "text": True, "check": True, } kwargs = {**defaults, **kwargs} try: # Execute the command and capture stdout and stderr result = subprocess.run(command, **kwargs) # Access the stdout and stderr output return ProcessResult(result.returncode, result.stdout, result.stderr) except subprocess.CalledProcessError as e: raise Exception(f"Command: {e.cmd}\nFailed with return code {e.returncode}:\n{e.stderr}") except FileNotFoundError as e: raise Exception( f"""Process failed because the executable could not be found. Did you specify a container image in the task definition if using custom dependencies?\n{e}""" ) def _dummy_task_func(): """ A Fake function to satisfy the inner PythonTask requirements """ return None class AttrDict(dict): """ Convert a dictionary to an attribute style lookup. Do not use this in regular places, this is used for namespacing inputs and outputs """ def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self class _PythonFStringInterpolizer: """A class for interpolating scripts that use python string.format syntax""" class _Formatter(string.Formatter): def format_field(self, value, format_spec): """ Special cased return for the given value. Given the type returns the string version for the type. Handles FlyteFile and FlyteDirectory specially. Downloads and returns the downloaded filepath. """ if isinstance(value, FlyteFile): value.download() return value.path if isinstance(value, FlyteDirectory): value.download() return value.path if isinstance(value, datetime.datetime): return value.isoformat() return super().format_field(value, format_spec) def interpolate( self, tmpl: str, inputs: typing.Optional[typing.Dict[str, str]] = None, outputs: typing.Optional[typing.Dict[str, str]] = None, ) -> str: """ Interpolate python formatted string templates with variables from the input and output argument dicts. The result is non destructive towards the given template string. """ inputs = inputs or {} outputs = outputs or {} inputs = AttrDict(inputs) outputs = AttrDict(outputs) consolidated_args = { "inputs": inputs, "outputs": outputs, "ctx": flytekit.current_context(), } try: return self._Formatter().format(tmpl, **consolidated_args) except KeyError as e: raise ValueError(f"Variable {e} in Query not found in inputs {consolidated_args.keys()}") T = typing.TypeVar("T") def _run_script(script: str, shell: str) -> ProcessResult: """ Run script as a subprocess and return the returncode, stdout, and stderr. While executing the su process, stdout of the subprocess will be printed to the current process stdout so that the subprocess execution will not appear unresponsive :param script: script to be executed :type script: str :param shell: shell to use to run the script :type shell: str :return: structure containing the process returncode, stdout (stripped from carriage returns), and stderr :rtype: ProcessResult """ process = subprocess.Popen( script, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0, shell=True, text=True, executable=shell, ) process_stdout, process_stderr = process.communicate() out = "" for line in process_stdout.splitlines(): print(line) out += line code = process.wait() return ProcessResult(code, out, process_stderr)
[docs] class ShellTask(PythonInstanceTask[T]): """ """ def __init__( self, name: str, debug: bool = False, script: typing.Optional[str] = None, script_file: typing.Optional[str] = None, task_config: T = None, shell: str = "/bin/sh", inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, output_locs: typing.Optional[typing.List[OutputLocation]] = None, **kwargs, ): """ Args: name: str Name of the Task. Should be unique in the project debug: bool Print the generated script and other debugging information script: The actual script specified as a string script_file: A path to the file that contains the script (Only script or script_file) can be provided task_config: Configuration for the task, can be either a Pod (or coming soon, BatchJob) config shell: Shell to use to run the script inputs: A Dictionary of input names to types output_locs: A list of :py:class:`OutputLocations` **kwargs: Other arguments that can be passed to :py:class:`~flytekit.core.python_function_task.PythonInstanceTask` """ if script and script_file: raise ValueError("Only either of script or script_file can be provided") if not script and not script_file: raise ValueError("Either a script or script_file is needed") if script_file: if not os.path.exists(script_file): raise ValueError(f"FileNotFound: the specified Script file at path {script_file} cannot be loaded") script_file = os.path.abspath(script_file) if task_config is not None: fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__ if not fully_qualified_class_name == "flytekitplugins.pod.task.Pod": raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.") # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used # to run pre- and post- execute functions using the corresponding task plugin. # We rename the function name here to ensure the generated task has a unique name and avoid duplicate task name # errors. # This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work. plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config)) self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) # Rename the internal task so that there are no conflicts at serialization time. Technically these internal # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities # at serialization time. self._config_task_instance._name = f"_bash.{name}" self._script = script self._script_file = script_file self._debug = debug self._shell = shell self._output_locs = output_locs if output_locs else [] self._interpolizer = _PythonFStringInterpolizer() outputs = self._validate_output_locs() self._process_result: typing.Optional[ProcessResult] = None super().__init__( name, task_config, task_type=self._config_task_instance.task_type, interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) def _validate_output_locs(self) -> typing.Dict[str, typing.Type]: outputs = {} for v in self._output_locs: if v is None: raise ValueError("OutputLocation cannot be none") if not isinstance(v, OutputLocation): raise ValueError("Every output type should be an output location on the file-system") if v.location is None: raise ValueError(f"Output Location not provided for output var {v.var}") if not issubclass(v.var_type, FlyteFile) and not issubclass(v.var_type, FlyteDirectory): raise ValueError( "Currently only outputs of type FlyteFile/FlyteDirectory and their derived types are supported" ) outputs[v.var] = v.var_type return outputs @property def result(self) -> typing.Optional[ProcessResult]: return self._process_result @property def script(self) -> typing.Optional[str]: return self._script @property def script_file(self) -> typing.Optional[os.PathLike]: return self._script_file
[docs] def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return self._config_task_instance.pre_execute(user_params)
[docs] def execute(self, **kwargs) -> typing.Any: """ Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem """ logger.info(f"Running shell script as type {self.task_type}") if self.script_file: with open(self.script_file) as f: self._script = f.read() outputs: typing.Dict[str, str] = {} if self._output_locs: for v in self._output_locs: outputs[v.var] = self._interpolizer.interpolate(v.location, inputs=kwargs) if os.name == "nt": self._script = self._script.lstrip().rstrip().replace("\n", "&&") gen_script = self._interpolizer.interpolate(self._script, inputs=kwargs, outputs=outputs) if self._debug: print("\n==============================================\n") print(gen_script) print("\n==============================================\n") if platform.system() == "Windows": if os.environ.get("ComSpec") is None: # https://github.com/python/cpython/issues/101283 os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe" self._shell = os.environ["ComSpec"] self._process_result = _run_script(gen_script, self._shell) if self._process_result.returncode != 0: files = os.listdir(".") fstr = "\n-".join(files) error = ( f"Failed to Execute Script, return-code {self._process_result.returncode} \n" f"Current directory contents: .\n-{fstr}\n" f"StdOut: {self._process_result.output}\n" f"StdErr: {self._process_result.error}\n" ) logger.error(error) # raise FlyteRecoverableException so that it's classified as user error and will be retried raise FlyteRecoverableException(error) final_outputs = [] for v in self._output_locs: if issubclass(v.var_type, FlyteFile): final_outputs.append(FlyteFile(outputs[v.var])) if issubclass(v.var_type, FlyteDirectory): final_outputs.append(FlyteDirectory(outputs[v.var])) if len(final_outputs) == 1: return final_outputs[0] if len(final_outputs) > 1: return tuple(final_outputs) return None
[docs] def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any: return self._config_task_instance.post_execute(user_params, rval)
class RawShellTask(ShellTask): """ """ def __init__( self, name: str, debug: bool = False, script: typing.Optional[str] = None, script_file: typing.Optional[str] = None, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, output_locs: typing.Optional[typing.List[OutputLocation]] = None, **kwargs, ): """ The `RawShellTask` is a minimal extension of the existing `ShellTask`. It's purpose is to support wrapping a "raw" or "pure" shell script which needs to be executed with some environment variables set, and some arguments, which may not be known until execution time. This class is not meant to be instantiated into tasks by users, but used with the factory function `get_raw_shell_task()`. An instance of this class will be returned with either user-specified or default template. The template itself will export the desired environment variables, and subsequently execute the desired "raw" script with the specified arguments. .. note:: This means that within your workflow, you can dynamically control the env variables, arguments, and even the actual script you want to run. .. note:: The downside is that a dynamic workflow will be required. The "raw" script passed in at execution time must be at the specified location. These args are forwarded directly to the parent `ShellTask` constructor as behavior does not diverge """ super().__init__( name=name, debug=debug, script=script, script_file=script_file, task_config=task_config, inputs=inputs, output_locs=output_locs, **kwargs, ) def make_export_string_from_env_dict(self, d: typing.Dict[str, str]) -> str: """ Utility function to convert a dictionary of desired environment variable key: value pairs into a string of ``` export k1=v1 export k2=v2 ... ``` """ items = [] for k, v in d.items(): items.append(f"export {k}={v}") return "\n".join(items) def execute(self, **kwargs) -> typing.Any: """ Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem """ logger.info(f"Running shell script as type {self.task_type}") if self.script_file: with open(self.script_file) as f: self._script = f.read() outputs: typing.Dict[str, str] = {} if self._output_locs: for v in self._output_locs: outputs[v.var] = self._interpolizer.interpolate(v.location, inputs=kwargs) if os.name == "nt": self._script = self._script.lstrip().rstrip().replace("\n", "&&") if "env" in kwargs and isinstance(kwargs["env"], dict): kwargs["export_env"] = self.make_export_string_from_env_dict(kwargs["env"]) gen_script = self._interpolizer.interpolate(self._script, inputs=kwargs, outputs=outputs) if self._debug: print("\n==============================================\n") print(gen_script) print("\n==============================================\n") try: subprocess.check_call(gen_script, shell=True) except subprocess.CalledProcessError as e: files = os.listdir(".") fstr = "\n-".join(files) logger.error( f"Failed to Execute Script, return-code {e.returncode} \n" f"StdErr: {e.stderr}\n" f"StdOut: {e.stdout}\n" f" Current directory contents: .\n-{fstr}" ) raise final_outputs = [] for v in self._output_locs: if issubclass(v.var_type, FlyteFile): final_outputs.append(FlyteFile(outputs[v.var])) if issubclass(v.var_type, FlyteDirectory): final_outputs.append(FlyteDirectory(outputs[v.var])) if len(final_outputs) == 1: return final_outputs[0] if len(final_outputs) > 1: return tuple(final_outputs) return None # The raw_shell_task is an instance of RawShellTask and wraps a 'pure' shell script # This utility function allows for the specification of env variables, arguments, and the actual script within the # workflow definition rather than at `RawShellTask` instantiation def get_raw_shell_task(name: str) -> RawShellTask: return RawShellTask( name=name, debug=True, inputs=flytekit.kwtypes(env=typing.Dict[str, str], script_args=str, script_file=str), output_locs=[ OutputLocation( var="out", var_type=FlyteDirectory, location="{ctx.working_directory}", ) ], script=""" #!/bin/bash set -uex cd {ctx.working_directory} {inputs.export_env} bash {inputs.script_file} {inputs.script_args} """, )