-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathhelper_transforms.py
120 lines (99 loc) · 4.01 KB
/
helper_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import collections
import itertools
import typing
import apache_beam as beam
from apache_beam import typehints
from apache_beam.internal.util import ArgumentPlaceholder
from apache_beam.transforms.combiners import _CurriedFn
from apache_beam.utils.windowed_value import WindowedValue
class LiftedCombinePerKey(beam.PTransform):
"""An implementation of CombinePerKey that does mapper-side pre-combining.
"""
def __init__(self, combine_fn, args, kwargs):
args_to_check = itertools.chain(args, kwargs.values())
if isinstance(combine_fn, _CurriedFn):
args_to_check = itertools.chain(
args_to_check, combine_fn.args, combine_fn.kwargs.values())
if any(isinstance(arg, ArgumentPlaceholder) for arg in args_to_check):
# This isn't implemented in dataflow either...
raise NotImplementedError('Deferred CombineFn side inputs.')
self._combine_fn = beam.transforms.combiners.curry_combine_fn(
combine_fn, args, kwargs)
def expand(self, pcoll):
return (
pcoll
| beam.ParDo(PartialGroupByKeyCombiningValues(self._combine_fn))
| beam.GroupByKey()
| beam.ParDo(FinishCombine(self._combine_fn)))
class PartialGroupByKeyCombiningValues(beam.DoFn):
"""Aggregates values into a per-key-window cache.
As bundles are in-memory-sized, we don't bother flushing until the very end.
"""
def __init__(self, combine_fn):
self._combine_fn = combine_fn
def setup(self):
self._combine_fn.setup()
def start_bundle(self):
self._cache = collections.defaultdict(self._combine_fn.create_accumulator)
def process(self, element, window=beam.DoFn.WindowParam):
k, vi = element
self._cache[k, window] = self._combine_fn.add_input(
self._cache[k, window], vi)
def finish_bundle(self):
for (k, w), va in self._cache.items():
# We compact the accumulator since a GBK (which necessitates encoding)
# will follow.
yield WindowedValue((k, self._combine_fn.compact(va)), w.end, (w, ))
def teardown(self):
self._combine_fn.teardown()
def default_type_hints(self):
hints = self._combine_fn.get_type_hints()
K = typehints.TypeVariable('K')
if hints.input_types:
args, kwargs = hints.input_types
args = (typehints.Tuple[K, args[0]], ) + args[1:]
hints = hints.with_input_types(*args, **kwargs)
else:
hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
hints = hints.with_output_types(typehints.Tuple[K, typing.Any])
return hints
class FinishCombine(beam.DoFn):
"""Merges partially combined results.
"""
def __init__(self, combine_fn):
self._combine_fn = combine_fn
def setup(self):
self._combine_fn.setup()
def process(self, element):
k, vs = element
return [(
k,
self._combine_fn.extract_output(
self._combine_fn.merge_accumulators(vs)))]
def teardown(self):
self._combine_fn.teardown()
def default_type_hints(self):
hints = self._combine_fn.get_type_hints()
K = typehints.TypeVariable('K')
hints = hints.with_input_types(typehints.Tuple[K, typing.Any])
if hints.output_types:
main_output_type = hints.simple_output_type('')
hints = hints.with_output_types(typehints.Tuple[K, main_output_type])
return hints