diff --git a/samples/model-builder/conftest.py b/samples/model-builder/conftest.py index 5375b870d7..3c5960a2b1 100644 --- a/samples/model-builder/conftest.py +++ b/samples/model-builder/conftest.py @@ -612,13 +612,6 @@ def mock_get_artifact(mock_artifact): yield mock_get_artifact -@pytest.fixture -def mock_artifact_get(mock_artifact): - with patch.object(aiplatform.Artifact, "get") as mock_artifact_get: - mock_artifact_get.return_value = mock_artifact - yield mock_artifact_get - - @pytest.fixture def mock_context_get(mock_context): with patch.object(aiplatform.Context, "get") as mock_context_get: @@ -626,6 +619,31 @@ def mock_context_get(mock_context): yield mock_context_get +@pytest.fixture +def mock_context_list(mock_context): + with patch.object(aiplatform.Context, "list") as mock_context_list: + # Returning list of 2 contexts to avoid confusion with get method + # which returns one unique context. + mock_context_list.return_value = [mock_context, mock_context] + yield mock_context_list + + +@pytest.fixture +def mock_create_schema_base_context(mock_context): + with patch.object( + aiplatform.metadata.schema.base_context.BaseContextSchema, "create" + ) as mock_create_schema_base_context: + mock_create_schema_base_context.return_value = mock_context + yield mock_create_schema_base_context + + +@pytest.fixture +def mock_artifact_get(mock_artifact): + with patch.object(aiplatform.Artifact, "get") as mock_artifact_get: + mock_artifact_get.return_value = mock_artifact + yield mock_artifact_get + + @pytest.fixture def mock_pipeline_job_create(mock_pipeline_job): with patch.object(aiplatform, "PipelineJob") as mock_pipeline_job_create: diff --git a/samples/model-builder/experiment_tracking/create_artifact_sample.py b/samples/model-builder/experiment_tracking/create_artifact_sample.py index 90a7432bea..1f3b7fdffd 100644 --- a/samples/model-builder/experiment_tracking/create_artifact_sample.py +++ b/samples/model-builder/experiment_tracking/create_artifact_sample.py @@ -40,8 +40,6 @@ def create_artifact_sample( project=project, location=location, ) - return artifact - # [END aiplatform_sdk_create_artifact_sample] diff --git a/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py b/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py index 3d67cacc17..ed8315a617 100644 --- a/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py +++ b/samples/model-builder/experiment_tracking/create_artifact_with_sdk_sample.py @@ -36,8 +36,6 @@ def create_artifact_sample( description=description, metadata=metadata, ) - return system_artifact_schema.create(project=project, location=location,) - # [END aiplatform_sdk_create_artifact_with_sdk_sample] diff --git a/samples/model-builder/experiment_tracking/create_context_with_sdk_sample.py b/samples/model-builder/experiment_tracking/create_context_with_sdk_sample.py new file mode 100644 index 0000000000..d688b14c46 --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_context_with_sdk_sample.py @@ -0,0 +1,41 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +from google.cloud import aiplatform +from google.cloud.aiplatform.metadata.schema.system import context_schema + + +# [START aiplatform_sdk_create_context_with_sdk_sample] +def create_context_sample( + display_name: str, + project: str, + location: str, + context_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, +): + aiplatform.init(project=project, location=location) + + return context_schema.Experiment( + display_name=display_name, + context_id=context_id, + metadata=metadata, + schema_version=schema_version, + description=description, + ).create() + +# [END aiplatform_sdk_create_context_with_sdk_sample] diff --git a/samples/model-builder/experiment_tracking/create_context_with_sdk_sample_test.py b/samples/model-builder/experiment_tracking/create_context_with_sdk_sample_test.py new file mode 100644 index 0000000000..b04fc7c69a --- /dev/null +++ b/samples/model-builder/experiment_tracking/create_context_with_sdk_sample_test.py @@ -0,0 +1,38 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import create_context_with_sdk_sample + +import test_constants as constants + + +def test_create_context_sample( + mock_sdk_init, mock_create_schema_base_context, mock_context, +): + exc = create_context_with_sdk_sample.create_context_sample( + display_name=constants.DISPLAY_NAME, + project=constants.PROJECT, + location=constants.LOCATION, + context_id=constants.RESOURCE_ID, + metadata=constants.METADATA, + schema_version=constants.SCHEMA_VERSION, + description=constants.DESCRIPTION, + ) + + mock_sdk_init.assert_called_with( + project=constants.PROJECT, location=constants.LOCATION, + ) + + mock_create_schema_base_context.assert_called_with() + assert exc is mock_context diff --git a/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py b/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py index ac0faa7065..c364fc732f 100644 --- a/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py +++ b/samples/model-builder/experiment_tracking/create_execution_with_sdk_sample.py @@ -43,5 +43,4 @@ def create_execution_sample( execution.assign_output_artifacts(output_artifacts) return execution - # [END aiplatform_sdk_create_execution_with_sdk_sample] diff --git a/samples/model-builder/experiment_tracking/delete_artifact_sample_test.py b/samples/model-builder/experiment_tracking/delete_artifact_sample_test.py index 555b6745ee..bd5aefc102 100644 --- a/samples/model-builder/experiment_tracking/delete_artifact_sample_test.py +++ b/samples/model-builder/experiment_tracking/delete_artifact_sample_test.py @@ -14,18 +14,18 @@ import delete_artifact_sample -import test_constants +import test_constants as constants def test_delete_artifact_sample(mock_artifact, mock_artifact_get): delete_artifact_sample.delete_artifact_sample( - artifact_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + artifact_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) mock_artifact_get.assert_called_with( - resource_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + resource_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) diff --git a/samples/model-builder/experiment_tracking/delete_context_sample_test.py b/samples/model-builder/experiment_tracking/delete_context_sample_test.py index 49e2a08c72..6b60acd4c8 100644 --- a/samples/model-builder/experiment_tracking/delete_context_sample_test.py +++ b/samples/model-builder/experiment_tracking/delete_context_sample_test.py @@ -14,18 +14,18 @@ import delete_context_sample -import test_constants +import test_constants as constants def test_delete_context_sample(mock_context_get): delete_context_sample.delete_context_sample( - context_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + context_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) mock_context_get.assert_called_with( - resource_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + resource_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) diff --git a/samples/model-builder/experiment_tracking/delete_execution_sample_test.py b/samples/model-builder/experiment_tracking/delete_execution_sample_test.py index 3e6cba5d8e..f5fa850fc6 100644 --- a/samples/model-builder/experiment_tracking/delete_execution_sample_test.py +++ b/samples/model-builder/experiment_tracking/delete_execution_sample_test.py @@ -14,18 +14,18 @@ import delete_execution_sample -import test_constants +import test_constants as constants def test_delete_execution_sample(mock_execution, mock_execution_get): delete_execution_sample.delete_execution_sample( - execution_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + execution_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) mock_execution_get.assert_called_with( - resource_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + resource_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) diff --git a/samples/model-builder/experiment_tracking/get_artifact_sample.py b/samples/model-builder/experiment_tracking/get_artifact_sample.py index 93ea031b2c..bfc55f994c 100644 --- a/samples/model-builder/experiment_tracking/get_artifact_sample.py +++ b/samples/model-builder/experiment_tracking/get_artifact_sample.py @@ -27,5 +27,4 @@ def get_artifact_sample( return artifact - # [END aiplatform_sdk_get_artifact_sample] diff --git a/samples/model-builder/experiment_tracking/get_artifact_sample_test.py b/samples/model-builder/experiment_tracking/get_artifact_sample_test.py index 21047e4e7d..f32c9e5fec 100644 --- a/samples/model-builder/experiment_tracking/get_artifact_sample_test.py +++ b/samples/model-builder/experiment_tracking/get_artifact_sample_test.py @@ -14,20 +14,20 @@ import get_artifact_sample -import test_constants +import test_constants as constants def test_get_artifact_sample(mock_artifact, mock_artifact_get): artifact = get_artifact_sample.get_artifact_sample( - artifact_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + artifact_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) mock_artifact_get.assert_called_with( - resource_id=test_constants.RESOURCE_ID, - project=test_constants.PROJECT, - location=test_constants.LOCATION, + resource_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, ) assert artifact is mock_artifact diff --git a/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py b/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py index 6c3f11c9ca..92fa56c841 100644 --- a/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py +++ b/samples/model-builder/experiment_tracking/get_artifact_with_uri_sample.py @@ -24,8 +24,6 @@ def get_artifact_with_uri_sample( artifact = aiplatform.Artifact.get_with_uri( uri=uri, project=project, location=location ) - return artifact - # [END aiplatform_sdk_get_artifact_with_uri_sample] diff --git a/samples/model-builder/experiment_tracking/get_context_sample.py b/samples/model-builder/experiment_tracking/get_context_sample.py new file mode 100644 index 0000000000..8cd62e78ea --- /dev/null +++ b/samples/model-builder/experiment_tracking/get_context_sample.py @@ -0,0 +1,28 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_get_context_sample] +def get_context_sample( + context_id: str, + project: str, + location: str, +): + context = aiplatform.Context.get( + resource_id=context_id, project=project, location=location) + return context + +# [END aiplatform_sdk_get_context_sample] diff --git a/samples/model-builder/experiment_tracking/get_context_sample_test.py b/samples/model-builder/experiment_tracking/get_context_sample_test.py new file mode 100644 index 0000000000..155a8eb032 --- /dev/null +++ b/samples/model-builder/experiment_tracking/get_context_sample_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import get_context_sample + +import test_constants as constants + + +def test_get_context_sample(mock_context, mock_context_get): + context = get_context_sample.get_context_sample( + context_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + mock_context_get.assert_called_with( + resource_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + assert context is mock_context diff --git a/samples/model-builder/experiment_tracking/get_execution_sample_test.py b/samples/model-builder/experiment_tracking/get_execution_sample_test.py index 21047e4e7d..583c27859a 100644 --- a/samples/model-builder/experiment_tracking/get_execution_sample_test.py +++ b/samples/model-builder/experiment_tracking/get_execution_sample_test.py @@ -12,22 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import get_artifact_sample +import get_execution_sample import test_constants -def test_get_artifact_sample(mock_artifact, mock_artifact_get): - artifact = get_artifact_sample.get_artifact_sample( - artifact_id=test_constants.RESOURCE_ID, +def test_get_execution_sample(mock_execution, mock_execution_get): + execution = get_execution_sample.get_execution_sample( + execution_id=test_constants.RESOURCE_ID, project=test_constants.PROJECT, location=test_constants.LOCATION, ) - mock_artifact_get.assert_called_with( + mock_execution_get.assert_called_with( resource_id=test_constants.RESOURCE_ID, project=test_constants.PROJECT, location=test_constants.LOCATION, ) - assert artifact is mock_artifact + assert execution is mock_execution diff --git a/samples/model-builder/experiment_tracking/list_artifact_sample.py b/samples/model-builder/experiment_tracking/list_artifact_sample.py index ac0f15d9f3..55f9419b32 100644 --- a/samples/model-builder/experiment_tracking/list_artifact_sample.py +++ b/samples/model-builder/experiment_tracking/list_artifact_sample.py @@ -29,8 +29,6 @@ def list_artifact_sample( location=location) combined_filters = f"{display_name_fitler} AND {create_date_filter}" - return aiplatform.Artifact.list(filter=combined_filters) - # [END aiplatform_sdk_create_artifact_with_sdk_sample] diff --git a/samples/model-builder/experiment_tracking/list_artifact_sample_test.py b/samples/model-builder/experiment_tracking/list_artifact_sample_test.py index 2df444b76c..5eaf9993e6 100644 --- a/samples/model-builder/experiment_tracking/list_artifact_sample_test.py +++ b/samples/model-builder/experiment_tracking/list_artifact_sample_test.py @@ -29,5 +29,7 @@ def test_list_artifact_with_sdk_sample(mock_artifact, mock_list_artifact): filter=f"{constants.DISPLAY_NAME} AND {constants.CREATE_DATE}" ) assert len(artifacts) == 2 + # Returning list of 2 context to avoid confusion with get method + # which returns one unique context. assert artifacts[0] is mock_artifact assert artifacts[1] is mock_artifact diff --git a/samples/model-builder/experiment_tracking/list_context_sample.py b/samples/model-builder/experiment_tracking/list_context_sample.py new file mode 100644 index 0000000000..de07bf51cf --- /dev/null +++ b/samples/model-builder/experiment_tracking/list_context_sample.py @@ -0,0 +1,28 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud import aiplatform + + +# [START aiplatform_sdk_list_context_sample] +def list_context_sample( + context_id: str, + project: str, + location: str, +): + context = aiplatform.Context.list( + resource_id=context_id, project=project, location=location) + return context + +# [END aiplatform_sdk_list_context_sample] diff --git a/samples/model-builder/experiment_tracking/list_context_sample_test.py b/samples/model-builder/experiment_tracking/list_context_sample_test.py new file mode 100644 index 0000000000..d09ce346fe --- /dev/null +++ b/samples/model-builder/experiment_tracking/list_context_sample_test.py @@ -0,0 +1,36 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import list_context_sample + +import test_constants as constants + + +def test_list_context_sample(mock_context, mock_context_list): + contexts = list_context_sample.list_context_sample( + context_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + + mock_context_list.assert_called_with( + resource_id=constants.RESOURCE_ID, + project=constants.PROJECT, + location=constants.LOCATION, + ) + assert len(contexts) == 2 + # Returning list of 2 context to avoid confusion with get method + # which returns one unique context. + assert contexts[0] is mock_context + assert contexts[1] is mock_context diff --git a/samples/model-builder/experiment_tracking/list_execution_sample_test.py b/samples/model-builder/experiment_tracking/list_execution_sample_test.py index ac53744d1b..cad29a11a4 100644 --- a/samples/model-builder/experiment_tracking/list_execution_sample_test.py +++ b/samples/model-builder/experiment_tracking/list_execution_sample_test.py @@ -29,5 +29,7 @@ def test_list_execution_sample(mock_execution, mock_list_execution): filter=f"{constants.DISPLAY_NAME} AND {constants.CREATE_DATE}" ) assert len(executions) == 2 + # Returning list of 2 executions to avoid confusion with get method + # which returns one unique execution. assert executions[0] is mock_execution assert executions[1] is mock_execution