Skip to content

Commit

Permalink
fix: Resource created by _construct_sdk_resource_from_gapic should …
Browse files Browse the repository at this point in the history
…use the project from the resource name instead of the default project.

This fixes the following failing assertion:
```
assert aiplatform.Endpoint._construct_sdk_resource_from_gapic(
    aiplatform_models.gca_endpoint_compat.Endpoint(name='projects/my-project/locations/us-central1/publishers/google/models/text-bison@001')
).project == "my-project"
```

PiperOrigin-RevId: 536839787
  • Loading branch information
Ark-kun authored and copybara-github committed May 31, 2023
1 parent 50e0898 commit 162b2f2
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 23 deletions.
7 changes: 6 additions & 1 deletion google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,13 @@ def _construct_sdk_resource_from_gapic(
VertexAiResourceNoun:
An initialized SDK object that represents GAPIC type.
"""
resource_name_parts = utils.extract_project_and_location_from_parent(
gapic_resource.name
)
sdk_resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
project=resource_name_parts.get("project") or project,
location=resource_name_parts.get("location") or location,
credentials=credentials,
)
sdk_resource._gca_resource = gapic_resource
return sdk_resource
Expand Down
18 changes: 10 additions & 8 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,12 +560,13 @@ def _construct_sdk_resource_from_gapic(
Endpoint (aiplatform.Endpoint):
An initialized Endpoint resource.
"""
endpoint = cls._empty_constructor(
project=project, location=location, credentials=credentials
endpoint = super()._construct_sdk_resource_from_gapic(
gapic_resource=gapic_resource,
project=project,
location=location,
credentials=credentials,
)

endpoint._gca_resource = gapic_resource

endpoint._prediction_client = cls._instantiate_prediction_client(
location=endpoint.location,
credentials=credentials,
Expand Down Expand Up @@ -2021,12 +2022,13 @@ def _construct_sdk_resource_from_gapic(
"Cannot import the urllib3 HTTP client. Please install google-cloud-aiplatform[private_endpoints]."
)

endpoint = cls._empty_constructor(
project=project, location=location, credentials=credentials
endpoint = super()._construct_sdk_resource_from_gapic(
gapic_resource=gapic_resource,
project=project,
location=location,
credentials=credentials,
)

endpoint._gca_resource = gapic_resource

endpoint._http_client = urllib3.PoolManager()

return endpoint
Expand Down
14 changes: 0 additions & 14 deletions google/cloud/aiplatform/preview/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,20 +431,6 @@ def list(
credentials=credentials,
)

@classmethod
def _construct_sdk_resource_from_gapic(
cls,
gapic_resource: proto.Message,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "models.DeploymentResourcePool":
drp = cls._empty_constructor(
project=project, location=location, credentials=credentials
)
drp._gca_resource = gapic_resource
return drp


class Endpoint(aiplatform.Endpoint):
@staticmethod
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,3 +2202,23 @@ def test_delete_with_force(self, sdk_undeploy_mock, delete_endpoint_mock, sync):
def test_list(self):
ep_list = aiplatform.PrivateEndpoint.list()
assert ep_list # Ensure list is not empty

def test_construct_sdk_resource_from_gapic_uses_resource_project(self):
PROJECT = "my-project"
LOCATION = "me-west1"
endpoint_name = f"projects/{PROJECT}/locations/{LOCATION}/endpoints/123"
endpoint = aiplatform.Endpoint._construct_sdk_resource_from_gapic(
models.gca_endpoint_compat.Endpoint(name=endpoint_name)
)
assert endpoint.project == PROJECT
assert endpoint.location == LOCATION
assert endpoint.project != _TEST_PROJECT
assert endpoint.location != _TEST_LOCATION

endpoint2 = aiplatform.Endpoint._construct_sdk_resource_from_gapic(
models.gca_endpoint_compat.Endpoint(name=endpoint_name),
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
assert endpoint2.project != _TEST_PROJECT
assert endpoint2.location != _TEST_LOCATION

0 comments on commit 162b2f2

Please sign in to comment.