Source code for flytekitplugins.vaex.sd_transformers

import os
import typing

from flytekit import FlyteContext, StructuredDatasetType, lazy_module
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.types.structured.structured_dataset import (
    PARQUET,
    StructuredDataset,
    StructuredDatasetDecoder,
    StructuredDatasetEncoder,
    StructuredDatasetTransformerEngine,
)

pd = lazy_module("pandas")
vaex = lazy_module("vaex")


[docs] class VaexDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(vaex.dataframe.DataFrameLocal, None, PARQUET)
[docs] def encode( self, ctx: FlyteContext, structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: df = typing.cast(vaex.dataframe.DataFrameLocal, structured_dataset.dataframe) local_dir = ctx.file_access.get_random_local_directory() local_path = os.path.join(local_dir, f"{0:05}") df.export_parquet(local_path) path = ctx.file_access.put_raw_data(local_dir) return literals.StructuredDataset( uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type=structured_dataset_type), )
[docs] class ParquetToVaexDataFrameDecodingHandler(StructuredDatasetDecoder): def __init__(self): super().__init__(vaex.dataframe.DataFrameLocal, None, PARQUET)
[docs] def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> vaex.dataframe.DataFrameLocal: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) path = f"{local_dir}/00000" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return vaex.open(path)[columns] return vaex.open(path)
class VaexDataFrameRenderer: """ Render a Vaex dataframe schema as an HTML table. """ def to_html(self, df: vaex.dataframe.DataFrameLocal) -> str: assert isinstance(df, vaex.dataframe.DataFrameLocal) describe_df = df.describe() return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) StructuredDatasetTransformerEngine.register(VaexDataFrameToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToVaexDataFrameDecodingHandler()) StructuredDatasetTransformerEngine.register_renderer(vaex.dataframe.DataFrameLocal, VaexDataFrameRenderer())