-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
enrichment.py
202 lines (175 loc) · 7.42 KB
/
enrichment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from datetime import timedelta
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import TypeVar
from typing import Union
import apache_beam as beam
from apache_beam.coders import coders
from apache_beam.io.requestresponse import DEFAULT_CACHE_ENTRY_TTL_SEC
from apache_beam.io.requestresponse import DEFAULT_TIMEOUT_SECS
from apache_beam.io.requestresponse import Caller
from apache_beam.io.requestresponse import DefaultThrottler
from apache_beam.io.requestresponse import ExponentialBackOffRepeater
from apache_beam.io.requestresponse import PreCallThrottler
from apache_beam.io.requestresponse import RedisCache
from apache_beam.io.requestresponse import Repeater
from apache_beam.io.requestresponse import RequestResponseIO
__all__ = [
"EnrichmentSourceHandler",
"Enrichment",
"cross_join",
]
InputT = TypeVar('InputT')
OutputT = TypeVar('OutputT')
JoinFn = Callable[[Dict[str, Any], Dict[str, Any]], beam.Row]
_LOGGER = logging.getLogger(__name__)
def has_valid_redis_address(host: str, port: int) -> bool:
"""returns `True` if both host and port are not `None`."""
if host and port:
return True
return False
def cross_join(left: Dict[str, Any], right: Dict[str, Any]) -> beam.Row:
"""performs a cross join on two `dict` objects.
Joins the columns of the right row onto the left row.
Args:
left (Dict[str, Any]): input request dictionary.
right (Dict[str, Any]): response dictionary from the API.
Returns:
`beam.Row` containing the merged columns.
"""
for k, v in right.items():
if k not in left:
# Don't override the values in left.
left[k] = v
elif left[k] != v:
_LOGGER.warning(
'%s exists in the input row as well the row fetched '
'from API but have different values - %s and %s. Using the input '
'value (%s) for the enriched row. You can override this behavior by '
'passing a custom `join_fn` to Enrichment transform.' %
(k, left[k], v, left[k]))
return beam.Row(**left)
class EnrichmentSourceHandler(Caller[InputT, OutputT]):
"""Wrapper class for `apache_beam.io.requestresponse.Caller`.
Ensure that the implementation of ``__call__`` method returns a tuple
of `beam.Row` objects.
"""
def get_cache_key(self, request: InputT) -> str:
"""Returns the request to be cached. This is how the response will be
looked up in the cache as well.
Implement this method to provide the key for the cache.
By default, the entire request is stored as the cache key.
For example, in `BigTableEnrichmentHandler`, the row key for the element
is returned here.
"""
return "request: %s" % request
class Enrichment(beam.PTransform[beam.PCollection[InputT],
beam.PCollection[OutputT]]):
"""A :class:`apache_beam.transforms.enrichment.Enrichment` transform to
enrich elements in a PCollection.
Uses the :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
to enrich elements by joining the metadata from external source.
Processes an input :class:`~apache_beam.pvalue.PCollection` of `beam.Row` by
applying a :class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
to each element and returning the enriched
:class:`~apache_beam.pvalue.PCollection`.
Args:
source_handler: Handles source lookup and metadata retrieval.
Implements the
:class:`apache_beam.transforms.enrichment.EnrichmentSourceHandler`
join_fn: A lambda function to join original element with lookup metadata.
Defaults to `CROSS_JOIN`.
timeout: (Optional) timeout for source requests. Defaults to 30 seconds.
repeater: provides method to repeat failed requests to API due to service
errors. Defaults to
:class:`apache_beam.io.requestresponse.ExponentialBackOffRepeater` to
repeat requests with exponential backoff.
throttler: provides methods to pre-throttle a request. Defaults to
:class:`apache_beam.io.requestresponse.DefaultThrottler` for
client-side adaptive throttling using
:class:`apache_beam.io.components.adaptive_throttler.AdaptiveThrottler`.
"""
def __init__(
self,
source_handler: EnrichmentSourceHandler,
join_fn: JoinFn = cross_join,
timeout: Optional[float] = DEFAULT_TIMEOUT_SECS,
repeater: Repeater = ExponentialBackOffRepeater(),
throttler: PreCallThrottler = DefaultThrottler()):
self._cache = None
self._source_handler = source_handler
self._join_fn = join_fn
self._timeout = timeout
self._repeater = repeater
self._throttler = throttler
def expand(self,
input_row: beam.PCollection[InputT]) -> beam.PCollection[OutputT]:
# For caching with enrichment transform, enrichment handlers provide a
# get_cache_key() method that returns a unique string formatted
# request for that row.
request_coder = coders.StrUtf8Coder()
if self._cache:
self._cache.request_coder = request_coder
fetched_data = input_row | RequestResponseIO(
caller=self._source_handler,
timeout=self._timeout,
repeater=self._repeater,
cache=self._cache,
throttler=self._throttler)
# EnrichmentSourceHandler returns a tuple of (request,response).
return (
fetched_data
| "enrichment_join" >>
beam.Map(lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict())))
def with_redis_cache(
self,
host: str,
port: int,
time_to_live: Union[int, timedelta] = DEFAULT_CACHE_ENTRY_TTL_SEC,
*,
request_coder: Optional[coders.Coder] = None,
response_coder: Optional[coders.Coder] = None,
**kwargs,
):
"""Configure the Redis cache to use with enrichment transform.
Args:
host (str): The hostname or IP address of the Redis server.
port (int): The port number of the Redis server.
time_to_live: `(Union[int, timedelta])` The time-to-live (TTL) for
records stored in Redis. Provide an integer (in seconds) or a
`datetime.timedelta` object.
request_coder: (Optional[`coders.Coder`]) coder for requests stored
in Redis.
response_coder: (Optional[`coders.Coder`]) coder for decoding responses
received from Redis.
kwargs: Optional additional keyword arguments that
are required to connect to your redis server. Same as `redis.Redis()`.
"""
if has_valid_redis_address(host, port):
self._cache = RedisCache( # type: ignore[assignment]
host=host,
port=port,
time_to_live=time_to_live,
request_coder=request_coder,
response_coder=response_coder,
**kwargs)
return self