Querying data in Snowflake#
This example shows how to use the SnowflakeTask
to execute a query in Snowflake.
import pandas as pd
from flytekit import ImageSpec, Secret, StructuredDataset, kwtypes, task, workflow
from flytekitplugins.snowflake import SnowflakeConfig, SnowflakeTask
image = ImageSpec(
packages=["flytekitplugins-snowflake", "pandas", "pyarrow"],
registry="ghcr.io/flyteorg",
)
"""
Define a Snowflake task to insert data into the FLYTEAGENT.PUBLIC.TEST table.
The `inputs` parameter specifies the types of the inputs using `kwtypes`.
The `query_template` uses Python string interpolation to insert these inputs into the SQL query.
The placeholders `%(id)s`, `%(name)s`, and `%(age)s` will be replaced by the actual values
provided when the task is executed.
"""
"""
You can get the SnowflakeConfig's metadata from the Snowflake console by executing the following query:
SELECT
CURRENT_USER() AS "User",
CONCAT(CURRENT_ORGANIZATION_NAME(), '-', CURRENT_ACCOUNT_NAME()) AS "Account",
CURRENT_DATABASE() AS "Database",
CURRENT_SCHEMA() AS "Schema",
CURRENT_WAREHOUSE() AS "Warehouse";
"""
snowflake_task_insert_query = SnowflakeTask(
name="insert-query",
inputs=kwtypes(id=int, name=str, age=int),
task_config=SnowflakeConfig(
user="FLYTE",
account="FLYTE_SNOFLAKE_ACCOUNT",
database="FLYTEAGENT",
schema="PUBLIC",
warehouse="COMPUTE_WH",
),
query_template="""
INSERT INTO FLYTEAGENT.PUBLIC.TEST (ID, NAME, AGE)
VALUES (%(id)s, %(name)s, %(age)s);
""",
)
snowflake_task_templatized_query = SnowflakeTask(
name="select-query",
output_schema_type=StructuredDataset,
task_config=SnowflakeConfig(
user="FLYTE",
account="FLYTE_SNOFLAKE_ACCOUNT",
database="FLYTEAGENT",
schema="PUBLIC",
warehouse="COMPUTE_WH",
),
query_template="SELECT * FROM FLYTEAGENT.PUBLIC.TEST ORDER BY ID DESC LIMIT 3;",
)
@task(
container_image=image,
secret_requests=[
Secret(
group="private-key",
key="snowflake",
)
],
)
def print_head(input_sd: StructuredDataset) -> pd.DataFrame:
# Download the DataFrame from the Snowflake table via StructuredDataset
# We don't have to provide the uri here because the input_sd already has the uri
df = input_sd.open(pd.DataFrame).all()
print(df)
return df
@task(
container_image=image,
secret_requests=[
Secret(
group="private-key",
key="snowflake",
)
],
)
def write_table() -> StructuredDataset:
df = pd.DataFrame({"ID": [1, 2, 3], "NAME": ["flyte", "is", "amazing"], "AGE": [30, 30, 30]})
print(df)
# Upload the DataFrame to the Snowflake table via StructuredDataset
user = ("FLYTE",)
account = ("FLYTE_SNOFLAKE_ACCOUNT",)
database = ("FLYTEAGENT",)
schema = ("PUBLIC",)
warehouse = ("COMPUTE_WH",)
table = "TEST"
uri = f"snowflake://{user}:{account}/{warehouse}/{database}/{schema}/{table}"
return StructuredDataset(dataframe=df, uri=uri)
@workflow
def wf() -> StructuredDataset:
sd = snowflake_task_templatized_query()
t1 = print_head(input_sd=sd)
insert_query = snowflake_task_insert_query(id=1, name="Flyte", age=30)
sd2 = snowflake_task_templatized_query()
wt = write_table()
sd >> t1 >> insert_query >> wt >> sd2
return print_head(input_sd=sd2)
if __name__ == "__main__":
wf()