Source code for flytekit.core.artifact

from __future__ import annotations

import datetime
import typing
from datetime import timedelta
from typing import Optional, Union

from flyteidl.core import artifact_id_pb2 as art_id
from flyteidl.core.artifact_id_pb2 import Granularity
from flyteidl.core.artifact_id_pb2 import Operator as Op
from google.protobuf.timestamp_pb2 import Timestamp

from flytekit.core.context_manager import FlyteContextManager, OutputMetadata, SerializableToString
from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING
from flytekit.loggers import logger

TIME_PARTITION_KWARG = "time_partition"
MAX_PARTITIONS = 10


class InputsBase(object):
    """
    A class to provide better partition semantics
    Used for invoking an Artifact to bind partition keys to input values.
    If there's a good reason to use a metaclass in the future we can, but a simple instance suffices for now
    """

    def __getattr__(self, name: str) -> art_id.InputBindingData:
        return art_id.InputBindingData(var=name)


Inputs = InputsBase()

# This type represents a user output
O = typing.TypeVar("O")


class ArtifactIDSpecification(object):
    """
    This is a special object that helps specify how Artifacts are to be created. See the comment in the
    call function of the main Artifact class. Also see the handling code in transform_variable_map for more
    information. There's a limited set of information that we ultimately need in a TypedInterface, so it
    doesn't make sense to carry the full Artifact object around. This object should be sufficient, despite
    having a pointer to the main artifact.
    """

    def __init__(self, a: Artifact):
        self.artifact = a
        self.partitions: Optional[Partitions] = None
        self.time_partition: Optional[TimePartition] = None

    def __call__(self, *args, **kwargs):
        return self.bind_partitions(*args, **kwargs)

    def bind_partitions(self, *args, **kwargs) -> ArtifactIDSpecification:
        # See the parallel function in the main Artifact class for more information.
        if len(args) > 0:
            raise ValueError("Cannot set partition values by position")

        if TIME_PARTITION_KWARG in kwargs:
            if not self.artifact.time_partitioned:
                raise ValueError("Cannot bind time partition to non-time partitioned artifact")
            p = kwargs[TIME_PARTITION_KWARG]
            if isinstance(p, datetime.datetime):
                t = Timestamp()
                t.FromDatetime(p)
                self.time_partition = TimePartition(
                    value=art_id.LabelValue(time_value=t),
                    granularity=self.artifact.time_partition_granularity or Granularity.DAY,
                )
            elif isinstance(p, art_id.InputBindingData):
                self.time_partition = TimePartition(
                    value=art_id.LabelValue(input_binding=p),
                    granularity=self.artifact.time_partition_granularity or Granularity.DAY,
                )
            else:
                raise ValueError(f"Time partition needs to be input binding data or static string, not {p}")
            # Given the context, shouldn't need to set further reference_artifacts.

            del kwargs[TIME_PARTITION_KWARG]
        else:
            # If user has not set time partition,
            if self.artifact.time_partitioned and self.time_partition is None:
                logger.debug(f"Time partition not bound for {self.artifact.name}, setting to dynamic binding.")
                self.time_partition = TimePartition(value=DYNAMIC_INPUT_BINDING)

        if self.artifact.partition_keys and len(self.artifact.partition_keys) > 0:
            p = self.partitions or Partitions(None)
            # k is the partition key, v should be static, or an input to the task or workflow
            for k, v in kwargs.items():
                if not self.artifact.partition_keys or k not in self.artifact.partition_keys:
                    raise ValueError(f"Partition key {k} not found in {self.artifact.partition_keys}")
                if isinstance(v, art_id.InputBindingData):
                    p.partitions[k] = Partition(art_id.LabelValue(input_binding=v), name=k)
                elif isinstance(v, str):
                    p.partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k)
                else:
                    raise ValueError(f"Partition key {k} needs to be input binding data or static string, not {v}")

            for k in self.artifact.partition_keys:
                if k not in p.partitions:
                    logger.debug(f"Partition {k} not bound for {self.artifact.name}, setting to dynamic binding.")
                    p.partitions[k] = Partition(value=DYNAMIC_INPUT_BINDING, name=k)
            # Given the context, shouldn't need to set further reference_artifacts.
            self.partitions = p

        return self

    def to_partial_artifact_id(self) -> art_id.ArtifactID:
        # This function should only be called by transform_variable_map
        artifact_id = self.artifact.to_id_idl()
        # Use the partitions from this object, but replacement is not allowed by protobuf, so generate new object
        p = Serializer.partitions_to_idl(self.partitions)
        tp = None
        if self.artifact.time_partitioned:
            if not self.time_partition:
                raise ValueError(
                    f"Artifact {artifact_id.artifact_key} requires a time partition, but it hasn't been bound."
                )
            tp = self.time_partition.to_flyte_idl()

        if self.artifact.partition_keys:
            required = len(self.artifact.partition_keys)
            fulfilled = len(p.value) if p else 0
            if required != fulfilled:
                raise ValueError(
                    f"Artifact {artifact_id.artifact_key} requires {required} partitions, but only {fulfilled} are "
                    f"bound."
                )
        artifact_id = art_id.ArtifactID(
            artifact_key=artifact_id.artifact_key,
            partitions=p,
            time_partition=tp,
            version=artifact_id.version,  # this should almost never be set since setting it
            # hardcodes the query to one version
        )
        return artifact_id

    def __repr__(self):
        return f"ArtifactIDSpecification({self.artifact.name}, {self.artifact.partition_keys}, TP: {self.artifact.time_partitioned})"


class ArtifactQuery(object):
    def __init__(
        self,
        artifact: Artifact,
        name: str,
        project: Optional[str] = None,
        domain: Optional[str] = None,
        time_partition: Optional[TimePartition] = None,
        partitions: Optional[Partitions] = None,
        tag: Optional[str] = None,
    ):
        if not name:
            raise ValueError("Cannot create query without name")

        # So normally, if you just do MyData.query(partitions="region": Inputs.region), it will just
        # use the input value to fill in the partition. But if you do
        #   MyData.query(region=OtherArtifact.partitions.region)
        # then you now have a dependency on the other artifact. This list keeps track of all the other Artifacts you've
        # referenced.
        self.artifact = artifact
        bindings: typing.List[Artifact] = []
        if time_partition:
            if time_partition.reference_artifact and time_partition.reference_artifact is not artifact:
                bindings.append(time_partition.reference_artifact)
        if partitions and partitions.partitions:
            for k, v in partitions.partitions.items():
                if v.reference_artifact and v.reference_artifact is not artifact:
                    bindings.append(v.reference_artifact)

        self.name = name
        self.project = project
        self.domain = domain
        self.time_partition = time_partition
        self.partitions = partitions
        self.tag = tag
        if len(bindings) > 0:
            b = set(bindings)
            if len(b) > 1:
                raise ValueError(f"Multiple bindings found in query {self}")
            self.binding: Optional[Artifact] = bindings[0]
        else:
            self.binding = None

    @property
    def bound(self) -> bool:
        if self.artifact.time_partitioned and not (self.time_partition and self.time_partition.value):
            return False
        if self.artifact.partition_keys:
            artifact_partitions = set(self.artifact.partition_keys)
            query_partitions = set()
            if self.partitions and self.partitions.partitions:
                pp = self.partitions.partitions
                query_partitions = set([k for k in pp.keys() if pp[k].value])

            if artifact_partitions != query_partitions:
                logger.error(
                    f"Query on {self.artifact.name} missing query params {artifact_partitions - query_partitions}"
                )
                return False

        return True

    def to_flyte_idl(
        self,
        **kwargs,
    ) -> art_id.ArtifactQuery:
        return Serializer.artifact_query_to_idl(self, **kwargs)

    def get_time_partition_str(self, **kwargs) -> str:
        tp_str = ""
        if self.time_partition:
            tp = self.time_partition.value
            if tp.HasField("time_value"):
                tp = tp.time_value.ToDatetime()
                tp_str += f" Time partition: {tp}"
            elif tp.HasField("input_binding"):
                var = tp.input_binding.var
                if var not in kwargs:
                    raise ValueError(f"Time partition input binding {var} not found in kwargs")
                else:
                    tp_str += f" Time partition from input<{var}>,"
        return tp_str

    def get_partition_str(self, **kwargs) -> str:
        p_str = ""
        if self.partitions and self.partitions.partitions and len(self.partitions.partitions) > 0:
            p_str = " Partitions: "
            for k, v in self.partitions.partitions.items():
                if v.value and v.value.HasField("static_value"):
                    p_str += f"{k}={v.value.static_value}, "
                elif v.value and v.value.HasField("input_binding"):
                    var = v.value.input_binding.var
                    if var not in kwargs:
                        raise ValueError(f"Partition input binding {var} not found in kwargs")
                    else:
                        p_str += f"{k} from input<{var}>, "
        return p_str.rstrip("\n\r, ")

    def get_str(self, **kwargs):
        # Detailed string that explains query a bit more, used in running
        tp_str = self.get_time_partition_str(**kwargs)
        p_str = self.get_partition_str(**kwargs)

        return f"'{self.artifact.name}'...{tp_str}{p_str}"

    def __str__(self):
        # Default string used for printing --help
        return f"Artifact Query: on {self.artifact.name}"


class TimePartition(object):
    def __init__(
        self,
        value: Union[art_id.LabelValue, art_id.InputBindingData, str, datetime.datetime, None],
        op: Optional[Op] = None,
        other: Optional[timedelta] = None,
        granularity: Granularity = Granularity.DAY,
    ):
        if isinstance(value, str):
            raise ValueError(f"value to a time partition shouldn't be a str {value}")
        elif isinstance(value, datetime.datetime):
            t = Timestamp()
            t.FromDatetime(value)
            value = art_id.LabelValue(time_value=t)
        elif isinstance(value, art_id.InputBindingData):
            value = art_id.LabelValue(input_binding=value)
        # else should already be a LabelValue or None
        self.value: art_id.LabelValue = value
        self.op = op
        self.other = other
        self.reference_artifact: Optional[Artifact] = None
        self.granularity = granularity

    def __add__(self, other: timedelta) -> TimePartition:
        tp = TimePartition(self.value, op=Op.PLUS, other=other, granularity=self.granularity)
        tp.reference_artifact = self.reference_artifact
        return tp

    def __sub__(self, other: timedelta) -> TimePartition:
        tp = TimePartition(self.value, op=Op.MINUS, other=other, granularity=self.granularity)
        tp.reference_artifact = self.reference_artifact
        return tp

    def to_flyte_idl(self, **kwargs) -> Optional[art_id.TimePartition]:
        return Serializer.time_partition_to_idl(self, **kwargs)


class Partition(object):
    def __init__(self, value: Optional[art_id.LabelValue], name: str):
        self.name = name
        self.value = value
        self.reference_artifact: Optional[Artifact] = None


class Partitions(object):
    def __init__(self, partitions: Optional[typing.Mapping[str, Union[str, art_id.InputBindingData, Partition]]]):
        self._partitions = {}
        if partitions:
            for k, v in partitions.items():
                if isinstance(v, Partition):
                    self._partitions[k] = v
                elif isinstance(v, art_id.InputBindingData):
                    self._partitions[k] = Partition(art_id.LabelValue(input_binding=v), name=k)
                else:
                    self._partitions[k] = Partition(art_id.LabelValue(static_value=v), name=k)
        self.reference_artifact: Optional[Artifact] = None

    @property
    def partitions(self) -> Optional[typing.Dict[str, Partition]]:
        return self._partitions

    def set_reference_artifact(self, artifact: Artifact):
        self.reference_artifact = artifact
        if self.partitions:
            for p in self.partitions.values():
                p.reference_artifact = artifact

    def __getattr__(self, item):
        if self.partitions and item in self.partitions:
            return self.partitions[item]
        raise AttributeError(f"Partition {item} not found in {self}")

    def to_flyte_idl(self, **kwargs) -> Optional[art_id.Partitions]:
        return Serializer.partitions_to_idl(self, **kwargs)


[docs] class Artifact(object): """ An Artifact is effectively just a metadata layer on top of data that exists in Flyte. Most data of interest will be the output of tasks and workflows. The other category is user uploads. This Python class has limited purpose, as a way for users to specify that tasks/workflows create Artifacts and the manner (i.e. name, partitions) in which they are created. Control creation parameters at task/workflow execution time :: @task def t1() -> Annotated[nn.Module, Artifact(name="my.artifact.name")]: ... """ def __init__( self, project: Optional[str] = None, domain: Optional[str] = None, name: Optional[str] = None, version: Optional[str] = None, time_partitioned: bool = False, time_partition: Optional[TimePartition] = None, time_partition_granularity: Optional[Granularity] = None, partition_keys: Optional[typing.List[str]] = None, partitions: Optional[Union[Partitions, typing.Dict[str, str]]] = None, ): """ :param project: Should not be directly user provided, the project/domain will come from the project/domain of the execution that produced the output. These values will be filled in automatically when retrieving however. :param domain: See above. :param name: The name of the Artifact. This should be user provided. :param version: Version of the Artifact, typically the execution ID, plus some additional entropy. Not user provided. :param time_partitioned: Whether or not this Artifact will have a time partition. :param time_partition: If you want to manually pass in the full TimePartition object :param time_partition_granularity: If you don't want to manually pass in the full TimePartition object, but want to control the granularity when one is automatically created for you. Note that consistency checking is limited while in alpha. :param partition_keys: This is a list of keys that will be used to partition the Artifact. These are not the values. Values are set via a () on the artifact and will end up in the partition_values field. :param partitions: This is a dictionary of partition keys to values. """ if not name: raise ValueError("Can't instantiate an Artifact without a name.") self.project = project self.domain = domain self.name = name self.version = version self.time_partitioned = time_partitioned self._time_partition = None if time_partition: self._time_partition = time_partition self._time_partition.reference_artifact = self self.partition_keys = partition_keys self._partitions: Optional[Partitions] = None self.time_partition_granularity = time_partition_granularity if partitions: if isinstance(partitions, dict): self._partitions = Partitions(partitions) self.partition_keys = list(partitions.keys()) elif isinstance(partitions, Partitions): self._partitions = partitions if not partitions.partitions: raise ValueError("Partitions must be non-empty") self.partition_keys = list(partitions.partitions.keys()) else: raise ValueError(f"Partitions must be a dict or Partitions object, not {type(partitions)}") self._partitions.set_reference_artifact(self) if not partitions and partition_keys: # this should be the only time where we create Partition objects with None p = {k: Partition(None, name=k) for k in partition_keys} self._partitions = Partitions(p) self._partitions.set_reference_artifact(self) if self.partition_keys and len(self.partition_keys) > MAX_PARTITIONS: raise ValueError("There is a hard limit of 10 partition keys per artifact currently.") def __call__(self, *args, **kwargs) -> ArtifactIDSpecification: """ This __call__ should only ever happen in the context of a task or workflow's output, to be used in an Annotated[] call. The other styles will go through different call functions. """ # Can't guarantee the order in which time/non-time partitions are bound so create the helper # object and invoke the function there. partial_id = ArtifactIDSpecification(self) return partial_id.bind_partitions(*args, **kwargs) @property def partitions(self) -> Optional[Partitions]: return self._partitions @property def time_partition(self) -> TimePartition: if not self.time_partitioned: raise ValueError(f"Artifact {self.name} is not time partitioned") if not self._time_partition and self.time_partitioned: self._time_partition = TimePartition(None, granularity=self.time_partition_granularity or Granularity.DAY) self._time_partition.reference_artifact = self return self._time_partition def __str__(self): tp_str = f" time partition={self.time_partition}\n" if self.time_partitioned else "" return ( f"Artifact: project={self.project}, domain={self.domain}, name={self.name}, version={self.version}\n" f" name={self.name}\n" f" partitions={self.partitions}\n" f"{tp_str}" ) def __repr__(self): return self.__str__()
[docs] def create_from( self, o: O, card: Optional[SerializableToString] = None, *args: SerializableToString, **kwargs ) -> O: """ This function allows users to declare partition values dynamically from the body of a task. Note that you'll still need to annotate your task function output with the relevant Artifact object. Below, one of the partition values is bound to an input, and the other is set at runtime. Note that since tasks are not run at compile time, flytekit cannot check that you've bound all the partition values. It's up to you to ensure that you've done so. Pricing = Artifact(name="pricing", partition_keys=["region"]) EstError = Artifact(name="estimation_error", partition_keys=["dataset"], time_partitioned=True) @task def t1() -> Annotated[pd.DataFrame, Pricing], Annotated[float, EstError]: df = get_pricing_results() dt = get_time() return Pricing.create_from(df, region="dubai"), \ EstError.create_from(msq_error, dataset="train", time_partition=dt) You can mix and match with the input syntax as well. @task def my_task() -> Annotated[pd.DataFrame, RideCountData(region=Inputs.region)]: ... return RideCountData.create_from(df, time_partition=datetime.datetime.now()) """ omt = FlyteContextManager.current_context().output_metadata_tracker additional = [card] additional.extend(args) filtered_additional: typing.List[SerializableToString] = [a for a in additional if a is not None] if not omt: logger.debug(f"Output metadata tracker not found, not annotating {o}") else: partition_vals = {} time_partition = None for k, v in kwargs.items(): if k == "time_partition": time_partition = v else: partition_vals[k] = str(v) omt.add( o, OutputMetadata( self, time_partition=time_partition if time_partition else None, dynamic_partitions=partition_vals if partition_vals else None, additional_items=filtered_additional if filtered_additional else None, ), ) return o
[docs] def query( self, project: Optional[str] = None, domain: Optional[str] = None, time_partition: Optional[Union[datetime.datetime, TimePartition, art_id.InputBindingData]] = None, partitions: Optional[Union[typing.Dict[str, str], Partitions]] = None, **kwargs, ) -> ArtifactQuery: if self.partition_keys: fn_args = {"project", "domain", "time_partition", "partitions", "tag"} k = set(self.partition_keys) if len(fn_args & k) > 0: raise ValueError( f"There are conflicting partition key names {fn_args ^ k}, please rename" f" use a partitions object" ) if partitions and kwargs: raise ValueError("Please either specify kwargs or a partitions object not both") p_obj: Optional[Partitions] = None if kwargs: p_obj = Partitions(kwargs) p_obj.reference_artifact = self # only set top level if partitions and isinstance(partitions, dict): p_obj = Partitions(partitions) p_obj.reference_artifact = self # only set top level tp = None if time_partition: if isinstance(time_partition, TimePartition): tp = TimePartition( time_partition.value, op=time_partition.op, other=time_partition.other, granularity=self.time_partition_granularity or Granularity.DAY, ) tp.reference_artifact = time_partition.reference_artifact else: tp = TimePartition(time_partition) tp.reference_artifact = self tp = tp or (self.time_partition if self.time_partitioned else None) aq = ArtifactQuery( artifact=self, name=self.name, project=project or self.project or None, domain=domain or self.domain or None, time_partition=tp, partitions=p_obj or self.partitions, ) return aq
@property def concrete_artifact_id(self) -> art_id.ArtifactID: # This property is used when you want to ensure that this is a materialized artifact, all fields are known. if self.name is None or self.project is None or self.domain is None or self.version is None: raise ValueError("Cannot create artifact id without name, project, domain, version") return self.to_id_idl()
[docs] def embed_as_query( self, partition: Optional[str] = None, bind_to_time_partition: Optional[bool] = None, expr: Optional[str] = None, op: Optional[Op] = None, ) -> art_id.ArtifactQuery: """ This should only be called in the context of a Trigger :param partition: Can embed a time partition :param bind_to_time_partition: Set to true if you want to bind to a time partition :param expr: Only valid if there's a time partition. :param op: If expr is given, then op is what to do with it. """ t = None if expr and (partition or bind_to_time_partition): t = art_id.TimeTransform(transform=expr, op=op) aq = art_id.ArtifactQuery( binding=art_id.ArtifactBindingData( partition_key=partition, bind_to_time_partition=bind_to_time_partition, time_transform=t, ) ) return aq
[docs] def to_id_idl(self) -> art_id.ArtifactID: """ Converts this object to the IDL representation. This is here instead of translator because it's in the interface, a relatively simple proto object that's exposed to the user. """ p = Serializer.partitions_to_idl(self.partitions) tp = Serializer.time_partition_to_idl(self.time_partition) if self.time_partitioned else None i = art_id.ArtifactID( artifact_key=art_id.ArtifactKey( project=self.project, domain=self.domain, name=self.name, ), version=self.version, partitions=p, time_partition=tp, ) return i
class ArtifactSerializationHandler(typing.Protocol): """ This protocol defines the interface for serializing artifact-related entities down to Flyte IDL. """ def partitions_to_idl(self, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: ... def time_partition_to_idl(self, tp: Optional[TimePartition], **kwargs) -> Optional[art_id.TimePartition]: ... def artifact_query_to_idl(self, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: ... class DefaultArtifactSerializationHandler(ArtifactSerializationHandler): def partitions_to_idl(self, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: if p and p.partitions: pp = {} for k, v in p.partitions.items(): if v.value is None: # For specifying partitions in the Variable partial id pp[k] = art_id.LabelValue(static_value="") else: pp[k] = v.value return art_id.Partitions(value=pp) return None def time_partition_to_idl(self, tp: Optional[TimePartition], **kwargs) -> Optional[art_id.TimePartition]: if tp: return art_id.TimePartition(value=tp.value, granularity=tp.granularity) return None def artifact_query_to_idl(self, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: ak = art_id.ArtifactKey( name=aq.name, project=aq.project, domain=aq.domain, ) p = self.partitions_to_idl(aq.partitions) tp = self.time_partition_to_idl(aq.time_partition) i = art_id.ArtifactID( artifact_key=ak, partitions=p, time_partition=tp, ) aq = art_id.ArtifactQuery( artifact_id=i, ) return aq class Serializer(object): serializer: ArtifactSerializationHandler = DefaultArtifactSerializationHandler() @classmethod def register_serializer(cls, serializer: ArtifactSerializationHandler): cls.serializer = serializer @classmethod def partitions_to_idl(cls, p: Optional[Partitions], **kwargs) -> Optional[art_id.Partitions]: return cls.serializer.partitions_to_idl(p, **kwargs) @classmethod def time_partition_to_idl(cls, tp: TimePartition, **kwargs) -> Optional[art_id.TimePartition]: return cls.serializer.time_partition_to_idl(tp, **kwargs) @classmethod def artifact_query_to_idl(cls, aq: ArtifactQuery, **kwargs) -> art_id.ArtifactQuery: return cls.serializer.artifact_query_to_idl(aq, **kwargs)