diff --git a/google/cloud/aiplatform/metadata/artifact.py b/google/cloud/aiplatform/metadata/artifact.py index 2c30d51bfb..65ee2cb92b 100644 --- a/google/cloud/aiplatform/metadata/artifact.py +++ b/google/cloud/aiplatform/metadata/artifact.py @@ -31,6 +31,7 @@ from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import resource from google.cloud.aiplatform.metadata import utils as metadata_utils +from google.cloud.aiplatform.metadata.schema import base_artifact from google.cloud.aiplatform.utils import rest_utils @@ -326,6 +327,56 @@ def create( credentials=credentials, ) + @classmethod + def create_from_base_artifact_schema( + cls, + *, + base_artifact_schema: "base_artifact.BaseArtifactSchema", + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Artifact": + """Creates a new Metadata Artifact from a BaseArtifactSchema class instance. + + Args: + base_artifact_schema (BaseArtifactSchema): + Required. An instance of the BaseArtifactType class that can be + provided instead of providing artifact specific parameters. + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//artifacts/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Artifact. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Artifact. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Artifact. Overrides + credentials set in aiplatform.init. + + Returns: + Artifact: Instantiated representation of the managed Metadata Artifact. + """ + + return cls.create( + resource_id=base_artifact_schema.artifact_id, + schema_title=base_artifact_schema.schema_title, + uri=base_artifact_schema.uri, + display_name=base_artifact_schema.display_name, + schema_version=base_artifact_schema.schema_version, + description=base_artifact_schema.description, + metadata=base_artifact_schema.metadata, + state=base_artifact_schema.state, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + @property def uri(self) -> Optional[str]: "Uri for this Artifact." diff --git a/google/cloud/aiplatform/metadata/execution.py b/google/cloud/aiplatform/metadata/execution.py index 9a85bce36f..895417fc64 100644 --- a/google/cloud/aiplatform/metadata/execution.py +++ b/google/cloud/aiplatform/metadata/execution.py @@ -31,6 +31,7 @@ from google.cloud.aiplatform.metadata import artifact from google.cloud.aiplatform.metadata import metadata_store from google.cloud.aiplatform.metadata import resource +from google.cloud.aiplatform.metadata.schema import base_execution class Execution(resource._Resource): @@ -166,6 +167,57 @@ def create( return self + @classmethod + def create_from_base_execution_schema( + cls, + *, + base_execution_schema: "base_execution.BaseExecutionSchema", + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "Execution": + """ + Creates a new Metadata Execution. + + Args: + base_execution_schema (BaseExecutionSchema): + An instance of the BaseExecutionSchema class that can be + provided instead of providing schema specific parameters. + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//artifacts/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Execution. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Execution. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Execution. Overrides + credentials set in aiplatform.init. + + Returns: + Execution: Instantiated representation of the managed Metadata Execution. + + """ + resource = Execution.create( + state=base_execution_schema.state, + schema_title=base_execution_schema.schema_title, + resource_id=base_execution_schema.execution_id, + display_name=base_execution_schema.display_name, + schema_version=base_execution_schema.schema_version, + metadata=base_execution_schema.metadata, + description=base_execution_schema.description, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + return resource + def __enter__(self): if self.state is not gca_execution.Execution.State.RUNNING: self.update(state=gca_execution.Execution.State.RUNNING) diff --git a/google/cloud/aiplatform/metadata/schema/base_artifact.py b/google/cloud/aiplatform/metadata/schema/base_artifact.py new file mode 100644 index 0000000000..c89d989edd --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/base_artifact.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc + +from typing import Optional, Dict + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform.compat.types import artifact as gca_artifact +from google.cloud.aiplatform.metadata import artifact +from google.cloud.aiplatform.metadata import constants + + +class BaseArtifactSchema(metaclass=abc.ABCMeta): + """Base class for Metadata Artifact types.""" + + @property + @classmethod + @abc.abstractmethod + def schema_title(cls) -> str: + """Identifies the Vertex Metadta schema title used by the resource.""" + pass + + def __init__( + self, + *, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + + """Initializes the Artifact with the given name, URI and metadata. + + This is the base class for defining various artifact types, which can be + passed to google.Artifact to create a corresponding resource. + Artifacts carry a `metadata` field, which is a dictionary for storing + metadata related to this artifact. Subclasses from ArtifactType can enforce + various structure and field requirements for the metadata field. + + Args: + resource_id (str): + Optional. The portion of the Artifact name with + the following format, this is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + self.artifact_id = artifact_id + self.uri = uri + self.display_name = display_name + self.schema_version = schema_version or constants._DEFAULT_SCHEMA_VERSION + self.description = description + self.metadata = metadata + self.state = state + + def create( + self, + *, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "artifact.Artifact": + """Creates a new Metadata Artifact. + + Args: + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//artifacts/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Artifact. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Artifact. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Artifact. Overrides + credentials set in aiplatform.init. + Returns: + Artifact: Instantiated representation of the managed Metadata Artifact. + """ + return artifact.Artifact.create_from_base_artifact_schema( + base_artifact_schema=self, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) diff --git a/google/cloud/aiplatform/metadata/schema/base_execution.py b/google/cloud/aiplatform/metadata/schema/base_execution.py new file mode 100644 index 0000000000..811b7d9791 --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/base_execution.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc + +from typing import Optional, Dict + +from google.auth import credentials as auth_credentials + +from google.cloud.aiplatform.compat.types import execution as gca_execution +from google.cloud.aiplatform.metadata import constants +from google.cloud.aiplatform.metadata import execution + + +class BaseExecutionSchema(metaclass=abc.ABCMeta): + """Base class for Metadata Execution schema.""" + + @property + @classmethod + @abc.abstractmethod + def schema_title(cls) -> str: + """Identifies the Vertex Metadta schema title used by the resource.""" + pass + + def __init__( + self, + *, + state: Optional[ + gca_execution.Execution.State + ] = gca_execution.Execution.State.RUNNING, + execution_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + + """Initializes the Execution with the given name, URI and metadata. + + Args: + state (gca_execution.Execution.State.RUNNING): + Optional. State of this Execution. Defaults to RUNNING. + execution_id (str): + Optional. The portion of the Execution name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//executions/. + display_name (str): + Optional. The user-defined name of the Execution. + schema_version (str): + Optional. schema_version specifies the version used by the Execution. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Execution. + description (str): + Optional. Describes the purpose of the Execution to be created. + """ + self.state = state + self.execution_id = execution_id + self.display_name = display_name + self.schema_version = schema_version or constants._DEFAULT_SCHEMA_VERSION + self.metadata = metadata + self.description = description + + def create( + self, + *, + metadata_store_id: Optional[str] = "default", + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "execution.Execution": + """Creates a new Metadata Execution. + + Args: + metadata_store_id (str): + Optional. The portion of the resource name with + the format: + projects/123/locations/us-central1/metadataStores//executions/ + If not provided, the MetadataStore's ID will be set to "default". + project (str): + Optional. Project used to create this Execution. Overrides project set in + aiplatform.init. + location (str): + Optional. Location used to create this Execution. Overrides location set in + aiplatform.init. + credentials (auth_credentials.Credentials): + Optional. Custom credentials used to create this Execution. Overrides + credentials set in aiplatform.init. + Returns: + Execution: Instantiated representation of the managed Metadata Execution. + + """ + self.execution = execution.Execution.create_from_base_execution_schema( + base_execution_schema=self, + metadata_store_id=metadata_store_id, + project=project, + location=location, + credentials=credentials, + ) + return self.execution diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py new file mode 100644 index 0000000000..99e0fb0ba6 --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -0,0 +1,270 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Optional, Dict + +from google.cloud.aiplatform.compat.types import artifact as gca_artifact +from google.cloud.aiplatform.metadata.schema import base_artifact +from google.cloud.aiplatform.metadata.schema import utils + +# The artifact property key for the resource_name +_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME = "resourceName" + + +class VertexDataset(base_artifact.BaseArtifactSchema): + """An artifact representing a Vertex Dataset.""" + + schema_title = "google.VertexDataset" + + def __init__( + self, + *, + vertex_dataset_name: str, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + vertex_dataset_name (str): + The name of the Dataset resource, in a form of + projects/{project}/locations/{location}/datasets/{dataset}. For + more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.datasets/get + This is used to generate the resource uri as follows: + https://{service-endpoint}/v1/{dataset_name}, + where {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + extended_metadata[_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME] = vertex_dataset_name + + super(VertexDataset, self).__init__( + uri=utils.create_uri_from_resource_name(resource_name=vertex_dataset_name), + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class VertexModel(base_artifact.BaseArtifactSchema): + """An artifact representing a Vertex Model.""" + + schema_title = "google.VertexModel" + + def __init__( + self, + *, + vertex_model_name: str, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + vertex_model_name (str): + The name of the Model resource, in a form of + projects/{project}/locations/{location}/models/{model}. For + more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models/get + This is used to generate the resource uri as follows: + https://{service-endpoint}/v1/{vertex_model_name}, + where {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + extended_metadata[_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME] = vertex_model_name + + super(VertexModel, self).__init__( + uri=utils.create_uri_from_resource_name(resource_name=vertex_model_name), + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class VertexEndpoint(base_artifact.BaseArtifactSchema): + """An artifact representing a Vertex Endpoint.""" + + schema_title = "google.VertexEndpoint" + + def __init__( + self, + *, + vertex_endpoint_name: str, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + vertex_endpoint_name (str): + The name of the Endpoint resource, in a form of + projects/{project}/locations/{location}/endpoints/{endpoint}. For + more details, see + https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/get + This is used to generate the resource uri as follows: + https://{service-endpoint}/v1/{vertex_endpoint_name}, + where {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + extended_metadata[_ARTIFACT_PROPERTY_KEY_RESOURCE_NAME] = vertex_endpoint_name + + super(VertexEndpoint, self).__init__( + uri=utils.create_uri_from_resource_name(resource_name=vertex_endpoint_name), + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class UnmanagedContainerModel(base_artifact.BaseArtifactSchema): + """An artifact representing a Vertex Unmanaged Container Model.""" + + schema_title = "google.UnmanagedContainerModel" + + def __init__( + self, + *, + predict_schema_ta: utils.PredictSchemata, + container_spec: utils.ContainerSpec, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + predict_schema_ta (PredictSchemata): + An instance of PredictSchemata which holds instance, parameter and prediction schema uris. + container_spec (ContainerSpec): + An instance of ContainerSpec which holds the container configuration for the model. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + extended_metadata["predictSchemata"] = predict_schema_ta.to_dict() + extended_metadata["containerSpec"] = container_spec.to_dict() + + super(UnmanagedContainerModel, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) diff --git a/google/cloud/aiplatform/metadata/schema/system/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/system/artifact_schema.py new file mode 100644 index 0000000000..f3491a5573 --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/system/artifact_schema.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from typing import Optional, Dict + +from google.cloud.aiplatform.compat.types import artifact as gca_artifact +from google.cloud.aiplatform.metadata.schema import base_artifact + + +class Model(base_artifact.BaseArtifactSchema): + """Artifact type for model.""" + + schema_title = "system.Model" + + def __init__( + self, + *, + uri: Optional[str] = None, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the base. + schema_version (str): + Optional. schema_version specifies the version used by the base. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(Model, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class Artifact(base_artifact.BaseArtifactSchema): + """A generic artifact.""" + + schema_title = "system.Artifact" + + def __init__( + self, + *, + uri: Optional[str] = None, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the base. + schema_version (str): + Optional. schema_version specifies the version used by the base. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(Artifact, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class Dataset(base_artifact.BaseArtifactSchema): + """An artifact representing a system Dataset.""" + + schema_title = "system.Dataset" + + def __init__( + self, + *, + uri: Optional[str] = None, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the base. + schema_version (str): + Optional. schema_version specifies the version used by the base. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(Dataset, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class Metrics(base_artifact.BaseArtifactSchema): + """Artifact schema for scalar metrics.""" + + schema_title = "system.Metrics" + + def __init__( + self, + *, + accuracy: Optional[float] = None, + precision: Optional[float] = None, + recall: Optional[float] = None, + f1score: Optional[float] = None, + mean_absolute_error: Optional[float] = None, + mean_squared_error: Optional[float] = None, + uri: Optional[str] = None, + artifact_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + accuracy (float): + Optional. + precision (float): + Optional. + recall (float): + Optional. + f1score (float): + Optional. + mean_absolute_error (float): + Optional. + mean_squared_error (float): + Optional. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + display_name (str): + Optional. The user-defined name of the base. + schema_version (str): + Optional. schema_version specifies the version used by the base. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if accuracy: + extended_metadata["accuracy"] = accuracy + if precision: + extended_metadata["precision"] = precision + if recall: + extended_metadata["recall"] = recall + if f1score: + extended_metadata["f1score"] = f1score + if mean_absolute_error: + extended_metadata["mean_absolute_error"] = mean_absolute_error + if mean_squared_error: + extended_metadata["mean_squared_error"] = mean_squared_error + + super(Metrics, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) diff --git a/google/cloud/aiplatform/metadata/schema/system/execution_schema.py b/google/cloud/aiplatform/metadata/schema/system/execution_schema.py new file mode 100644 index 0000000000..68c96902cb --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/system/execution_schema.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +from typing import Optional, Dict + +from google.cloud.aiplatform.compat.types import execution as gca_execution +from google.cloud.aiplatform.metadata.schema import base_execution + + +class ContainerExecution(base_execution.BaseExecutionSchema): + """Execution schema for a container execution.""" + + schema_title = "system.ContainerExecution" + + def __init__( + self, + *, + state: Optional[ + gca_execution.Execution.State + ] = gca_execution.Execution.State.RUNNING, + execution_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + state (gca_execution.Execution.State.RUNNING): + Optional. State of this Execution. Defaults to RUNNING. + execution_id (str): + Optional. The portion of the Execution name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//executions/. + display_name (str): + Optional. The user-defined name of the Execution. + schema_version (str): + Optional. schema_version specifies the version used by the Execution. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Execution. + description (str): + Optional. Describes the purpose of the Execution to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(ContainerExecution, self).__init__( + execution_id=execution_id, + state=state, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) + + +class CustomJobExecution(base_execution.BaseExecutionSchema): + """Execution schema for a custom job execution.""" + + schema_title = "system.CustomJobExecution" + + def __init__( + self, + *, + state: Optional[ + gca_execution.Execution.State + ] = gca_execution.Execution.State.RUNNING, + execution_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + state (gca_execution.Execution.State.RUNNING): + Optional. State of this Execution. Defaults to RUNNING. + execution_id (str): + Optional. The portion of the Execution name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//executions/. + display_name (str): + Optional. The user-defined name of the Execution. + schema_version (str): + Optional. schema_version specifies the version used by the Execution. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Execution. + description (str): + Optional. Describes the purpose of the Execution to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(CustomJobExecution, self).__init__( + execution_id=execution_id, + state=state, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) + + +class Run(base_execution.BaseExecutionSchema): + """Execution schema for root run execution.""" + + schema_title = "system.Run" + + def __init__( + self, + *, + state: Optional[ + gca_execution.Execution.State + ] = gca_execution.Execution.State.RUNNING, + execution_id: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + metadata: Optional[Dict] = None, + description: Optional[str] = None, + ): + """Args: + state (gca_execution.Execution.State.RUNNING): + Optional. State of this Execution. Defaults to RUNNING. + execution_id (str): + Optional. The portion of the Execution name with + the following format, this is globally unique in a metadataStore. + projects/123/locations/us-central1/metadataStores//executions/. + display_name (str): + Optional. The user-defined name of the Execution. + schema_version (str): + Optional. schema_version specifies the version used by the Execution. + If not set, defaults to use the latest version. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Execution. + description (str): + Optional. Describes the purpose of the Execution to be created. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + super(Run, self).__init__( + execution_id=execution_id, + state=state, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + ) diff --git a/google/cloud/aiplatform/metadata/schema/utils.py b/google/cloud/aiplatform/metadata/schema/utils.py new file mode 100644 index 0000000000..72577d9324 --- /dev/null +++ b/google/cloud/aiplatform/metadata/schema/utils.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re + +from typing import Optional, Dict, List +from dataclasses import dataclass + + +@dataclass +class PredictSchemata: + """A class holding instance, parameter and prediction schema uris. + + Args: + instance_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing the format of a single instance, which are used in + PredictRequest.instances, ExplainRequest.instances and + BatchPredictionJob.input_config. The schema is defined as an + OpenAPI 3.0.2 `Schema Object. + parameters_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing the parameters of prediction and explanation via + PredictRequest.parameters, ExplainRequest.parameters and + BatchPredictionJob.model_parameters. The schema is defined as an + OpenAPI 3.0.2 `Schema Object. + prediction_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud Storage + describing the format of a single prediction produced by this Model + , which are returned via PredictResponse.predictions, + ExplainResponse.explanations, and BatchPredictionJob.output_config. + The schema is defined as an OpenAPI 3.0.2 `Schema Object. + """ + + instance_schema_uri: str + parameters_schema_uri: str + prediction_schema_uri: str + + def to_dict(self): + """ML metadata schema dictionary representation of this DataClass""" + results = {} + results["instanceSchemaUri"] = self.instance_schema_uri + results["parametersSchemaUri"] = self.parameters_schema_uri + results["predictionSchemaUri"] = self.prediction_schema_uri + + return results + + +@dataclass +class ContainerSpec: + """Container configuration for the model. + Args: + image_uri (str): + Required. URI of the Docker image to be used as the custom + container for serving predictions. This URI must identify an image + in Artifact Registry or Container Registry. + command (Sequence[str]): + Optional. Specifies the command that runs when the container + starts. This overrides the container's `ENTRYPOINT`. + args (Sequence[str]): + Optional. Specifies arguments for the command that runs when the + container starts. This overrides the container's `CMD` + env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]): + Optional. List of environment variables to set in the container. + After the container starts running, code running in the container + can read these environment variables. Additionally, the command + and args fields can reference these variables. Later entries in + this list can also reference earlier entries. For example, the + following example sets the variable ``VAR_2`` to have the value + ``foo bar``: .. code:: json [ { "name": "VAR_1", "value": "foo" }, + { "name": "VAR_2", "value": "$(VAR_1) bar" } ] If you switch the + order of the variables in the example, then the expansion does not + occur. This field corresponds to the ``env`` field of the + Kubernetes Containers `v1 core API. + ports (Sequence[google.cloud.aiplatform_v1.types.Port]): + Optional. List of ports to expose from the container. Vertex AI + sends any prediction requests that it receives to the first port on + this list. Vertex AI also sends `liveness and health checks. + predict_route (str): + Optional. HTTP path on the container to send prediction requests + to. Vertex AI forwards requests sent using + projects.locations.endpoints.predict to this path on the + container's IP address and port. Vertex AI then returns the + container's response in the API response. For example, if you set + this field to ``/foo``, then when Vertex AI receives a prediction + request, it forwards the request body in a POST request to the + ``/foo`` path on the port of your container specified by the first + value of this ``ModelContainerSpec``'s ports field. If you don't + specify this field, it defaults to the following value when you + deploy this Model to an Endpoint + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + - ENDPOINT: The last segment (following ``endpoints/``)of the + Endpoint.name][] field of the Endpoint where this Model has + been deployed. (Vertex AI makes this value available to your + container code as the ```AIP_ENDPOINT_ID`` environment variable + health_route (str): + Optional. HTTP path on the container to send health checks to. + Vertex AI intermittently sends GET requests to this path on the + container's IP address and port to check that the container is + healthy. Read more about `health checks + display_name (str): + """ + + image_uri: str + command: Optional[List[str]] = None + args: Optional[List[str]] = None + env: Optional[List[Dict[str, str]]] = None + ports: Optional[List[int]] = None + predict_route: Optional[str] = None + health_route: Optional[str] = None + + def to_dict(self): + """ML metadata schema dictionary representation of this DataClass""" + results = {} + results["imageUri"] = self.image_uri + if self.command: + results["command"] = self.command + if self.args: + results["args"] = self.args + if self.env: + results["env"] = self.env + if self.ports: + results["ports"] = self.ports + if self.predict_route: + results["predictRoute"] = self.predict_route + if self.health_route: + results["healthRoute"] = self.health_route + + return results + + +def create_uri_from_resource_name(resource_name: str) -> bool: + """Construct the service URI for a given resource_name. + Args: + resource_name (str): + The name of the Vertex resource, in a form of + projects/{project}/locations/{location}/{resource_type}/{resource_id} + Returns: + The resource URI in the form of: + https://{service-endpoint}/v1/{resource_name}, + where {service-endpoint} is one of the supported service endpoints at + https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints + Raises: + ValueError: If resource_name does not match the specified format. + """ + # TODO: support nested resource names such as models/123/evaluations/456 + match_results = re.match( + r"^projects\/[A-Za-z0-9-]*\/locations\/([A-Za-z0-9-]*)\/[A-Za-z0-9-]*\/[A-Za-z0-9-]*$", + resource_name, + ) + if not match_results: + raise ValueError(f"Invalid resource_name format for {resource_name}.") + + location = match_results.group(1) + return f"https://{location}-aiplatform.googleapis.com/v1/{resource_name}" diff --git a/tests/system/aiplatform/test_e2e_metadata_schema.py b/tests/system/aiplatform/test_e2e_metadata_schema.py new file mode 100644 index 0000000000..238e32606b --- /dev/null +++ b/tests/system/aiplatform/test_e2e_metadata_schema.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json + +import pytest + +from google.cloud import aiplatform +from google.cloud.aiplatform.metadata.schema.google import ( + artifact_schema as google_artifact_schema, +) +from google.cloud.aiplatform.metadata.schema.system import ( + artifact_schema as system_artifact_schema, +) +from google.cloud.aiplatform.metadata.schema.system import ( + execution_schema as system_execution_schema, +) +from tests.system.aiplatform import e2e_base + + +@pytest.mark.usefixtures("tear_down_resources") +class TestMetadataSchema(e2e_base.TestEndToEnd): + + _temp_prefix = "tmpvrtxmlmdsdk-e2e" + + def setup_class(cls): + # Truncating the name because of resource id constraints from the service + cls.artifact_display_name = cls._make_display_name("base-artifact")[:30] + cls.artifact_id = cls._make_display_name("base-artifact-id")[:30] + cls.artifact_uri = cls._make_display_name("base-uri") + cls.artifact_metadata = {"test_property": "test_value"} + cls.artifact_description = cls._make_display_name("base-description") + cls.execution_display_name = cls._make_display_name("base-execution")[:30] + cls.execution_description = cls._make_display_name("base-description") + + def test_system_dataset_artifact_create(self): + + aiplatform.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + ) + + artifact = system_artifact_schema.Dataset( + display_name=self.artifact_display_name, + uri=self.artifact_uri, + metadata=self.artifact_metadata, + description=self.artifact_description, + ).create() + + assert artifact.display_name == self.artifact_display_name + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + self.artifact_metadata, sort_keys=True + ) + assert artifact.schema_title == "system.Dataset" + assert artifact.description == self.artifact_description + assert "/metadataStores/default/artifacts/" in artifact.resource_name + + def test_google_dataset_artifact_create(self): + + aiplatform.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + ) + vertex_dataset_name = f"projects/{e2e_base._PROJECT}/locations/{e2e_base._LOCATION}/datasets/dataset" + artifact = google_artifact_schema.VertexDataset( + vertex_dataset_name=vertex_dataset_name, + display_name=self.artifact_display_name, + metadata=self.artifact_metadata, + description=self.artifact_description, + ).create() + expected_metadata = self.artifact_metadata.copy() + expected_metadata["resourceName"] = vertex_dataset_name + + assert artifact.display_name == self.artifact_display_name + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_title == "google.VertexDataset" + assert artifact.description == self.artifact_description + assert "/metadataStores/default/artifacts/" in artifact.resource_name + assert ( + artifact.uri + == f"https://{e2e_base._LOCATION}-aiplatform.googleapis.com/v1/{vertex_dataset_name}" + ) + + def test_execution_create_using_system_schema_class(self): + + aiplatform.init( + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + ) + + execution = system_execution_schema.CustomJobExecution( + display_name=self.execution_display_name, + description=self.execution_description, + ).create() + + assert execution.display_name == self.execution_display_name + assert execution.schema_title == "system.CustomJobExecution" + assert execution.description == self.execution_description + assert "/metadataStores/default/executions/" in execution.resource_name diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py new file mode 100644 index 0000000000..cbf7d38609 --- /dev/null +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -0,0 +1,563 @@ +# -*- coding: utf-8 -*- + +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import pytest + +from importlib import reload +from unittest import mock +from unittest.mock import patch + +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.compat.types import artifact as gca_artifact +from google.cloud.aiplatform.compat.types import execution as gca_execution +from google.cloud.aiplatform.metadata import metadata +from google.cloud.aiplatform.metadata.schema import base_artifact +from google.cloud.aiplatform.metadata.schema import base_execution +from google.cloud.aiplatform.metadata.schema.google import ( + artifact_schema as google_artifact_schema, +) +from google.cloud.aiplatform.metadata.schema.system import ( + artifact_schema as system_artifact_schema, +) +from google.cloud.aiplatform.metadata.schema.system import ( + execution_schema as system_execution_schema, +) +from google.cloud.aiplatform.metadata.schema import utils +from google.cloud.aiplatform_v1 import MetadataServiceClient +from google.cloud.aiplatform_v1 import Artifact as GapicArtifact +from google.cloud.aiplatform_v1 import Execution as GapicExecution + + +# project +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_METADATA_STORE = "test-metadata-store" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +# resource attributes +_TEST_ARTIFACT_STATE = gca_artifact.Artifact.State.STATE_UNSPECIFIED +_TEST_EXECUTION_STATE = gca_execution.Execution.State.STATE_UNSPECIFIED +_TEST_URI = "test-uri" +_TEST_DISPLAY_NAME = "test-display-name" +_TEST_SCHEMA_TITLE = "test.Example" +_TEST_SCHEMA_VERSION = "0.0.1" +_TEST_DESCRIPTION = "test description" +_TEST_METADATA = {"test-param1": 1, "test-param2": "test-value", "test-param3": True} +_TEST_UPDATED_METADATA = { + "test-param1": 2, + "test-param2": "test-value-1", + "test-param3": False, +} + +# artifact +_TEST_ARTIFACT_ID = "test-artifact-id" +_TEST_ARTIFACT_NAME = f"{_TEST_PARENT}/artifacts/{_TEST_ARTIFACT_ID}" + +# execution +_TEST_EXECUTION_ID = "test-execution-id" +_TEST_EXECUTION_NAME = f"{_TEST_PARENT}/executions/{_TEST_EXECUTION_ID}" + + +@pytest.fixture +def create_artifact_mock(): + with patch.object(MetadataServiceClient, "create_artifact") as create_artifact_mock: + create_artifact_mock.return_value = GapicArtifact( + name=_TEST_ARTIFACT_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + state=GapicArtifact.State.STATE_UNSPECIFIED, + ) + yield create_artifact_mock + + +@pytest.fixture +def create_execution_mock(): + with patch.object( + MetadataServiceClient, "create_execution" + ) as create_execution_mock: + create_execution_mock.return_value = GapicExecution( + name=_TEST_EXECUTION_NAME, + display_name=_TEST_DISPLAY_NAME, + schema_title=_TEST_SCHEMA_TITLE, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_METADATA, + state=GapicExecution.State.RUNNING, + ) + yield create_execution_mock + + +class TestMetadataBaseArtifactSchema: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_base_class_instatiated_uses_schema_title(self): + class TestArtifact(base_artifact.BaseArtifactSchema): + schema_title = _TEST_SCHEMA_TITLE + + artifact = TestArtifact() + assert artifact.schema_title == _TEST_SCHEMA_TITLE + + def test_base_class_parameters_overrides_default_values(self): + class TestArtifact(base_artifact.BaseArtifactSchema): + schema_title = _TEST_SCHEMA_TITLE + + artifact = TestArtifact( + state=_TEST_ARTIFACT_STATE, + schema_version=_TEST_SCHEMA_VERSION, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert artifact.state == _TEST_ARTIFACT_STATE + assert artifact.state == _TEST_ARTIFACT_STATE + assert artifact.schema_version == _TEST_SCHEMA_VERSION + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.schema_title == _TEST_SCHEMA_TITLE + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == _TEST_UPDATED_METADATA + + def test_base_class_without_schema_title_raises_error(self): + with pytest.raises(TypeError): + base_artifact.BaseArtifactSchema() + + @pytest.mark.usefixtures("create_artifact_mock") + def test_create_is_called_with_default_parameters(self, create_artifact_mock): + aiplatform.init(project=_TEST_PROJECT) + + class TestArtifact(base_artifact.BaseArtifactSchema): + schema_title = _TEST_SCHEMA_TITLE + + artifact = TestArtifact( + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + state=_TEST_ARTIFACT_STATE, + ) + artifact.create(metadata_store_id=_TEST_METADATA_STORE) + create_artifact_mock.assert_called_once_with( + parent=f"{_TEST_PARENT}/metadataStores/{_TEST_METADATA_STORE}", + artifact=mock.ANY, + artifact_id=None, + ) + _, _, kwargs = create_artifact_mock.mock_calls[0] + assert kwargs["artifact"].schema_title == _TEST_SCHEMA_TITLE + assert kwargs["artifact"].uri == _TEST_URI + assert kwargs["artifact"].display_name == _TEST_DISPLAY_NAME + assert kwargs["artifact"].description == _TEST_DESCRIPTION + assert kwargs["artifact"].metadata == _TEST_UPDATED_METADATA + assert kwargs["artifact"].state == _TEST_ARTIFACT_STATE + + +class TestMetadataBaseExecutionSchema: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_base_class_overrides_default_schema_title(self): + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution() + assert execution.schema_title == _TEST_SCHEMA_TITLE + + def test_base_class_parameters_overrides_default_values(self): + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution( + state=_TEST_EXECUTION_STATE, + schema_version=_TEST_SCHEMA_VERSION, + execution_id=_TEST_EXECUTION_ID, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert execution.state == _TEST_EXECUTION_STATE + assert execution.schema_version == _TEST_SCHEMA_VERSION + assert execution.execution_id == _TEST_EXECUTION_ID + assert execution.schema_title == _TEST_SCHEMA_TITLE + assert execution.display_name == _TEST_DISPLAY_NAME + assert execution.description == _TEST_DESCRIPTION + assert execution.metadata == _TEST_UPDATED_METADATA + + def test_base_class_without_schema_title_raises_error(self): + with pytest.raises(TypeError): + base_execution.BaseExecutionSchema() + + @pytest.mark.usefixtures("create_execution_mock") + def test_create_method_calls_gapic_library_with_correct_parameters( + self, create_execution_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + class TestExecution(base_execution.BaseExecutionSchema): + schema_title = _TEST_SCHEMA_TITLE + + execution = TestExecution( + state=_TEST_EXECUTION_STATE, + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + execution.create(metadata_store_id=_TEST_METADATA_STORE) + create_execution_mock.assert_called_once_with( + parent=f"{_TEST_PARENT}/metadataStores/{_TEST_METADATA_STORE}", + execution=mock.ANY, + execution_id=None, + ) + _, _, kwargs = create_execution_mock.mock_calls[0] + assert kwargs["execution"].schema_title == _TEST_SCHEMA_TITLE + assert kwargs["execution"].state == _TEST_EXECUTION_STATE + assert kwargs["execution"].display_name == _TEST_DISPLAY_NAME + assert kwargs["execution"].description == _TEST_DESCRIPTION + assert kwargs["execution"].metadata == _TEST_UPDATED_METADATA + + +class TestMetadataGoogleArtifactSchema: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_vertex_dataset_schema_title_is_set_correctly(self): + artifact = google_artifact_schema.VertexDataset( + vertex_dataset_name=_TEST_ARTIFACT_NAME, + ) + assert artifact.schema_title == "google.VertexDataset" + + def test_vertex_dataset_constructor_parameters_are_set_correctly(self): + artifact = google_artifact_schema.VertexDataset( + vertex_dataset_name=f"{_TEST_PARENT}/datasets/dataset-id", + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata={}, + ) + assert ( + artifact.uri + == "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/datasets/dataset-id" + ) + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == { + "resourceName": "projects/test-project/locations/us-central1/datasets/dataset-id" + } + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_vertex_model_schema_title_is_set_correctly(self): + artifact = google_artifact_schema.VertexModel( + vertex_model_name=_TEST_ARTIFACT_NAME, + ) + assert artifact.schema_title == "google.VertexModel" + + def test_vertex_model_constructor_parameters_are_set_correctly(self): + artifact = google_artifact_schema.VertexModel( + vertex_model_name=f"{_TEST_PARENT}/models/model-id", + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata={}, + ) + assert ( + artifact.uri + == "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/models/model-id" + ) + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == { + "resourceName": "projects/test-project/locations/us-central1/models/model-id" + } + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_vertex_endpoint_schema_title_is_set_correctly(self): + artifact = google_artifact_schema.VertexEndpoint( + vertex_endpoint_name=_TEST_ARTIFACT_NAME, + ) + assert artifact.schema_title == "google.VertexEndpoint" + + def test_vertex_endpoint_constructor_parameters_are_set_correctly(self): + artifact = google_artifact_schema.VertexEndpoint( + vertex_endpoint_name=f"{_TEST_PARENT}/endpoints/endpoint-id", + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata={}, + ) + assert ( + artifact.uri + == "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/endpoints/endpoint-id" + ) + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == { + "resourceName": "projects/test-project/locations/us-central1/endpoints/endpoint-id" + } + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_unmanaged_container_model_title_is_set_correctly(self): + predict_schema_ta = utils.PredictSchemata( + instance_schema_uri="instance_uri", + prediction_schema_uri="prediction_uri", + parameters_schema_uri="parameters_uri", + ) + + container_spec = utils.ContainerSpec( + image_uri="gcr.io/test_container_image_uri" + ) + artifact = google_artifact_schema.UnmanagedContainerModel( + predict_schema_ta=predict_schema_ta, + container_spec=container_spec, + ) + assert artifact.schema_title == "google.UnmanagedContainerModel" + + def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self): + predict_schema_ta = utils.PredictSchemata( + instance_schema_uri="instance_uri", + prediction_schema_uri="prediction_uri", + parameters_schema_uri="parameters_uri", + ) + + container_spec = utils.ContainerSpec( + image_uri="gcr.io/test_container_image_uri" + ) + + artifact = google_artifact_schema.UnmanagedContainerModel( + predict_schema_ta=predict_schema_ta, + container_spec=container_spec, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2, + "test-param2": "test-value-1", + "test-param3": False, + "predictSchemata": { + "instanceSchemaUri": "instance_uri", + "parametersSchemaUri": "parameters_uri", + "predictionSchemaUri": "prediction_uri", + }, + "containerSpec": {"imageUri": "gcr.io/test_container_image_uri"}, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata) == json.dumps(expected_metadata) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + +class TestMetadataSystemArtifactSchema: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_system_dataset_schema_title_is_set_correctly(self): + artifact = system_artifact_schema.Dataset() + assert artifact.schema_title == "system.Dataset" + + def test_system_dataset_constructor_parameters_are_set_correctly(self): + artifact = system_artifact_schema.Dataset( + uri=_TEST_URI, + artifact_id=_TEST_ARTIFACT_ID, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert artifact.uri == _TEST_URI + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == _TEST_UPDATED_METADATA + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_system_artifact_schema_title_is_set_correctly(self): + artifact = system_artifact_schema.Artifact() + assert artifact.schema_title == "system.Artifact" + + def test_system_artifact_constructor_parameters_are_set_correctly(self): + artifact = system_artifact_schema.Artifact( + uri=_TEST_URI, + artifact_id=_TEST_ARTIFACT_ID, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert artifact.uri == _TEST_URI + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == _TEST_UPDATED_METADATA + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_system_model_schema_title_is_set_correctly(self): + artifact = system_artifact_schema.Model() + assert artifact.schema_title == "system.Model" + + def test_system_model_constructor_parameters_are_set_correctly(self): + artifact = system_artifact_schema.Model( + uri=_TEST_URI, + artifact_id=_TEST_ARTIFACT_ID, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert artifact.uri == _TEST_URI + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.metadata == _TEST_UPDATED_METADATA + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_system_metrics_schema_title_is_set_correctly(self): + artifact = system_artifact_schema.Metrics() + assert artifact.schema_title == "system.Metrics" + + def test_system_metrics_values_default_to_none(self): + artifact = system_artifact_schema.Metrics() + assert artifact.metadata == {} + + def test_system_metrics_constructor_parameters_are_set_correctly(self): + artifact = system_artifact_schema.Metrics( + accuracy=0.1, + precision=0.2, + recall=0.3, + f1score=0.4, + mean_absolute_error=0.5, + mean_squared_error=0.6, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + assert artifact.uri == _TEST_URI + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert artifact.schema_version == _TEST_SCHEMA_VERSION + assert artifact.metadata["accuracy"] == 0.1 + assert artifact.metadata["precision"] == 0.2 + assert artifact.metadata["recall"] == 0.3 + assert artifact.metadata["f1score"] == 0.4 + assert artifact.metadata["mean_absolute_error"] == 0.5 + assert artifact.metadata["mean_squared_error"] == 0.6 + + +class TestMetadataSystemSchemaExecution: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + # Test system.Execution Schemas + def test_system_container_execution_schema_title_is_set_correctly(self): + execution = system_execution_schema.ContainerExecution() + assert execution.schema_title == "system.ContainerExecution" + + def test_system_custom_job_execution_schema_title_is_set_correctly(self): + execution = system_execution_schema.CustomJobExecution() + assert execution.schema_title == "system.CustomJobExecution" + + def test_system_run_execution_schema_title_is_set_correctly(self): + execution = system_execution_schema.Run() + assert execution.schema_title == "system.Run" + + +class TestMetadataUtils: + def setup_method(self): + reload(initializer) + reload(metadata) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_predict_schemata_to_dict_method_returns_correct_schema(self): + predict_schema_ta = utils.PredictSchemata( + instance_schema_uri="instance_uri", + prediction_schema_uri="prediction_uri", + parameters_schema_uri="parameters_uri", + ) + expected_results = { + "instanceSchemaUri": "instance_uri", + "parametersSchemaUri": "parameters_uri", + "predictionSchemaUri": "prediction_uri", + } + + assert json.dumps(predict_schema_ta.to_dict()) == json.dumps(expected_results) + + def test_container_spec_to_dict_method_returns_correct_schema(self): + container_spec = utils.ContainerSpec( + image_uri="gcr.io/some_container_image_uri", + command=["test_command"], + args=["test_args"], + env=[{"env_var_name": "env_var_value"}], + ports=[1], + predict_route="test_prediction_rout", + health_route="test_health_rout", + ) + + expected_results = { + "imageUri": "gcr.io/some_container_image_uri", + "command": ["test_command"], + "args": ["test_args"], + "env": [{"env_var_name": "env_var_value"}], + "ports": [1], + "predictRoute": "test_prediction_rout", + "healthRoute": "test_health_rout", + } + + assert json.dumps(container_spec.to_dict()) == json.dumps(expected_results)