From 0da8c5373b35d9bc7520e93934b109c3ff583dac Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Mon, 13 Nov 2023 11:37:34 -0800 Subject: [PATCH] feat: Add preview HyperparameterTuningJob which can be run on persistent resource PiperOrigin-RevId: 582032717 --- .../cloud/aiplatform/hyperparameter_tuning.py | 85 +++- google/cloud/aiplatform/jobs.py | 5 +- google/cloud/aiplatform/preview/jobs.py | 370 +++++++++++++++++- ...arameter_tuning_job_persistent_resource.py | 319 +++++++++++++++ 4 files changed, 776 insertions(+), 3 deletions(-) create mode 100644 tests/unit/aiplatform/test_hyperparameter_tuning_job_persistent_resource.py diff --git a/google/cloud/aiplatform/hyperparameter_tuning.py b/google/cloud/aiplatform/hyperparameter_tuning.py index 2da02cb680..871ae11790 100644 --- a/google/cloud/aiplatform/hyperparameter_tuning.py +++ b/google/cloud/aiplatform/hyperparameter_tuning.py @@ -20,7 +20,10 @@ import proto -from google.cloud.aiplatform.compat.types import study as gca_study_compat +from google.cloud.aiplatform.compat.types import ( + study_v1beta1 as gca_study_compat_v1beta1, + study as gca_study_compat, +) SEARCH_ALGORITHM_TO_PROTO_VALUE = { "random": gca_study_compat.StudySpec.Algorithm.RANDOM_SEARCH, @@ -91,6 +94,46 @@ def _proto_parameter_value_spec(self) -> proto.Message: ) return proto_parameter_value_spec + @property + def _proto_parameter_value_spec_v1beta1(self) -> proto.Message: + """Converts this parameter to it's parameter value representation.""" + if isinstance( + self._proto_parameter_value_class(), + gca_study_compat.StudySpec.ParameterSpec.DoubleValueSpec, + ): + proto_parameter_value_spec = ( + gca_study_compat_v1beta1.StudySpec.ParameterSpec.DoubleValueSpec() + ) + elif isinstance( + self._proto_parameter_value_class(), + gca_study_compat.StudySpec.ParameterSpec.IntegerValueSpec, + ): + proto_parameter_value_spec = ( + gca_study_compat_v1beta1.StudySpec.ParameterSpec.IntegerValueSpec() + ) + elif isinstance( + self._proto_parameter_value_class(), + gca_study_compat.StudySpec.ParameterSpec.CategoricalValueSpec, + ): + proto_parameter_value_spec = ( + gca_study_compat_v1beta1.StudySpec.ParameterSpec.CategoricalValueSpec() + ) + elif isinstance( + self._proto_parameter_value_class(), + gca_study_compat.StudySpec.ParameterSpec.DiscreteValueSpec, + ): + proto_parameter_value_spec = ( + gca_study_compat_v1beta1.StudySpec.ParameterSpec.DiscreteValueSpec() + ) + else: + proto_parameter_value_spec = self._proto_parameter_value_class() + + for self_attr_key, proto_attr_key in self._parameter_value_map: + setattr( + proto_parameter_value_spec, proto_attr_key, getattr(self, self_attr_key) + ) + return proto_parameter_value_spec + def _to_parameter_spec( self, parameter_id: str ) -> gca_study_compat.StudySpec.ParameterSpec: @@ -129,6 +172,46 @@ def _to_parameter_spec( return parameter_spec + def _to_parameter_spec_v1beta1( + self, parameter_id: str + ) -> gca_study_compat_v1beta1.StudySpec.ParameterSpec: + """Converts this parameter to ParameterSpec.""" + conditions = [] + if self.conditional_parameter_spec is not None: + for (conditional_param_id, spec) in self.conditional_parameter_spec.items(): + condition = ( + gca_study_compat_v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec() + ) + if self._parameter_spec_value_key == _INT_VALUE_SPEC: + condition.parent_int_values = gca_study_compat_v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition( + values=spec.parent_values + ) + elif self._parameter_spec_value_key == _CATEGORICAL_VALUE_SPEC: + condition.parent_categorical_values = gca_study_compat_v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition( + values=spec.parent_values + ) + elif self._parameter_spec_value_key == _DISCRETE_VALUE_SPEC: + condition.parent_discrete_values = gca_study_compat_v1beta1.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition( + values=spec.parent_values + ) + condition.parameter_spec = spec._to_parameter_spec_v1beta1( + conditional_param_id + ) + conditions.append(condition) + parameter_spec = gca_study_compat_v1beta1.StudySpec.ParameterSpec( + parameter_id=parameter_id, + scale_type=_SCALE_TYPE_MAP.get(getattr(self, "scale", "unspecified")), + conditional_parameter_specs=conditions, + ) + + setattr( + parameter_spec, + self._parameter_spec_value_key, + self._proto_parameter_value_spec_v1beta1, + ) + + return parameter_spec + class DoubleParameterSpec(_ParameterSpec): diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index bb9f82c4eb..653d31f73a 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2036,7 +2036,7 @@ def job_spec(self): return self._gca_resource.job_spec -class HyperparameterTuningJob(_RunnableJob): +class HyperparameterTuningJob(_RunnableJob, base.PreviewMixin): """Vertex AI Hyperparameter Tuning Job.""" _resource_noun = "hyperparameterTuningJobs" @@ -2047,6 +2047,9 @@ class HyperparameterTuningJob(_RunnableJob): _parse_resource_name_method = "parse_hyperparameter_tuning_job_path" _format_resource_name_method = "hyperparameter_tuning_job_path" _job_type = "training" + _preview_class = ( + "google.cloud.aiplatform.aiplatform.preview.jobs.HyperparameterTuningJob" + ) def __init__( self, diff --git a/google/cloud/aiplatform/preview/jobs.py b/google/cloud/aiplatform/preview/jobs.py index 22838207ea..2c81bb46fc 100644 --- a/google/cloud/aiplatform/preview/jobs.py +++ b/google/cloud/aiplatform/preview/jobs.py @@ -16,6 +16,8 @@ # from typing import Dict, List, Optional, Union + +import copy import uuid from google.api_core import retry @@ -28,14 +30,17 @@ from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import ( custom_job_v1beta1 as gca_custom_job_compat, + hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_compat, job_state as gca_job_state, job_state_v1beta1 as gca_job_state_v1beta1, + study_v1beta1, ) from google.cloud.aiplatform.compat.types import ( execution_v1beta1 as gcs_execution_compat, ) from google.cloud.aiplatform.compat.types import io_v1beta1 as gca_io_compat from google.cloud.aiplatform.metadata import constants as metadata_constants +from google.cloud.aiplatform import hyperparameter_tuning from google.cloud.aiplatform.utils import console_utils import proto @@ -103,7 +108,7 @@ def __init__( } ] - my_job = aiplatform.CustomJob( + my_job = aiplatform.preview.jobs.CustomJob( display_name='my_job', worker_pool_specs=worker_pool_specs, labels={'my_key': 'my_value'}, @@ -464,3 +469,366 @@ def submit( else: custom_jobs = [custom_job] run_context.update({metadata_constants._CUSTOM_JOB_KEY: custom_jobs}) + + +class HyperparameterTuningJob(jobs.HyperparameterTuningJob): + """Vertex AI Hyperparameter Tuning Job.""" + + def __init__( + self, + # TODO(b/223262536): Make display_name parameter fully optional in next major release + display_name: str, + custom_job: CustomJob, + metric_spec: Dict[str, str], + parameter_spec: Dict[str, hyperparameter_tuning._ParameterSpec], + max_trial_count: int, + parallel_trial_count: int, + max_failed_trial_count: int = 0, + search_algorithm: Optional[str] = None, + measurement_selection: Optional[str] = "best", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + labels: Optional[Dict[str, str]] = None, + encryption_spec_key_name: Optional[str] = None, + ): + """ + Configures a HyperparameterTuning Job. + + Example usage: + + ``` + from google.cloud.aiplatform import hyperparameter_tuning as hpt + + worker_pool_specs = [ + { + "machine_spec": { + "machine_type": "n1-standard-4", + "accelerator_type": "NVIDIA_TESLA_K80", + "accelerator_count": 1, + }, + "replica_count": 1, + "container_spec": { + "image_uri": container_image_uri, + "command": [], + "args": [], + }, + } + ] + + custom_job = aiplatform.preview.jobs.CustomJob( + display_name='my_job', + worker_pool_specs=worker_pool_specs, + labels={'my_key': 'my_value'}, + persistent_resource_id='my_persistent_resource', + ) + + + hp_job = aiplatform.preview.jobs.HyperparameterTuningJob( + display_name='hp-test', + custom_job=job, + metric_spec={ + 'loss': 'minimize', + }, + parameter_spec={ + 'lr': hpt.DoubleParameterSpec(min=0.001, max=0.1, scale='log'), + 'units': hpt.IntegerParameterSpec(min=4, max=128, scale='linear'), + 'activation': hpt.CategoricalParameterSpec(values=['relu', 'selu']), + 'batch_size': hpt.DiscreteParameterSpec(values=[128, 256], scale='linear') + }, + max_trial_count=128, + parallel_trial_count=8, + labels={'my_key': 'my_value'}, + ) + + hp_job.run() + + print(hp_job.trials) + ``` + + + For more information on using hyperparameter tuning please visit: + https://cloud.google.com/ai-platform-unified/docs/training/using-hyperparameter-tuning + + Args: + display_name (str): + Required. The user-defined name of the HyperparameterTuningJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + custom_job (aiplatform.preview.jobs.CustomJob): + Required. Configured CustomJob. The worker pool spec from this custom job + applies to the CustomJobs created in all the trials. A persistent_resource_id can be + specified on the custom job to be used when running this Hyperparameter Tuning job. + metric_spec: Dict[str, str] + Required. Dictionary representing metrics to optimize. The dictionary key is the metric_id, + which is reported by your training job, and the dictionary value is the + optimization goal of the metric('minimize' or 'maximize'). example: + + metric_spec = {'loss': 'minimize', 'accuracy': 'maximize'} + + parameter_spec (Dict[str, hyperparameter_tuning._ParameterSpec]): + Required. Dictionary representing parameters to optimize. The dictionary key is the metric_id, + which is passed into your training job as a command line key word argument, and the + dictionary value is the parameter specification of the metric. + + + from google.cloud.aiplatform import hyperparameter_tuning as hpt + + parameter_spec={ + 'decay': hpt.DoubleParameterSpec(min=1e-7, max=1, scale='linear'), + 'learning_rate': hpt.DoubleParameterSpec(min=1e-7, max=1, scale='linear') + 'batch_size': hpt.DiscreteParamterSpec(values=[4, 8, 16, 32, 64, 128], scale='linear') + } + + Supported parameter specifications can be found until aiplatform.hyperparameter_tuning. + These parameter specification are currently supported: + DoubleParameterSpec, IntegerParameterSpec, CategoricalParameterSpace, DiscreteParameterSpec + + max_trial_count (int): + Required. The desired total number of Trials. + parallel_trial_count (int): + Required. The desired number of Trials to run in parallel. + max_failed_trial_count (int): + Optional. The number of failed Trials that need to be + seen before failing the HyperparameterTuningJob. + If set to 0, Vertex AI decides how many Trials + must fail before the whole job fails. + search_algorithm (str): + The search algorithm specified for the Study. + Accepts one of the following: + `None` - If you do not specify an algorithm, your job uses + the default Vertex AI algorithm. The default algorithm + applies Bayesian optimization to arrive at the optimal + solution with a more effective search over the parameter space. + + 'grid' - A simple grid search within the feasible space. This + option is particularly useful if you want to specify a quantity + of trials that is greater than the number of points in the + feasible space. In such cases, if you do not specify a grid + search, the Vertex AI default algorithm may generate duplicate + suggestions. To use grid search, all parameter specs must be + of type `IntegerParameterSpec`, `CategoricalParameterSpace`, + or `DiscreteParameterSpec`. + + 'random' - A simple random search within the feasible space. + measurement_selection (str): + This indicates which measurement to use if/when the service + automatically selects the final measurement from previously reported + intermediate measurements. + + Accepts: 'best', 'last' + + Choose this based on two considerations: + A) Do you expect your measurements to monotonically improve? If so, + choose 'last'. On the other hand, if you're in a situation + where your system can "over-train" and you expect the performance to + get better for a while but then start declining, choose + 'best'. B) Are your measurements significantly noisy + and/or irreproducible? If so, 'best' will tend to be + over-optimistic, and it may be better to choose 'last'. If + both or neither of (A) and (B) apply, it doesn't matter which + selection type is chosen. + project (str): + Optional. Project to run the HyperparameterTuningjob in. Overrides project set in aiplatform.init. + location (str): + Optional. Location to run the HyperparameterTuning in. Overrides location set in aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to run call HyperparameterTuning service. Overrides + credentials set in aiplatform.init. + labels (Dict[str, str]): + Optional. The labels with user-defined metadata to + organize HyperparameterTuningJobs. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + encryption_spec_key_name (str): + Optional. Customer-managed encryption key options for a + HyperparameterTuningJob. If this is set, then + all resources created by the + HyperparameterTuningJob will be encrypted with + the provided encryption key. + """ + + super(jobs.HyperparameterTuningJob, self).__init__( + project=project, location=location, credentials=credentials + ) + + metrics = [ + study_v1beta1.StudySpec.MetricSpec(metric_id=metric_id, goal=goal.upper()) + for metric_id, goal in metric_spec.items() + ] + + parameters = [ + parameter._to_parameter_spec_v1beta1(parameter_id=parameter_id) + for parameter_id, parameter in parameter_spec.items() + ] + + study_spec = study_v1beta1.StudySpec( + metrics=metrics, + parameters=parameters, + algorithm=hyperparameter_tuning.SEARCH_ALGORITHM_TO_PROTO_VALUE[ + search_algorithm + ], + measurement_selection_type=hyperparameter_tuning.MEASUREMENT_SELECTION_TO_PROTO_VALUE[ + measurement_selection + ], + ) + + if not display_name: + display_name = self.__class__._generate_display_name() + + self._gca_resource = ( + gca_hyperparameter_tuning_job_compat.HyperparameterTuningJob( + display_name=display_name, + study_spec=study_spec, + max_trial_count=max_trial_count, + parallel_trial_count=parallel_trial_count, + max_failed_trial_count=max_failed_trial_count, + trial_job_spec=copy.deepcopy(custom_job.job_spec), + labels=labels, + encryption_spec=initializer.global_config.get_encryption_spec( + encryption_spec_key_name=encryption_spec_key_name, + select_version=compat.V1BETA1, + ), + ) + ) + + def _get_gca_resource( + self, + resource_name: str, + parent_resource_name_fields: Optional[Dict[str, str]] = None, + ) -> proto.Message: + """Returns GAPIC service representation of client class resource. + + Args: + resource_name (str): Required. A fully-qualified resource name or ID. + parent_resource_name_fields (Dict[str,str]): + Optional. Mapping of parent resource name key to values. These + will be used to compose the resource name if only resource ID is given. + Should not include project and location. + """ + resource_name = utils.full_resource_name( + resource_name=resource_name, + resource_noun=self._resource_noun, + parse_resource_name_method=self._parse_resource_name, + format_resource_name_method=self._format_resource_name, + project=self.project, + location=self.location, + parent_resource_name_fields=parent_resource_name_fields, + resource_id_validator=self._resource_id_validator, + ) + + return getattr(self.api_client.select_version("v1beta1"), self._getter_method)( + name=resource_name, retry=_DEFAULT_RETRY + ) + + @base.optional_sync() + def _run( + self, + service_account: Optional[str] = None, + network: Optional[str] = None, + timeout: Optional[int] = None, # seconds + restart_job_on_worker_restart: bool = False, + enable_web_access: bool = False, + tensorboard: Optional[str] = None, + sync: bool = True, + create_request_timeout: Optional[float] = None, + disable_retries: bool = False, + ) -> None: + """Helper method to ensure network synchronization and to run the configured CustomJob. + + Args: + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + network (str): + Optional. The full name of the Compute Engine network to which the job + should be peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. + timeout (int): + Optional. The maximum job running time in seconds. The default is 7 days. + restart_job_on_worker_restart (bool): + Restarts the entire CustomJob if a worker + gets restarted. This feature can be used by + distributed training jobs that are not resilient + to workers leaving and joining a job. + enable_web_access (bool): + Whether you want Vertex AI to enable interactive shell access + to training containers. + https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell + tensorboard (str): + Optional. The name of a Vertex AI + [Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard] + resource to which this CustomJob will upload Tensorboard + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + + The training script should write Tensorboard to following Vertex AI environment + variable: + + AIP_TENSORBOARD_LOG_DIR + + `service_account` is required with provided `tensorboard`. + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + sync (bool): + Whether to execute this method synchronously. If False, this method + will unblock and it will be executed in a concurrent Future. + create_request_timeout (float): + Optional. The timeout for the create request in seconds. + disable_retries (bool): + Indicates if the job should retry for internal errors after the + job starts running. If True, overrides + `restart_job_on_worker_restart` to False. + """ + if service_account: + self._gca_resource.trial_job_spec.service_account = service_account + + if network: + self._gca_resource.trial_job_spec.network = network + + if timeout or restart_job_on_worker_restart or disable_retries: + duration = duration_pb2.Duration(seconds=timeout) if timeout else None + self._gca_resource.trial_job_spec.scheduling = ( + gca_custom_job_compat.Scheduling( + timeout=duration, + restart_job_on_worker_restart=restart_job_on_worker_restart, + disable_retries=disable_retries, + ) + ) + + if enable_web_access: + self._gca_resource.trial_job_spec.enable_web_access = enable_web_access + + if tensorboard: + self._gca_resource.trial_job_spec.tensorboard = tensorboard + + _LOGGER.log_create_with_lro(self.__class__) + + self._gca_resource = self.api_client.select_version( + "v1beta1" + ).create_hyperparameter_tuning_job( + parent=self._parent, + hyperparameter_tuning_job=self._gca_resource, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_complete_with_getter( + self.__class__, self._gca_resource, "hpt_job" + ) + + _LOGGER.info("View HyperparameterTuningJob:\n%s" % self._dashboard_uri()) + + if tensorboard: + _LOGGER.info( + "View Tensorboard:\n%s" + % console_utils.custom_job_tensorboard_console_uri( + tensorboard, self.resource_name + ) + ) + + self._block_until_complete() diff --git a/tests/unit/aiplatform/test_hyperparameter_tuning_job_persistent_resource.py b/tests/unit/aiplatform/test_hyperparameter_tuning_job_persistent_resource.py new file mode 100644 index 0000000000..9b528c5c87 --- /dev/null +++ b/tests/unit/aiplatform/test_hyperparameter_tuning_job_persistent_resource.py @@ -0,0 +1,319 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from importlib import reload +from unittest import mock +from unittest.mock import patch + +from google.cloud import aiplatform +from google.cloud.aiplatform.compat.services import ( + job_service_client_v1beta1, +) +from google.cloud.aiplatform import hyperparameter_tuning as hpt +from google.cloud.aiplatform.compat.types import ( + custom_job_v1beta1, + encryption_spec_v1beta1, + hyperparameter_tuning_job_v1beta1, + io_v1beta1, + job_state_v1beta1 as gca_job_state_compat, + study_v1beta1 as gca_study_compat, +) +from google.cloud.aiplatform.preview import jobs +import constants as test_constants +import pytest + +from google.protobuf import duration_pb2 + +_TEST_PROJECT = test_constants.ProjectConstants._TEST_PROJECT +_TEST_LOCATION = test_constants.ProjectConstants._TEST_LOCATION +_TEST_ID = "1028944691210842416" +_TEST_DISPLAY_NAME = test_constants.TrainingJobConstants._TEST_DISPLAY_NAME + +_TEST_PARENT = test_constants.ProjectConstants._TEST_PARENT + +_TEST_HYPERPARAMETERTUNING_JOB_NAME = ( + f"{_TEST_PARENT}/hyperparameterTuningJobs/{_TEST_ID}" +) + +_TEST_PREBUILT_CONTAINER_IMAGE = "gcr.io/cloud-aiplatform/container:image" + +_TEST_RUN_ARGS = test_constants.TrainingJobConstants._TEST_RUN_ARGS +_TEST_EXPERIMENT = "test-experiment" +_TEST_EXPERIMENT_RUN = "test-experiment-run" + +_TEST_STAGING_BUCKET = test_constants.TrainingJobConstants._TEST_STAGING_BUCKET + +# CMEK encryption +_TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_1234" +_TEST_DEFAULT_ENCRYPTION_SPEC = encryption_spec_v1beta1.EncryptionSpec( + kms_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME +) + +_TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT + +_TEST_METRIC_SPEC_KEY = "test-metric" +_TEST_METRIC_SPEC_VALUE = "maximize" + +_TEST_PARALLEL_TRIAL_COUNT = 8 +_TEST_MAX_TRIAL_COUNT = 64 +_TEST_MAX_FAILED_TRIAL_COUNT = 4 +_TEST_SEARCH_ALGORITHM = "random" +_TEST_MEASUREMENT_SELECTION = "best" + +_TEST_NETWORK = test_constants.TrainingJobConstants._TEST_NETWORK + +_TEST_TIMEOUT = test_constants.TrainingJobConstants._TEST_TIMEOUT +_TEST_RESTART_JOB_ON_WORKER_RESTART = ( + test_constants.TrainingJobConstants._TEST_RESTART_JOB_ON_WORKER_RESTART +) +_TEST_DISABLE_RETRIES = test_constants.TrainingJobConstants._TEST_DISABLE_RETRIES + +_TEST_LABELS = test_constants.ProjectConstants._TEST_LABELS + +_TEST_CONDITIONAL_PARAMETER_DECAY = hpt.DoubleParameterSpec( + min=1e-07, max=1, scale="linear", parent_values=[32, 64] +) +_TEST_CONDITIONAL_PARAMETER_LR = hpt.DoubleParameterSpec( + min=1e-07, max=1, scale="linear", parent_values=[4, 8, 16] +) + + +# Persistent Resource +_TEST_PERSISTENT_RESOURCE_ID = "test-persistent-resource-1" + +_TEST_TRIAL_JOB_SPEC = custom_job_v1beta1.CustomJobSpec( + worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC, + base_output_directory=io_v1beta1.GcsDestination( + output_uri_prefix=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR + ), + scheduling=custom_job_v1beta1.Scheduling( + timeout=duration_pb2.Duration(seconds=_TEST_TIMEOUT), + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + disable_retries=_TEST_DISABLE_RETRIES, + ), + service_account=test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, +) + +_TEST_BASE_HYPERPARAMETER_TUNING_JOB_WITH_PERSISTENT_RESOURCE_PROTO = hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + study_spec=gca_study_compat.StudySpec( + metrics=[ + gca_study_compat.StudySpec.MetricSpec( + metric_id=_TEST_METRIC_SPEC_KEY, goal=_TEST_METRIC_SPEC_VALUE.upper() + ) + ], + parameters=[ + gca_study_compat.StudySpec.ParameterSpec( + parameter_id="lr", + scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LOG_SCALE, + double_value_spec=gca_study_compat.StudySpec.ParameterSpec.DoubleValueSpec( + min_value=0.001, max_value=0.1 + ), + ), + gca_study_compat.StudySpec.ParameterSpec( + parameter_id="units", + scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE, + integer_value_spec=gca_study_compat.StudySpec.ParameterSpec.IntegerValueSpec( + min_value=4, max_value=1028 + ), + ), + gca_study_compat.StudySpec.ParameterSpec( + parameter_id="activation", + categorical_value_spec=gca_study_compat.StudySpec.ParameterSpec.CategoricalValueSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + ), + gca_study_compat.StudySpec.ParameterSpec( + parameter_id="batch_size", + scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE, + discrete_value_spec=gca_study_compat.StudySpec.ParameterSpec.DiscreteValueSpec( + values=[4, 8, 16, 32, 64] + ), + conditional_parameter_specs=[ + gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec( + parent_discrete_values=gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition( + values=[32, 64] + ), + parameter_spec=gca_study_compat.StudySpec.ParameterSpec( + double_value_spec=gca_study_compat.StudySpec.ParameterSpec.DoubleValueSpec( + min_value=1e-07, max_value=1 + ), + scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE, + parameter_id="decay", + ), + ), + gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec( + parent_discrete_values=gca_study_compat.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition( + values=[4, 8, 16] + ), + parameter_spec=gca_study_compat.StudySpec.ParameterSpec( + double_value_spec=gca_study_compat.StudySpec.ParameterSpec.DoubleValueSpec( + min_value=1e-07, max_value=1 + ), + scale_type=gca_study_compat.StudySpec.ParameterSpec.ScaleType.UNIT_LINEAR_SCALE, + parameter_id="learning_rate", + ), + ), + ], + ), + ], + algorithm=gca_study_compat.StudySpec.Algorithm.RANDOM_SEARCH, + measurement_selection_type=gca_study_compat.StudySpec.MeasurementSelectionType.BEST_MEASUREMENT, + ), + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + trial_job_spec=_TEST_TRIAL_JOB_SPEC, + labels=_TEST_LABELS, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, +) + + +def _get_hyperparameter_tuning_job_proto(state=None, name=None, error=None): + hyperparameter_tuning_job_proto = copy.deepcopy( + _TEST_BASE_HYPERPARAMETER_TUNING_JOB_WITH_PERSISTENT_RESOURCE_PROTO + ) + hyperparameter_tuning_job_proto.name = name + hyperparameter_tuning_job_proto.state = state + hyperparameter_tuning_job_proto.error = error + + return hyperparameter_tuning_job_proto + + +@pytest.fixture +def create_preview_hyperparameter_tuning_job_mock(): + with mock.patch.object( + job_service_client_v1beta1.JobServiceClient, "create_hyperparameter_tuning_job" + ) as create_preview_hyperparameter_tuning_job_mock: + create_preview_hyperparameter_tuning_job_mock.return_value = ( + _get_hyperparameter_tuning_job_proto( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + ) + yield create_preview_hyperparameter_tuning_job_mock + + +@pytest.fixture +def get_hyperparameter_tuning_job_mock(): + with patch.object( + job_service_client_v1beta1.JobServiceClient, "get_hyperparameter_tuning_job" + ) as get_hyperparameter_tuning_job_mock: + get_hyperparameter_tuning_job_mock.side_effect = [ + _get_hyperparameter_tuning_job_proto( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_hyperparameter_tuning_job_proto( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_hyperparameter_tuning_job_proto( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + _get_hyperparameter_tuning_job_proto( + name=_TEST_HYPERPARAMETERTUNING_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_hyperparameter_tuning_job_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestHyperparameterTuningJobPersistentResource: + def setup_method(self): + reload(aiplatform.initializer) + reload(aiplatform) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_hyperparameter_tuning_job_with_persistent_resource( + self, + create_preview_hyperparameter_tuning_job_mock, + get_hyperparameter_tuning_job_mock, + sync, + ): + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + custom_job = jobs.CustomJob( + display_name=test_constants.TrainingJobConstants._TEST_DISPLAY_NAME, + worker_pool_specs=test_constants.TrainingJobConstants._TEST_WORKER_POOL_SPEC, + base_output_dir=test_constants.TrainingJobConstants._TEST_BASE_OUTPUT_DIR, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + ) + + job = jobs.HyperparameterTuningJob( + display_name=_TEST_DISPLAY_NAME, + custom_job=custom_job, + metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE}, + parameter_spec={ + "lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"), + "units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"), + "activation": hpt.CategoricalParameterSpec( + values=["relu", "sigmoid", "elu", "selu", "tanh"] + ), + "batch_size": hpt.DiscreteParameterSpec( + values=[4, 8, 16, 32, 64], + scale="linear", + conditional_parameter_spec={ + "decay": _TEST_CONDITIONAL_PARAMETER_DECAY, + "learning_rate": _TEST_CONDITIONAL_PARAMETER_LR, + }, + ), + }, + parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT, + max_trial_count=_TEST_MAX_TRIAL_COUNT, + max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT, + search_algorithm=_TEST_SEARCH_ALGORITHM, + measurement_selection=_TEST_MEASUREMENT_SELECTION, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=sync, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + ) + + job.wait() + + expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto() + + create_preview_hyperparameter_tuning_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + hyperparameter_tuning_job=expected_hyperparameter_tuning_job, + timeout=None, + ) + + assert job.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + assert job.network == _TEST_NETWORK + assert job.trials == []