From 801252ab613dd6601f23918cc46a28cded4b0d98 Mon Sep 17 00:00:00 2001 From: Ritesh Ghorse Date: Tue, 20 Feb 2024 12:22:16 -0500 Subject: [PATCH] [Python] Redis cache support for enrichment transform (#30307) * redis cache support for enrichment transform * update rrio test * convert rrio tests to execute pipeline, tagged outputs, cache_request_key * fix rrio exception test, add to changes * get_cache_key, doc comments, changed __all__ * fix pydoc * use pool manager * revert test * use property() for setter --- CHANGES.md | 2 +- sdks/python/apache_beam/io/requestresponse.py | 595 +++++++++++++++--- .../apache_beam/io/requestresponse_it_test.py | 192 ++++-- .../apache_beam/io/requestresponse_test.py | 3 +- .../io/requestresponse_tests_requirements.txt | 18 + .../apache_beam/transforms/enrichment.py | 85 ++- .../enrichment_handlers/bigtable.py | 10 +- .../enrichment_handlers/bigtable_it_test.py | 94 ++- .../transforms/enrichment_it_test.py | 2 +- sdks/python/pytest.ini | 2 + sdks/python/scripts/generate_pydoc.sh | 6 +- .../python/test-suites/dataflow/common.gradle | 30 + sdks/python/test-suites/direct/common.gradle | 28 + 13 files changed, 909 insertions(+), 158 deletions(-) create mode 100644 sdks/python/apache_beam/io/requestresponse_tests_requirements.txt diff --git a/CHANGES.md b/CHANGES.md index 7a460962a33e..ea59c3e964a4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -65,8 +65,8 @@ ## New Features / Improvements -* [Enrichment Transform](https://s.apache.org/enrichment-transform) along with GCP BigTable handler added to Python SDK ([#30001](https://github.com/apache/beam/pull/30001)). * Allow writing clustered and not time partitioned BigQuery tables (Java) ([#30094](https://github.com/apache/beam/pull/30094)). +* Redis cache support added to RequestResponseIO and Enrichment transform (Python) ([#30307](https://github.com/apache/beam/pull/30307)) * Merged sdks/java/fn-execution and runners/core-construction-java into the main SDK. These artifacts were never meant for users, but noting that they no longer exist. These are steps to bring portability into the core SDK alongside all other core functionality. diff --git a/sdks/python/apache_beam/io/requestresponse.py b/sdks/python/apache_beam/io/requestresponse.py index 63ec7061d3e5..706bce95f5ee 100644 --- a/sdks/python/apache_beam/io/requestresponse.py +++ b/sdks/python/apache_beam/io/requestresponse.py @@ -19,16 +19,26 @@ import abc import concurrent.futures import contextlib +import enum +import json import logging import sys import time +from datetime import timedelta +from typing import Any +from typing import Dict from typing import Generic from typing import Optional +from typing import Tuple from typing import TypeVar +from typing import Union from google.api_core.exceptions import TooManyRequests import apache_beam as beam +import redis +from apache_beam import pvalue +from apache_beam.coders import coders from apache_beam.io.components.adaptive_throttler import AdaptiveThrottler from apache_beam.metrics import Metrics from apache_beam.ml.inference.vertex_ai_inference import MSEC_TO_SEC @@ -37,10 +47,24 @@ RequestT = TypeVar('RequestT') ResponseT = TypeVar('ResponseT') -DEFAULT_TIMEOUT_SECS = 30 # seconds +# DEFAULT_TIMEOUT_SECS represents the time interval for completing the request +# with external source. +DEFAULT_TIMEOUT_SECS = 30 + +# DEFAULT_CACHE_ENTRY_TTL_SEC represents the total time-to-live +# for cache record. +DEFAULT_CACHE_ENTRY_TTL_SEC = 24 * 60 * 60 _LOGGER = logging.getLogger(__name__) +__all__ = [ + 'RequestResponseIO', + 'ExponentialBackOffRepeater', + 'DefaultThrottler', + 'NoOpsRepeater', + 'RedisCache', +] + class UserCodeExecutionException(Exception): """Base class for errors related to calling Web APIs.""" @@ -90,6 +114,7 @@ class Caller(contextlib.AbstractContextManager, abc.ABC, Generic[RequestT, ResponseT]): """Interface for user custom code intended for API calls. + For setup and teardown of clients when applicable, implement the ``__enter__`` and ``__exit__`` methods respectively.""" @abc.abstractmethod @@ -107,16 +132,27 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): return None + def get_cache_key(self, request: RequestT) -> str: + """Returns the request to be cached. + + This is how the response will be looked up in the cache as well. + By default, entire request is cached as the key for the cache. + Implement this method to override the key for the cache. + For example, in `BigTableEnrichmentHandler`, the row key for the element + is returned here. + """ + return "" + class ShouldBackOff(abc.ABC): """ - ShouldBackOff provides mechanism to apply adaptive throttling. + Provides mechanism to apply adaptive throttling. """ pass class Repeater(abc.ABC): - """Repeater provides mechanism to repeat requests for a + """Provides mechanism to repeat requests for a configurable condition.""" @abc.abstractmethod def repeat( @@ -125,17 +161,17 @@ def repeat( request: RequestT, timeout: float, metrics_collector: Optional[_MetricsCollector]) -> ResponseT: - """repeat method is called from the RequestResponseIO when - a repeater is enabled. + """Implements a repeater strategy for RequestResponseIO when a repeater + is enabled. Args: - caller: :class:`apache_beam.io.requestresponse.Caller` object that calls - the API. + caller: a `~apache_beam.io.requestresponse.Caller` object that + calls the API. request: input request to repeat. timeout: time to wait for the request to complete. metrics_collector: (Optional) a - ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to - collect the metrics for RequestResponseIO. + `~apache_beam.io.requestresponse._MetricsCollector` object + to collect the metrics for RequestResponseIO. """ pass @@ -167,9 +203,10 @@ def _execute_request( class ExponentialBackOffRepeater(Repeater): - """Exponential BackOff Repeater uses exponential backoff retry strategy for - exceptions due to the remote service such as TooManyRequests (HTTP 429), - UserCodeTimeoutException, UserCodeQuotaException. + """Configure exponential backoff retry strategy. + + It retries for exceptions due to the remote service such as + TooManyRequests (HTTP 429), UserCodeTimeoutException, UserCodeQuotaException. It utilizes the decorator :func:`apache_beam.utils.retry.with_exponential_backoff`. @@ -189,20 +226,19 @@ def repeat( a repeater is enabled. Args: - caller: :class:`apache_beam.io.requestresponse.Caller` object that + caller: a `~apache_beam.io.requestresponse.Caller` object that calls the API. request: input request to repeat. timeout: time to wait for the request to complete. metrics_collector: (Optional) a - ``:class:`apache_beam.io.requestresponse._MetricsCollector``` object to + `~apache_beam.io.requestresponse._MetricsCollector` object to collect the metrics for RequestResponseIO. """ return _execute_request(caller, request, timeout, metrics_collector) class NoOpsRepeater(Repeater): - """ - NoOpsRepeater executes a request just once irrespective of any exception. + """Executes a request just once irrespective of any exception. """ def repeat( self, @@ -213,18 +249,8 @@ def repeat( return _execute_request(caller, request, timeout, metrics_collector) -class CacheReader(abc.ABC): - """CacheReader provides mechanism to read from the cache.""" - pass - - -class CacheWriter(abc.ABC): - """CacheWriter provides mechanism to write to the cache.""" - pass - - class PreCallThrottler(abc.ABC): - """PreCallThrottler provides a throttle mechanism before sending request.""" + """Provides a throttle mechanism before sending request.""" pass @@ -251,75 +277,16 @@ def __init__( self.delay_secs = delay_secs -class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], - beam.PCollection[ResponseT]]): - """A :class:`RequestResponseIO` transform to read and write to APIs. - - Processes an input :class:`~apache_beam.pvalue.PCollection` of requests - by making a call to the API as defined in :class:`Caller`'s `__call__` - and returns a :class:`~apache_beam.pvalue.PCollection` of responses. - """ - def __init__( - self, - caller: Caller[RequestT, ResponseT], - timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, - should_backoff: Optional[ShouldBackOff] = None, - repeater: Repeater = ExponentialBackOffRepeater(), - cache_reader: Optional[CacheReader] = None, - cache_writer: Optional[CacheWriter] = None, - throttler: PreCallThrottler = DefaultThrottler(), - ): - """ - Instantiates a RequestResponseIO transform. - - Args: - caller (~apache_beam.io.requestresponse.Caller): an implementation of - `Caller` object that makes call to the API. - timeout (float): timeout value in seconds to wait for response from API. - should_backoff (~apache_beam.io.requestresponse.ShouldBackOff): - (Optional) provides methods for backoff. - repeater (~apache_beam.io.requestresponse.Repeater): provides method to - repeat failed requests to API due to service errors. Defaults to - :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to - repeat requests with exponential backoff. - cache_reader (~apache_beam.io.requestresponse.CacheReader): (Optional) - provides methods to read external cache. - cache_writer (~apache_beam.io.requestresponse.CacheWriter): (Optional) - provides methods to write to external cache. - throttler (~apache_beam.io.requestresponse.PreCallThrottler): - provides methods to pre-throttle a request. Defaults to - :class:`apache_beam.io.requestresponse.DefaultThrottler` for - client-side adaptive throttling using - :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler` - """ - self._caller = caller - self._timeout = timeout - self._should_backoff = should_backoff - if repeater: - self._repeater = repeater - else: - self._repeater = NoOpsRepeater() - self._cache_reader = cache_reader - self._cache_writer = cache_writer - self._throttler = throttler +class _FilterCacheReadFn(beam.DoFn): + """A `DoFn` that partitions cache reads. - def expand( - self, - requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: - # TODO(riteshghorse): handle Cache and Throttle PTransforms when available. - if isinstance(self._throttler, DefaultThrottler): - return requests | _Call( - caller=self._caller, - timeout=self._timeout, - should_backoff=self._should_backoff, - repeater=self._repeater, - throttler=self._throttler) + It emits to main output for successful cache read requests or + to the tagged output - `cache_misses` - otherwise.""" + def process(self, element: Tuple[RequestT, ResponseT], *args, **kwargs): + if not element[1]: + yield pvalue.TaggedOutput('cache_misses', element[0]) else: - return requests | _Call( - caller=self._caller, - timeout=self._timeout, - should_backoff=self._should_backoff, - repeater=self._repeater) + yield element class _Call(beam.PTransform[beam.PCollection[RequestT], @@ -333,15 +300,11 @@ class _Call(beam.PTransform[beam.PCollection[RequestT], regulate the duration of each call, defaults to 30 seconds. Args: - caller (:class:`apache_beam.io.requestresponse.Caller`): a callable - object that invokes API call. + caller: a `Caller` object that invokes API call. timeout (float): timeout value in seconds to wait for response from API. - should_backoff (~apache_beam.io.requestresponse.ShouldBackOff): - (Optional) provides methods for backoff. - repeater (~apache_beam.io.requestresponse.Repeater): (Optional) provides - methods to repeat requests to API. - throttler (~apache_beam.io.requestresponse.PreCallThrottler): - (Optional) provides methods to pre-throttle a request. + should_backoff: (Optional) provides methods for backoff. + repeater: (Optional) provides methods to repeat requests to API. + throttler: (Optional) provides methods to pre-throttle a request. """ def __init__( self, @@ -411,3 +374,431 @@ def process(self, request: RequestT, *args, **kwargs): def teardown(self): self._metrics_collector.teardown_counter.inc(1) self._caller.__exit__(*sys.exc_info()) + + +class Cache(abc.ABC): + """Base Cache class for + :class:`apache_beam.io.requestresponse.RequestResponseIO`. + + For adding cache support to RequestResponseIO, implement this class. + """ + @abc.abstractmethod + def get_read(self): + """returns a PTransform that reads from the cache.""" + pass + + @abc.abstractmethod + def get_write(self): + """returns a PTransform that writes to the cache.""" + pass + + @property + @abc.abstractmethod + def request_coder(self): + """request coder to use with Cache.""" + pass + + @request_coder.setter + @abc.abstractmethod + def request_coder(self, request_coder: coders.Coder): + """sets the request coder to use with Cache.""" + pass + + @property + @abc.abstractmethod + def source_caller(self): + """Actual caller that is using the cache.""" + pass + + @source_caller.setter + @abc.abstractmethod + def source_caller(self, caller: Caller): + """Sets the source caller for + :class:`apache_beam.io.requestresponse.RequestResponseIO` to pull + cache request key from respective callers.""" + pass + + +class _RedisMode(enum.Enum): + """ + Mode of operation for redis cache when using + `~apache_beam.io.requestresponse._RedisCaller`. + """ + READ = 0 + WRITE = 1 + + +class _RedisCaller(Caller): + """An implementation of + `~apache_beam.io.requestresponse.Caller` for Redis client. + + It provides the functionality for making requests to Redis server using + :class:`apache_beam.io.requestresponse.RequestResponseIO`. + """ + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta], + *, + request_coder: Optional[coders.Coder], + response_coder: Optional[coders.Coder], + kwargs: Optional[Dict[str, Any]] = None, + source_caller: Optional[Caller] = None, + mode: _RedisMode, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + kwargs: Optional(Dict[str, Any]) additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + source_caller: (Optional[`Caller`]): The source caller using this Redis + cache in case of fetching the cache request to store in Redis. + mode: `_RedisMode` An enum type specifying the operational mode of + the `_RedisCaller`. + """ + self.host, self.port = host, port + self.time_to_live = time_to_live + self.request_coder = request_coder + self.response_coder = response_coder + self.kwargs = kwargs + self.source_caller = source_caller + self.mode = mode + + def __enter__(self): + self.client = redis.Redis(self.host, self.port, **self.kwargs) + + def __call__(self, element, *args, **kwargs): + if self.mode == _RedisMode.READ: + cache_request = self.source_caller.get_cache_key(element) + # check if the caller is a enrichment handler. EnrichmentHandler + # provides the request format for cache. + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element) + + encoded_response = self.client.get(encoded_request) + if not encoded_response: + # no cache entry present for this request. + return element, None + + if self.response_coder is None: + try: + response_dict = json.loads(encoded_response.decode('utf-8')) + response = beam.Row(**response_dict) + except Exception: + _LOGGER.warning( + 'cannot decode response from redis cache for %s.' % element) + return element, None + else: + response = self.response_coder.decode(encoded_response) + return element, response + else: + cache_request = self.source_caller.get_cache_key(element[0]) + if cache_request: + encoded_request = self.request_coder.encode(cache_request) + else: + encoded_request = self.request_coder.encode(element[0]) + if self.response_coder is None: + try: + encoded_response = json.dumps(element[1]._asdict()).encode('utf-8') + except Exception: + _LOGGER.warning( + 'cannot encode response %s for %s to store in ' + 'redis cache.' % (element[1], element[0])) + return element + else: + encoded_response = self.response_coder.encode(element[1]) + # Write to cache with TTL. Set nx to True to prevent overwriting for the + # same key. + self.client.set( + encoded_request, encoded_response, self.time_to_live, nx=True) + return element + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + +class _ReadFromRedis(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """A `PTransform` that performs Redis cache read.""" + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta], + *, + kwargs: Optional[Dict[str, Any]] = None, + request_coder: Optional[coders.Coder], + response_coder: Optional[coders.Coder], + source_caller: Optional[Caller[RequestT, ResponseT]] = None, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + kwargs: Optional(Dict[str, Any]) additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + source_caller: (Optional[`Caller`]): The source caller using this Redis + cache in case of fetching the cache request to store in Redis. + """ + self.request_coder = request_coder + self.response_coder = response_coder + self.redis_caller = _RedisCaller( + host, + port, + time_to_live, + request_coder=self.request_coder, + response_coder=self.response_coder, + kwargs=kwargs, + source_caller=source_caller, + mode=_RedisMode.READ) + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + return requests | RequestResponseIO(self.redis_caller) + + +class _WriteToRedis(beam.PTransform[beam.PCollection[Tuple[RequestT, + ResponseT]], + beam.PCollection[ResponseT]]): + """A `PTransfrom` that performs write to Redis cache.""" + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta], + *, + kwargs: Optional[Dict[str, Any]] = None, + request_coder: Optional[coders.Coder], + response_coder: Optional[coders.Coder], + source_caller: Optional[Caller[RequestT, ResponseT]] = None, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + kwargs: Optional(Dict[str, Any]) additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + source_caller: (Optional[`Caller`]): The source caller using this Redis + cache in case of fetching the cache request to store in Redis. + """ + self.request_coder = request_coder + self.response_coder = response_coder + self.redis_caller = _RedisCaller( + host, + port, + time_to_live, + request_coder=self.request_coder, + response_coder=self.response_coder, + kwargs=kwargs, + source_caller=source_caller, + mode=_RedisMode.WRITE) + + def expand( + self, elements: beam.PCollection[Tuple[RequestT, ResponseT]] + ) -> beam.PCollection[ResponseT]: + return elements | RequestResponseIO(self.redis_caller) + + +def ensure_coders_exist(request_coder): + """checks if the coder exists to encode the request for caching.""" + if not request_coder: + raise ValueError( + 'need request coder to be able to use ' + 'Cache with RequestResponseIO.') + + +class RedisCache(Cache): + """Configure cache using Redis for + :class:`apache_beam.io.requestresponse.RequestResponseIO`.""" + def __init__( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta] = DEFAULT_CACHE_ENTRY_TTL_SEC, + *, + request_coder: Optional[coders.Coder] = None, + response_coder: Optional[coders.Coder] = None, + **kwargs, + ): + """ + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + request_coder: (Optional[`coders.Coder`]) coder for encoding requests. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + kwargs: Optional additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + """ + self._host = host + self._port = port + self._time_to_live = time_to_live + self._request_coder = request_coder + self._response_coder = response_coder + self._kwargs = kwargs if kwargs else {} + self._source_caller = None + + def get_read(self): + """get_read returns a PTransform for reading from the cache.""" + ensure_coders_exist(self._request_coder) + return _ReadFromRedis( + self._host, + self._port, + time_to_live=self._time_to_live, + kwargs=self._kwargs, + request_coder=self._request_coder, + response_coder=self._response_coder, + source_caller=self._source_caller) + + def get_write(self): + """returns a PTransform for writing to the cache.""" + ensure_coders_exist(self._request_coder) + return _WriteToRedis( + self._host, + self._port, + time_to_live=self._time_to_live, + kwargs=self._kwargs, + request_coder=self._request_coder, + response_coder=self._response_coder, + source_caller=self._source_caller) + + @property + def source_caller(self): + return self._source_caller + + @source_caller.setter + def source_caller(self, source_caller: Caller): + self._source_caller = source_caller + + @property + def request_coder(self): + return self._request_coder + + @request_coder.setter + def request_coder(self, request_coder: coders.Coder): + self._request_coder = request_coder + + +class RequestResponseIO(beam.PTransform[beam.PCollection[RequestT], + beam.PCollection[ResponseT]]): + """A :class:`RequestResponseIO` transform to read and write to APIs. + + Processes an input :class:`~apache_beam.pvalue.PCollection` of requests + by making a call to the API as defined in `Caller`'s `__call__` method + and returns a :class:`~apache_beam.pvalue.PCollection` of responses. + """ + def __init__( + self, + caller: Caller[RequestT, ResponseT], + timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, + should_backoff: Optional[ShouldBackOff] = None, + repeater: Repeater = ExponentialBackOffRepeater(), + cache: Optional[Cache] = None, + throttler: PreCallThrottler = DefaultThrottler(), + ): + """ + Instantiates a RequestResponseIO transform. + + Args: + caller: an implementation of + `Caller` object that makes call to the API. + timeout (float): timeout value in seconds to wait for response from API. + should_backoff: (Optional) provides methods for backoff. + repeater: provides method to repeat failed requests to API due to service + errors. Defaults to + :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to + repeat requests with exponential backoff. + cache: (Optional) a `~apache_beam.io.requestresponse.Cache` object + to use the appropriate cache. + throttler: provides methods to pre-throttle a request. Defaults to + :class:`apache_beam.io.requestresponse.DefaultThrottler` for + client-side adaptive throttling using + :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler` + """ + self._caller = caller + self._timeout = timeout + self._should_backoff = should_backoff + if repeater: + self._repeater = repeater + else: + self._repeater = NoOpsRepeater() + self._cache = cache + self._throttler = throttler + + def expand( + self, + requests: beam.PCollection[RequestT]) -> beam.PCollection[ResponseT]: + # TODO(riteshghorse): handle Throttle PTransforms when available. + + if self._cache: + self._cache.source_caller = self._caller + + inputs = requests + + if self._cache: + # read from cache. + outputs = inputs | self._cache.get_read() + # filter responses that are None and send them to the Call transform + # to fetch a value from external service. + cached_responses, inputs = (outputs + | beam.ParDo(_FilterCacheReadFn() + ).with_outputs( + 'cache_misses', main='cached_responses')) + + if isinstance(self._throttler, DefaultThrottler): + # DefaultThrottler applies throttling in the DoFn of + # Call PTransform. + responses = ( + inputs + | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater, + throttler=self._throttler)) + else: + # No throttling mechanism. The requests are made to the external source + # as they come. + responses = ( + inputs + | _Call( + caller=self._caller, + timeout=self._timeout, + should_backoff=self._should_backoff, + repeater=self._repeater)) + + if self._cache: + # write to cache. + _ = responses | self._cache.get_write() + return (cached_responses, responses) | beam.Flatten() + + return responses diff --git a/sdks/python/apache_beam/io/requestresponse_it_test.py b/sdks/python/apache_beam/io/requestresponse_it_test.py index 396347c58d16..bd8c63dea587 100644 --- a/sdks/python/apache_beam/io/requestresponse_it_test.py +++ b/sdks/python/apache_beam/io/requestresponse_it_test.py @@ -15,6 +15,7 @@ # limitations under the License. # import base64 +import logging import sys import typing import unittest @@ -22,15 +23,19 @@ from typing import Tuple from typing import Union +import pytest import urllib3 import apache_beam as beam +from apache_beam.coders import coders from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline # pylint: disable=ungrouped-imports try: + from testcontainers.redis import RedisContainer from apache_beam.io.requestresponse import Caller + from apache_beam.io.requestresponse import RedisCache from apache_beam.io.requestresponse import RequestResponseIO from apache_beam.io.requestresponse import UserCodeExecutionException from apache_beam.io.requestresponse import UserCodeQuotaException @@ -41,6 +46,8 @@ _PAYLOAD = base64.b64encode(bytes('payload', 'utf-8')) _HTTP_ENDPOINT_ADDRESS_FLAG = '--httpEndpointAddress' +_LOGGER = logging.getLogger(__name__) + class EchoITOptions(PipelineOptions): """Shared options for running integration tests on a deployed @@ -52,6 +59,7 @@ class EchoITOptions(PipelineOptions): def _add_argparse_args(cls, parser) -> None: parser.add_argument( _HTTP_ENDPOINT_ADDRESS_FLAG, + default='http://10.138.0.32:8080', dest='http_endpoint_address', help='The HTTP address of the Echo API endpoint; must being with ' 'http(s)://') @@ -95,7 +103,8 @@ def __call__(self, request: Request, *args, **kwargs) -> EchoResponse: or a ``UserCodeQuotaException``. """ try: - resp = urllib3.request( + http = urllib3.PoolManager() + resp = http.request( "POST", self.url, json={ @@ -118,6 +127,18 @@ def __call__(self, request: Request, *args, **kwargs) -> EchoResponse: raise UserCodeExecutionException(e) +class ValidateResponse(beam.DoFn): + """Validates response received from Mock API server.""" + def process(self, element, *args, **kwargs): + if (element.id != 'echo-should-never-exceed-quota' or + element.payload != _PAYLOAD): + raise ValueError( + 'got EchoResponse(id: %s, payload: %s), want ' + 'EchoResponse(id: echo-should-never-exceed-quota, ' + 'payload: %s' % (element.id, element.payload, _PAYLOAD)) + + +@pytest.mark.uses_mock_api class EchoHTTPCallerTestIT(unittest.TestCase): options: Union[EchoITOptions, None] = None client: Union[EchoHTTPCaller, None] = None @@ -131,58 +152,157 @@ def setUpClass(cls) -> None: cls.client = EchoHTTPCaller(http_endpoint_address) - def setUp(self) -> None: - client, options = EchoHTTPCallerTestIT._get_client_and_options() - - req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD) - try: - # The following is needed to exceed the API - client(req) - client(req) - client(req) - except UserCodeExecutionException as e: - if not isinstance(e, UserCodeQuotaException): - raise e - @classmethod def _get_client_and_options(cls) -> Tuple[EchoHTTPCaller, EchoITOptions]: assert cls.options is not None assert cls.client is not None return cls.client, cls.options - def test_given_valid_request_receives_response(self): + def test_request_response_io(self): client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) + with TestPipeline(is_integration_test=True) as test_pipeline: + output = ( + test_pipeline + | 'Create PCollection' >> beam.Create([req]) + | 'RRIO Transform' >> RequestResponseIO(client) + | 'Validate' >> beam.ParDo(ValidateResponse())) + self.assertIsNotNone(output) - response: EchoResponse = client(req) - self.assertEqual(req.id, response.id) - self.assertEqual(req.payload, response.payload) +class ValidateCacheResponses(beam.DoFn): + """Validates that the responses are fetched from the cache.""" + def process(self, element, *args, **kwargs): + if not element[1] or 'cached-' not in element[1]: + raise ValueError( + 'responses not fetched from cache even though cache ' + 'entries are present.') - def test_given_exceeded_quota_should_raise(self): - client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = Request(id=options.should_exceed_quota_id, payload=_PAYLOAD) +class ValidateCallerResponses(beam.DoFn): + """Validates that the responses are fetched from the caller.""" + def process(self, element, *args, **kwargs): + if not element[1] or 'ACK-' not in element[1]: + raise ValueError('responses not fetched from caller when they should.') - self.assertRaises(UserCodeQuotaException, lambda: client(req)) - def test_not_found_should_raise(self): - client, _ = EchoHTTPCallerTestIT._get_client_and_options() +class FakeCallerForCache(Caller[str, str]): + def __init__(self, use_cache: bool = False): + self.use_cache = use_cache - req = Request(id='i-dont-exist-quota-id', payload=_PAYLOAD) - self.assertRaisesRegex( - UserCodeExecutionException, "Not Found", lambda: client(req)) + def __enter__(self): + pass - def test_request_response_io(self): - client, options = EchoHTTPCallerTestIT._get_client_and_options() - req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD) - with TestPipeline(is_integration_test=True) as test_pipeline: - output = ( + def __call__(self, element, *args, **kwargs): + if self.use_cache: + return None, None + + return element, 'ACK-{element}' + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.mark.uses_redis +class TestRedisCache(unittest.TestCase): + def setUp(self) -> None: + self.retries = 3 + self._start_container() + + def test_rrio_cache_all_miss(self): + """Cache is empty so all responses are fetched from caller.""" + caller = FakeCallerForCache() + req = ['redis', 'cachetools', 'memcache'] + cache = RedisCache( + self.host, + self.port, + time_to_live=30, + request_coder=coders.StrUtf8Coder(), + response_coder=coders.StrUtf8Coder()) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(req) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCallerResponses())) + + def test_rrio_cache_all_hit(self): + """Validate that records are fetched from cache.""" + caller = FakeCallerForCache() + requests = ['foo', 'bar'] + responses = ['cached-foo', 'cached-bar'] + coder = coders.StrUtf8Coder() + for i in range(len(requests)): + enc_req = coder.encode(requests[i]) + enc_resp = coder.encode(responses[i]) + self.client.setex(enc_req, 120, enc_resp) + cache = RedisCache( + self.host, + self.port, + time_to_live=30, + request_coder=coders.StrUtf8Coder(), + response_coder=coders.StrUtf8Coder()) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCacheResponses())) + + def test_rrio_cache_miss_and_hit(self): + """Run two back-to-back pipelines, one with pulling the data from caller + and other from the cache.""" + caller = FakeCallerForCache() + requests = ['beam', 'flink', 'spark'] + cache = RedisCache( + self.host, + self.port, + request_coder=coders.StrUtf8Coder(), + response_coder=coders.StrUtf8Coder()) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCallerResponses())) + + caller = FakeCallerForCache(use_cache=True) + with TestPipeline(is_integration_test=True) as p: + _ = ( + p + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache) + | beam.ParDo(ValidateCallerResponses())) + + def test_rrio_no_coder_exception(self): + caller = FakeCallerForCache() + requests = ['beam', 'flink', 'spark'] + cache = RedisCache(self.host, self.port) + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( test_pipeline - | 'Create PCollection' >> beam.Create([req]) - | 'RRIO Transform' >> RequestResponseIO(client)) - self.assertIsNotNone(output) + | beam.Create(requests) + | RequestResponseIO(caller, cache=cache)) + res = test_pipeline.run() + res.wait_until_finish() + + def tearDown(self) -> None: + self.container.stop() + + def _start_container(self): + for i in range(self.retries): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error('Unable to start redis container for RRIO tests.') + raise e if __name__ == '__main__': diff --git a/sdks/python/apache_beam/io/requestresponse_test.py b/sdks/python/apache_beam/io/requestresponse_test.py index 6d807c2a8eb8..cfc2fe5e668d 100644 --- a/sdks/python/apache_beam/io/requestresponse_test.py +++ b/sdks/python/apache_beam/io/requestresponse_test.py @@ -23,7 +23,8 @@ # pylint: disable=ungrouped-imports try: from google.api_core.exceptions import TooManyRequests - from apache_beam.io.requestresponse import Caller, DefaultThrottler + from apache_beam.io.requestresponse import Caller + from apache_beam.io.requestresponse import DefaultThrottler from apache_beam.io.requestresponse import RequestResponseIO from apache_beam.io.requestresponse import UserCodeExecutionException from apache_beam.io.requestresponse import UserCodeTimeoutException diff --git a/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt b/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt new file mode 100644 index 000000000000..1d8869705097 --- /dev/null +++ b/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +redis>=5.0.0 diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index a2f961be6437..93344835e930 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -15,18 +15,23 @@ # limitations under the License. # import logging +from datetime import timedelta from typing import Any from typing import Callable from typing import Dict from typing import Optional from typing import TypeVar +from typing import Union import apache_beam as beam +from apache_beam.coders import coders +from apache_beam.io.requestresponse import DEFAULT_CACHE_ENTRY_TTL_SEC from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS from apache_beam.io.requestresponse import Caller from apache_beam.io.requestresponse import DefaultThrottler from apache_beam.io.requestresponse import ExponentialBackOffRepeater from apache_beam.io.requestresponse import PreCallThrottler +from apache_beam.io.requestresponse import RedisCache from apache_beam.io.requestresponse import Repeater from apache_beam.io.requestresponse import RequestResponseIO @@ -44,8 +49,15 @@ _LOGGER = logging.getLogger(__name__) +def has_valid_redis_address(host: str, port: int) -> bool: + """returns `True` if both host and port are not `None`.""" + if host and port: + return True + return False + + def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: - """cross_join performs a cross join on two `dict` objects. + """performs a cross join on two `dict` objects. Joins the columns of the right row onto the left row. @@ -71,20 +83,29 @@ def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row: class EnrichmentSourceHandler(Caller[InputT, OutputT]): - """Wrapper class for :class:`apache_beam.io.requestresponse.Caller`. + """Wrapper class for `apache_beam.io.requestresponse.Caller`. Ensure that the implementation of ``__call__`` method returns a tuple of `beam.Row` objects. """ - pass + def get_cache_key(self, request: InputT) -> str: + """Returns the request to be cached. This is how the response will be + looked up in the cache as well. + + Implement this method to provide the key for the cache. + By default, the entire request is stored as the cache key. + + For example, in `BigTableEnrichmentHandler`, the row key for the element + is returned here. + """ + return "request: %s" % request class Enrichment(beam.PTransform[beam.PCollection[InputT], beam.PCollection[OutputT]]): """A :class:`apache_beam.transforms.enrichment.Enrichment` transform to enrich elements in a PCollection. - **NOTE:** This transform and its implementation are under development and - do not provide backward compatibility guarantees. + Uses the :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler` to enrich elements by joining the metadata from external source. @@ -100,12 +121,11 @@ class Enrichment(beam.PTransform[beam.PCollection[InputT], join_fn: A lambda function to join original element with lookup metadata. Defaults to `CROSS_JOIN`. timeout: (Optional) timeout for source requests. Defaults to 30 seconds. - repeater (~apache_beam.io.requestresponse.Repeater): provides method to - repeat failed requests to API due to service errors. Defaults to + repeater: provides method to repeat failed requests to API due to service + errors. Defaults to :class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to repeat requests with exponential backoff. - throttler (~apache_beam.io.requestresponse.PreCallThrottler): - provides methods to pre-throttle a request. Defaults to + throttler: provides methods to pre-throttle a request. Defaults to :class:`apache_beam.io.requestresponse.DefaultThrottler` for client-side adaptive throttling using :class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`. @@ -116,8 +136,8 @@ def __init__( join_fn: JoinFn = cross_join, timeout: Optional[float] = DEFAULT_TIMEOUT_SECS, repeater: Repeater = ExponentialBackOffRepeater(), - throttler: PreCallThrottler = DefaultThrottler(), - ): + throttler: PreCallThrottler = DefaultThrottler()): + self._cache = None self._source_handler = source_handler self._join_fn = join_fn self._timeout = timeout @@ -126,12 +146,55 @@ def __init__( def expand(self, input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]: + # For caching with enrichment transform, enrichment handlers provide a + # get_cache_key() method that returns a unique string formatted + # request for that row. + request_coder = coders.StrUtf8Coder() + if self._cache: + self._cache.request_coder = request_coder + fetched_data = input_row | RequestResponseIO( caller=self._source_handler, timeout=self._timeout, repeater=self._repeater, + cache=self._cache, throttler=self._throttler) # EnrichmentSourceHandler returns a tuple of (request,response). return fetched_data | beam.Map( lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict())) + + def with_redis_cache( + self, + host: str, + port: int, + time_to_live: Union[int, timedelta] = DEFAULT_CACHE_ENTRY_TTL_SEC, + *, + request_coder: Optional[coders.Coder] = None, + response_coder: Optional[coders.Coder] = None, + **kwargs, + ): + """Configure the Redis cache to use with enrichment transform. + + Args: + host (str): The hostname or IP address of the Redis server. + port (int): The port number of the Redis server. + time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for + records stored in Redis. Provide an integer (in seconds) or a + `datetime.timedelta` object. + request_coder: (Optional[`coders.Coder`]) coder for requests stored + in Redis. + response_coder: (Optional[`coders.Coder`]) coder for decoding responses + received from Redis. + kwargs: Optional additional keyword arguments that + are required to connect to your redis server. Same as `redis.Redis()`. + """ + if has_valid_redis_address(host, port): + self._cache = RedisCache( # type: ignore[assignment] + host=host, + port=port, + time_to_live=time_to_live, + request_coder=request_coder, + response_coder=response_coder, + **kwargs) + return self diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index 873dd156cb87..943000a9f6bb 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -53,9 +53,8 @@ class ExceptionLevel(Enum): class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): - """BigTableEnrichmentHandler is a handler for - :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact - with GCP BigTable. + """A handler for :class:`apache_beam.transforms.enrichment.Enrichment` + transform to interact with GCP BigTable. Args: project_id (str): GCP project-id of the BigTable cluster. @@ -161,3 +160,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.client = None self.instance = None self._table = None + + def get_cache_key(self, request: beam.Row) -> str: + """Returns a string formatted with row key since it is unique to + a request made to `Bigtable`.""" + return "%s: %s" % (self._row_key, request._asdict()[self._row_key]) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py index 86fc438960d3..b792bc8ba946 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable_it_test.py @@ -16,15 +16,18 @@ # import datetime +import logging import unittest from typing import Dict from typing import List from typing import NamedTuple from typing import Tuple +from unittest.mock import MagicMock import pytest import apache_beam as beam +from apache_beam.coders import coders from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import BeamAssertException @@ -33,11 +36,14 @@ from google.api_core.exceptions import NotFound from google.cloud.bigtable import Client from google.cloud.bigtable.row_filters import ColumnRangeFilter + from testcontainers.redis import RedisContainer from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler from apache_beam.transforms.enrichment_handlers.bigtable import ExceptionLevel except ImportError: - raise unittest.SkipTest('GCP BigTable dependencies are not installed.') + raise unittest.SkipTest('Bigtable test dependencies are not installed.') + +_LOGGER = logging.getLogger(__name__) class ValidateResponse(beam.DoFn): @@ -142,7 +148,7 @@ def create_rows(table): row.commit() -@pytest.mark.it_postcommit +@pytest.mark.uses_redis class TestBigTableEnrichment(unittest.TestCase): def setUp(self): self.project_id = 'apache-beam-testing' @@ -160,8 +166,25 @@ def setUp(self): instance = client.instance(self.instance_id) self.table = instance.table(self.table_id) create_rows(self.table) + self.retries = 3 + self._start_container() + + def _start_container(self): + for i in range(self.retries): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error('Unable to start redis container for RRIO tests.') + raise e def tearDown(self) -> None: + self.container.stop() self.table = None def test_enrichment_with_bigtable(self): @@ -336,6 +359,73 @@ def test_enrichment_with_bigtable_with_timestamp(self): expected_enriched_fields, include_timestamp=True))) + def test_bigtable_enrichment_with_redis(self): + """ + In this test, we run two pipelines back to back. + + In the first pipeline, we run a simple bigtable enrichment pipeline with + zero cache records. Therefore, it makes call to the Bigtable source and + ultimately writes to the cache with a TTL of 300 seconds. + + For the second pipeline, we mock the `BigTableEnrichmentHandler`'s + `__call__` method to always return a `None` response. However, this change + won't impact the second pipeline because the Enrichment transform first + checks the cache to fulfill requests. Since all requests are cached, it + will return from there without making calls to the Bigtable source. + """ + expected_fields = [ + 'sale_id', 'customer_id', 'product_id', 'quantity', 'product' + ] + expected_enriched_fields = { + 'product': ['product_name', 'product_stock'], + } + start_column = 'product_name'.encode() + column_filter = ColumnRangeFilter(self.column_family_id, start_column) + bigtable = BigTableEnrichmentHandler( + project_id=self.project_id, + instance_id=self.instance_id, + table_id=self.table_id, + row_key=self.row_key, + row_filter=column_filter) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create1" >> beam.Create(self.req) + | "Enrich W/ BigTable1" >> Enrichment(bigtable).with_redis_cache( + self.host, self.port, 300) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + + # manually check cache entry + c = coders.StrUtf8Coder() + for req in self.req: + key = bigtable.get_cache_key(req) + response = self.client.get(c.encode(key)) + if not response: + raise ValueError("No cache entry found for %s" % key) + + actual = BigTableEnrichmentHandler.__call__ + BigTableEnrichmentHandler.__call__ = MagicMock( + return_value=( + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), + beam.Row())) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create2" >> beam.Create(self.req) + | "Enrich W/ BigTable2" >> Enrichment(bigtable).with_redis_cache( + self.host, self.port) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + BigTableEnrichmentHandler.__call__ = actual + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_it_test.py b/sdks/python/apache_beam/transforms/enrichment_it_test.py index 89842cb18be0..4a45fae2e869 100644 --- a/sdks/python/apache_beam/transforms/enrichment_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_it_test.py @@ -109,7 +109,7 @@ def process(self, element: beam.Row, *args, **kwargs): raise BeamAssertException(f"Expected a not None field: {field}") -@pytest.mark.it_postcommit +@pytest.mark.uses_mock_api class TestEnrichment(unittest.TestCase): options: Union[EchoITOptions, None] = None client: Union[SampleHTTPEnrichment, None] = None diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini index 4ffbb4524c06..c95aa5974da7 100644 --- a/sdks/python/pytest.ini +++ b/sdks/python/pytest.ini @@ -64,6 +64,8 @@ markers = uses_tf: tests that utilize tensorflow. uses_transformers: tests that utilize transformers in some way. vertex_ai_postcommit: vertex ai postcommits that need additional deps. + uses_redis: enrichment transform tests that need redis. + uses_mock_api: tests that uses the mock API cluster. # Default timeout intended for unit tests. # If certain tests need a different value, please see the docs on how to diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index 82740ae67c9f..f5307e43f041 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -134,7 +134,7 @@ autodoc_member_order = 'bysource' autodoc_mock_imports = ["tensorrt", "cuda", "torch", "onnxruntime", "onnx", "tensorflow", "tensorflow_hub", "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", "datatable", "transformers", - "sentence_transformers", + "sentence_transformers", "redis", ] # Allow a special section for documenting DataFrame API @@ -204,6 +204,10 @@ ignore_identifiers = [ 'apache_beam.typehints.typehints.validate_composite_type_param()', 'apache_beam.utils.windowed_value._IntervalWindowBase', 'apache_beam.coders.coder_impl.StreamCoderImpl', + 'apache_beam.io.requestresponse.Caller', + 'apache_beam.io.requestresponse.Repeater', + 'apache_beam.io.requestresponse.PreCallThrottler', + 'apache_beam.io.requestresponse.Cache', # Private classes which are used within the same module 'apache_beam.transforms.external_test.PayloadBase', diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index 5fc1751a9686..d29d85af3011 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -487,6 +487,35 @@ task tftTests { } } +// Tests that depend on Mock API:https://github.com/apache/beam/tree/master/.test-infra/mock-apis. . +task mockAPITests { + dependsOn 'initializeForDataflowJob' + dependsOn ':sdks:python:sdist' + def requirementsFile = "${rootDir}/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt" + doFirst { + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile" + } + } + doLast { + def testOpts = basicTestOpts + def argMap = [ + "test_opts": testOpts, + "collect": "uses_mock_api", + "runner": "TestDataflowRunner", + "project": "apache-beam-testing", + "region": "us-west1", + "requirements_file": "$requirementsFile" + ] + def cmdArgs = mapToArgString(argMap) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs" + } + } +} + // add all RunInference E2E tests that run on DataflowRunner // As of now, this test suite is enable in py38 suite as the base NVIDIA image used for Tensor RT // contains Python 3.8. @@ -495,6 +524,7 @@ project.tasks.register("inferencePostCommitIT") { dependsOn = [ 'tensorRTtests', 'vertexAIInferenceTest', + 'mockAPITests', ] } diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index 771e0be19cfe..657f7adf801d 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -364,6 +364,33 @@ task transformersInferenceTest { } } +// Enrichment transform tests that uses Redis +task enrichmentRedisTest { + dependsOn 'installGcpTest' + dependsOn ':sdks:python:sdist' + def requirementsFile = "${rootDir}/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt" + doFirst { + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install -r $requirementsFile" + } + } + doLast { + def testOpts = basicTestOpts + def argMap = [ + "test_opts": testOpts, + "suite": "postCommitIT-direct-py${pythonVersionSuffix}", + "collect": "uses_redis", + "runner": "TestDirectRunner" + ] + def cmdArgs = mapToArgString(argMap) + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && ${runScriptsDir}/run_integration_test.sh $cmdArgs" + } + } +} + // Add all the RunInference framework IT tests to this gradle task that runs on Direct Runner Post commit suite. project.tasks.register("inferencePostCommitIT") { dependsOn = [ @@ -372,6 +399,7 @@ project.tasks.register("inferencePostCommitIT") { 'tensorflowInferenceTest', 'xgboostInferenceTest', 'transformersInferenceTest', + 'enrichmentRedisTest', // (TODO) https://github.com/apache/beam/issues/25799 // uncomment tfx bsl tests once tfx supports protobuf 4.x // 'tfxInferenceTest',