From 5cb9635edbfffb4661bfd806506abd64a40f3645 Mon Sep 17 00:00:00 2001 From: Sam R Date: Wed, 29 Sep 2021 13:01:03 -0700 Subject: [PATCH] Fix BEAM-12984 --- .../runners/interactive/pipeline_fragment.py | 12 +++++++ .../interactive/pipeline_fragment_test.py | 35 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py index 7564a765ad65..84fdc9ad24fa 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment.py @@ -190,6 +190,18 @@ def _mark_necessary_transforms_and_pcolls(self, runner_pcolls_to_user_pcolls): break # Mark the AppliedPTransform as necessary. necessary_transforms.add(producer) + + # Also mark composites that are not the root transform. If the root + # transform is added, then all transforms are incorrectly marked as + # necessary. If composites are not handled, then there will be + # orphaned PCollections. + if producer.parent is not None: + necessary_transforms.update(producer.parts) + + # This will recursively add all the PCollections in this composite. + for part in producer.parts: + updated_all_inputs.update(part.outputs.values()) + # Record all necessary input and side input PCollections. updated_all_inputs.update(producer.inputs) # pylint: disable=map-builtin-not-iterating diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py index 6e9d327e3021..c6d28f6e7189 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_fragment_test.py @@ -129,6 +129,41 @@ def test_fragment_does_not_prune_teststream(self): # resulting graph is invalid and the following call will raise an exception. fragment.to_runner_api() + @patch('IPython.get_ipython', new_callable=mock_get_ipython) + def test_pipeline_composites(self, cell): + """Tests that composites are supported. + """ + with cell: # Cell 1 + p = beam.Pipeline(ir.InteractiveRunner()) + ib.watch({'p': p}) + + with cell: # Cell 2 + # pylint: disable=range-builtin-not-iterating + init = p | 'Init' >> beam.Create(range(5)) + + with cell: # Cell 3 + # Have a composite within a composite to test that all transforms under a + # composite are added. + + @beam.ptransform_fn + def Bar(pcoll): + return pcoll | beam.Map(lambda n: n) + + @beam.ptransform_fn + def Foo(pcoll): + p1 = pcoll | beam.Map(lambda n: n) + p2 = pcoll | beam.Map(str) + bar = pcoll | Bar() + return {'pc1': p1, 'pc2': p2, 'bar': bar} + + res = init | Foo() + + ib.watch(locals()) + pc = res['pc1'] + + result = pf.PipelineFragment([pc]).run() + self.assertEqual([0, 1, 2, 3, 4], list(result.get(pc))) + if __name__ == '__main__': unittest.main()