Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark Dataset runner] Break linage of dataset to reduce Spark planning overhead in case of large query plans #25187

Merged
merged 1 commit into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ interface NamedDataset<T> {
Dataset<WindowedValue<T>> dataset();
}

private final Collection<? extends NamedDataset<?>> leaveDatasets;
private final Collection<? extends NamedDataset<?>> leaves;
private final SparkSession session;

EvaluationContext(Collection<? extends NamedDataset<?>> leaveDatasets, SparkSession session) {
this.leaveDatasets = leaveDatasets;
EvaluationContext(Collection<? extends NamedDataset<?>> leaves, SparkSession session) {
this.leaves = leaves;
this.session = session;
}

/** Trigger evaluation of all leave datasets. */
/** Trigger evaluation of all leaf datasets. */
public void evaluate() {
for (NamedDataset<?> ds : leaveDatasets) {
for (NamedDataset<?> ds : leaves) {
final Dataset<?> dataset = ds.dataset();
if (dataset == null) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
Expand Down Expand Up @@ -80,6 +83,12 @@
public abstract class PipelineTranslator {
private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);

// Threshold to limit query plan complexity to avoid unnecessary planning overhead. Currently this
// is fairly low, Catalyst won't be able to optimize beyond ParDos anyways. Until there's
// dedicated support for schema transforms, there's little value of allowing more complex plans at
// this point.
private static final int PLAN_COMPLEXITY_THRESHOLD = 6;
aromanenko-dev marked this conversation as resolved.
Show resolved Hide resolved

public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
}
Expand Down Expand Up @@ -129,12 +138,22 @@ public EvaluationContext translate(
*/
private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
private final String name;
private final float complexityFactor;
aromanenko-dev marked this conversation as resolved.
Show resolved Hide resolved
private float planComplexity = 0;

private @MonotonicNonNull Dataset<WindowedValue<T>> dataset = null;
private @MonotonicNonNull Broadcast<SideInputValues<T>> sideInputBroadcast = null;

// dependent downstream transforms (if empty this is a leaf)
private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
// upstream dependencies (requires inputs)
aromanenko-dev marked this conversation as resolved.
Show resolved Hide resolved
private final List<TranslationResult<?>> dependencies;

private TranslationResult(PCollection<?> pCol) {
private TranslationResult(
PCollection<?> pCol, float complexityFactor, List<TranslationResult<?>> dependencies) {
this.name = pCol.getName();
this.complexityFactor = complexityFactor;
this.dependencies = dependencies;
}

@Override
Expand All @@ -146,13 +165,37 @@ public String name() {
public @Nullable Dataset<WindowedValue<T>> dataset() {
return dataset;
}

private boolean isLeaf() {
return dependentTransforms.isEmpty();
}

private int usages() {
return dependentTransforms.size();
}

private void resetPlanComplexity() {
planComplexity = 1;
}

/** Estimate complexity of query plan by multiplying complexities of all dependencies. */
private float estimatePlanComplexity() {
if (planComplexity > 0) {
return planComplexity;
}
float complexity = 1 + complexityFactor;
for (TranslationResult<?> result : dependencies) {
complexity *= result.estimatePlanComplexity();
}
return (planComplexity = complexity);
}
}

/** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
public interface TranslationState extends EncoderProvider {
<T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);

boolean isLeave(PCollection<?> pCollection);
boolean isLeaf(PCollection<?> pCollection);

<T> void putDataset(
PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean cache);
Expand Down Expand Up @@ -188,6 +231,7 @@ private class TranslatingVisitor extends PTransformVisitor implements Translatio
private final PipelineOptions options;
private final Supplier<PipelineOptions> optionsSupplier;
private final StorageLevel storageLevel;
private final boolean isMemoryOnly;

private final Set<TranslationResult<?>> leaves;

Expand All @@ -200,6 +244,7 @@ public TranslatingVisitor(
this.options = options;
this.optionsSupplier = new BroadcastOptions(sparkSession, options);
this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
this.isMemoryOnly = storageLevel.equals(MEMORY_ONLY());
this.encoders = new HashMap<>();
this.leaves = new HashSet<>();
}
Expand Down Expand Up @@ -247,20 +292,37 @@ public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
public <T> void putDataset(
PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean cache) {
TranslationResult<T> result = getResult(pCollection);
if (cache && result.dependentTransforms.size() > 1) {
LOG.info("Dataset {} will be cached.", result.name);
result.dataset = dataset.persist(storageLevel); // use NONE to disable
} else {
result.dataset = dataset;
if (result.dependentTransforms.isEmpty()) {
leaves.add(result);
result.dataset = dataset;

if (!cache && isMemoryOnly) {
result.resetPlanComplexity(); // cached as RDD in memory which breaks linage
} else if (cache && result.usages() > 1) {
if (isMemoryOnly) {
// Cache as RDD in-memory only, this helps to also break linage of complex query plans.
LOG.info("Dataset {} will be cached in-memory as RDD for reuse.", result.name);
result.dataset = sparkSession.createDataset(dataset.rdd().persist(), dataset.encoder());
result.resetPlanComplexity();
} else {
LOG.info("Dataset {} will be cached for reuse.", result.name);
dataset.persist(storageLevel); // use NONE to disable
}
}

if (result.estimatePlanComplexity() > PLAN_COMPLEXITY_THRESHOLD) {
// Break linage of dataset to limit planning overhead for complex query plans.
LOG.info("Breaking linage of dataset {} to limit complexity of query plan.", result.name);
result.dataset = sparkSession.createDataset(dataset.rdd(), dataset.encoder());
result.resetPlanComplexity();
}

if (result.isLeaf()) {
leaves.add(result);
}
}

@Override
public boolean isLeave(PCollection<?> pCollection) {
return getResult(pCollection).dependentTransforms.isEmpty();
public boolean isLeaf(PCollection<?> pCollection) {
return getResult(pCollection).isLeaf();
}

@Override
Expand Down Expand Up @@ -325,15 +387,18 @@ <InT extends PInput, OutT extends POutput> void visit(
Node node,
PTransform<InT, OutT> transform,
TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
// add new translation result for every output of `transform`
for (PCollection<?> pOut : node.getOutputs().values()) {
results.put(pOut, new TranslationResult<>(pOut));
}
// track `transform` as downstream dependency for every input
// Track `transform` as downstream dependency of every input and reversely
// every input is a dependency of each output of `transform`.
List<TranslationResult<?>> dependencies = new ArrayList<>(node.getInputs().size());
for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
dependencies.add(input);
input.dependentTransforms.add(transform);
}
// add new translation result for every output of `transform`
for (PCollection<?> pOut : node.getOutputs().values()) {
results.put(pOut, new TranslationResult<>(pOut, translator.complexityFactor, dependencies));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
public abstract class TransformTranslator<
InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>> {

// Factor to help estimate the complexity of the Spark execution plan. This is used to limit
// complexity by break linage where necessary to avoid overly large plans. Such plans can become
// very expensive during planning in the Catalyst optimizer.
protected final float complexityFactor;

protected TransformTranslator(float complexityFactor) {
this.complexityFactor = complexityFactor;
}

protected abstract void translate(TransformT transform, Context cxt) throws IOException;

final void translate(
Expand Down Expand Up @@ -150,8 +159,8 @@ public <T> void putDataset(
}

@Override
public boolean isLeave(PCollection<?> pCollection) {
return state.isLeave(pCollection);
public boolean isLeaf(PCollection<?> pCollection) {
return state.isLeaf(pCollection);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
class CombineGloballyTranslatorBatch<InT, AccT, OutT>
extends TransformTranslator<PCollection<InT>, PCollection<OutT>, Combine.Globally<InT, OutT>> {

CombineGloballyTranslatorBatch() {
super(0.2f);
}

@Override
protected void translate(Combine.Globally<InT, OutT> transform, Context cxt) {
WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class CombineGroupedValuesTranslatorBatch<K, InT, AccT, OutT>
PCollection<KV<K, OutT>>,
Combine.GroupedValues<K, InT, OutT>> {

CombineGroupedValuesTranslatorBatch() {
super(0.2f);
}

@Override
protected void translate(Combine.GroupedValues<K, InT, OutT> transform, Context cxt)
throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class CombinePerKeyTranslatorBatch<K, InT, AccT, OutT>
extends TransformTranslator<
PCollection<KV<K, InT>>, PCollection<KV<K, OutT>>, Combine.PerKey<K, InT, OutT>> {

CombinePerKeyTranslatorBatch() {
super(0.2f);
}

@Override
public void translate(Combine.PerKey<K, InT, OutT> transform, Context cxt) {
WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
class FlattenTranslatorBatch<T>
extends TransformTranslator<PCollectionList<T>, PCollection<T>, Flatten.PCollections<T>> {

FlattenTranslatorBatch() {
super(0.1f);
}

@Override
public void translate(Flatten.PCollections<T> transform, Context cxt) {
Collection<PCollection<?>> pCollections = cxt.getInputs().values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,12 @@ class GroupByKeyTranslatorBatch<K, V>

private boolean useCollectList = true;

GroupByKeyTranslatorBatch() {}
GroupByKeyTranslatorBatch() {
super(0.2f);
}

GroupByKeyTranslatorBatch(boolean useCollectList) {
super(0.2f);
this.useCollectList = useCollectList;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@

class ImpulseTranslatorBatch extends TransformTranslator<PBegin, PCollection<byte[]>, Impulse> {

ImpulseTranslatorBatch() {
super(0);
}

@Override
public void translate(Impulse transform, Context cxt) {
Dataset<WindowedValue<byte[]>> dataset =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class ParDoTranslatorBatch<InputT, OutputT>
private static final ClassTag<Tuple2<Integer, WindowedValue<Object>>> TUPLE2_CTAG =
ClassTag.apply(Tuple2.class);

ParDoTranslatorBatch() {
super(0.2f);
}

@Override
public boolean canTranslate(ParDo.MultiOutput<InputT, OutputT> transform) {
DoFn<InputT, OutputT> doFn = transform.getFn();
Expand Down Expand Up @@ -123,7 +127,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
Map<TupleTag<?>, PCollection<?>> outputs =
Maps.filterEntries(
cxt.getOutputs(),
e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeave(e.getValue())));
e -> e != null && (e.getKey().equals(mainOut) || !cxt.isLeaf(e.getValue())));

if (outputs.size() > 1) {
// In case of multiple outputs / tags, map each tag to a column by index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
class ReadSourceTranslatorBatch<T>
extends TransformTranslator<PBegin, PCollection<T>, SplittableParDo.PrimitiveBoundedRead<T>> {

ReadSourceTranslatorBatch() {
super(0.05f);
}

@Override
public void translate(SplittableParDo.PrimitiveBoundedRead<T> transform, Context cxt)
throws IOException {
Expand All @@ -50,6 +54,7 @@ public void translate(SplittableParDo.PrimitiveBoundedRead<T> transform, Context

cxt.putDataset(
cxt.getOutput(),
BoundedDatasetFactory.createDatasetFromRDD(session, source, options, encoder));
BoundedDatasetFactory.createDatasetFromRDD(session, source, options, encoder),
false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
class ReshuffleTranslatorBatch<K, V>
extends TransformTranslator<PCollection<KV<K, V>>, PCollection<KV<K, V>>, Reshuffle<K, V>> {

ReshuffleTranslatorBatch() {
super(0.1f);
}

@Override
protected void translate(Reshuffle<K, V> transform, Context cxt) throws IOException {
Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput());
Expand All @@ -40,6 +44,10 @@ protected void translate(Reshuffle<K, V> transform, Context cxt) throws IOExcept
static class ViaRandomKey<V>
extends TransformTranslator<PCollection<V>, PCollection<V>, Reshuffle.ViaRandomKey<V>> {

ViaRandomKey() {
super(0.1f);
}

@Override
protected void translate(Reshuffle.ViaRandomKey<V> transform, Context cxt) throws IOException {
Dataset<WindowedValue<V>> input = cxt.getDataset(cxt.getInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
class WindowAssignTranslatorBatch<T>
extends TransformTranslator<PCollection<T>, PCollection<T>, Window.Assign<T>> {

WindowAssignTranslatorBatch() {
super(0.05f);
}

@Override
public void translate(Window.Assign<T> transform, Context cxt) {
WindowFn<T, ?> windowFn = transform.getWindowFn();
Expand Down