from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job
from flytekit.models import common as _common
from . import training_job as _training_job
[docs]class HyperparameterTuningObjectiveType(object):
MINIMIZE = _pb2_hpo_job.HyperparameterTuningObjectiveType.MINIMIZE
MAXIMIZE = _pb2_hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE
[docs]class HyperparameterTuningObjective(_common.FlyteIdlEntity):
"""
HyperparameterTuningObjective is a data structure that contains the target metric and the
objective of the hyperparameter tuning.
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html
"""
def __init__(
self,
objective_type: int,
metric_name: str,
):
self._objective_type = objective_type
self._metric_name = metric_name
@property
def objective_type(self) -> int:
"""
Enum value of HyperparameterTuningObjectiveType. objective_type determines the direction of the tuning of
the Hyperparameter Tuning Job with respect to the specified metric.
:rtype: int
"""
return self._objective_type
@property
def metric_name(self) -> str:
"""
The target metric name, which is the user-defined name of the metric specified in the
training job's algorithm specification
:rtype: str
"""
return self._metric_name
[docs] def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningObjective:
return _pb2_hpo_job.HyperparameterTuningObjective(
objective_type=self.objective_type,
metric_name=self._metric_name,
)
[docs] @classmethod
def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningObjective):
return cls(
objective_type=pb2_object.objective_type,
metric_name=pb2_object.metric_name,
)
[docs]class HyperparameterTuningStrategy:
BAYESIAN = _pb2_hpo_job.HyperparameterTuningStrategy.BAYESIAN
RANDOM = _pb2_hpo_job.HyperparameterTuningStrategy.RANDOM
[docs]class TrainingJobEarlyStoppingType:
OFF = _pb2_hpo_job.TrainingJobEarlyStoppingType.OFF
AUTO = _pb2_hpo_job.TrainingJobEarlyStoppingType.AUTO
[docs]class HyperparameterTuningJobConfig(_common.FlyteIdlEntity):
"""
The specification of the hyperparameter tuning process
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-ex-tuning-job.html#automatic-model-tuning-ex-low-tuning-config
"""
def __init__(
self,
tuning_strategy: int,
tuning_objective: HyperparameterTuningObjective,
training_job_early_stopping_type: TrainingJobEarlyStoppingType,
):
self._tuning_strategy = tuning_strategy
self._tuning_objective = tuning_objective
self._training_job_early_stopping_type = training_job_early_stopping_type
@property
def tuning_strategy(self) -> int:
"""
Enum value of HyperparameterTuningStrategy. Setting the strategy used when searching in the hyperparameter space
:rtype: int
"""
return self._tuning_strategy
@property
def tuning_objective(self) -> HyperparameterTuningObjective:
"""
The target metric and the objective of the hyperparameter tuning.
:rtype: HyperparameterTuningObjective
"""
return self._tuning_objective
@property
def training_job_early_stopping_type(self) -> int:
"""
Enum value of TrainingJobEarlyStoppingType. When the training jobs launched by the hyperparameter tuning job
are not improving significantly, a hyperparameter tuning job can be stopping early. This attribute determines
how the early stopping is to be done.
Note that there's only a subset of built-in algorithms that supports early stopping.
see: https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-early-stopping.html
:rtype: int
"""
return self._training_job_early_stopping_type
[docs] def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJobConfig:
return _pb2_hpo_job.HyperparameterTuningJobConfig(
tuning_strategy=self._tuning_strategy,
tuning_objective=self._tuning_objective.to_flyte_idl(),
training_job_early_stopping_type=self._training_job_early_stopping_type,
)
[docs] @classmethod
def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJobConfig):
return cls(
tuning_strategy=pb2_object.tuning_strategy,
tuning_objective=HyperparameterTuningObjective.from_flyte_idl(pb2_object.tuning_objective),
training_job_early_stopping_type=pb2_object.training_job_early_stopping_type,
)
class HyperparameterTuningJob(_common.FlyteIdlEntity):
def __init__(
self,
max_number_of_training_jobs: int,
max_parallel_training_jobs: int,
training_job: _training_job.TrainingJob,
):
self._max_number_of_training_jobs = max_number_of_training_jobs
self._max_parallel_training_jobs = max_parallel_training_jobs
self._training_job = training_job
@property
def max_number_of_training_jobs(self) -> int:
"""
The maximum number of training jobs that a hyperparameter tuning job can launch.
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ResourceLimits.html
:rtype: int
"""
return self._max_number_of_training_jobs
@property
def max_parallel_training_jobs(self) -> int:
"""
The maximum number of concurrent training job that an hpo job can launch
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ResourceLimits.html
:rtype: int
"""
return self._max_parallel_training_jobs
@property
def training_job(self) -> _training_job.TrainingJob:
"""
The reference to the underlying training job that the hyperparameter tuning job will launch during the process
:rtype: _training_job.TrainingJob
"""
return self._training_job
def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJob:
return _pb2_hpo_job.HyperparameterTuningJob(
max_number_of_training_jobs=self._max_number_of_training_jobs,
max_parallel_training_jobs=self._max_parallel_training_jobs,
training_job=self._training_job.to_flyte_idl(), # SDK task has already serialized it
)
@classmethod
def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJob):
return cls(
max_number_of_training_jobs=pb2_object.max_number_of_training_jobs,
max_parallel_training_jobs=pb2_object.max_parallel_training_jobs,
training_job=_training_job.TrainingJob.from_flyte_idl(pb2_object.training_job),
)