diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java index 8d751d5d8173..05f542702f19 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java @@ -152,6 +152,8 @@ public String name() { public interface TranslationState extends EncoderProvider { Dataset> getDataset(PCollection pCollection); + boolean isLeave(PCollection pCollection); + void putDataset( PCollection pCollection, Dataset> dataset, boolean cache); @@ -256,6 +258,11 @@ public void putDataset( } } + @Override + public boolean isLeave(PCollection pCollection) { + return getResult(pCollection).dependentTransforms.isEmpty(); + } + @Override public Broadcast> getSideInputBroadcast( PCollection pCollection, SideInputValues.Loader loader) { diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java index 8a3c7579f541..e0bbb2af820e 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java @@ -149,6 +149,11 @@ public void putDataset( state.putDataset(pCollection, dataset, cache); } + @Override + public boolean isLeave(PCollection pCollection) { + return state.isLeave(pCollection); + } + @Override public Supplier getOptionsSupplier() { return state.getOptionsSupplier(); diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java index c760efd229c8..64a4f591ff74 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java @@ -19,7 +19,6 @@ import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator; import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; -import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import java.io.Serializable; import java.util.ArrayDeque; @@ -61,17 +60,17 @@ */ abstract class DoFnPartitionIteratorFactory implements Function1>, Iterator>, Serializable { - private final String stepName; - private final DoFn doFn; - private final DoFnSchemaInformation doFnSchema; - private final Supplier options; - private final Coder coder; - private final WindowingStrategy windowingStrategy; - private final TupleTag mainOutput; - private final List> additionalOutputs; - private final Map, Coder> outputCoders; - private final Map> sideInputs; - private final SideInputReader sideInputReader; + protected final String stepName; + protected final DoFn doFn; + protected final DoFnSchemaInformation doFnSchema; + protected final Supplier options; + protected final Coder coder; + protected final WindowingStrategy windowingStrategy; + protected final TupleTag mainOutput; + protected final List> additionalOutputs; + protected final Map, Coder> outputCoders; + protected final Map> sideInputs; + protected final SideInputReader sideInputReader; private DoFnPartitionIteratorFactory( AppliedPTransform, ?, MultiOutput> appliedPT, @@ -147,7 +146,11 @@ DoFnRunners.OutputManager outputManager(Deque> buffer) { return new DoFnRunners.OutputManager() { @Override public void output(TupleTag tag, WindowedValue output) { - buffer.add((WindowedValue) output); + // SingleOut will only ever emmit the mainOutput. Though, there might be additional + // outputs which are skipped if unused to avoid caching. + if (mainOutput.equals(tag)) { + buffer.add((WindowedValue) output); + } } }; } @@ -177,8 +180,11 @@ DoFnRunners.OutputManager outputManager(Deque void output(TupleTag tag, WindowedValue output) { - Integer columnIdx = checkStateNotNull(tagColIdx.get(tag.getId()), "Unknown tag %s", tag); - buffer.add(tuple(columnIdx, (WindowedValue) output)); + // Additional unused outputs can be skipped here. In that case columnIdx is null. + Integer columnIdx = tagColIdx.get(tag.getId()); + if (columnIdx != null) { + buffer.add(tuple(columnIdx, (WindowedValue) output)); + } } }; } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java index 3083ff5101b9..4d545e438133 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java @@ -110,12 +110,19 @@ public void translate(ParDo.MultiOutput transform, Context cxt) throws IOException { PCollection input = (PCollection) cxt.getInput(); - Map, PCollection> outputs = cxt.getOutputs(); Dataset> inputDs = cxt.getDataset(input); SideInputReader sideInputReader = createSideInputReader(transform.getSideInputs().values(), cxt); + TupleTag mainOut = transform.getMainOutputTag(); + // Filter out unconsumed PCollections (except mainOut) to potentially avoid the costs of caching + // if not really beneficial. + Map, PCollection> outputs = + Maps.filterEntries( + cxt.getOutputs(), + e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeave(e.getValue()))); + if (outputs.size() > 1) { // In case of multiple outputs / tags, map each tag to a column by index. // At the end split the result into multiple datasets selecting one column each. @@ -176,7 +183,7 @@ public void translate(ParDo.MultiOutput transform, Context cxt) } } } else { - PCollection output = cxt.getOutput(transform.getMainOutputTag()); + PCollection output = cxt.getOutput(mainOut); DoFnPartitionIteratorFactory> doFnMapper = DoFnPartitionIteratorFactory.singleOutput( cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, sideInputReader); diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java index 33eef26dddda..278fd012d77e 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/SparkSessionRule.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.spark.structuredstreaming; import static java.util.stream.Collectors.toMap; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; import java.io.Serializable; import java.util.Arrays; @@ -29,6 +30,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.spark.sql.SparkSession; import org.junit.rules.ExternalResource; +import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; @@ -69,6 +71,24 @@ public PipelineOptions configure(PipelineOptions options) { return opts; } + /** {@code true} if sessions contains cached Datasets or RDDs. */ + public boolean hasCachedData() { + return !session.sharedState().cacheManager().isEmpty() + || !session.sparkContext().getPersistentRDDs().isEmpty(); + } + + public TestRule clearCache() { + return new ExternalResource() { + @Override + protected void after() { + // clear cached datasets + session.sharedState().cacheManager().clearCache(); + // clear cached RDDs + session.sparkContext().getPersistentRDDs().foreach(fun1(t -> t._2.unpersist(true))); + } + }; + } + @Override public Statement apply(Statement base, Description description) { builder.appName(description.getDisplayName()); diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java index f319173ed2bb..672a2db4fe1e 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java @@ -17,14 +17,13 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static org.junit.Assert.assertTrue; + import java.io.Serializable; import java.util.List; import java.util.Map; import org.apache.beam.runners.spark.SparkCommonPipelineOptions; -import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; -import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -37,23 +36,23 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TestRule; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Test class for beam to spark {@link ParDo} translation. */ @RunWith(JUnit4.class) public class ParDoTest implements Serializable { - @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - - private static PipelineOptions testOptions() { - SparkStructuredStreamingPipelineOptions options = - PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); - options.setRunner(SparkStructuredStreamingRunner.class); - options.setTestMode(true); - return options; - } + @ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule(); + + @Rule + public transient TestPipeline pipeline = + TestPipeline.fromOptions(SESSION.createPipelineOptions()); + + @Rule public transient TestRule clearCache = SESSION.clearCache(); @Test public void testPardo() { @@ -61,32 +60,42 @@ public void testPardo() { pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(ParDo.of(PLUS_ONE_DOFN)); PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); } @Test public void testPardoWithOutputTagsCachedRDD() { - pardoWithOutputTags("MEMORY_ONLY"); + pardoWithOutputTags("MEMORY_ONLY", true); + assertTrue("Expected cached data", SESSION.hasCachedData()); } @Test public void testPardoWithOutputTagsCachedDataset() { - pardoWithOutputTags("MEMORY_AND_DISK"); + pardoWithOutputTags("MEMORY_AND_DISK", true); + assertTrue("Expected cached data", SESSION.hasCachedData()); + } + + @Test + public void testPardoWithUnusedOutputTags() { + pardoWithOutputTags("MEMORY_AND_DISK", false); + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); } - private void pardoWithOutputTags(String storageLevel) { + private void pardoWithOutputTags(String storageLevel, boolean evaluateAdditionalOutputs) { pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel); - TupleTag even = new TupleTag() {}; - TupleTag unevenAsString = new TupleTag() {}; + TupleTag mainTag = new TupleTag() {}; + TupleTag additionalUnevenTag = new TupleTag() {}; DoFn doFn = new DoFn() { @ProcessElement public void processElement(@Element Integer i, MultiOutputReceiver out) { if (i % 2 == 0) { - out.get(even).output(i); + out.get(mainTag).output(i); } else { - out.get(unevenAsString).output(i.toString()); + out.get(additionalUnevenTag).output(i.toString()); } } }; @@ -94,10 +103,12 @@ public void processElement(@Element Integer i, MultiOutputReceiver out) { PCollectionTuple outputs = pipeline .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) - .apply(ParDo.of(doFn).withOutputTags(even, TupleTagList.of(unevenAsString))); + .apply(ParDo.of(doFn).withOutputTags(mainTag, TupleTagList.of(additionalUnevenTag))); - PAssert.that(outputs.get(even)).containsInAnyOrder(2, 4, 6, 8, 10); - PAssert.that(outputs.get(unevenAsString)).containsInAnyOrder("1", "3", "5", "7", "9"); + PAssert.that(outputs.get(mainTag)).containsInAnyOrder(2, 4, 6, 8, 10); + if (evaluateAdditionalOutputs) { + PAssert.that(outputs.get(additionalUnevenTag)).containsInAnyOrder("1", "3", "5", "7", "9"); + } pipeline.run(); } @@ -106,10 +117,12 @@ public void testTwoPardoInRow() { PCollection input = pipeline .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) - .apply(ParDo.of(PLUS_ONE_DOFN)) - .apply(ParDo.of(PLUS_ONE_DOFN)); + .apply("Plus 1 (1st)", ParDo.of(PLUS_ONE_DOFN)) + .apply("Plus 1 (2nd)", ParDo.of(PLUS_ONE_DOFN)); PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10, 11, 12); pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); } @Test @@ -133,6 +146,8 @@ public void processElement(ProcessContext c) { .withSideInputs(sideInputView)); PAssert.that(input).containsInAnyOrder(4, 5, 6, 7, 8, 9, 10); pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); } @Test @@ -158,6 +173,8 @@ public void processElement(ProcessContext c) { PAssert.that(input).containsInAnyOrder(2, 3, 4, 5, 6, 7, 8, 9, 10); pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); } @Test @@ -183,6 +200,8 @@ public void processElement(ProcessContext c) { .withSideInputs(sideInputView)); PAssert.that(input).containsInAnyOrder(3, 4, 5, 6, 7, 8, 9, 10); pipeline.run(); + + assertTrue("No usage of cache expected", !SESSION.hasCachedData()); } private static final DoFn PLUS_ONE_DOFN =