diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 00bd9e035648..9c7a70ceed74 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 6 + "modification": 7 } diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index 6fe8320e758b..6df505508ae9 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -19,20 +19,27 @@ # Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long # to install Vertex AI Python SDK. +import logging +import time from collections.abc import Iterable from collections.abc import Sequence from typing import Any from typing import Optional +from google.api_core.exceptions import ServerError +from google.api_core.exceptions import TooManyRequests from google.auth.credentials import Credentials import apache_beam as beam import vertexai +from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler +from apache_beam.metrics.metric import Metrics from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import RunInference from apache_beam.ml.transforms.base import EmbeddingsManager from apache_beam.ml.transforms.base import _ImageEmbeddingHandler from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.utils import retry from vertexai.language_models import TextEmbeddingInput from vertexai.language_models import TextEmbeddingModel from vertexai.vision_models import Image @@ -51,6 +58,26 @@ "CLUSTERING" ] _BATCH_SIZE = 5 # Vertex AI limits requests to 5 at a time. +_MSEC_TO_SEC = 1000 + +LOGGER = logging.getLogger("VertexAIEmbeddings") + + +def _retry_on_appropriate_gcp_error(exception): + """ + Retry filter that returns True if a returned HTTP error code is 5xx or 429. + This is used to retry remote requests that fail, most notably 429 + (TooManyRequests.) + + Args: + exception: the returned exception encountered during the request/response + loop. + + Returns: + boolean indication whether or not the exception is a Server Error (5xx) or + a TooManyRequests (429) error. + """ + return isinstance(exception, (TooManyRequests, ServerError)) class _VertexAITextEmbeddingHandler(ModelHandler): @@ -74,6 +101,41 @@ def __init__( self.task_type = task_type self.title = title + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + text_batch: Sequence[TextEmbeddingInput], + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(text_batch) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[str], @@ -89,7 +151,8 @@ def run_inference( text=text, title=self.title, task_type=self.task_type) for text in text_batch ] - embeddings_batch = model.get_embeddings(text_batch) + embeddings_batch = self.get_request( + text_batch=text_batch, model=model, throttle_delay_secs=5) embeddings.extend([el.values for el in embeddings_batch]) return embeddings @@ -173,6 +236,41 @@ def __init__( self.model_name = model_name self.dimension = dimension + # Configure AdaptiveThrottler and throttling metrics for client-side + # throttling behavior. + # See https://docs.google.com/document/d/1ePorJGZnLbNCmLD9mR7iFYOdPsyDA1rDnTpYnbdrzSU/edit?usp=sharing + # for more details. + self.throttled_secs = Metrics.counter( + VertexAIImageEmbeddings, "cumulativeThrottlingSeconds") + self.throttler = AdaptiveThrottler( + window_ms=1, bucket_ms=1, overload_ratio=2) + + @retry.with_exponential_backoff( + num_retries=5, retry_filter=_retry_on_appropriate_gcp_error) + def get_request( + self, + img: Image, + model: MultiModalEmbeddingModel, + throttle_delay_secs: int): + while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): + LOGGER.info( + "Delaying request for %d seconds due to previous failures", + throttle_delay_secs) + time.sleep(throttle_delay_secs) + self.throttled_secs.inc(throttle_delay_secs) + + try: + req_time = time.time() + prediction = model.get_embeddings(image=img, dimension=self.dimension) + self.throttler.successful_request(req_time * _MSEC_TO_SEC) + return prediction + except TooManyRequests as e: + LOGGER.warning("request was limited by the service with code %i", e.code) + raise + except Exception as e: + LOGGER.error("unexpected exception raised as part of request, got %s", e) + raise + def run_inference( self, batch: Sequence[Image], @@ -182,8 +280,7 @@ def run_inference( embeddings = [] # Maximum request size for muli-model embedding models is 1. for img in batch: - embedding_response = model.get_embeddings( - image=img, dimension=self.dimension) + embedding_response = self.get_request(img, model, throttle_delay_secs=5) embeddings.append(embedding_response.image_embedding) return embeddings