From d4b1dab49f8ccc1b27048d4dbeb3cfc2ac63ffd1 Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Fri, 27 Jan 2023 10:53:49 +0100 Subject: [PATCH] [Spark Dataset runner] Avoid copying outputs for most cases in ParDo translation (related to #24711) --- .../batch/ParDoTranslatorBatch.java | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java index 40d26be8a8b1..f6d1c8ff85ed 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -49,6 +50,7 @@ import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.rdd.RDD; @@ -122,12 +124,10 @@ public void translate(ParDo.MultiOutput transform, Context cxt) MetricsAccumulator metrics = MetricsAccumulator.getInstance(cxt.getSparkSession()); TupleTag mainOut = transform.getMainOutputTag(); - // Filter out unconsumed PCollections (except mainOut) to potentially avoid the costs of caching - // if not really beneficial. + + // Filter out obsolete PCollections to only cache when absolutely necessary Map, PCollection> outputs = - Maps.filterEntries( - cxt.getOutputs(), - e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeaf(e.getValue()))); + skipObsoleteOutputs(cxt.getOutputs(), mainOut, transform.getAdditionalOutputTags(), cxt); if (outputs.size() > 1) { // In case of multiple outputs / tags, map each tag to a column by index. @@ -202,6 +202,36 @@ public void translate(ParDo.MultiOutput transform, Context cxt) } } + /** + * Filter out obsolete, unused output tags except for {@code mainTag}. + * + *

This can help to avoid unnecessary caching in case of multiple outputs if only {@code + * mainTag} is consumed. + */ + private Map, PCollection> skipObsoleteOutputs( + Map, PCollection> outputs, + TupleTag mainTag, + TupleTagList otherTags, + Context cxt) { + switch (outputs.size()) { + case 1: + return outputs; // always keep main output + case 2: + TupleTag otherTag = otherTags.get(0); + return cxt.isLeaf(checkStateNotNull(outputs.get(otherTag))) + ? Collections.singletonMap(mainTag, checkStateNotNull(outputs.get(mainTag))) + : outputs; + default: + Map, PCollection> filtered = Maps.newHashMapWithExpectedSize(outputs.size()); + for (Map.Entry, PCollection> e : outputs.entrySet()) { + if (e.getKey().equals(mainTag) || !cxt.isLeaf(e.getValue())) { + filtered.put(e.getKey(), e.getValue()); + } + } + return filtered; + } + } + static Fun1, TraversableOnce> selectByColumnIdx(int idx) { return t -> idx == t._1 ? listOf(t._2) : emptyList(); }