Skip to content

Commit

Permalink
feat: Support Model Serialization in Vertex Experiments(xgboost)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501505712
  • Loading branch information
jaycee-li authored and copybara-github committed Jan 12, 2023
1 parent d4deed3 commit fe75eba
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 29 deletions.
85 changes: 78 additions & 7 deletions google/cloud/aiplatform/metadata/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import importlib
import os
import pickle
import tempfile
Expand Down Expand Up @@ -45,12 +46,17 @@
"save_method": "_save_sklearn_model",
"load_method": "_load_sklearn_model",
"model_file": "model.pkl",
}
},
"xgboost": {
"save_method": "_save_xgboost_model",
"load_method": "_load_xgboost_model",
"model_file": "model.bst",
},
}


def save_model(
model: "sklearn.base.BaseEstimator", # noqa: F821
model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821
artifact_id: Optional[str] = None,
*,
uri: Optional[str] = None,
Expand All @@ -63,7 +69,7 @@ def save_model(
) -> google_artifact_schema.ExperimentModel:
"""Saves a ML model into a MLMD artifact.
Supported model frameworks: sklearn.
Supported model frameworks: sklearn, xgboost.
Example usage:
aiplatform.init(project="my-project", location="my-location", staging_bucket="gs://my-bucket")
Expand All @@ -72,7 +78,7 @@ def save_model(
aiplatform.save_model(model, "my-sklearn-model")
Args:
model (sklearn.base.BaseEstimator):
model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]):
Required. A machine learning model.
artifact_id (str):
Optional. The resource id of the artifact. This id must be globally unique
Expand Down Expand Up @@ -116,10 +122,23 @@ def save_model(
except ImportError:
pass
else:
if isinstance(model, sklearn.base.BaseEstimator):
# An instance of sklearn.base.BaseEstimator might be a sklearn model
# or a xgboost/lightgbm model implemented on top of sklearn.
if isinstance(
model, sklearn.base.BaseEstimator
) and model.__class__.__module__.startswith("sklearn"):
framework_name = "sklearn"
framework_version = sklearn.__version__

try:
import xgboost as xgb
except ImportError:
pass
else:
if isinstance(model, (xgb.Booster, xgb.XGBModel)):
framework_name = "xgboost"
framework_version = xgb.__version__

if framework_name not in _FRAMEWORK_SPECS:
raise ValueError(
f"Model type {model.__class__.__module__}.{model.__class__.__name__} not supported."
Expand Down Expand Up @@ -305,9 +324,24 @@ def _save_sklearn_model(
pickle.dump(model, f, protocol=_PICKLE_PROTOCOL)


def _save_xgboost_model(
model: Union["xgb.Booster", "xgb.XGBModel"], # noqa: F821
path: str,
):
"""Saves a xgboost model.
Args:
model (Union[xgb.Booster, xgb.XGBModel]):
Requred. A xgboost model.
path (str):
Required. The local path to save the model.
"""
model.save_model(path)


def load_model(
model: Union[str, google_artifact_schema.ExperimentModel]
) -> "sklearn.base.BaseEstimator": # noqa: F821
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster"]: # noqa: F821
"""Retrieves the original ML model from an ExperimentModel resource.
Args:
Expand Down Expand Up @@ -375,7 +409,44 @@ def _load_sklearn_model(
return sk_model


# TODO(b/264893283)
def _load_xgboost_model(
model_file: str,
model_artifact: google_artifact_schema.ExperimentModel,
) -> Union["xgb.Booster", "xgb.XGBModel"]: # noqa: F821
"""Loads a xgboost model from local path.
Args:
model_file (str):
Required. A local model file to load.
model_artifact (google_artifact_schema.ExperimentModel):
Required. The artifact that saved the model.
Returns:
The xgboost model instance.
Raises:
ImportError: if xgboost is not installed.
"""
try:
import xgboost as xgb
except ImportError:
raise ImportError(
"xgboost is not installed and is required for loading models."
) from None

if xgb.__version__ < model_artifact.framework_version:
_LOGGER.warning(
f"The original model was saved via xgboost {model_artifact.framework_version}. "
f"You are using xgboost {xgb.__version__}."
"Attempting to load model..."
)

module, class_name = model_artifact.model_class.rsplit(".", maxsplit=1)
xgb_model = getattr(importlib.import_module(module), class_name)()
xgb_model.load_model(model_file)

return xgb_model


def register_model(
model: Union[str, google_artifact_schema.ExperimentModel],
*,
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def log_classification_metrics(
@_v1_not_supported
def log_model(
self,
model: "sklearn.base.BaseEstimator", # noqa: F821
model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821
artifact_id: Optional[str] = None,
*,
uri: Optional[str] = None,
Expand All @@ -1121,7 +1121,7 @@ def log_model(
) -> google_artifact_schema.ExperimentModel:
"""Saves a ML model into a MLMD artifact and log it to this ExperimentRun.
Supported model frameworks: sklearn.
Supported model frameworks: sklearn, xgboost.
Example usage:
model = LinearRegression()
Expand All @@ -1136,7 +1136,7 @@ def log_model(
aiplatform.log_model(model, "my-sklearn-model")
Args:
model (sklearn.base.BaseEstimator):
model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]):
Required. A machine learning model.
artifact_id (str):
Optional. The resource id of the artifact. This id must be globally unique
Expand Down
6 changes: 3 additions & 3 deletions google/cloud/aiplatform/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def log_classification_metrics(

def log_model(
self,
model: "sklearn.base.BaseEstimator", # noqa: F821
model: Union["sklearn.base.BaseEstimator", "xgb.Booster"], # noqa: F821
artifact_id: Optional[str] = None,
*,
uri: Optional[str] = None,
Expand All @@ -489,7 +489,7 @@ def log_model(
) -> google_artifact_schema.ExperimentModel:
"""Saves a ML model into a MLMD artifact and log it to this ExperimentRun.
Supported model frameworks: sklearn.
Supported model frameworks: sklearn, xgboost.
Example usage:
model = LinearRegression()
Expand All @@ -504,7 +504,7 @@ def log_model(
aiplatform.log_model(model, "my-sklearn-model")
Args:
model (sklearn.base.BaseEstimator):
model (Union["sklearn.base.BaseEstimator", "xgb.Booster"]):
Required. A machine learning model.
artifact_id (str):
Optional. The resource id of the artifact. This id must be globally unique
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import copy
from typing import Optional, Dict, List, Sequence
from typing import Optional, Dict, List, Sequence, Union

from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import explain
Expand Down Expand Up @@ -742,7 +742,9 @@ def framework_version(self) -> Optional[str]:
def model_class(self) -> Optional[str]:
return self.metadata.get("modelClass")

def load_model(self) -> "sklearn.base.BaseEstimator": # noqa: F821
def load_model(
self,
) -> Union["sklearn.base.BaseEstimator", "xgb.Booster"]: # noqa: F821
"""Retrieves the original ML model from an ExperimentModel.
Example usage:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
testing_extra_require = (
full_extra_require
+ profiler_extra_require
+ ["grpcio-testing", "pytest-asyncio", "pytest-xdist", "ipython", "kfp"]
+ ["grpcio-testing", "pytest-asyncio", "pytest-xdist", "ipython", "kfp", "xgboost"]
)


Expand Down
Loading

0 comments on commit fe75eba

Please sign in to comment.