Source code for flytekitplugins.awssagemaker_inference.workflow

from typing import Any, Dict, Optional, Tuple, Type

from flytekit import Workflow, kwtypes

from .task import (
    SageMakerDeleteEndpointConfigTask,
    SageMakerDeleteEndpointTask,
    SageMakerDeleteModelTask,
    SageMakerEndpointConfigTask,
    SageMakerEndpointTask,
    SageMakerModelTask,
)


def create_deployment_task(
    name: str,
    task_type: Any,
    config: Dict[str, Any],
    region: str,
    inputs: Optional[Dict[str, Type]],
    images: Optional[Dict[str, Any]],
    region_at_runtime: bool,
) -> Tuple[Any, Optional[Dict[str, Type]]]:
    if region_at_runtime:
        if inputs:
            inputs.update({"region": str})
        else:
            inputs = kwtypes(region=str)
    return (
        task_type(name=name, config=config, region=region, inputs=inputs, images=images),
        inputs,
    )


[docs] def create_sagemaker_deployment( name: str, model_config: Dict[str, Any], endpoint_config_config: Dict[str, Any], endpoint_config: Dict[str, Any], images: Optional[Dict[str, Any]] = None, model_input_types: Optional[Dict[str, Type]] = None, endpoint_config_input_types: Optional[Dict[str, Type]] = None, endpoint_input_types: Optional[Dict[str, Type]] = None, region: Optional[str] = None, region_at_runtime: bool = False, ) -> Workflow: """ Creates SageMaker model, endpoint config and endpoint. :param model_config: Configuration for the SageMaker model creation API call. :param endpoint_config_config: Configuration for the SageMaker endpoint configuration creation API call. :param endpoint_config: Configuration for the SageMaker endpoint creation API call. :param images: A dictionary of images for SageMaker model creation. :param model_input_types: Mapping of SageMaker model configuration inputs to their types. :param endpoint_config_input_types: Mapping of SageMaker endpoint configuration inputs to their types. :param endpoint_input_types: Mapping of SageMaker endpoint inputs to their types. :param region: The region for SageMaker API calls. :param region_at_runtime: Set this to True if you want to provide the region at runtime. """ if not any((region, region_at_runtime)): raise ValueError("Region parameter is required.") wf = Workflow(name=f"sagemaker-deploy-{name}") if region_at_runtime: wf.add_workflow_input("region", str) inputs = { SageMakerModelTask: { "input_types": model_input_types, "name": "sagemaker-model", "images": True, "config": model_config, }, SageMakerEndpointConfigTask: { "input_types": endpoint_config_input_types, "name": "sagemaker-endpoint-config", "images": False, "config": endpoint_config_config, }, SageMakerEndpointTask: { "input_types": endpoint_input_types, "name": "sagemaker-endpoint", "images": False, "config": endpoint_config, }, } nodes = [] for key, value in inputs.items(): input_types = value["input_types"] obj, new_input_types = create_deployment_task( name=f"{value['name']}-{name}", task_type=key, config=value["config"], region=region, inputs=input_types, images=images if value["images"] else None, region_at_runtime=region_at_runtime, ) input_dict = {} if isinstance(new_input_types, dict): for param, t in new_input_types.items(): # Handles the scenario when the same input is present during different API calls. if param not in wf.inputs.keys(): wf.add_workflow_input(param, t) input_dict[param] = wf.inputs[param] node = wf.add_entity(obj, **input_dict) if len(nodes) > 0: nodes[-1] >> node nodes.append(node) wf.add_workflow_output("wf_output", nodes[2].outputs["result"], str) return wf
def create_delete_task( name: str, task_type: Any, config: Dict[str, Any], region: str, value: str, region_at_runtime: bool, ) -> Any: return task_type( name=name, config=config, region=region, inputs=(kwtypes(**{value: str, "region": str}) if region_at_runtime else kwtypes(**{value: str})), )
[docs] def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False) -> Workflow: """ Deletes SageMaker model, endpoint config and endpoint. :param name: The prefix to be added to the task names. :param region: The region to use for SageMaker API calls. :param region_at_runtime: Set this to True if you want to provide the region at runtime. """ if not any((region, region_at_runtime)): raise ValueError("Region parameter is required.") wf = Workflow(name=f"sagemaker-delete-deployment-{name}") if region_at_runtime: wf.add_workflow_input("region", str) inputs = { SageMakerDeleteEndpointTask: "endpoint_name", SageMakerDeleteEndpointConfigTask: "endpoint_config_name", SageMakerDeleteModelTask: "model_name", } nodes = [] for key, value in inputs.items(): obj = create_delete_task( name=f"sagemaker-delete-{value.replace('_name', '').replace('_', '-')}-{name}", task_type=key, config={value.title().replace("_", ""): f"{{inputs.{value}}}"}, region=region, value=value, region_at_runtime=region_at_runtime, ) wf.add_workflow_input(value, str) node = wf.add_entity( obj, **( { value: wf.inputs[value], "region": wf.inputs["region"], } if region_at_runtime else {value: wf.inputs[value]} ), ) if len(nodes) > 0: nodes[-1] >> node nodes.append(node) return wf