Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark Dataset runner] Avoid copying outputs for most cases in ParDo translation #25624

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -49,6 +50,7 @@
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
Expand Down Expand Up @@ -122,12 +124,10 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
MetricsAccumulator metrics = MetricsAccumulator.getInstance(cxt.getSparkSession());

TupleTag<OutputT> mainOut = transform.getMainOutputTag();
// Filter out unconsumed PCollections (except mainOut) to potentially avoid the costs of caching
// if not really beneficial.

// Filter out obsolete PCollections to only cache when absolutely necessary
Map<TupleTag<?>, PCollection<?>> outputs =
Maps.filterEntries(
cxt.getOutputs(),
e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeaf(e.getValue())));
skipObsoleteOutputs(cxt.getOutputs(), mainOut, transform.getAdditionalOutputTags(), cxt);

if (outputs.size() > 1) {
// In case of multiple outputs / tags, map each tag to a column by index.
Expand Down Expand Up @@ -202,6 +202,36 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
}
}

/**
* Filter out obsolete, unused output tags except for {@code mainTag}.
*
* <p>This can help to avoid unnecessary caching in case of multiple outputs if only {@code
* mainTag} is consumed.
*/
private Map<TupleTag<?>, PCollection<?>> skipObsoleteOutputs(
Map<TupleTag<?>, PCollection<?>> outputs,
TupleTag<?> mainTag,
TupleTagList otherTags,
Context cxt) {
switch (outputs.size()) {
case 1:
return outputs; // always keep main output
case 2:
TupleTag<?> otherTag = otherTags.get(0);
return cxt.isLeaf(checkStateNotNull(outputs.get(otherTag)))
? Collections.singletonMap(mainTag, checkStateNotNull(outputs.get(mainTag)))
: outputs;
default:
Map<TupleTag<?>, PCollection<?>> filtered = Maps.newHashMapWithExpectedSize(outputs.size());
for (Map.Entry<TupleTag<?>, PCollection<?>> e : outputs.entrySet()) {
if (e.getKey().equals(mainTag) || !cxt.isLeaf(e.getValue())) {
filtered.put(e.getKey(), e.getValue());
}
}
return filtered;
}
}

static <T> Fun1<Tuple2<Integer, T>, TraversableOnce<T>> selectByColumnIdx(int idx) {
return t -> idx == t._1 ? listOf(t._2) : emptyList();
}
Expand Down