Source code for flytekit.extras.tensorflow.record
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union
import tensorflow as tf
from dataclasses_json import DataClassJsonMixin
from tensorflow.python.data.ops.readers import TFRecordDatasetV2
from typing_extensions import Annotated, get_args, get_origin
from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.directory import TFRecordsDirectory
from flytekit.types.file import TFRecordFile
@dataclass
class TFRecordDatasetConfig(DataClassJsonMixin):
"""
TFRecordDatasetConfig can be used while creating tf.data.TFRecordDataset comprising
record of one or more TFRecord files.
Args:
compression_type: A scalar evaluating to one of "" (no compression), "ZLIB", or "GZIP".
buffer_size: The number of bytes in the read buffer. If None, a sensible default for both local and remote file systems is used.
num_parallel_reads: The number of files to read in parallel. If greater than one, the record of files read in parallel are outputted in an interleaved order.
name: A name for the operation.
"""
compression_type: Optional[str] = None
buffer_size: Optional[int] = None
num_parallel_reads: Optional[int] = None
name: Optional[str] = None
def extract_metadata_and_uri(
lv: Literal, t: Type[Union[TFRecordFile, TFRecordsDirectory]]
) -> Tuple[Union[TFRecordFile, TFRecordsDirectory], TFRecordDatasetConfig]:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {t}")
metadata = TFRecordDatasetConfig()
if get_origin(t) is Annotated:
_, metadata = get_args(t)
if isinstance(metadata, TFRecordDatasetConfig):
return uri, metadata
else:
raise TypeTransformerFailedError(f"{t}'s metadata needs to be of type TFRecordDatasetConfig")
return uri, metadata
[docs]
class TensorFlowRecordFileTransformer(TypeTransformer[TFRecordFile]):
"""
TypeTransformer that supports serialising and deserialising to and from TFRecord file.
https://www.tensorflow.org/tutorials/load_data/tfrecord
"""
TENSORFLOW_FORMAT = "TensorFlowRecord"
def __init__(self):
super().__init__(name="TensorFlow Record File", t=TFRecordFile)
[docs]
def get_literal_type(self, t: Type[TFRecordFile]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)
[docs]
def to_literal(
self, ctx: FlyteContext, python_val: TFRecordFile, python_type: Type[TFRecordFile], expected: LiteralType
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)
local_dir = ctx.file_access.get_random_local_directory()
local_path = os.path.join(local_dir, "0000.tfrecord")
with tf.io.TFRecordWriter(local_path) as writer:
writer.write(python_val.SerializeToString())
remote_path = ctx.file_access.put_raw_data(local_path)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
[docs]
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordFile]
) -> TFRecordDatasetV2:
uri, metadata = extract_metadata_and_uri(lv, expected_python_type)
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)
filenames = [local_path]
return tf.data.TFRecordDataset(
filenames=filenames,
compression_type=metadata.compression_type,
buffer_size=metadata.buffer_size,
num_parallel_reads=metadata.num_parallel_reads,
name=metadata.name,
)
[docs]
def guess_python_type(self, literal_type: LiteralType) -> Type[TFRecordFile]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.TENSORFLOW_FORMAT
):
return TFRecordFile
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
[docs]
class TensorFlowRecordsDirTransformer(TypeTransformer[TFRecordsDirectory]):
"""
TypeTransformer that supports serialising and deserialising to and from TFRecord directory.
https://www.tensorflow.org/tutorials/load_data/tfrecord
"""
TENSORFLOW_FORMAT = "TensorFlowRecord"
def __init__(self):
super().__init__(name="TensorFlow Record Directory", t=TFRecordsDirectory)
[docs]
def get_literal_type(self, t: Type[TFRecordsDirectory]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
)
)
[docs]
def to_literal(
self,
ctx: FlyteContext,
python_val: TFRecordsDirectory,
python_type: Type[TFRecordsDirectory],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
)
)
local_dir = ctx.file_access.get_random_local_directory()
for i, val in enumerate(python_val):
local_path = f"{local_dir}/part_{i}.tfrecord"
with tf.io.TFRecordWriter(local_path) as writer:
writer.write(val.SerializeToString())
remote_path = ctx.file_access.put_raw_data(local_dir)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
[docs]
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory]
) -> TFRecordDatasetV2:
uri, metadata = extract_metadata_and_uri(lv, expected_python_type)
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(uri, local_dir, is_multipart=True)
files = os.scandir(local_dir)
filenames = [os.path.join(local_dir, f.name) for f in files]
return tf.data.TFRecordDataset(
filenames=filenames,
compression_type=metadata.compression_type,
buffer_size=metadata.buffer_size,
num_parallel_reads=metadata.num_parallel_reads,
name=metadata.name,
)
[docs]
def guess_python_type(self, literal_type: LiteralType) -> Type[TFRecordsDirectory]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART
and literal_type.blob.format == self.TENSORFLOW_FORMAT
):
return TFRecordsDirectory
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
TypeEngine.register(TensorFlowRecordsDirTransformer())
TypeEngine.register(TensorFlowRecordFileTransformer())