-
Notifications
You must be signed in to change notification settings - Fork 349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: Pull out reusable code in CustomTrainingJob to use in other training jobs #49
refactor: Pull out reusable code in CustomTrainingJob to use in other training jobs #49
Conversation
53a91f9
to
79c1a16
Compare
) | ||
self._gca_resource = None | ||
|
||
def _create_managed_model(self, model_display_name: str) -> gca_model.Model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Encapsulated this logic into a function to emphasize the parameter dependencies to generate a managed model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, seems like this isn't used by AutoML. I guess it's specific to Custom Training then? Will move there then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up having each subclass create their own Model
and pass it in to the base class. This is due to the fact that each training job type has different requirements as defined in the yaml files.
|
||
return managed_model | ||
|
||
def _create_input_data_config(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Encapsulated this logic into a function to emphasize the parameter dependencies to generate a input data config.
base_output_dir: Optional[str] = None): | ||
"""Runs the training job. | ||
""" | ||
if self._has_run: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only functionality that differs from the existing CustomTrainingJob
implementation.
This validation was previously performed earlier but I've moved it into the base class (i.e. it occurs later than before) to hide it from the subclasses implementation details.
return model | ||
|
||
@property | ||
def state(self) -> gca_pipeline_state.PipelineState: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied verbatim from CustomModelTrainingJob
) | ||
|
||
@property | ||
def _model_upload_fail_string(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied verbatim from CustomModelTrainingJob
) | ||
|
||
@property | ||
def _has_run(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied verbatim from CustomModelTrainingJob
"""Helper property to check if this training job has been run.""" | ||
return self._gca_resource is not None | ||
|
||
def _assert_has_run(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied verbatim from CustomModelTrainingJob
] | ||
) | ||
|
||
class DatasetWithSplits(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New class to encapsulate all the parameters that need to be passed together.
self.test_fraction_split=test_fraction_split | ||
|
||
|
||
class TrainingJob(base.AiPlatformResourceNoun): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can potentially make an abc to prevent instantiation. Thoughts?
The preference from the design is to flatten coupled arguments as long the objects they represent are not deeply nested. There is precedence in ML libraries for this pattern: keras.Model.fit with validation_split, similar the API surface discussed here. The aim for this SDK is to feel familiar to practitioners that use these libraries. We'll gate on feedback on whether we should change this.
We should gate this on whether the organization makes sense to group these classes together in the same module not necessarily that there is some threshold of number of lines we want to stay under. Since these all relate to training using the PipelineService I think we're safe to keep them in a single module.
Looking at the code, I'm not sure it would prevent initialization because it looks like the interface is completely defined. We should use an ABC to define an interface that concrete classes should implement. If our parent class already fully implements that interface it should be usable. To me, based on the code, it looks like you may want to separate out some of the shared logic as private methods in the ABC and require the concrete classes implement some of the public interface. Either that or derive AutoMLTablesTrainingJob from CustomTrainingJob.
All methods except for unit tests. |
Thanks for the clarifications @sasha-gitg, they all make sense to me. I will make the changes. |
Sure thing. One consideration for splitting off the base class into its own module is whether other teams will be building off the base class or will it just be us. Since it looks like its just us, a relatively small group of developers, keeping it in the same module sgtm. However, putting everything in one module has the issue that we're combining all the import statements. Examining the import dependencies typically are a great way to check if the coupling of your class. For example, if However, we also want to be able import training jobs in one line:
It seems like we can keep everything in one TrainingJob module but still have multiple files: Would you be open to this approach or do you still prefer all training_job's in one file? |
@sasha-gitg What do you want me to do with the tests? Your CustomTraining tests already ensure coverage of the super class. Do you want me to refactor the tests in some way as well? |
If test are passing with full coverage after refactoring then we can move forward with the tests we have. |
This is a more compelling argument to split the files but I would move this decision to the PR that implements
I don't see value in this considering we expose all classes on the SDKs surface and the expectation is that our internal code imports implementations directly. Our style guide also enforces module level imports.
Yes, this is the approach we would take to namespace all these classes to the same module when splitting them into different files. But if we're going to add a level of indirection we should justify that tradeoff with concrete benefits. As the SDK grows these benefits we'll become more apparent but I would prefer we avoid adding premature complexity. I do think that implementing AutoMLTablesTrainingJob will surface those concrete benefits. |
Sounds good, thanks for the review! I'll move everything into the |
8548eb2
to
8a2b53d
Compare
Sorry, I meant to say that we want to import CustomTrainingJob and AutoMLTablesTrainingJob as part of the same module. Not necessarily a class-level import. |
For tabular Datasets, all their data is exported to | ||
training, to pick and choose from. | ||
training_fraction_split (float): | ||
The fraction of the input data that is to be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sasha-gitg sometimes I see Required
and sometimes Optional
and sometimes nothing is written here. Any guidance on why nothing is written for this parameter (I copied it from CustomTrainingJobClass).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are generally copied from the protos if they represent the same field: https://github.com/googleapis/python-aiplatform/blob/dev/google/cloud/aiplatform_v1beta1/types/training_pipeline.py#L260
So we inherit the arg commenting from the service.
For comments we add, it's only necessary we mark them as Required if they are indeed so.
Args: | ||
display_name (str): | ||
Required. The user-defined name of this TrainingPipeline. | ||
container_uri (str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this only relevant to Custom Training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok removed from base class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are broken in the build. Those should pass to ensure this refactor did not break the current class.
Nit: The PR title should technically be called "refactor:..." as we shouldn't be altering behavior with this change.
"model_serving_container_image_uri and model_display_name passed in. " | ||
"Ensure that your training script saves to model to " | ||
"os.environ['AIP_MODEL_DIR']." | ||
return super().run_job( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to call super here since this method is not overridden.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
def _model_upload_fail_string(self) -> str: | ||
"""Helper property for model upload failure.""" | ||
return ( | ||
f"Training Pipeline {self.resource_name} is not configured to upload a " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This message seems to only apply to custom training jobs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
sgtm @sasha-gitg! will make the changes today |
@sasha-gitg put in the fixes! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Requested a few minor changes. Also need to resolve merge conflicts with distributed training PR (those changes should only affect the CustomTrainingJob
class).
@@ -26,6 +26,7 @@ | |||
import time | |||
from typing import Callable, List, Optional, Sequence, Union | |||
|
|||
from abc import ABC, abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module level import here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering about this. So according to the style guide I should be import abc, and then refer use abc.abstractmethod
every time in the code?
What about the line above it?
from typing import Callable, List, Optional, Sequence, Union
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes on abc.abstractmethod
.
typing
is an exception to the rule typing-imports.
Args: | ||
display_name (str): | ||
Required. The user-defined name of this TrainingPipeline. | ||
container_uri (str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the docstring.
input_data_config = gca_training_pipeline.InputDataConfig( | ||
fraction_split=fraction_split, | ||
dataset_id=dataset.name, | ||
gcs_destination=gca_io.GcsDestination( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the field gcs_destination
only used for Custom Training? The site https://cloud.google.com/ai-platform-unified/docs/reference/rest/v1beta1/projects.locations.trainingPipelines#TrainingPipeline.FIELDS.training_task_definition only says:
object (GcsDestination)
The Cloud Storage location where the training data is to be written to. In the given directory a new directory is created with name: dataset-<dataset-id>-<annotation-type>-<timestamp-of-training-call> where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. All training input data is written into that directory.
The AI Platform environment variables representing Cloud Storage data URIs are represented in the Cloud Storage wildcard format to support sharded data. e.g.: "gs://.../training-*.jsonl"
AIP_DATA_FORMAT = "jsonl" for non-tabular data, "csv" for tabular data
AIP_TRAINING_DATA_URI =
"gcsDestination/dataset---/training-*.${AIP_DATA_FORMAT}"
AIP_VALIDATION_DATA_URI =
"gcsDestination/dataset---/validation-*.${AIP_DATA_FORMAT}"
AIP_TEST_DATA_URI =
"gcsDestination/dataset---/test-*.${AIP_DATA_FORMAT}"
What training data is being written? Not too familiar with AutoML yet and it isn't obvious to me why this is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Yes, I think this is custom training specific. Though, it would be worth confirming by using the API without passing the field.
3e525b7
to
e019332
Compare
Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately
e019332
to
b9f9c8f
Compare
72a78d0
to
8db3cc1
Compare
… training jobs (googleapis#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung <[email protected]>
… training jobs (googleapis#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung <[email protected]>
… training jobs (googleapis#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung <[email protected]>
… training jobs (googleapis#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung <[email protected]>
… training jobs (googleapis#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung <[email protected]>
… training jobs (googleapis#49) * Extracted reusable CustomTrainingJob code into TrainingJob base class Moved around some functions Completed refactor to use base class bug: remove requirement for import_schema_uri when passing in gcs_source (googleapis#46) Ran linters Removed model from TrainingJob and moved to CustomTrainingJob Removed DatasetWithSplits Added doc strings and simplified training_job_base code Moved TrainingJob class into training_jobs.py Removed container_uri from base TrainingJob class Addressed comments Fixed managed model Ran linter Fixed issues with abc, doc string and super call Refactored to create input data config separately * Ran linter Co-authored-by: Ivan Cheung <[email protected]>
Summary
There is reusable code in
CustomTrainingJob
, specifically in theinit
andrun
methods. Other training jobs such asAutoMLTablesTrainingJob
can use it as well.The code mainly has to do with creating and running a training pipeline.
The leftover code in
CustomTrainingJob
will thus only be related to custom training and not generic pipeline code.In preparation for: https://b.corp.google.com/issues/172282518
Remaining questions
1. In theTrainingJob.run
method,training_fraction_split
,validation_fraction_split
,test_fraction_split: float
are all only needed when thedataset
parameter is also provided. I've grouped this together intoDatasetWithSplits
. I'd want to replace theCustomTrainingJob.run
parameters with this as well. What do you all think?2. Should we putTrainingJob
into its own file and separate outCustomTrainingJob
andAutoMLTablesTrainingJob
into their own? If not, how should we organize things as to not have massive files?3. Should I makeTrainingJob
an abstract base class to prevent initialization? I can also do this by overridingnew
:4. Docstrings: On which functions are they needed? I'll try to look up the Google guidelines.TODO
- [ ] Write unit tests for the genericTrainingJob
class.Testing
Passes all existing unit tests