SQLAlchemy#
SQLAlchemy is the Python SQL toolkit and Object Relational Mapper that gives application developers the full power and flexibility of SQL.
Flyte provides an easy-to-use interface to utilize SQLAlchemy to connect to various SQL Databases. In this example, we use a Postgres DB to understand how SQLAlchemy can be used within Flyte.
This task will run with a pre-built container, and thus users needn’t build one. You can simply implement the task, then register and execute it immediately.
First, install the Flyte Sqlalchemy plugin:
pip install flytekitplugins-sqlalchemy
Let’s first import the libraries.
from flytekit import kwtypes, task, workflow
from flytekit.types.schema import FlyteSchema
from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask
First we define a SQLALchemyTask
, which returns the first n
records from the rna
table of the
RNA central database . Since this database is public, we can
hard-code the database URI, including the user and password in a string.
Note
The output of SQLAlchemyTask is a FlyteSchema
by default.
Caution
Never store passwords for proprietary or sensitive databases! If you need to store and access secrets in a task, Flyte provides a convenient API. See Secrets for more details.
DATABASE_URI = "postgresql://reader:NWDMCE5xdipIjRrp@hh-pgsql-public.ebi.ac.uk:5432/pfmegrnargs"
# Here we define the schema of the expected output of the query, which we then re-use in the `get_mean_length` task.
DataSchema = FlyteSchema[kwtypes(sequence_length=int)]
sql_task = SQLAlchemyTask(
"rna",
query_template="""
select len as sequence_length from rna
where len >= {{ .inputs.min_length }}
and len <= {{ .inputs.max_length }}
limit {{ .inputs.limit }}
""",
inputs=kwtypes(min_length=int, max_length=int, limit=int),
output_schema_type=DataSchema,
task_config=SQLAlchemyConfig(uri=DATABASE_URI),
)
Next, we define a task that computes the mean length of sequences in the subset of RNA sequences that our query
returned.
Note for those running this in your live Flyte backend via pyflyte run
. run
will use the default flytekit
image if one is not specified. The default flytekit image does not have the sqlalchemy flytekit plugin installed.
To correctly kick off an execution of this task, you’ll need to use the following command.
pyflyte --config ~/.flyte/your-config.yaml run --destination-dir /app --remote --image ghcr.io/flyteorg/flytekit:py3.8-sqlalchemy-latest integrations/flytekit_plugins/sql/sql_alchemy.py my_wf --min_length 3 --max_length 100 --limit 50
Note also we added the destination-dir
argument, since by default pyflyte run
copies code into /root
which
is not what that image’s workdir is set to.
@task
def get_mean_length(data: DataSchema) -> float:
dataframe = data.open().all()
return dataframe["sequence_length"].mean().item()
Finally, we put everything together into a workflow:
@workflow
def my_wf(min_length: int, max_length: int, limit: int) -> float:
return get_mean_length(data=sql_task(min_length=min_length, max_length=max_length, limit=limit))
if __name__ == "__main__":
print(f"Running {__file__} main...")
print(my_wf(min_length=50, max_length=200, limit=5))