From 52de070a4bf340c90b938d4f4cccbd41147ef5c5 Mon Sep 17 00:00:00 2001 From: Ivan Cheung Date: Wed, 18 Nov 2020 20:10:25 +0900 Subject: [PATCH] refactor: Pull out reusable code in CustomTrainingJob to use in other training jobs (#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung --- google/cloud/aiplatform/training_jobs.py | 558 ++++++++++++++++------- 1 file changed, 388 insertions(+), 170 deletions(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 689768d82c..978509361b 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -26,6 +26,7 @@ import time from typing import Callable, Dict, List, Optional, NamedTuple, Sequence, Union +import abc from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base @@ -66,6 +67,313 @@ ) +class _TrainingJob(base.AiPlatformResourceNoun): + client_class = pipeline_service_client.PipelineServiceClient + _is_client_prediction_client = False + + def __init__( + self, + display_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Constructs a Training Job. + + Args: + display_name (str): + Required. The user-defined name of this TrainingPipeline. + project (str): + Optional project to retrieve model from. If not set, project set in + aiplatform.init will be used. + location (str): + Optional location to retrieve model from. If not set, location set in + aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional credentials to use to retrieve the model. + """ + utils.validate_display_name(display_name) + + super().__init__(project=project, location=location, credentials=credentials) + self._display_name = display_name + self._project = project + self._credentials = credentials + self._gca_resource = None + + @property + @abc.abstractmethod + def _model_upload_fail_string(self) -> str: + """Helper property for model upload failure.""" + + pass + + @abc.abstractmethod + def run(self) -> Optional[models.Model]: + """Runs the training job. Should call _run_job internally""" + pass + + def _create_input_data_config( + self, + dataset: Optional[datasets.Dataset], + training_fraction_split: float, + validation_fraction_split: float, + test_fraction_split: float, + ) -> gca_training_pipeline.InputDataConfig: + + """Constructs a input data config to pass to the training pipeline. + Override this to create a custom config + + Args: + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + """ + + input_data_config = None + if dataset: + # Create fraction split spec + fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + ) + + # create input data config + input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=fraction_split, dataset_id=dataset.name, + ) + + return input_data_config + + def _run_job( + self, + training_task_definition: str, + training_task_inputs: dict, + dataset: Optional[datasets.Dataset], + training_fraction_split: float, + validation_fraction_split: float, + test_fraction_split: float, + model: Optional[gca_model.Model] = None, + ) -> Optional[models.Model]: + """Runs the training job. + + Args: + training_task_definition (str): + Required. A Google Cloud Storage path to the + YAML file that defines the training task which + is responsible for producing the model artifact, + and may also include additional auxiliary work. + The definition files that can be used here are + found in gs://google-cloud- + aiplatform/schema/trainingjob/definition/. Note: + The URI given on output will be immutable and + probably different, including the URI scheme, + than the one given on input. The output URI will + point to a location where the user only has a + read access. + training_task_inputs (~.struct.Value): + Required. The training task's parameter(s), as specified in + the + ``training_task_definition``'s + ``inputs``. + dataset (aiplatform.Dataset): + Optional. The dataset within the same Project from which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + model (~.model.Model): + Optional. Describes the Model that may be uploaded (via + [ModelService.UploadMode][]) by this TrainingPipeline. The + TrainingPipeline's + ``training_task_definition`` + should make clear whether this Model description should be + populated, and if there are any special requirements + regarding how it should be filled. If nothing is mentioned + in the + ``training_task_definition``, + then it should be assumed that this field should not be + filled and the training task either uploads the Model + without a need of this information, or that training task + does not support uploading a Model as part of the pipeline. + When the Pipeline's state becomes + ``PIPELINE_STATE_SUCCEEDED`` and the trained Model had been + uploaded into AI Platform, then the model_to_upload's + resource ``name`` + is populated. The Model is always uploaded into the Project + and Location in which this pipeline is. + """ + + if self._has_run: + raise RuntimeError("Training has already run.") + + input_data_config = self._create_input_data_config( + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + ) + + # create training pipeline + training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=self._display_name, + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs, + model_to_upload=model, + input_data_config=input_data_config, + ) + + training_pipeline = self.api_client.create_training_pipeline( + parent=initializer.global_config.common_location_path( + self.project, self.location + ), + training_pipeline=training_pipeline, + ) + + self._gca_resource = training_pipeline + + _LOGGER.info("View Training:\n%s" % self._dashboard_uri()) + + model = self._get_model() + + if model is None: + _LOGGER.warning( + "Training did not produce a Managed Model returning None. " + + self._model_upload_fail_string + ) + return model + + @property + def state(self) -> gca_pipeline_state.PipelineState: + """Current training state.""" + self._assert_has_run() + return self._gca_resource.state + + def get_model(self) -> Optional[models.Model]: + """AI Platform Model produced by this training, if one was produced. + + Returns: + model: AI Platform Model produced by this training or None if a model was + not produced by this training. + """ + self._assert_has_run() + if not self._gca_resource.model_to_upload: + raise RuntimeError(self._model_upload_fail_string) + + return self._get_model() + + def _get_model(self) -> Optional[models.Model]: + """Helper method to get and instantiate the Model to Upload. + + Returns: + model: AI Platform Model if training succeeded and produced an AI Platform + Model. None otherwise. + + Raises: + RuntimeError if Training failed. + """ + self._block_until_complete() + + if self.has_failed: + raise RuntimeError( + f"Training Pipeline {self.resource_name} failed. No model available." + ) + + if not self._gca_resource.model_to_upload: + return None + + if self._gca_resource.model_to_upload.name: + fields = utils.extract_fields_from_resource_name( + self._gca_resource.model_to_upload.name + ) + return models.Model( + fields.id, project=fields.project, location=fields.location + ) + + def _block_until_complete(self): + """Helper method to block and check on job until complete.""" + + # Used these numbers so failures surface fast + wait = 5 # start at five seconds + max_wait = 60 * 5 # 5 minute wait + multiplier = 2 # scale wait by 2 every iteration + + while self.state not in _PIPELINE_COMPLETE_STATES: + self._sync_gca_resource() + time.sleep(wait) + _LOGGER.info( + "Training %s current state:\n%s" + % (self._gca_resource.name, self._gca_resource.state) + ) + wait = min(wait * multiplier, max_wait) + + self._raise_failure() + + if self._gca_resource.model_to_upload and not self.has_failed: + _LOGGER.info( + "Model available at %s" % self._gca_resource.model_to_upload.name + ) + + def _raise_failure(self): + """Helper method to raise failure if TrainingPipeline fails. + + Raises: + RuntimeError: If training failed.""" + if self._gca_resource.error.code != code_pb2.OK: + raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error) + + @property + def has_failed(self) -> bool: + """Returns True if training has failed. False otherwise.""" + self._assert_has_run() + return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED + + def _dashboard_uri(self) -> str: + """Helper method to compose the dashboard uri where training can be viewed.""" + fields = utils.extract_fields_from_resource_name(self.resource_name) + url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}" + return url + + def _sync_gca_resource(self): + """Helper method to sync the local gca_source against the service.""" + self._gca_resource = self.api_client.get_training_pipeline( + name=self.resource_name + ) + + @property + def _has_run(self) -> bool: + """Helper property to check if this training job has been run.""" + return self._gca_resource is not None + + def _assert_has_run(self): + """Helper method to assert that this training has run.""" + if not self._has_run: + raise RuntimeError( + "TrainingPipeline has not been launched. You must run this" + " TrainingPipeline using TrainingPipeline.run. " + ) + + def _timestamped_gcs_dir(root_gcs_path: str, dir_name_prefix: str) -> str: """Composes a timestamped GCS directory. @@ -524,16 +832,13 @@ def chief_worker_pool( # TODO(b/172368325) add scheduling, custom_job.Scheduling -class CustomTrainingJob(base.AiPlatformResourceNoun): +class CustomTrainingJob(_TrainingJob): """Class to launch a Custom Training Job in AI Platform using a script. Takes a training implementation as a python script and executes that script in Cloud AI Platform Training. """ - client_class = pipeline_service_client.PipelineServiceClient - _is_client_prediction_client = False - # TODO(b/172365796) add remainder of model optional arguments def __init__( self, @@ -617,15 +922,15 @@ def __init__( Bucket used to stage source and training artifacts. Overrides staging_bucket set in aiplatform.init. """ - utils.validate_display_name(display_name) - super().__init__(project=project, location=location, credentials=credentials) - self._display_name = display_name - self._script_path = script_path + super().__init__( + display_name=display_name, + project=project, + location=location, + credentials=credentials, + ) + self._container_uri = container_uri self._requirements = requirements - self._staging_bucket = staging_bucket - self._project = project - self._credentials = credentials self._model_serving_container_image_uri = model_serving_container_image_uri self._model_serving_container_predict_route = ( model_serving_container_predict_route @@ -633,7 +938,55 @@ def __init__( self._model_serving_container_health_route = ( model_serving_container_health_route ) - self._gca_resource = None + + self._script_path = script_path + self._staging_bucket = staging_bucket + + def _create_input_data_config( + self, + dataset: Optional[datasets.Dataset], + training_fraction_split: float, + validation_fraction_split: float, + test_fraction_split: float, + ) -> gca_training_pipeline.InputDataConfig: + """Constructs a input data config to pass to the training pipeline. + Override this to create a custom config + + Args: + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + training_fraction_split (float): + The fraction of the input data that is to be + used to train the Model. This is ignored if Dataset is not provided. + validation_fraction_split (float): + The fraction of the input data that is to be + used to validate the Model. This is ignored if Dataset is not provided. + test_fraction_split (float): + The fraction of the input data that is to be + used to evaluate the Model. This is ignored if Dataset is not provided. + """ + + input_data_config = None + + if dataset: + # Create fraction split spec + fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=training_fraction_split, + validation_fraction=validation_fraction_split, + test_fraction=test_fraction_split, + ) + + # create input data config + input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=fraction_split, + dataset_id=dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=self._base_output_dir + ), + ) + + return input_data_config # TODO(b/172365904) add filter split, training_pipeline.FilterSplit # TODO(b/172366411) predefined filter split training_pipeline.PredfinedFilterSplit @@ -776,10 +1129,21 @@ def run( if args: spec["pythonPackageSpec"]["args"] = args - managed_model = None + training_task_inputs = json_format.ParseDict( + { + "workerPoolSpecs": worker_pool_specs, + "baseOutputDirectory": {"output_uri_prefix": base_output_dir}, + }, + struct_pb2.Value(), + ) + + training_task_definition = schema.training_job.definition.custom_task + # create model payload + managed_model = None if model_display_name: utils.validate_display_name(model_display_name) + container_spec = gca_model.ModelContainerSpec( image_uri=self._model_serving_container_image_uri, predict_route=self._model_serving_container_predict_route, @@ -790,126 +1154,24 @@ def run( display_name=model_display_name, container_spec=container_spec ) - input_data_config = None - if dataset: - # Create fraction split spec - fraction_split = gca_training_pipeline.FractionSplit( - training_fraction=training_fraction_split, - validation_fraction=validation_fraction_split, - test_fraction=test_fraction_split, - ) - - # create input data config - input_data_config = gca_training_pipeline.InputDataConfig( - fraction_split=fraction_split, - dataset_id=dataset.name, - gcs_destination=gca_io.GcsDestination( - output_uri_prefix=base_output_dir - ), - ) + self._base_output_dir = base_output_dir - # create training pipeline - training_pipeline = gca_training_pipeline.TrainingPipeline( - display_name=self._display_name, - training_task_definition=schema.training_job.definition.custom_task, - training_task_inputs=json_format.ParseDict( - { - "workerPoolSpecs": worker_pool_specs, - "baseOutputDirectory": {"output_uri_prefix": base_output_dir}, - }, - struct_pb2.Value(), - ), - model_to_upload=managed_model, - input_data_config=input_data_config, + model = self._run_job( + training_task_definition=training_task_definition, + training_task_inputs=training_task_inputs, + dataset=dataset, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + model=managed_model, ) - training_pipeline = self.api_client.create_training_pipeline( - parent=initializer.global_config.common_location_path( - self.project, self.location - ), - training_pipeline=training_pipeline, - ) - - self._gca_resource = training_pipeline + self._base_output_dir = None - _LOGGER.info("View Training:\n%s" % self._dashboard_uri()) _LOGGER.info("Training Output directory:\n%s " % base_output_dir) - model = self._get_model() - - if model is None: - _LOGGER.warn( - "Training did not produce a Managed Model returning None. " - + self._model_upload_fail_string - ) return model - def _sync_gca_resource(self): - """Helper method to sync the local gca_source against the service.""" - self._gca_resource = self.api_client.get_training_pipeline( - name=self.resource_name - ) - - def _block_until_complete(self): - """Helper method to block and check on job until complete.""" - - # Used these numbers so failures surface fast - wait = 5 # start at five seconds - max_wait = 60 * 5 # 5 minute wait - multiplier = 2 # scale wait by 2 every iteration - - while self.state not in _PIPELINE_COMPLETE_STATES: - self._sync_gca_resource() - time.sleep(wait) - _LOGGER.info( - "Training %s current state:\n%s" - % (self._gca_resource.name, self._gca_resource.state) - ) - wait = min(wait * multiplier, max_wait) - - self._raise_failure() - - if self._gca_resource.model_to_upload and not self.has_failed: - _LOGGER.info( - "Model available at %s" % self._gca_resource.model_to_upload.name - ) - - def _raise_failure(self): - """Helper method to raise failure if TrainingPipeline fails. - - Raises: - RuntimeError: If training failed.""" - if self._gca_resource.error.code != code_pb2.OK: - raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error) - - def _get_model(self) -> Optional[models.Model]: - """Helper method to get and instantiate the Model to Upload. - - Returns: - model: AI Platform Model if training succeeded and produced an AI Platform - Model. None otherwise. - - Raises: - RuntimeError if Training failed. - """ - self._block_until_complete() - - if self.has_failed: - raise RuntimeError( - f"Training Pipeline {self.resource_name} failed. No model available." - ) - - if not self._gca_resource.model_to_upload: - return None - - if self._gca_resource.model_to_upload.name: - fields = utils.extract_fields_from_resource_name( - self._gca_resource.model_to_upload.name - ) - return models.Model( - fields.id, project=fields.project, location=fields.location - ) - @property def _model_upload_fail_string(self) -> str: """Helper property for model upload failure.""" @@ -921,50 +1183,6 @@ def _model_upload_fail_string(self) -> str: "os.environ['AIP_MODEL_DIR']." ) - def get_model(self) -> Optional[models.Model]: - """AI Platform Model produced by this training, if one was produced. - - Returns: - model: AI Platform Model produced by this training or None if a model was - not produced by this training. - """ - self._assert_has_run() - if not self._gca_resource.model_to_upload: - raise RuntimeError(self._model_upload_fail_string) - - return self._get_model() - - @property - def _has_run(self) -> bool: - """Helper property to check if this training job has been run.""" - return self._gca_resource is not None - - def _assert_has_run(self): - """Helper method to assert that this training has run.""" - if not self._has_run: - raise RuntimeError( - "TrainingPipeline has not been launched. You must run this" - " TrainingPipeline using TrainingPipeline.run. " - ) - - @property - def state(self) -> gca_pipeline_state.PipelineState: - """Current training state.""" - self._assert_has_run() - return self._gca_resource.state - - @property - def has_failed(self) -> bool: - """Returns True if training has failed. False otherwise.""" - self._assert_has_run() - return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED - - def _dashboard_uri(self) -> str: - """Helper method to compose the dashboard uri where training can be viewed.""" - fields = utils.extract_fields_from_resource_name(self.resource_name) - url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/training/{fields.id}?project={fields.project}" - return url - -class AutoMLTablesTrainingJob(base.AiPlatformResourceNoun): +class AutoMLTablesTrainingJob(_TrainingJob): pass