Source code for flytekit.core.testing
import typing
from contextlib import contextmanager
from typing import Union
from unittest.mock import MagicMock
from flytekit.core.base_task import PythonTask
from flytekit.core.reference_entity import ReferenceEntity
from flytekit.core.workflow import WorkflowBase
from flytekit.loggers import logger
[docs]
@contextmanager
def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]:
"""
Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native
interface associated with it.
The returned object is a MagicMock and allows to perform all such methods. This MagicMock, mocks the execute method
on the PythonTask
Usage:
.. code-block:: python
@task
def t1(i: int) -> int:
pass
with task_mock(t1) as m:
m.side_effect = lambda x: x
t1(10)
# The mock is valid only within this context
"""
if not isinstance(t, PythonTask) and not isinstance(t, WorkflowBase) and not isinstance(t, ReferenceEntity):
raise Exception("Can only be used for tasks")
m = MagicMock()
def _log(*args, **kwargs):
logger.warning(f"Invoking mock method for task: '{t.name}'")
return m(*args, **kwargs)
_captured_fn = t.execute
t.execute = _log # type: ignore
yield m
t.execute = _captured_fn # type: ignore
[docs]
def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]):
"""
This is a decorator used for testing.
"""
if (
not isinstance(target, PythonTask)
and not isinstance(target, WorkflowBase)
and not isinstance(target, ReferenceEntity)
):
raise Exception("Can only use mocks on tasks/workflows declared in Python.")
logger.info(
"When using this patch function on Flyte entities, please be aware weird issues may arise if also"
"using mock.patch on internal Flyte classes like PythonFunctionWorkflow. See"
"https://github.com/flyteorg/flyte/issues/854 for more information"
)
def wrapper(test_fn):
def new_test(*args, **kwargs):
logger.warning(f"Invoking mock method for target: '{target.name}'")
m = MagicMock()
saved = target.execute
target.execute = m
results = test_fn(m, *args, **kwargs)
target.execute = saved
return results
return new_test
return wrapper