Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

poc: python translation phase that replaces GBK+CombineValue pairs with CombinePerKey #32592

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import copy
import functools
import itertools
import json
import logging
import operator
from builtins import object
Expand Down Expand Up @@ -71,6 +72,7 @@
KNOWN_COMPOSITES = frozenset([
common_urns.primitives.GROUP_BY_KEY.urn,
common_urns.composites.COMBINE_PER_KEY.urn,
common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
common_urns.primitives.PAR_DO.urn, # After SDF expansion.
])

Expand Down Expand Up @@ -614,6 +616,7 @@ def parents_map(self):
}



def leaf_transform_stages(
root_ids, # type: Iterable[str]
components, # type: beam_runner_api_pb2.Components
Expand Down Expand Up @@ -795,6 +798,7 @@ def standard_optimize_phases():
annotate_downstream_side_inputs,
annotate_stateful_dofns_as_roots,
fix_side_input_pcoll_coders,
replace_gbk_combinevalue_pairs,
pack_combiners,
lift_combiners,
expand_sdf,
Expand Down Expand Up @@ -1024,6 +1028,87 @@ def get_stage_key(stage):
_DEFAULT_PACK_COMBINERS_LIMIT = 128


def replace_gbk_combinevalue_pairs(stages, context):
# type: (Iterable[Stage], TransformContext) -> Iterator[Stage]

"""
Replaces GroupByKey + CombineValues pairs into CombinePerKeys

This replacement is only done if the GBK's output pcollection
is _only_ consumed by CombineValue(s). If the GBK's output
pcollection is consumed by any other transform, the GBK is
not replaced.
"""
# First record the producers and consumers of each PCollection.
producers_by_pcoll = {} # type: Dict[str, Stage]
consumers_by_pcoll = collections.defaultdict(
list) # type: DefaultDict[str, List[Stage]]

for stage in stages:
for transform in stage.transforms:
for input in transform.inputs.values():
consumers_by_pcoll[input].append(stage)
for output in transform.outputs.values():
producers_by_pcoll[output] = stage

processed_stages_by_name = set()
for stage in stages:
transform = only_element(stage.transforms)
if transform.unique_name in processed_stages_by_name:
continue
if transform.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
consumer_transforms = [
only_transform(consumer.transforms)
for consumer in consumers_by_pcoll[only_element(
transform.outputs.values())]
]
if not consumer_transforms:
yield stage
continue
if not all(consumer.spec.urn ==
common_urns.combine_components.COMBINE_GROUPED_VALUES.urn
for consumer in consumer_transforms):
yield stage
continue
for consumer in consumer_transforms:
# Replace GroupByKey + CombineValues with CombinePerKey.
# The name of the new merged stage is the GBK stage name joined with
# the CombineValues stage name, e.g. "GBK+CombineValues"
def label(transform):
if transform.unique_name == '':
return ''
try:
return transform.unique_name.rsplit('/', 1)[1]
except IndexError:
return transform.unique_name

name = '%s+%s' % (label(transform), label(consumer))
unique_name = consumer.unique_name.rsplit('/', 1)[0] + '/' + name
encoded_source_xforms_anno = json.dumps(
[transform.unique_name, consumer.unique_name]).encode('utf-8')
stage = Stage(
unique_name,
[
beam_runner_api_pb2.PTransform(
unique_name=unique_name,
inputs={'input': only_element(transform.inputs.values())},
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.composites.COMBINE_PER_KEY.urn),
annotations={
'pretranslated_xforms': encoded_source_xforms_anno
},
environment_id=transform.environment_id),
],
downstream_side_inputs=frozenset(),
must_follow=stage.must_follow)
stage.transforms[0].outputs.MergeFrom(consumer.outputs)
processed_stages_by_name.add(consumer.unique_name)
yield stage
processed_stages_by_name.add(transform.unique_name)
else:
yield stage


def pack_per_key_combiners(stages, context, can_pack=lambda s: True):
# type: (Iterable[Stage], TransformContext, Callable[[str], Union[bool, int]]) -> Iterator[Stage]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
# pytype: skip-file

import json
import logging
import unittest

Expand Down Expand Up @@ -57,6 +57,27 @@ def expand(self, pcoll):
self.assertEqual(len(key_with_none_stages), 1)
self.assertIn('multiple-key-with-none', key_with_none_stages[0].parent)

def test_replace_gbk_combinevalue_pairs(self):

pipeline = beam.Pipeline()
kvs = [('a', 1), ('a', 2), ('b', 3), ('b', 4)]
_ = pipeline | Create(
kvs, reshuffle=False
) | "MyGBK" >> beam.GroupByKey() | "MyCombiner" >> beam.CombineValues(sum)
pipeline_proto = pipeline.to_runner_api()
_, stages = translations.create_and_optimize_stages(
pipeline_proto, [translations.replace_gbk_combinevalue_pairs],
known_runner_urns=frozenset())
combine_per_key_stages = []
for stage in stages:
for transform in stage.transforms:
if transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn:
combine_per_key_stages.append(stage)
self.assertEqual(len(combine_per_key_stages), 1)
stage = combine_per_key_stages[0]
anno = stage.transforms[0].annotations["pretranslated_xforms"]
assert json.loads(anno.decode('utf-8')) == ["MyGBK", "MyCombiner"]

def test_pack_combiners(self):
class MultipleCombines(beam.PTransform):
def annotations(self):
Expand Down
Loading