Skip to content

Commit

Permalink
feat: Add samples for Metadata context list, get, and create (#1525)
Browse files Browse the repository at this point in the history
* feat: Add samples for context list,get and create

* fix lint issues.

* Change import path to aiplatform.Context

* Fix create mock.

* remove duplicate mock method

* Update samples/model-builder/experiment_tracking/get_context_sample_test.py

Co-authored-by: Dan Lee <[email protected]>

* Update samples/model-builder/experiment_tracking/list_context_sample_test.py

Co-authored-by: Dan Lee <[email protected]>

* Update samples/model-builder/experiment_tracking/list_context_sample_test.py

Co-authored-by: Dan Lee <[email protected]>

* Update samples/model-builder/experiment_tracking/get_context_sample_test.py

Co-authored-by: Dan Lee <[email protected]>

* update formatting and comments based on review feedback

Co-authored-by: Dan Lee <[email protected]>
  • Loading branch information
SinaChavoshi and dandhlee authored Jul 25, 2022
1 parent b53e2b5 commit d913e1d
Show file tree
Hide file tree
Showing 20 changed files with 267 additions and 51 deletions.
32 changes: 25 additions & 7 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,20 +612,38 @@ def mock_get_artifact(mock_artifact):
yield mock_get_artifact


@pytest.fixture
def mock_artifact_get(mock_artifact):
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
mock_artifact_get.return_value = mock_artifact
yield mock_artifact_get


@pytest.fixture
def mock_context_get(mock_context):
with patch.object(aiplatform.Context, "get") as mock_context_get:
mock_context_get.return_value = mock_context
yield mock_context_get


@pytest.fixture
def mock_context_list(mock_context):
with patch.object(aiplatform.Context, "list") as mock_context_list:
# Returning list of 2 contexts to avoid confusion with get method
# which returns one unique context.
mock_context_list.return_value = [mock_context, mock_context]
yield mock_context_list


@pytest.fixture
def mock_create_schema_base_context(mock_context):
with patch.object(
aiplatform.metadata.schema.base_context.BaseContextSchema, "create"
) as mock_create_schema_base_context:
mock_create_schema_base_context.return_value = mock_context
yield mock_create_schema_base_context


@pytest.fixture
def mock_artifact_get(mock_artifact):
with patch.object(aiplatform.Artifact, "get") as mock_artifact_get:
mock_artifact_get.return_value = mock_artifact
yield mock_artifact_get


@pytest.fixture
def mock_pipeline_job_create(mock_pipeline_job):
with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def create_artifact_sample(
project=project,
location=location,
)

return artifact


# [END aiplatform_sdk_create_artifact_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def create_artifact_sample(
description=description,
metadata=metadata,
)

return system_artifact_schema.create(project=project, location=location,)


# [END aiplatform_sdk_create_artifact_with_sdk_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional

from google.cloud import aiplatform
from google.cloud.aiplatform.metadata.schema.system import context_schema


# [START aiplatform_sdk_create_context_with_sdk_sample]
def create_context_sample(
display_name: str,
project: str,
location: str,
context_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
schema_version: Optional[str] = None,
description: Optional[str] = None,
):
aiplatform.init(project=project, location=location)

return context_schema.Experiment(
display_name=display_name,
context_id=context_id,
metadata=metadata,
schema_version=schema_version,
description=description,
).create()

# [END aiplatform_sdk_create_context_with_sdk_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import create_context_with_sdk_sample

import test_constants as constants


def test_create_context_sample(
mock_sdk_init, mock_create_schema_base_context, mock_context,
):
exc = create_context_with_sdk_sample.create_context_sample(
display_name=constants.DISPLAY_NAME,
project=constants.PROJECT,
location=constants.LOCATION,
context_id=constants.RESOURCE_ID,
metadata=constants.METADATA,
schema_version=constants.SCHEMA_VERSION,
description=constants.DESCRIPTION,
)

mock_sdk_init.assert_called_with(
project=constants.PROJECT, location=constants.LOCATION,
)

mock_create_schema_base_context.assert_called_with()
assert exc is mock_context
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ def create_execution_sample(
execution.assign_output_artifacts(output_artifacts)
return execution


# [END aiplatform_sdk_create_execution_with_sdk_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

import delete_artifact_sample

import test_constants
import test_constants as constants


def test_delete_artifact_sample(mock_artifact, mock_artifact_get):
delete_artifact_sample.delete_artifact_sample(
artifact_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
artifact_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_artifact_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
resource_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

import delete_context_sample

import test_constants
import test_constants as constants


def test_delete_context_sample(mock_context_get):
delete_context_sample.delete_context_sample(
context_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
context_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_context_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
resource_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

import delete_execution_sample

import test_constants
import test_constants as constants


def test_delete_execution_sample(mock_execution, mock_execution_get):
delete_execution_sample.delete_execution_sample(
execution_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
execution_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_execution_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
resource_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ def get_artifact_sample(

return artifact


# [END aiplatform_sdk_get_artifact_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@

import get_artifact_sample

import test_constants
import test_constants as constants


def test_get_artifact_sample(mock_artifact, mock_artifact_get):
artifact = get_artifact_sample.get_artifact_sample(
artifact_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
artifact_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_artifact_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
resource_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

assert artifact is mock_artifact
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def get_artifact_with_uri_sample(
artifact = aiplatform.Artifact.get_with_uri(
uri=uri, project=project, location=location
)

return artifact


# [END aiplatform_sdk_get_artifact_with_uri_sample]
28 changes: 28 additions & 0 deletions samples/model-builder/experiment_tracking/get_context_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import aiplatform


# [START aiplatform_sdk_get_context_sample]
def get_context_sample(
context_id: str,
project: str,
location: str,
):
context = aiplatform.Context.get(
resource_id=context_id, project=project, location=location)
return context

# [END aiplatform_sdk_get_context_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import get_context_sample

import test_constants as constants


def test_get_context_sample(mock_context, mock_context_get):
context = get_context_sample.get_context_sample(
context_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

mock_context_get.assert_called_with(
resource_id=constants.RESOURCE_ID,
project=constants.PROJECT,
location=constants.LOCATION,
)

assert context is mock_context
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import get_artifact_sample
import get_execution_sample

import test_constants


def test_get_artifact_sample(mock_artifact, mock_artifact_get):
artifact = get_artifact_sample.get_artifact_sample(
artifact_id=test_constants.RESOURCE_ID,
def test_get_execution_sample(mock_execution, mock_execution_get):
execution = get_execution_sample.get_execution_sample(
execution_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

mock_artifact_get.assert_called_with(
mock_execution_get.assert_called_with(
resource_id=test_constants.RESOURCE_ID,
project=test_constants.PROJECT,
location=test_constants.LOCATION,
)

assert artifact is mock_artifact
assert execution is mock_execution
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def list_artifact_sample(
location=location)

combined_filters = f"{display_name_fitler} AND {create_date_filter}"

return aiplatform.Artifact.list(filter=combined_filters)


# [END aiplatform_sdk_create_artifact_with_sdk_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ def test_list_artifact_with_sdk_sample(mock_artifact, mock_list_artifact):
filter=f"{constants.DISPLAY_NAME} AND {constants.CREATE_DATE}"
)
assert len(artifacts) == 2
# Returning list of 2 context to avoid confusion with get method
# which returns one unique context.
assert artifacts[0] is mock_artifact
assert artifacts[1] is mock_artifact
Loading

0 comments on commit d913e1d

Please sign in to comment.