Skip to content

Commit

Permalink
[Spark Dataset runner] Avoid copying outputs for most cases in ParDo …
Browse files Browse the repository at this point in the history
…translation (related to #24711) (#25624)
  • Loading branch information
Moritz Mack authored Mar 6, 2023
1 parent 4da6025 commit 5a2b93d
Showing 1 changed file with 35 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -48,6 +49,7 @@
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Dataset;
Expand Down Expand Up @@ -113,12 +115,10 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
MetricsAccumulator metrics = MetricsAccumulator.getInstance(cxt.getSparkSession());

TupleTag<OutputT> mainOut = transform.getMainOutputTag();
// Filter out unconsumed PCollections (except mainOut) to potentially avoid the costs of caching
// if not really beneficial.

// Filter out obsolete PCollections to only cache when absolutely necessary
Map<TupleTag<?>, PCollection<?>> outputs =
Maps.filterEntries(
cxt.getOutputs(),
e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeaf(e.getValue())));
skipObsoleteOutputs(cxt.getOutputs(), mainOut, transform.getAdditionalOutputTags(), cxt);

if (outputs.size() > 1) {
// In case of multiple outputs / tags, map each tag to a column by index.
Expand Down Expand Up @@ -169,6 +169,36 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
}
}

/**
* Filter out obsolete, unused output tags except for {@code mainTag}.
*
* <p>This can help to avoid unnecessary caching in case of multiple outputs if only {@code
* mainTag} is consumed.
*/
private Map<TupleTag<?>, PCollection<?>> skipObsoleteOutputs(
Map<TupleTag<?>, PCollection<?>> outputs,
TupleTag<?> mainTag,
TupleTagList otherTags,
Context cxt) {
switch (outputs.size()) {
case 1:
return outputs; // always keep main output
case 2:
TupleTag<?> otherTag = otherTags.get(0);
return cxt.isLeaf(checkStateNotNull(outputs.get(otherTag)))
? Collections.singletonMap(mainTag, checkStateNotNull(outputs.get(mainTag)))
: outputs;
default:
Map<TupleTag<?>, PCollection<?>> filtered = Maps.newHashMapWithExpectedSize(outputs.size());
for (Map.Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) {
if (e.getKey().equals(mainTag) || !cxt.isLeaf(e.getValue())) {
filtered.put(e.getKey(), e.getValue());
}
}
return filtered;
}
}

static <T> Fun1<Tuple2<Integer, T>, TraversableOnce<T>> selectByColumnIdx(int idx) {
return t -> idx == t._1 ? listOf(t._2) : emptyList();
}
Expand Down

0 comments on commit 5a2b93d

Please sign in to comment.