Source code for flytekitplugins.sqlalchemy.task

import typing
from dataclasses import dataclass

from flytekit import current_context, kwtypes, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.configuration.default_images import DefaultImages, PythonVersion
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask
from flytekit.core.shim_task import ShimTaskExecutor
from flytekit.loggers import logger
from flytekit.models import task as task_models
from flytekit.models.security import Secret
from flytekit.types.schema import FlyteSchema

pd = lazy_module("pandas")
pandas_io_sql = lazy_module("pandas.io.sql")
sqlalchemy = lazy_module("sqlalchemy")


[docs] class SQLAlchemyDefaultImages(DefaultImages): """Default images for the sqlalchemy flytekit plugin.""" _DEFAULT_IMAGE_PREFIXES = { PythonVersion.PYTHON_3_8: "cr.flyte.org/flyteorg/flytekit:py3.8-sqlalchemy-", PythonVersion.PYTHON_3_9: "cr.flyte.org/flyteorg/flytekit:py3.9-sqlalchemy-", PythonVersion.PYTHON_3_10: "cr.flyte.org/flyteorg/flytekit:py3.10-sqlalchemy-", PythonVersion.PYTHON_3_11: "cr.flyte.org/flyteorg/flytekit:py3.11-sqlalchemy-", PythonVersion.PYTHON_3_12: "cr.flyte.org/flyteorg/flytekit:py3.12-sqlalchemy-", }
[docs] @dataclass class SQLAlchemyConfig(object): """ Use this configuration to configure task. String should be standard sqlalchemy connector format (https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls). Database can be found: - within the container - or from a publicly accessible source Args: uri: default sqlalchemy connector connect_args: sqlalchemy kwarg overrides -- ex: host secret_connect_args: flyte secrets loaded into sqlalchemy connect args -- ex: {"password": flytekit.models.security.Secret(name=SECRET_NAME, group=SECRET_GROUP)} """ uri: str connect_args: typing.Optional[typing.Dict[str, typing.Any]] = None secret_connect_args: typing.Optional[typing.Dict[str, Secret]] = None @staticmethod def _secret_to_dict(secret: Secret) -> typing.Dict[str, typing.Optional[str]]: return { "group": secret.group, "key": secret.key, "group_version": secret.group_version, "mount_requirement": secret.mount_requirement.value, }
[docs] def secret_connect_args_to_dicts(self) -> typing.Optional[typing.Dict[str, typing.Dict[str, typing.Optional[str]]]]: if self.secret_connect_args is None: return None secret_connect_args_dicts = {} for key, secret in self.secret_connect_args.items(): secret_connect_args_dicts[key] = self._secret_to_dict(secret) return secret_connect_args_dicts
[docs] class SQLAlchemyTask(PythonCustomizedContainerTask[SQLAlchemyConfig], SQLTask[SQLAlchemyConfig]): """ Makes it possible to run client side SQLAlchemy queries that optionally return a FlyteSchema object """ # TODO: How should we use pre-built containers for running portable tasks like this? Should this always be a referenced task type? _SQLALCHEMY_TASK_TYPE = "sqlalchemy" def __init__( self, name: str, query_template: str, task_config: SQLAlchemyConfig, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = FlyteSchema, container_image: str = SQLAlchemyDefaultImages.default_image(), **kwargs, ): if output_schema_type: outputs = kwtypes(results=output_schema_type) else: outputs = None super().__init__( name=name, task_config=task_config, executor_type=SQLAlchemyTaskExecutor, task_type=self._SQLALCHEMY_TASK_TYPE, query_template=query_template, container_image=container_image, inputs=inputs, outputs=outputs, **kwargs, ) @property def output_columns(self) -> typing.Optional[typing.List[str]]: c = self.python_interface.outputs["results"].column_names() return c if c else None
[docs] def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing.Any]: return { "query_template": self.query_template, "uri": self.task_config.uri, "connect_args": self.task_config.connect_args or {}, "secret_connect_args": self.task_config.secret_connect_args_to_dicts(), }
class SQLAlchemyTaskExecutor(ShimTaskExecutor[SQLAlchemyTask]): def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.Any: if tt.custom["secret_connect_args"] is not None: for key, secret_dict in tt.custom["secret_connect_args"].items(): value = current_context().secrets.get(group=secret_dict["group"], key=secret_dict["key"]) tt.custom["connect_args"][key] = value engine = sqlalchemy.create_engine(tt.custom["uri"], connect_args=tt.custom["connect_args"], echo=False) logger.info(f"Connecting to db {tt.custom['uri']}") interpolated_query = SQLAlchemyTask.interpolate_query(tt.custom["query_template"], **kwargs) logger.info(f"Interpolated query {interpolated_query}") with engine.begin() as connection: df = None if tt.interface.outputs: df = pd.read_sql_query(sqlalchemy.text(interpolated_query), connection) else: pandas_io_sql.pandasSQL_builder(connection).execute(sqlalchemy.text(interpolated_query)) return df