Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
rohdesamuel committed Sep 29, 2021
1 parent e10af1b commit 5cb9635
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/runners/interactive/pipeline_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ def _mark_necessary_transforms_and_pcolls(self, runner_pcolls_to_user_pcolls):
break
# Mark the AppliedPTransform as necessary.
necessary_transforms.add(producer)

# Also mark composites that are not the root transform. If the root
# transform is added, then all transforms are incorrectly marked as
# necessary. If composites are not handled, then there will be
# orphaned PCollections.
if producer.parent is not None:
necessary_transforms.update(producer.parts)

# This will recursively add all the PCollections in this composite.
for part in producer.parts:
updated_all_inputs.update(part.outputs.values())

# Record all necessary input and side input PCollections.
updated_all_inputs.update(producer.inputs)
# pylint: disable=map-builtin-not-iterating
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,41 @@ def test_fragment_does_not_prune_teststream(self):
# resulting graph is invalid and the following call will raise an exception.
fragment.to_runner_api()

@patch('IPython.get_ipython', new_callable=mock_get_ipython)
def test_pipeline_composites(self, cell):
"""Tests that composites are supported.
"""
with cell: # Cell 1
p = beam.Pipeline(ir.InteractiveRunner())
ib.watch({'p': p})

with cell: # Cell 2
# pylint: disable=range-builtin-not-iterating
init = p | 'Init' >> beam.Create(range(5))

with cell: # Cell 3
# Have a composite within a composite to test that all transforms under a
# composite are added.

@beam.ptransform_fn
def Bar(pcoll):
return pcoll | beam.Map(lambda n: n)

@beam.ptransform_fn
def Foo(pcoll):
p1 = pcoll | beam.Map(lambda n: n)
p2 = pcoll | beam.Map(str)
bar = pcoll | Bar()
return {'pc1': p1, 'pc2': p2, 'bar': bar}

res = init | Foo()

ib.watch(locals())
pc = res['pc1']

result = pf.PipelineFragment([pc]).run()
self.assertEqual([0, 1, 2, 3, 4], list(result.get(pc)))


if __name__ == '__main__':
unittest.main()

0 comments on commit 5cb9635

Please sign in to comment.