Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into owl-bot-copy
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcology authored Jan 12, 2024
2 parents 24299c0 + 1fbf049 commit 9e9bb6a
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def __init__(
if self.public_endpoint_domain_name:
self._public_match_client = self._instantiate_public_match_client()

self._match_grpc_stub_cache = {}
self._private_service_connect_ip_address = None

@classmethod
def create(
cls,
Expand Down Expand Up @@ -521,40 +524,85 @@ def _instantiate_public_match_client(

def _instantiate_private_match_service_stub(
self,
deployed_index_id: str,
deployed_index_id: Optional[str] = None,
ip_address: Optional[str] = None,
) -> match_service_pb2_grpc.MatchServiceStub:
"""Helper method to instantiate private match service stub.
Args:
deployed_index_id (str):
Required. The user specified ID of the
DeployedIndex.
Optional. Required for private service access endpoint.
The user specified ID of the DeployedIndex.
ip_address (str):
Optional. Required for private service connect. The ip address
the forwarding rule makes use of.
Returns:
stub (match_service_pb2_grpc.MatchServiceStub):
Initialized match service stub.
Raises:
RuntimeError: No deployed index with id deployed_index_id found
ValueError: Should not set ip address for networks other than
private service connect.
"""
# Find the deployed index by id
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]
if ip_address:
# Should only set for Private Service Connect
if self.public_endpoint_domain_name:
raise ValueError(
"MatchingEngineIndexEndpoint is set to use ",
"public network. Could not establish connection using "
"provided ip address",
)
elif self.private_service_access_network:
raise ValueError(
"MatchingEngineIndexEndpoint is set to use ",
"private service access network. Could not establish "
"connection using provided ip address",
)
else:
# Private Service Access, find server ip for deployed index
deployed_indexes = [
deployed_index
for deployed_index in self.deployed_indexes
if deployed_index.id == deployed_index_id
]

if not deployed_indexes:
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
if not deployed_indexes:
raise RuntimeError(
f"No deployed index with id '{deployed_index_id}' found"
)

# Retrieve server ip from deployed index
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
# Retrieve server ip from deployed index
ip_address = deployed_indexes[0].private_endpoints.match_grpc_address

# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(server_ip))
return match_service_pb2_grpc.MatchServiceStub(channel)
if ip_address not in self._match_grpc_stub_cache:
# Set up channel and stub
channel = grpc.insecure_channel("{}:10000".format(ip_address))
self._match_grpc_stub_cache[
ip_address
] = match_service_pb2_grpc.MatchServiceStub(channel)
return self._match_grpc_stub_cache[ip_address]

@property
def public_endpoint_domain_name(self) -> Optional[str]:
"""Public endpoint DNS name."""
self._assert_gca_resource_is_available()
return self._gca_resource.public_endpoint_domain_name

@property
def private_service_access_network(self) -> Optional[str]:
""" "Private service access network."""
self._assert_gca_resource_is_available()
return self._gca_resource.network

@property
def private_service_connect_ip_address(self) -> Optional[str]:
""" "Private service connect ip address."""
return self._private_service_connect_ip_address

@private_service_connect_ip_address.setter
def private_service_connect_ip_address(self, ip_address: str) -> Optional[str]:
""" "Setter for private service connect ip address."""
self._private_service_connect_ip_address = ip_address

def update(
self,
display_name: str,
Expand Down Expand Up @@ -1214,6 +1262,7 @@ def find_neighbors(
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
return_full_datapoint=return_full_datapoint,
)

# Create the FindNeighbors request
Expand Down Expand Up @@ -1300,7 +1349,8 @@ def read_index_datapoints(
if not self._public_match_client:
# Call private match service stub with BatchGetEmbeddings request
embeddings = self._batch_get_embeddings(
deployed_index_id=deployed_index_id, ids=ids
deployed_index_id=deployed_index_id,
ids=ids,
)

response = []
Expand Down Expand Up @@ -1362,7 +1412,8 @@ def _batch_get_embeddings(
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
"""
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
deployed_index_id=deployed_index_id,
ip_address=self._private_service_connect_ip_address,
)

# Create the batch get embeddings request
Expand All @@ -1384,6 +1435,7 @@ def match(
per_crowding_attribute_num_neighbors: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index for private endpoint only.
Expand Down Expand Up @@ -1415,12 +1467,18 @@ def match(
query time allows user to tune search performance. This value
increase result in both search accuracy and latency increase.
The value should be between 0.0 and 1.0.
return_full_datapoint (bool):
Optional. If set to true, the full datapoints (including all
vector values and of the nearest neighbors are returned.
Note that returning full datapoint will significantly increase the
latency and cost of the query.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""
stub = self._instantiate_private_match_service_stub(
deployed_index_id=deployed_index_id
deployed_index_id=deployed_index_id,
ip_address=self._private_service_connect_ip_address,
)

# Create the batch match request
Expand Down Expand Up @@ -1451,6 +1509,7 @@ def match(
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
embedding_enabled=return_full_datapoint,
)
requests.append(request)

Expand Down
16 changes: 8 additions & 8 deletions tests/system/aiplatform/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_API_ENDPOINT = f"{_TEST_LOCATION}-aiplatform.googleapis.com"
_TEST_IMAGE_DATASET_ID = "1084241610289446912" # permanent_50_flowers_dataset
_TEST_IMAGE_DATASET_ID = "1997950066622464000" # permanent_50_flowers_dataset
_TEST_TEXT_DATASET_ID = (
"6203215905493614592" # permanent_text_entity_extraction_dataset
)
Expand Down Expand Up @@ -390,24 +390,24 @@ def test_export_data_for_custom_training(self, staging_bucket):
# Custom training data export should be generic, hence using the base
# _Dataset class here in test. In practice, users shuold be able to
# use this function in any inhericted classes of _Dataset.
dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_TEXT_DATASET_ID)
dataset = aiplatform.datasets._Dataset(dataset_name=_TEST_IMAGE_DATASET_ID)

split = {
"training_fraction": 0.6,
"validation_fraction": 0.2,
"test_fraction": 0.2,
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
}

export_data_response = dataset.export_data_for_custom_training(
output_dir=f"gs://{staging_bucket.name}",
annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml",
annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml",
split=split,
)

# Ensure three output paths (training, validation and test) are provided
assert len(export_data_response["exported_files"]) == 3
# Ensure data stats are calculated and present
assert export_data_response["data_stats"]["training_data_items_count"] > 0
# Ensure data stats are calculated and correct
assert export_data_response["data_stats"]["training_data_items_count"] == 40

def test_update_dataset(self):
"""Create a new dataset and use update() method to change its display_name, labels, and description.
Expand Down
2 changes: 2 additions & 0 deletions tests/system/aiplatform/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state):
def test_mdm_notification_channel_alert_config(self, shared_state):
self.endpoint = shared_state["resources"][0]
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
# Reset objective_config.explanation_config
objective_config.explanation_config = None
# test model monitoring configurations
job = aiplatform.ModelDeploymentMonitoringJob.create(
display_name=self._make_display_name(key=JOB_NAME),
Expand Down
145 changes: 145 additions & 0 deletions tests/system/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-

# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=protected-access, g-multiple-import
"""System tests for generative models."""

import pytest

# Google imports
from google import auth
from google.cloud import aiplatform
from tests.system.aiplatform import e2e_base
from vertexai.preview import generative_models


class TestGenerativeModels(e2e_base.TestEndToEnd):
"""System tests for generative models."""

_temp_prefix = "temp_generative_models_test_"

def setup_method(self):
super().setup_method()
credentials, _ = auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
aiplatform.init(
project=e2e_base._PROJECT,
location=e2e_base._LOCATION,
credentials=credentials,
)

def test_generate_content_from_text(self):
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Why is sky blue?")
assert response.text

@pytest.mark.asyncio
async def test_generate_content_async(self):
model = generative_models.GenerativeModel("gemini-pro")
response = await model.generate_content_async("Why is sky blue?")
assert response.text

def test_generate_content_streaming(self):
model = generative_models.GenerativeModel("gemini-pro")
stream = model.generate_content("Why is sky blue?", stream=True)
for chunk in stream:
assert chunk.text

@pytest.mark.asyncio
async def test_generate_content_streaming_async(self):
model = generative_models.GenerativeModel("gemini-pro")
async_stream = await model.generate_content_async(
"Why is sky blue?",
stream=True,
)
async for chunk in async_stream:
assert chunk.text

def test_generate_content_with_parameters(self):
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content(
contents="Why is sky blue?",
generation_config=generative_models.GenerationConfig(
temperature=0.1,
top_p=0.95,
top_k=20,
candidate_count=1,
max_output_tokens=100,
stop_sequences=["STOP!"],
),
safety_settings={
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE,
},
)
assert response.text

def test_generate_content_from_list_of_content_dict(self):
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content(
contents=[{"role": "user", "parts": [{"text": "Why is sky blue?"}]}]
)
assert response.text

def test_generate_content_from_remote_image(self):
vision_model = generative_models.GenerativeModel("gemini-pro-vision")
image_part = generative_models.Part.from_uri(
uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg",
mime_type="image/jpeg",
)
response = vision_model.generate_content(image_part)
assert response.text
assert "cat" in response.text

def test_generate_content_from_text_and_remote_image(self):
vision_model = generative_models.GenerativeModel("gemini-pro-vision")
image_part = generative_models.Part.from_uri(
uri="gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg",
mime_type="image/jpeg",
)
response = vision_model.generate_content(
contents=["What is shown in this image?", image_part],
)
assert response.text
assert "cat" in response.text

def test_generate_content_from_text_and_remote_video(self):
vision_model = generative_models.GenerativeModel("gemini-pro-vision")
video_part = generative_models.Part.from_uri(
uri="gs://cloud-samples-data/video/animals.mp4",
mime_type="video/mp4",
)
response = vision_model.generate_content(
contents=["What is in the video?", video_part],
)
assert response.text
assert "Zootopia" in response.text

# Chat

def test_send_message_from_text(self):
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
response1 = chat.send_message("I really like fantasy books.")
assert response1.text
assert len(chat.history) == 2

response2 = chat.send_message("What things do I like?.")
assert response2.text
assert len(chat.history) == 4
Loading

0 comments on commit 9e9bb6a

Please sign in to comment.