diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index c934a01eecb..468c82dcc0d 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -198,6 +198,7 @@ def network(self) -> Optional[str]: def create( cls, display_name: Optional[str] = None, + endpoint_id: Optional[str] = None, description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), @@ -215,6 +216,17 @@ def create( Optional. The user-defined name of the Endpoint. The name can be up to 128 characters long and can be consist of any UTF-8 characters. + endpoint_id (str): + Optional. The ID to use for endpoint, which will become + the final component of the endpoint resource name. If + not provided, Vertex AI will generate a value for this + ID. + + This value should be 1-10 characters, and valid + characters are /[0-9]/. When using HTTP/JSON, this field + is populated based on a query string argument, such as + ``?endpoint_id=12345``. This is the fallback for fields + that are not included in either the URI or the body. project (str): Required. Project to retrieve endpoint from. If not set, project set in aiplatform.init will be used. @@ -276,6 +288,7 @@ def create( return cls._create( api_client=api_client, display_name=display_name, + endpoint_id=endpoint_id, project=project, location=location, description=description, @@ -297,6 +310,7 @@ def _create( display_name: str, project: str, location: str, + endpoint_id: Optional[str] = None, description: Optional[str] = None, labels: Optional[Dict[str, str]] = None, metadata: Optional[Sequence[Tuple[str, str]]] = (), @@ -321,6 +335,17 @@ def _create( location (str): Required. Location to retrieve endpoint from. If not set, location set in aiplatform.init will be used. + endpoint_id (str): + Optional. The ID to use for endpoint, which will become + the final component of the endpoint resource name. If + not provided, Vertex AI will generate a value for this + ID. + + This value should be 1-10 characters, and valid + characters are /[0-9]/. When using HTTP/JSON, this field + is populated based on a query string argument, such as + ``?endpoint_id=12345``. This is the fallback for fields + that are not included in either the URI or the body. description (str): Optional. The description of the Endpoint. labels (Dict[str, str]): @@ -368,6 +393,7 @@ def _create( operation_future = api_client.create_endpoint( parent=parent, endpoint=gapic_endpoint, + endpoint_id=endpoint_id, metadata=metadata, timeout=create_request_timeout, ) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index a0844fa9b37..dbc301827f7 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -547,6 +547,7 @@ def test_init_aiplatform_with_encryption_key_name_and_create_endpoint( create_endpoint_mock.assert_called_once_with( parent=_TEST_PARENT, endpoint=expected_endpoint, + endpoint_id=None, metadata=(), timeout=None, ) @@ -573,6 +574,7 @@ def test_create(self, create_endpoint_mock, sync): create_endpoint_mock.assert_called_once_with( parent=_TEST_PARENT, endpoint=expected_endpoint, + endpoint_id=None, metadata=(), timeout=None, ) @@ -580,6 +582,31 @@ def test_create(self, create_endpoint_mock, sync): expected_endpoint.name = _TEST_ENDPOINT_NAME assert my_endpoint._gca_resource == expected_endpoint + @pytest.mark.usefixtures("get_endpoint_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_create_with_endpoint_id(self, create_endpoint_mock, sync): + my_endpoint = models.Endpoint.create( + display_name=_TEST_DISPLAY_NAME, + endpoint_id=_TEST_ID, + description=_TEST_DESCRIPTION, + sync=sync, + create_request_timeout=None, + ) + if not sync: + my_endpoint.wait() + + expected_endpoint = gca_endpoint.Endpoint( + display_name=_TEST_DISPLAY_NAME, + description=_TEST_DESCRIPTION, + ) + create_endpoint_mock.assert_called_once_with( + parent=_TEST_PARENT, + endpoint=expected_endpoint, + endpoint_id=_TEST_ID, + metadata=(), + timeout=None, + ) + @pytest.mark.usefixtures("get_endpoint_mock") @pytest.mark.parametrize("sync", [True, False]) def test_create_with_timeout(self, create_endpoint_mock, sync): @@ -599,6 +626,7 @@ def test_create_with_timeout(self, create_endpoint_mock, sync): create_endpoint_mock.assert_called_once_with( parent=_TEST_PARENT, endpoint=expected_endpoint, + endpoint_id=None, metadata=(), timeout=180.0, ) @@ -642,6 +670,7 @@ def test_create_with_description(self, create_endpoint_mock, sync): create_endpoint_mock.assert_called_once_with( parent=_TEST_PARENT, endpoint=expected_endpoint, + endpoint_id=None, metadata=(), timeout=None, ) @@ -665,6 +694,7 @@ def test_create_with_labels(self, create_endpoint_mock, sync): create_endpoint_mock.assert_called_once_with( parent=_TEST_PARENT, endpoint=expected_endpoint, + endpoint_id=None, metadata=(), timeout=None, )