Source code for flytekitplugins.duckdb.task

import json
from typing import Dict, List, NamedTuple, Optional, Union

import duckdb
import pandas as pd
import pyarrow as pa

from flytekit import PythonInstanceTask
from flytekit.extend import Interface
from flytekit.types.structured.structured_dataset import StructuredDataset

class QueryOutput(NamedTuple):
    counter: int = -1
    output: Optional[str] = None

[docs]class DuckDBQuery(PythonInstanceTask): _TASK_TYPE = "duckdb" def __init__( self, name: str, query: Union[str, List[str]], inputs: Optional[Dict[str, Union[StructuredDataset, list]]] = None, **kwargs, ): """ This method initializes the DuckDBQuery. Args: name: Name of the task query: DuckDB query to execute inputs: The query parameters to be used while executing the query """ self._query = query # create an in-memory database that's non-persistent self._con = duckdb.connect(":memory:") outputs = {"result": StructuredDataset} super(DuckDBQuery, self).__init__( name=name, task_type=self._TASK_TYPE, task_config=None, interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) def _execute_query(self, params: list, query: str, counter: int, multiple_params: bool): """ This method runs the DuckDBQuery. Args: params: Query parameters to use while executing the query query: DuckDB query to execute counter: Use counter to map user-given arguments to the query parameters multiple_params: Set flag to indicate the presence of params for multiple queries """ if any(x in query for x in ("$", "?")): if multiple_params: counter += 1 if not counter < len(params): raise ValueError("Parameter doesn't exist.") if "insert" in query.lower(): # run executemany disregarding the number of entries to store for an insert query yield QueryOutput(output=self._con.executemany(query, params[counter]), counter=counter) else: yield QueryOutput(output=self._con.execute(query, params[counter]), counter=counter) else: if params: yield QueryOutput(output=self._con.execute(query, params), counter=counter) else: raise ValueError("Parameter not specified.") else: yield QueryOutput(output=self._con.execute(query), counter=counter)
[docs] def execute(self, **kwargs) -> StructuredDataset: # TODO: Enable iterative download after adding the functionality to structured dataset code. params = None for key in self.python_interface.inputs.keys(): val = kwargs.get(key) if isinstance(val, StructuredDataset): # register structured dataset self._con.register(key, elif isinstance(val, (pd.DataFrame, pa.Table)): # register pandas dataframe/arrow table self._con.register(key, val) elif isinstance(val, list): # copy val into params params = val elif isinstance(val, str): # load into a list params = json.loads(val) else: raise ValueError(f"Expected inputs of type StructuredDataset, str or list, received {type(val)}") final_query = self._query query_output = QueryOutput() # set flag to indicate the presence of params for multiple queries multiple_params = isinstance(params[0], list) if params else False if isinstance(self._query, list) and len(self._query) > 1: # loop until the penultimate query for query in self._query[:-1]: query_output = next( self._execute_query( params=params, query=query, counter=query_output.counter, multiple_params=multiple_params ) ) final_query = self._query[-1] # fetch query output from the last query # expecting a SELECT query dataframe = next( self._execute_query( params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params ) ).output.arrow() return StructuredDataset(dataframe=dataframe)