import contextlib
import os
import shutil
import sqlite3
import tempfile
import typing
from dataclasses import dataclass
from flytekit import FlyteContext, kwtypes, lazy_module
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask
from flytekit.core.shim_task import ShimTaskExecutor
from flytekit.models import task as task_models
if typing.TYPE_CHECKING:
import pandas as pd
else:
pd = lazy_module("pandas")
def unarchive_file(local_path: str, to_dir: str):
"""
Unarchive given archive and returns the unarchived file name. It is expected that only one file is unarchived.
More than one file or 0 files will result in a ``RuntimeError``
"""
archive_dir = os.path.join(to_dir, "_arch")
shutil.unpack_archive(local_path, archive_dir)
# file gets uncompressed into to_dir/_arch/*.*
files = os.listdir(archive_dir)
if not files or len(files) == 0 or len(files) > 1:
raise RuntimeError(f"Uncompressed archive should contain only one file - found {files}!")
return os.path.join(archive_dir, files[0])
[docs]
@dataclass
class SQLite3Config(object):
"""
Use this configuration to configure if sqlite3 files that should be loaded by the task. The file itself is
considered as a database and hence is treated like a configuration
The path to a static sqlite3 compatible database file can be
- within the container
- or from a publicly downloadable source
Args:
uri: default FlyteFile that will be downloaded on execute
compressed: Boolean that indicates if the given file is a compressed archive. Supported file types are
[zip, tar, gztar, bztar, xztar]
"""
uri: str
compressed: bool = False
[docs]
class SQLite3Task(PythonCustomizedContainerTask[SQLite3Config], SQLTask[SQLite3Config]):
"""
Run client side SQLite3 queries that optionally return a FlyteSchema object.
.. note::
This is a pre-built container task. That is, your user container will not be used at task execution time.
Instead the image defined in this task definition will be used instead.
.. literalinclude:: ../../../tests/flytekit/unit/extras/sqlite3/test_task.py
:start-after: # sqlite3_start
:end-before: # sqlite3_end
:language: python
:dedent: 4
See the :ref:`integrations guide <cookbook:integrations_sql_sqlite3>` for additional usage examples and
the base class :py:class:`flytekit.extend.PythonCustomizedContainerTask` as well.
"""
_SQLITE_TASK_TYPE = "sqlite"
def __init__(
self,
name: str,
query_template: str,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
task_config: typing.Optional[SQLite3Config] = None,
output_schema_type: typing.Optional[typing.Type["FlyteSchema"]] = None, # type: ignore
container_image: typing.Optional[str] = None,
**kwargs,
):
if task_config is None or task_config.uri is None:
raise ValueError("SQLite DB uri is required.")
from flytekit.types.schema import FlyteSchema
outputs = kwtypes(results=output_schema_type if output_schema_type else FlyteSchema)
super().__init__(
name=name,
task_config=task_config,
# if you use your own image, keep in mind to specify the container image here
container_image=container_image or DefaultImages.default_image(),
executor_type=SQLite3TaskExecutor,
task_type=self._SQLITE_TASK_TYPE,
# Sanitize query by removing the newlines at the end of the query. Keep in mind
# that the query can be a multiline string.
query_template=query_template,
inputs=inputs,
outputs=outputs,
**kwargs,
)
@property
def output_columns(self) -> typing.Optional[typing.List[str]]:
c = self.python_interface.outputs["results"].column_names()
return c if c else None
[docs]
def get_custom(self, settings: SerializationSettings) -> typing.Dict[str, typing.Any]:
return {
"query_template": self.query_template,
"uri": self.task_config.uri,
"compressed": self.task_config.compressed,
}
class SQLite3TaskExecutor(ShimTaskExecutor[SQLite3Task]):
def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.Any:
with tempfile.TemporaryDirectory() as temp_dir:
ctx = FlyteContext.current_context()
file_ext = os.path.basename(tt.custom["uri"])
local_path = os.path.join(temp_dir, file_ext)
ctx.file_access.get_data(tt.custom["uri"], local_path)
if tt.custom["compressed"]:
local_path = unarchive_file(local_path, temp_dir)
print(f"Connecting to db {local_path}")
interpolated_query = SQLite3Task.interpolate_query(tt.custom["query_template"], **kwargs)
print(f"Interpolated query {interpolated_query}")
with contextlib.closing(sqlite3.connect(local_path)) as con:
df = pd.read_sql_query(interpolated_query, con)
return df