Skip to content

Commit

Permalink
feat: Support Model Serialization in Vertex Experiments(sklearn)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501487417
  • Loading branch information
jaycee-li authored and copybara-github committed Jan 12, 2023
1 parent 94b2f29 commit d4deed3
Show file tree
Hide file tree
Showing 14 changed files with 1,830 additions and 16 deletions.
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,17 @@
log_classification_metrics = (
metadata.metadata._experiment_tracker.log_classification_metrics
)
log_model = metadata.metadata._experiment_tracker.log_model
get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df
start_run = metadata.metadata._experiment_tracker.start_run
start_execution = metadata.metadata._experiment_tracker.start_execution
log = metadata.metadata._experiment_tracker.log
log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics
end_run = metadata.metadata._experiment_tracker.end_run

save_model = metadata._models.save_model
get_experiment_model = metadata.schema.google.artifact_schema.ExperimentModel.get

Experiment = metadata.experiment_resources.Experiment
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
Artifact = metadata.artifact.Artifact
Expand All @@ -116,11 +120,14 @@
"log_params",
"log_metrics",
"log_classification_metrics",
"log_model",
"log_time_series_metrics",
"get_experiment_df",
"get_pipeline_df",
"start_run",
"start_execution",
"save_model",
"get_experiment_model",
"Artifact",
"AutoMLImageTrainingJob",
"AutoMLTabularTrainingJob",
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/aiplatform/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
is_prebuilt_prediction_container_uri = (
container_uri_builders.is_prebuilt_prediction_container_uri
)
_get_closest_match_prebuilt_container_uri = (
container_uri_builders._get_closest_match_prebuilt_container_uri
)

__all__ = (
"get_prebuilt_prediction_container_uri",
"is_prebuilt_prediction_container_uri",
"_get_closest_match_prebuilt_container_uri",
)
106 changes: 104 additions & 2 deletions google/cloud/aiplatform/helpers/container_uri_builders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Google LLC
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,9 +14,11 @@

import re
from typing import Optional
import warnings

from google.cloud.aiplatform.constants import prediction
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.constants import prediction
from packaging import version


def get_prebuilt_prediction_container_uri(
Expand Down Expand Up @@ -122,3 +124,103 @@ def is_prebuilt_prediction_container_uri(image_uri: str) -> bool:
If the image is prebuilt by Vertex AI prediction.
"""
return re.fullmatch(prediction.CONTAINER_URI_REGEX, image_uri) is not None


# TODO(b/264191784) Deduplicate this method
def _get_closest_match_prebuilt_container_uri(
framework: str,
framework_version: str,
region: Optional[str] = None,
accelerator: str = "cpu",
) -> str:
"""Return a pre-built container uri that is suitable for a specific framework and version.
If there is no exact match for the given version, the closest one that is
higher than the input version will be used.
Args:
framework (str):
Required. The ML framework of the pre-built container. For example,
`"tensorflow"`, `"xgboost"`, or `"sklearn"`
framework_version (str):
Required. The version of the specified ML framework as a string.
region (str):
Optional. AI region or multi-region. Used to select the correct
Artifact Registry multi-region repository and reduce latency.
Must start with `"us"`, `"asia"` or `"europe"`.
Default is location set by `aiplatform.init()`.
accelerator (str):
Optional. The type of accelerator support provided by container. For
example: `"cpu"` or `"gpu"`
Default is `"cpu"`.
Returns:
A string representing the pre-built container uri.
Raises:
ValueError: If the framework doesn't have suitable pre-built container.
"""
URI_MAP = prediction._SERVING_CONTAINER_URI_MAP
DOCS_URI_MESSAGE = (
f"See {prediction._SERVING_CONTAINER_DOCUMENTATION_URL} "
"for complete list of supported containers"
)

# If region not provided, use initializer location
region = region or initializer.global_config.location
region = region.split("-", 1)[0]
framework = framework.lower()

if not URI_MAP.get(region):
raise ValueError(
f"Unsupported container region `{region}`, supported regions are "
f"{', '.join(URI_MAP.keys())}. "
f"{DOCS_URI_MESSAGE}"
)

if not URI_MAP[region].get(framework):
raise ValueError(
f"No containers found for framework `{framework}`. Supported frameworks are "
f"{', '.join(URI_MAP[region].keys())} {DOCS_URI_MESSAGE}"
)

if not URI_MAP[region][framework].get(accelerator):
raise ValueError(
f"{framework} containers do not support `{accelerator}` accelerator. Supported accelerators "
f"are {', '.join(URI_MAP[region][framework].keys())}. {DOCS_URI_MESSAGE}"
)

framework_version = version.Version(framework_version)
available_version_list = [
version.Version(available_version)
for available_version in URI_MAP[region][framework][accelerator].keys()
]
try:
closest_version = min(
[
available_version
for available_version in available_version_list
if available_version >= framework_version
# manually implement Version.major for packaging < 20.0
and available_version._version.release[0]
== framework_version._version.release[0]
]
)
except ValueError:
raise ValueError(
f"You are using `{framework}` version `{framework_version}`. "
f"Vertex pre-built containers support up to `{framework}` version "
f"`{max(available_version_list)}` and don't assume forward compatibility. "
f"Please build your own custom container. {DOCS_URI_MESSAGE}"
) from None

if closest_version != framework_version:
warnings.warn(
f"No exact match for `{framework}` version `{framework_version}`. "
f"Pre-built container for `{framework}` version `{closest_version}` is used. "
f"{DOCS_URI_MESSAGE}"
)

final_uri = URI_MAP[region][framework][accelerator].get(str(closest_version))

return final_uri
Loading

0 comments on commit d4deed3

Please sign in to comment.