Note
Click here to download the full example code
SQLAlchemy#
SQLAlchemy is the Python SQL toolkit and Object Relational Mapper that gives application developers the full power and flexibility of SQL.
Flyte provides an easy-to-use interface to utilize SQLAlchemy to connect to various SQL Databases. In this example, we use a Postgres DB to understand how SQLAlchemy can be used within Flyte.
This task will run with a pre-built container, and thus users needn’t build one. You can simply implement the task, then register and execute it immediately.
First, install the Flyte Sqlalchemy plugin:
pip install flytekitplugins-sqlalchemy
Let’s first import the libraries.
First we define a SQLALchemyTask
, which returns the first n
records from the rna
table of the
RNA central database . Since this database is public, we can
hard-code the database URI, including the user and password in a string.
Note
The output of SQLAlchemyTask is a FlyteSchema
by default.
Caution
Never store passwords for proprietary or sensitive databases! If you need to store and access secrets in a task, Flyte provides a convenient API. See Using Secrets in a Task for more details.
DATABASE_URI = (
"postgresql://reader:NWDMCE5xdipIjRrp@hh-pgsql-public.ebi.ac.uk:5432/pfmegrnargs"
)
# Here we define the schema of the expected output of the query, which we then re-use in the `get_mean_length` task.
DataSchema = FlyteSchema[kwtypes(sequence_length=int)]
sql_task = SQLAlchemyTask(
"rna",
query_template="""
select len as sequence_length from rna
where len >= {{ .inputs.min_length }}
and len <= {{ .inputs.max_length }}
limit {{ .inputs.limit }}
""",
inputs=kwtypes(min_length=int, max_length=int, limit=int),
output_schema_type=DataSchema,
task_config=SQLAlchemyConfig(uri=DATABASE_URI),
)
Next, we define a task that computes the mean length of sequences in the subset of RNA sequences that our query returned.
@task
def get_mean_length(data: DataSchema) -> float:
dataframe = data.open().all()
return dataframe["sequence_length"].mean().item()
Finally, we put everything together into a workflow:
@workflow
def my_wf(min_length: int, max_length: int, limit: int) -> float:
return get_mean_length(
data=sql_task(min_length=min_length, max_length=max_length, limit=limit)
)
if __name__ == "__main__":
print(f"Running {__file__} main...")
print(my_wf(min_length=50, max_length=200, limit=5))
Total running time of the script: ( 0 minutes 0.000 seconds)