From 762edd7f3a64f076dbee156fa48b8a7e5e6a512f Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Thu, 22 Sep 2022 12:30:03 +0200 Subject: [PATCH] Improved pipeline translation in SparkStructuredStreamingRunner (#22446) * Closes #22445: Improved pipeline translation in SparkStructuredStreamingRunner (also closes #22382) --- .../translation/helpers/EncoderFactory.java | 12 +- .../translation/utils/ScalaInterop.java} | 27 +- .../spark/structuredstreaming/Constants.java | 25 - .../SparkStructuredStreamingRunner.java | 51 +- .../io/BoundedDatasetFactory.java | 324 ++++++++++ .../streaming => io}/package-info.java | 4 +- .../metrics/WithMetricsSupport.java | 2 +- .../AbstractTranslationContext.java | 235 ------- .../translation/PipelineTranslator.java | 57 +- .../translation/TransformTranslator.java | 198 +++++- .../translation/TranslationContext.java | 124 +++- .../translation/batch/AggregatorCombiner.java | 270 -------- .../translation/batch/Aggregators.java | 591 ++++++++++++++++++ .../batch/CombineGloballyTranslatorBatch.java | 121 ++++ .../batch/CombinePerKeyTranslatorBatch.java | 181 ++++-- .../CreatePCollectionViewTranslatorBatch.java | 24 +- .../translation/batch/DatasetSourceBatch.java | 240 ------- .../translation/batch/DoFnFunction.java | 164 ----- .../batch/DoFnMapPartitionsFactory.java | 224 +++++++ .../batch/FlattenTranslatorBatch.java | 60 +- .../translation/batch/GroupByKeyHelpers.java | 106 ++++ .../batch/GroupByKeyTranslatorBatch.java | 298 +++++++-- .../batch/ImpulseTranslatorBatch.java | 26 +- .../batch/ParDoTranslatorBatch.java | 315 +++++----- .../batch/PipelineTranslatorBatch.java | 28 +- .../translation/batch/ProcessContext.java | 138 ---- .../batch/ReadSourceTranslatorBatch.java | 76 +-- .../batch/ReshuffleTranslatorBatch.java | 30 - .../batch/WindowAssignTranslatorBatch.java | 90 ++- .../translation/helpers/CoderHelpers.java | 10 +- .../translation/helpers/EncoderFactory.java | 71 ++- .../translation/helpers/EncoderHelpers.java | 546 +++++++++++++++- .../translation/helpers/MultiOutputCoder.java | 84 --- .../translation/helpers/RowHelpers.java | 75 --- .../translation/helpers/SchemaHelpers.java | 39 -- .../translation/helpers/WindowingHelpers.java | 82 --- .../streaming/DatasetSourceStreaming.java | 25 - .../PipelineTranslatorStreaming.java | 93 --- .../ReadSourceTranslatorStreaming.java | 87 --- .../translation/utils/ScalaInterop.java | 114 ++++ .../metrics/sink/InMemoryMetrics.java | 2 +- .../translation/batch/AggregatorsTest.java | 370 +++++++++++ .../batch/CombineGloballyTest.java | 155 +++++ ...ombineTest.java => CombinePerKeyTest.java} | 92 ++- .../translation/batch/ComplexSourceTest.java | 15 +- .../translation/batch/FlattenTest.java | 12 +- .../translation/batch/GroupByKeyTest.java | 152 +++-- .../translation/batch/ParDoTest.java | 54 +- .../translation/batch/SimpleSourceTest.java | 12 +- .../translation/batch/WindowAssignTest.java | 12 +- .../helpers/EncoderHelpersTest.java | 210 ++++++- .../spark/SparkCommonPipelineOptions.java | 6 + .../runners/spark/SparkPipelineOptions.java | 6 - 53 files changed, 4084 insertions(+), 2281 deletions(-) rename runners/spark/{3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java => 2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java} (61%) delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java create mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java rename runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/{translation/streaming => io}/package-info.java (83%) delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java create mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java create mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java create mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java create mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java delete mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java create mode 100644 runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java create mode 100644 runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java create mode 100644 runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java rename runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/{CombineTest.java => CombinePerKeyTest.java} (71%) diff --git a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java index 54b400f08d00..2b86ec839c9e 100644 --- a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java +++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java @@ -17,26 +17,24 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; + import org.apache.spark.sql.Encoder; import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; import org.apache.spark.sql.types.DataType; -import scala.collection.immutable.List; -import scala.collection.immutable.Nil$; -import scala.collection.mutable.WrappedArray; import scala.reflect.ClassTag$; public class EncoderFactory { static Encoder create( Expression serializer, Expression deserializer, Class clazz) { - // TODO Isolate usage of Scala APIs in utility https://github.com/apache/beam/issues/22382 - List serializers = Nil$.MODULE$.$colon$colon(serializer); return new ExpressionEncoder<>( SchemaHelpers.binarySchema(), false, - serializers, + listOf(serializer), deserializer, ClassTag$.MODULE$.apply(clazz)); } @@ -46,6 +44,6 @@ static Encoder create( * input arg is {@code null}. */ static Expression invokeIfNotNull(Class cls, String fun, DataType type, Expression... args) { - return new StaticInvoke(cls, type, fun, new WrappedArray.ofRef<>(args), true, true); + return new StaticInvoke(cls, type, fun, seqOf(args), true, true); } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java similarity index 61% rename from runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java rename to runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java index 2406c0f49ab5..c5bc71af6026 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java +++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java @@ -15,17 +15,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; +package org.apache.beam.runners.spark.structuredstreaming.translation.utils; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.KV; -import org.apache.spark.api.java.function.MapFunction; +import scala.collection.Seq; +import scala.collection.immutable.List; +import scala.collection.immutable.Nil$; +import scala.collection.mutable.WrappedArray; -/** Helper functions for working with {@link org.apache.beam.sdk.values.KV}. */ -public final class KVHelpers { +/** Utilities for easier interoperability with the Spark Scala API. */ +public class ScalaInterop { + private ScalaInterop() {} - /** A Spark {@link MapFunction} for extracting the key out of a {@link KV} for GBK for example. */ - public static MapFunction>, K> extractKey() { - return wv -> wv.getValue().getKey(); + public static Seq seqOf(T... t) { + return new WrappedArray.ofRef<>(t); + } + + public static Seq listOf(T t) { + return emptyList().$colon$colon(t); + } + + public static List emptyList() { + return (List) Nil$.MODULE$; } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java deleted file mode 100644 index 08c187ce6c6d..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming; - -public class Constants { - - public static final String BEAM_SOURCE_OPTION = "beam-source"; - public static final String DEFAULT_PARALLELISM = "default-parallelism"; - public static final String PIPELINE_OPTIONS = "pipeline-options"; -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java index b1de9e941e40..68f54ac93bf0 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java @@ -29,10 +29,9 @@ import org.apache.beam.runners.spark.structuredstreaming.metrics.CompositeSource; import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; import org.apache.beam.runners.spark.structuredstreaming.metrics.SparkBeamMetricSource; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch; -import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.metrics.MetricsEnvironment; @@ -41,6 +40,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.PipelineOptionsValidator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.spark.SparkEnv$; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.metrics.MetricsSystem; @@ -48,24 +48,34 @@ import org.slf4j.LoggerFactory; /** - * SparkStructuredStreamingRunner is based on spark structured streaming framework and is no more - * based on RDD/DStream API. See - * https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html It is still - * experimental, its coverage of the Beam model is partial. The SparkStructuredStreamingRunner - * translate operations defined on a pipeline to a representation executable by Spark, and then - * submitting the job to Spark to be executed. If we wanted to run a Beam pipeline with the default - * options of a single threaded spark instance in local mode, we would do the following: + * A Spark runner build on top of Spark's SQL Engine (Structured + * Streaming framework). * - *

{@code Pipeline p = [logic for pipeline creation] SparkStructuredStreamingPipelineResult - * result = (SparkStructuredStreamingPipelineResult) p.run(); } + *

This runner is experimental, its coverage of the Beam model is still partial. Due to + * limitations of the Structured Streaming framework (e.g. lack of support for multiple stateful + * operators), streaming mode is not yet supported by this runner. + * + *

The runner translates transforms defined on a Beam pipeline to Spark `Dataset` transformations + * (leveraging the high level Dataset API) and then submits these to Spark to be executed. + * + *

To run a Beam pipeline with the default options using Spark's local mode, we would do the + * following: + * + *

{@code
+ * Pipeline p = [logic for pipeline creation]
+ * PipelineResult result = p.run();
+ * }
* *

To create a pipeline runner to run against a different spark cluster, with a custom master url * we would do the following: * - *

{@code Pipeline p = [logic for pipeline creation] SparkStructuredStreamingPipelineOptions - * options = SparkPipelineOptionsFactory.create(); options.setSparkMaster("spark://host:port"); - * SparkStructuredStreamingPipelineResult result = (SparkStructuredStreamingPipelineResult) p.run(); - * } + *

{@code
+ * Pipeline p = [logic for pipeline creation]
+ * SparkCommonPipelineOptions options = p.getOptions.as(SparkCommonPipelineOptions.class);
+ * options.setSparkMaster("spark://host:port");
+ * PipelineResult result = p.run();
+ * }
*/ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -135,7 +145,7 @@ public SparkStructuredStreamingPipelineResult run(final Pipeline pipeline) { AggregatorsAccumulator.clear(); MetricsAccumulator.clear(); - final AbstractTranslationContext translationContext = translatePipeline(pipeline); + final TranslationContext translationContext = translatePipeline(pipeline); final ExecutorService executorService = Executors.newSingleThreadExecutor(); final Future submissionFuture = @@ -169,8 +179,10 @@ public SparkStructuredStreamingPipelineResult run(final Pipeline pipeline) { return result; } - private AbstractTranslationContext translatePipeline(Pipeline pipeline) { + private TranslationContext translatePipeline(Pipeline pipeline) { PipelineTranslator.detectTranslationMode(pipeline, options); + Preconditions.checkArgument( + !options.isStreaming(), "%s does not support streaming pipelines.", getClass().getName()); // Default to using the primitive versions of Read.Bounded and Read.Unbounded for non-portable // execution. @@ -182,10 +194,7 @@ private AbstractTranslationContext translatePipeline(Pipeline pipeline) { PipelineTranslator.replaceTransforms(pipeline, options); prepareFilesToStage(options); - PipelineTranslator pipelineTranslator = - options.isStreaming() - ? new PipelineTranslatorStreaming(options) - : new PipelineTranslatorBatch(options); + PipelineTranslator pipelineTranslator = new PipelineTranslatorBatch(options); final JavaSparkContext jsc = JavaSparkContext.fromSparkContext( diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java new file mode 100644 index 000000000000..83dc98f3c100 --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.io; + +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static scala.collection.JavaConverters.asScalaIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.IntSupplier; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.BoundedSource.BoundedReader; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.spark.InterruptibleIterator; +import org.apache.spark.Partition; +import org.apache.spark.SparkContext; +import org.apache.spark.TaskContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import scala.Option; +import scala.collection.Iterator; +import scala.reflect.ClassTag; + +public class BoundedDatasetFactory { + private BoundedDatasetFactory() {} + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}. + * + *

Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization. + * This makes this approach at the time being significantly less performant than creating a + * dataset from an RDD. + */ + public static Dataset> createDatasetFromRows( + SparkSession session, + BoundedSource source, + SerializablePipelineOptions options, + Encoder> encoder) { + Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + BeamTable table = new BeamTable<>(source, params); + LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty()); + return Dataset.ofRows(session, logicalPlan).as(encoder); + } + + /** + * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}. + * + *

This is currently the most efficient approach as it avoid any serialization overhead. + */ + public static Dataset> createDatasetFromRDD( + SparkSession session, + BoundedSource source, + SerializablePipelineOptions options, + Encoder> encoder) { + Params params = new Params<>(encoder, options, session.sparkContext().defaultParallelism()); + RDD> rdd = new BoundedRDD<>(session.sparkContext(), source, params); + return session.createDataset(rdd, encoder); + } + + /** An {@link RDD} for a bounded Beam source. */ + private static class BoundedRDD extends RDD> { + final BoundedSource source; + final Params params; + + public BoundedRDD(SparkContext sc, BoundedSource source, Params params) { + super(sc, emptyList(), ClassTag.apply(WindowedValue.class)); + this.source = source; + this.params = params; + } + + @Override + public Iterator> compute(Partition split, TaskContext context) { + return new InterruptibleIterator<>( + context, + asScalaIterator(new SourcePartitionIterator<>((SourcePartition) split, params))); + } + + @Override + public Partition[] getPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]); + } + } + + /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */ + private static class BeamTable implements Table, SupportsRead { + final BoundedSource source; + final Params params; + + BeamTable(BoundedSource source, Params params) { + this.source = source; + this.params = params; + } + + public Encoder> getEncoder() { + return params.encoder; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) { + return () -> + new Scan() { + @Override + public StructType readSchema() { + return params.encoder.schema(); + } + + @Override + public Batch toBatch() { + return new BeamBatch<>(source, params); + } + }; + } + + @Override + public String name() { + return "BeamSource<" + source.getClass().getName() + ">"; + } + + @Override + public StructType schema() { + return params.encoder.schema(); + } + + @Override + public Set capabilities() { + return ImmutableSet.of(TableCapability.BATCH_READ); + } + + private static class BeamBatch implements Batch, Serializable { + final BoundedSource source; + final Params params; + + private BeamBatch(BoundedSource source, Params params) { + this.source = source; + this.params = params; + } + + @Override + public InputPartition[] planInputPartitions() { + return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return p -> new BeamPartitionReader<>(((SourcePartition) p), params); + } + } + + private static class BeamPartitionReader implements PartitionReader { + final SourcePartitionIterator iterator; + final Serializer> serializer; + transient @Nullable InternalRow next; + + BeamPartitionReader(SourcePartition partition, Params params) { + iterator = new SourcePartitionIterator<>(partition, params); + serializer = ((ExpressionEncoder>) params.encoder).createSerializer(); + } + + @Override + public boolean next() throws IOException { + if (iterator.hasNext()) { + next = serializer.apply(iterator.next()); + return true; + } + return false; + } + + @Override + public InternalRow get() { + if (next == null) { + throw new IllegalStateException("Next not available"); + } + return next; + } + + @Override + public void close() throws IOException { + next = null; + iterator.close(); + } + } + } + + /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */ + private static class SourcePartition implements Partition, InputPartition { + final BoundedSource source; + final int index; + + SourcePartition(BoundedSource source, IntSupplier idxSupplier) { + this.source = source; + this.index = idxSupplier.getAsInt(); + } + + static List> partitionsOf(BoundedSource source, Params params) { + try { + PipelineOptions options = params.options.get(); + long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions; + List> split = (List>) source.split(desiredSize, options); + IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement; + return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList()); + } catch (Exception e) { + throw new RuntimeException( + "Error splitting BoundedSource " + source.getClass().getCanonicalName(), e); + } + } + + @Override + public int index() { + return index; + } + + @Override + public int hashCode() { + return index; + } + } + + /** A partition iterator on a partitioned Beam {@link BoundedSource}. */ + private static class SourcePartitionIterator extends AbstractIterator> + implements Closeable { + BoundedReader reader; + boolean started = false; + + public SourcePartitionIterator(SourcePartition partition, Params params) { + try { + reader = partition.source.createReader(params.options.get()); + } catch (IOException e) { + throw new RuntimeException("Failed to create reader from a BoundedSource.", e); + } + } + + @Override + @SuppressWarnings("nullness") // ok, reader not used any longer + public void close() throws IOException { + if (reader != null) { + endOfData(); + try { + reader.close(); + } finally { + reader = null; + } + } + } + + @Override + protected WindowedValue computeNext() { + try { + if (started ? reader.advance() : start()) { + return timestampedValueInGlobalWindow(reader.getCurrent(), reader.getCurrentTimestamp()); + } else { + close(); + return endOfData(); + } + } catch (IOException e) { + throw new RuntimeException("Failed to start or advance reader.", e); + } + } + + private boolean start() throws IOException { + started = true; + return reader.start(); + } + } + + /** Shared parameters. */ + private static class Params implements Serializable { + final Encoder> encoder; + final SerializablePipelineOptions options; + final int numPartitions; + + Params( + Encoder> encoder, SerializablePipelineOptions options, int numPartitions) { + checkArgument(numPartitions > 0, "Number of partitions must be greater than zero."); + this.encoder = encoder; + this.options = options; + this.numPartitions = numPartitions; + } + } +} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java similarity index 83% rename from runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java rename to runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java index 67f3613e056b..23de70c705b3 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java @@ -16,5 +16,5 @@ * limitations under the License. */ -/** Internal utilities to translate Beam pipelines to Spark streaming. */ -package org.apache.beam.runners.spark.structuredstreaming.translation.streaming; +/** Spark-specific transforms for I/O. */ +package org.apache.beam.runners.spark.structuredstreaming.io; diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java index d48a229996f7..c9233a128c16 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java @@ -36,7 +36,6 @@ *

{@link MetricRegistry} is not an interface, so this is not a by-the-book decorator. That said, * it delegates all metric related getters to the "decorated" instance. */ -@SuppressWarnings({"rawtypes"}) // required by interface public class WithMetricsSupport extends MetricRegistry { private final MetricRegistry internalMetricRegistry; @@ -70,6 +69,7 @@ public SortedMap getCounters(final MetricFilter filter) { } @Override + @SuppressWarnings({"rawtypes"}) // required by interface public SortedMap getGauges(final MetricFilter filter) { ImmutableSortedMap.Builder builder = new ImmutableSortedMap.Builder<>(Ordering.from(String.CASE_INSENSITIVE_ORDER)); diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java deleted file mode 100644 index aed287ba6d56..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation; - -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.core.construction.TransformInputs; -import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.VoidCoder; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.spark.api.java.function.ForeachFunction; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.ForeachWriter; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.streaming.DataStreamWriter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Base class that gives a context for {@link PTransform} translation: keeping track of the - * datasets, the {@link SparkSession}, the current transform being translated. - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public abstract class AbstractTranslationContext { - - private static final Logger LOG = LoggerFactory.getLogger(AbstractTranslationContext.class); - - /** All the datasets of the DAG. */ - private final Map> datasets; - /** datasets that are not used as input to other datasets (leaves of the DAG). */ - private final Set> leaves; - - private final SerializablePipelineOptions serializablePipelineOptions; - - @SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy - private AppliedPTransform currentTransform; - - @SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy - private final SparkSession sparkSession; - - private final Map, Dataset> broadcastDataSets; - - public AbstractTranslationContext(SparkStructuredStreamingPipelineOptions options) { - this.sparkSession = SparkSessionFactory.getOrCreateSession(options); - this.serializablePipelineOptions = new SerializablePipelineOptions(options); - this.datasets = new HashMap<>(); - this.leaves = new HashSet<>(); - this.broadcastDataSets = new HashMap<>(); - } - - public SparkSession getSparkSession() { - return sparkSession; - } - - public SerializablePipelineOptions getSerializableOptions() { - return serializablePipelineOptions; - } - - // -------------------------------------------------------------------------------------------- - // Transforms methods - // -------------------------------------------------------------------------------------------- - public void setCurrentTransform(AppliedPTransform currentTransform) { - this.currentTransform = currentTransform; - } - - public AppliedPTransform getCurrentTransform() { - return currentTransform; - } - - // -------------------------------------------------------------------------------------------- - // Datasets methods - // -------------------------------------------------------------------------------------------- - @SuppressWarnings("unchecked") - public Dataset emptyDataset() { - return (Dataset) sparkSession.emptyDataset(EncoderHelpers.fromBeamCoder(VoidCoder.of())); - } - - @SuppressWarnings("unchecked") - public Dataset> getDataset(PValue value) { - Dataset dataset = datasets.get(value); - // assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore - leaves.remove(dataset); - return (Dataset>) dataset; - } - - /** - * TODO: All these 3 methods (putDataset*) are temporary and they are used only for generics type - * checking. We should unify them in the future. - */ - public void putDatasetWildcard(PValue value, Dataset> dataset) { - if (!datasets.containsKey(value)) { - datasets.put(value, dataset); - leaves.add(dataset); - } - } - - public void putDataset(PValue value, Dataset> dataset) { - if (!datasets.containsKey(value)) { - datasets.put(value, dataset); - leaves.add(dataset); - } - } - - public void setSideInputDataset( - PCollectionView value, Dataset> set) { - if (!broadcastDataSets.containsKey(value)) { - broadcastDataSets.put(value, set); - } - } - - @SuppressWarnings("unchecked") - public Dataset getSideInputDataSet(PCollectionView value) { - return (Dataset) broadcastDataSets.get(value); - } - - // -------------------------------------------------------------------------------------------- - // PCollections methods - // -------------------------------------------------------------------------------------------- - public PValue getInput() { - return Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); - } - - public Map, PCollection> getInputs() { - return currentTransform.getInputs(); - } - - public PValue getOutput() { - return Iterables.getOnlyElement(currentTransform.getOutputs().values()); - } - - public Map, PCollection> getOutputs() { - return currentTransform.getOutputs(); - } - - @SuppressWarnings("unchecked") - public Map, Coder> getOutputCoders() { - return currentTransform.getOutputs().entrySet().stream() - .filter(e -> e.getValue() instanceof PCollection) - .collect(Collectors.toMap(Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder())); - } - - // -------------------------------------------------------------------------------------------- - // Pipeline methods - // -------------------------------------------------------------------------------------------- - - /** Starts the pipeline. */ - public void startPipeline() { - SparkStructuredStreamingPipelineOptions options = - serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class); - int datasetIndex = 0; - for (Dataset dataset : leaves) { - if (options.isStreaming()) { - // TODO: deal with Beam Discarding, Accumulating and Accumulating & Retracting outputmodes - // with DatastreamWriter.outputMode - DataStreamWriter dataStreamWriter = dataset.writeStream(); - // spark sets a default checkpoint dir if not set. - if (options.getCheckpointDir() != null) { - dataStreamWriter = - dataStreamWriter.option("checkpointLocation", options.getCheckpointDir()); - } - launchStreaming(dataStreamWriter.foreach(new NoOpForeachWriter<>())); - } else { - if (options.getTestMode()) { - LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex); - dataset.explain(true); - } - // apply a dummy fn just to apply foreach action that will trigger the pipeline run in - // spark - dataset.foreach((ForeachFunction) t -> {}); - } - } - } - - public abstract void launchStreaming(DataStreamWriter dataStreamWriter); - - public static void printDatasetContent(Dataset dataset) { - // cannot use dataset.show because dataset schema is binary so it will print binary - // code. - List windowedValues = dataset.collectAsList(); - for (WindowedValue windowedValue : windowedValues) { - LOG.debug("**** dataset content {} ****", windowedValue.toString()); - } - } - - private static class NoOpForeachWriter extends ForeachWriter { - - @Override - public boolean open(long partitionId, long epochId) { - return false; - } - - @Override - public void process(T value) { - // do nothing - } - - @Override - public void close(Throwable errorOrNull) { - // do nothing - } - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java index 0f851d9588d9..0fa48fc3d3ea 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java @@ -17,24 +17,25 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation; +import java.io.IOException; import org.apache.beam.runners.core.construction.PTransformTranslation; -import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch; -import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.options.StreamingOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts. * It also does the pipeline preparation: mode detection, transforms replacement, classpath - * preparation. If we have a streaming job, it is instantiated as a {@link - * PipelineTranslatorStreaming}. If we have a batch job, it is instantiated as a {@link - * PipelineTranslatorBatch}. + * preparation. */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -42,7 +43,7 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults { private int depth = 0; private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class); - protected AbstractTranslationContext translationContext; + protected TranslationContext translationContext; // -------------------------------------------------------------------------------------------- // Pipeline preparation methods @@ -123,22 +124,25 @@ private static String genSpaces(int n) { } /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */ - protected abstract TransformTranslator getTransformTranslator(TransformHierarchy.Node node); + protected abstract @Nullable < + InT extends PInput, OutT extends POutput, TransformT extends PTransform> + TransformTranslator getTransformTranslator( + @Nullable TransformT transform); /** Apply the given TransformTranslator to the given node. */ - private > void applyTransformTranslator( - TransformHierarchy.Node node, TransformTranslator transformTranslator) { + private > + void applyTransformTranslator( + TransformHierarchy.Node node, + TransformT transform, + TransformTranslator transformTranslator) { // create the applied PTransform on the translationContext - translationContext.setCurrentTransform(node.toAppliedPTransform(getPipeline())); - - // avoid type capture - @SuppressWarnings("unchecked") - T typedTransform = (T) node.getTransform(); - @SuppressWarnings("unchecked") - TransformTranslator typedTransformTranslator = (TransformTranslator) transformTranslator; - - // apply the transformTranslator - typedTransformTranslator.translateTransform(typedTransform, translationContext); + AppliedPTransform> appliedTransform = + (AppliedPTransform) node.toAppliedPTransform(getPipeline()); + try { + transformTranslator.translate(transform, appliedTransform, translationContext); + } catch (IOException e) { + throw new RuntimeException(e); + } } // -------------------------------------------------------------------------------------------- @@ -164,10 +168,12 @@ public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) { LOG.debug("{} enterCompositeTransform- {}", genSpaces(depth), node.getFullName()); depth++; - TransformTranslator transformTranslator = getTransformTranslator(node); + PTransform transform = (PTransform) node.getTransform(); + TransformTranslator> transformTranslator = + getTransformTranslator(transform); if (transformTranslator != null) { - applyTransformTranslator(node, transformTranslator); + applyTransformTranslator(node, transform, transformTranslator); LOG.debug("{} translated- {}", genSpaces(depth), node.getFullName()); return CompositeBehavior.DO_NOT_ENTER_TRANSFORM; } else { @@ -187,16 +193,19 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { // get the transformation corresponding to the node we are // currently visiting and translate it into its Spark alternative. - TransformTranslator transformTranslator = getTransformTranslator(node); + PTransform transform = (PTransform) node.getTransform(); + TransformTranslator> transformTranslator = + getTransformTranslator(transform); + if (transformTranslator == null) { String transformUrn = PTransformTranslation.urnForTransform(node.getTransform()); throw new UnsupportedOperationException( "The transform " + transformUrn + " is currently not supported."); } - applyTransformTranslator(node, transformTranslator); + applyTransformTranslator(node, transform, transformTranslator); } - public AbstractTranslationContext getTranslationContext() { + public TranslationContext getTranslationContext() { return translationContext; } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java index 61580aed2192..d991a0d9148d 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java @@ -17,15 +17,197 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation; -import java.io.Serializable; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.getOnlyElement; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.core.construction.TransformInputs; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SparkSession; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import scala.Tuple2; +import scala.reflect.ClassTag; + +/** + * Supports translation between a Beam transform, and Spark's operations on Datasets. + * + *

WARNING: Do not make this class serializable! It could easily hide situations where + * unnecessary references leak into Spark closures. + */ +public abstract class TransformTranslator< + InT extends PInput, OutT extends POutput, TransformT extends PTransform> { + + protected abstract void translate(TransformT transform, Context cxt) throws IOException; + + public final void translate( + TransformT transform, + AppliedPTransform> appliedTransform, + TranslationContext cxt) + throws IOException { + translate(transform, new Context(appliedTransform, cxt)); + } + + protected class Context { + private final AppliedPTransform> transform; + private final TranslationContext cxt; + private @MonotonicNonNull InT pIn = null; + private @MonotonicNonNull OutT pOut = null; + + protected Context( + AppliedPTransform> transform, TranslationContext cxt) { + this.transform = transform; + this.cxt = cxt; + } + + public InT getInput() { + if (pIn == null) { + pIn = (InT) getOnlyElement(TransformInputs.nonAdditionalInputs(transform)); + } + return pIn; + } + + public Map, PCollection> getInputs() { + return transform.getInputs(); + } + + public OutT getOutput() { + if (pOut == null) { + pOut = (OutT) getOnlyElement(transform.getOutputs().values()); + } + return pOut; + } + + public PCollection getOutput(TupleTag tag) { + PCollection pc = (PCollection) transform.getOutputs().get(tag); + if (pc == null) { + throw new IllegalStateException("No output for tag " + tag); + } + return pc; + } + + public Map, PCollection> getOutputs() { + return transform.getOutputs(); + } + + public AppliedPTransform> getCurrentTransform() { + return transform; + } + + public Dataset> getDataset(PCollection pCollection) { + return cxt.getDataset(pCollection); + } + + public void putDataset(PCollection pCollection, Dataset> dataset) { + cxt.putDataset(pCollection, dataset); + } + + public SerializablePipelineOptions getSerializableOptions() { + return cxt.getSerializableOptions(); + } + + public PipelineOptions getOptions() { + return cxt.getSerializableOptions().get(); + } + + // FIXME Types don't guarantee anything! + public void setSideInputDataset( + PCollectionView value, Dataset> set) { + cxt.setSideInputDataset(value, set); + } + + public Dataset getSideInputDataset(PCollectionView sideInput) { + return cxt.getSideInputDataSet(sideInput); + } + + public Dataset> createDataset( + List> data, Encoder> enc) { + return data.isEmpty() + ? cxt.getSparkSession().emptyDataset(enc) + : cxt.getSparkSession().createDataset(data, enc); + } + + public Broadcast broadcast(T value) { + return cxt.getSparkSession().sparkContext().broadcast(value, (ClassTag) ClassTag.AnyRef()); + } + + public SparkSession getSparkSession() { + return cxt.getSparkSession(); + } + + public Encoder encoderOf(Coder coder) { + return coder instanceof KvCoder ? kvEncoderOf((KvCoder) coder) : getOrCreateEncoder(coder); + } + + public Encoder> kvEncoderOf(KvCoder coder) { + return cxt.encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder))); + } + + public Encoder keyEncoderOf(KvCoder coder) { + return getOrCreateEncoder(coder.getKeyCoder()); + } + + public Encoder valueEncoderOf(KvCoder coder) { + return getOrCreateEncoder(coder.getValueCoder()); + } + + public Encoder> windowedEncoder(Coder coder) { + return windowedValueEncoder(encoderOf(coder), windowEncoder()); + } + + public Encoder> windowedEncoder(Encoder enc) { + return windowedValueEncoder(enc, windowEncoder()); + } + + public Encoder> tupleEncoder(Encoder e1, Encoder e2) { + return Encoders.tuple(e1, e2); + } + + public Encoder> windowedEncoder( + Coder coder, Coder windowCoder) { + return windowedValueEncoder(encoderOf(coder), getOrCreateWindowCoder(windowCoder)); + } + + public Encoder windowEncoder() { + checkState( + !transform.getInputs().isEmpty(), "Transform has no inputs, cannot get windowCoder!"); + Coder coder = + ((PCollection) getInput()).getWindowingStrategy().getWindowFn().windowCoder(); + return cxt.encoderOf(coder, c -> encoderFor(c)); + } + + private Encoder getOrCreateWindowCoder( + Coder coder) { + return cxt.encoderOf((Coder) coder, c -> encoderFor(c)); + } -/** Supports translation between a Beam transform, and Spark's operations on Datasets. */ -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) -public interface TransformTranslator extends Serializable { + private Encoder getOrCreateEncoder(Coder coder) { + return cxt.encoderOf(coder, c -> encoderFor(c)); + } + } - /** Base class for translators of {@link PTransform}. */ - void translateTransform(TransformT transform, AbstractTranslationContext context); + protected Coder windowCoder(PCollection pc) { + return (Coder) pc.getWindowingStrategy().getWindowFn().windowCoder(); + } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java index 12cb2d2fef00..617aa67c5feb 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java @@ -17,27 +17,125 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation; -import java.util.concurrent.TimeoutException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; -import org.apache.spark.sql.streaming.DataStreamWriter; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; +import org.apache.spark.api.java.function.ForeachFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** - * Subclass of {@link - * org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext} that - * address spark breaking changes. + * Base class that gives a context for {@link PTransform} translation: keeping track of the + * datasets, the {@link SparkSession}, the current transform being translated. */ -public class TranslationContext extends AbstractTranslationContext { +public class TranslationContext { + + private static final Logger LOG = LoggerFactory.getLogger(TranslationContext.class); + + /** All the datasets of the DAG. */ + private final Map> datasets; + /** datasets that are not used as input to other datasets (leaves of the DAG). */ + private final Set> leaves; + + private final SerializablePipelineOptions serializablePipelineOptions; + + private final SparkSession sparkSession; + + private final Map, Dataset> broadcastDataSets; + + private final Map, ExpressionEncoder> encoders; public TranslationContext(SparkStructuredStreamingPipelineOptions options) { - super(options); + this.sparkSession = SparkSessionFactory.getOrCreateSession(options); + this.serializablePipelineOptions = new SerializablePipelineOptions(options); + this.datasets = new HashMap<>(); + this.leaves = new HashSet<>(); + this.broadcastDataSets = new HashMap<>(); + this.encoders = new HashMap<>(); + } + + public SparkSession getSparkSession() { + return sparkSession; + } + + public SerializablePipelineOptions getSerializableOptions() { + return serializablePipelineOptions; + } + + public Encoder encoderOf(Coder coder, Function, Encoder> loadFn) { + return (Encoder) encoders.computeIfAbsent(coder, (Function) loadFn); + } + + @SuppressWarnings("unchecked") // can't be avoided + public Dataset> getDataset(PCollection pCollection) { + Dataset dataset = Preconditions.checkStateNotNull(datasets.get(pCollection)); + // assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore + leaves.remove(dataset); + return (Dataset>) dataset; + } + + public void putDataset(PCollection pCollection, Dataset> dataset) { + if (!datasets.containsKey(pCollection)) { + datasets.put(pCollection, dataset); + leaves.add(dataset); + } + } + + public void setSideInputDataset( + PCollectionView value, Dataset> set) { + if (!broadcastDataSets.containsKey(value)) { + broadcastDataSets.put(value, set); + } + } + + @SuppressWarnings("unchecked") // can't be avoided + public Dataset getSideInputDataSet(PCollectionView value) { + return (Dataset) Preconditions.checkStateNotNull(broadcastDataSets.get(value)); + } + + /** + * Starts the batch pipeline, streaming is not supported. + * + * @see org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner + */ + public void startPipeline() { + encoders.clear(); + + SparkStructuredStreamingPipelineOptions options = + serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class); + int datasetIndex = 0; + for (Dataset dataset : leaves) { + if (options.getTestMode()) { + LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex); + dataset.explain(true); + } + // force evaluation using a dummy foreach action + dataset.foreach((ForeachFunction) t -> {}); + } } - @Override - public void launchStreaming(DataStreamWriter dataStreamWriter) { - try { - dataStreamWriter.start(); - } catch (TimeoutException e) { - throw new RuntimeException("A timeout occurred when running the streaming pipeline", e); + public static void printDatasetContent(Dataset> dataset) { + // cannot use dataset.show because dataset schema is binary so it will print binary + // code. + List> windowedValues = dataset.collectAsList(); + for (WindowedValue windowedValue : windowedValues) { + System.out.println(windowedValue); } } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java deleted file mode 100644 index d0f46ea807c2..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.batch; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.IterableCoder; -import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.PaneInfo; -import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; -import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.expressions.Aggregator; -import org.joda.time.Instant; -import scala.Tuple2; - -/** An {@link Aggregator} for the Spark Batch Runner. - * The accumulator is a {@code Iterable> because an {@code InputT} can be in multiple windows. So, when accumulating {@code InputT} values, we create one accumulator per input window. - * */ -class AggregatorCombiner - extends Aggregator< - WindowedValue>, - Iterable>, - Iterable>> { - - private final Combine.CombineFn combineFn; - private WindowingStrategy windowingStrategy; - private TimestampCombiner timestampCombiner; - private Coder accumulatorCoder; - private IterableCoder> bufferEncoder; - private IterableCoder> outputCoder; - - public AggregatorCombiner( - Combine.CombineFn combineFn, - WindowingStrategy windowingStrategy, - Coder accumulatorCoder, - Coder outputCoder) { - this.combineFn = combineFn; - this.windowingStrategy = (WindowingStrategy) windowingStrategy; - this.timestampCombiner = windowingStrategy.getTimestampCombiner(); - this.accumulatorCoder = accumulatorCoder; - this.bufferEncoder = - IterableCoder.of( - WindowedValue.FullWindowedValueCoder.of( - accumulatorCoder, windowingStrategy.getWindowFn().windowCoder())); - this.outputCoder = - IterableCoder.of( - WindowedValue.FullWindowedValueCoder.of( - outputCoder, windowingStrategy.getWindowFn().windowCoder())); - } - - @Override - public Iterable> zero() { - return new ArrayList<>(); - } - - private Iterable> createAccumulator(WindowedValue> inputWv) { - // need to create an accumulator because combineFn can modify its input accumulator. - AccumT accumulator = combineFn.createAccumulator(); - AccumT accumT = combineFn.addInput(accumulator, inputWv.getValue().getValue()); - return Lists.newArrayList( - WindowedValue.of(accumT, inputWv.getTimestamp(), inputWv.getWindows(), inputWv.getPane())); - } - - @Override - public Iterable> reduce( - Iterable> accumulators, WindowedValue> inputWv) { - return merge(accumulators, createAccumulator(inputWv)); - } - - @Override - public Iterable> merge( - Iterable> accumulators1, - Iterable> accumulators2) { - - // merge the windows of all the accumulators - Iterable> accumulators = Iterables.concat(accumulators1, accumulators2); - Set accumulatorsWindows = collectAccumulatorsWindows(accumulators); - Map windowToMergeResult; - try { - windowToMergeResult = mergeWindows(windowingStrategy, accumulatorsWindows); - } catch (Exception e) { - throw new RuntimeException("Unable to merge accumulators windows", e); - } - - // group accumulators by their merged window - Map>> mergedWindowToAccumulators = new HashMap<>(); - for (WindowedValue accumulatorWv : accumulators) { - // Encode a version of the accumulator if it is in multiple windows. The combineFn is able to - // mutate the accumulator instance and this could lead to incorrect results if the same - // instance is merged across multiple windows so we decode a new instance as needed. This - // prevents issues during merging of accumulators. - byte[] encodedAccumT = null; - if (accumulatorWv.getWindows().size() > 1) { - try { - encodedAccumT = CoderUtils.encodeToByteArray(accumulatorCoder, accumulatorWv.getValue()); - } catch (CoderException e) { - throw new RuntimeException( - String.format( - "Unable to encode accumulator %s with coder %s.", - accumulatorWv.getValue(), accumulatorCoder), - e); - } - } - for (BoundedWindow accumulatorWindow : accumulatorWv.getWindows()) { - W mergedWindowForAccumulator = windowToMergeResult.get(accumulatorWindow); - mergedWindowForAccumulator = - (mergedWindowForAccumulator == null) - ? (W) accumulatorWindow - : mergedWindowForAccumulator; - - // Decode a copy of the accumulator when necessary. - AccumT accumT; - if (encodedAccumT != null) { - try { - accumT = CoderUtils.decodeFromByteArray(accumulatorCoder, encodedAccumT); - } catch (CoderException e) { - throw new RuntimeException( - String.format( - "Unable to encode accumulator %s with coder %s.", - accumulatorWv.getValue(), accumulatorCoder), - e); - } - } else { - accumT = accumulatorWv.getValue(); - } - - // we need only the timestamp and the AccumT, we create a tuple - Tuple2 accumAndInstant = - new Tuple2<>( - accumT, - timestampCombiner.assign(mergedWindowForAccumulator, accumulatorWv.getTimestamp())); - if (mergedWindowToAccumulators.get(mergedWindowForAccumulator) == null) { - mergedWindowToAccumulators.put( - mergedWindowForAccumulator, Lists.newArrayList(accumAndInstant)); - } else { - mergedWindowToAccumulators.get(mergedWindowForAccumulator).add(accumAndInstant); - } - } - } - // merge the accumulators for each mergedWindow - List> result = new ArrayList<>(); - for (Map.Entry>> entry : - mergedWindowToAccumulators.entrySet()) { - W mergedWindow = entry.getKey(); - List> accumsAndInstantsForMergedWindow = entry.getValue(); - - // we need to create the first accumulator because combineFn.mergerAccumulators can modify the - // first accumulator - AccumT first = combineFn.createAccumulator(); - Iterable accumulatorsToMerge = - Iterables.concat( - Collections.singleton(first), - accumsAndInstantsForMergedWindow.stream() - .map(x -> x._1()) - .collect(Collectors.toList())); - result.add( - WindowedValue.of( - combineFn.mergeAccumulators(accumulatorsToMerge), - timestampCombiner.combine( - accumsAndInstantsForMergedWindow.stream() - .map(x -> x._2()) - .collect(Collectors.toList())), - mergedWindow, - PaneInfo.NO_FIRING)); - } - return result; - } - - @Override - public Iterable> finish(Iterable> reduction) { - List> result = new ArrayList<>(); - for (WindowedValue windowedValue : reduction) { - result.add(windowedValue.withValue(combineFn.extractOutput(windowedValue.getValue()))); - } - return result; - } - - @Override - public Encoder>> bufferEncoder() { - return EncoderHelpers.fromBeamCoder(bufferEncoder); - } - - @Override - public Encoder>> outputEncoder() { - return EncoderHelpers.fromBeamCoder(outputCoder); - } - - private Set collectAccumulatorsWindows(Iterable> accumulators) { - Set windows = new HashSet<>(); - for (WindowedValue accumulator : accumulators) { - for (BoundedWindow untypedWindow : accumulator.getWindows()) { - @SuppressWarnings("unchecked") - W window = (W) untypedWindow; - windows.add(window); - } - } - return windows; - } - - private Map mergeWindows(WindowingStrategy windowingStrategy, Set windows) - throws Exception { - WindowFn windowFn = windowingStrategy.getWindowFn(); - - if (!windowingStrategy.needsMerge()) { - // Return an empty map, indicating that every window is not merged. - return Collections.emptyMap(); - } - - Map windowToMergeResult = new HashMap<>(); - windowFn.mergeWindows(new MergeContextImpl(windowFn, windows, windowToMergeResult)); - return windowToMergeResult; - } - - private class MergeContextImpl extends WindowFn.MergeContext { - - private Set windows; - private Map windowToMergeResult; - - MergeContextImpl(WindowFn windowFn, Set windows, Map windowToMergeResult) { - windowFn.super(); - this.windows = windows; - this.windowToMergeResult = windowToMergeResult; - } - - @Override - public Collection windows() { - return windows; - } - - @Override - public void merge(Collection toBeMerged, W mergeResult) throws Exception { - for (W w : toBeMerged) { - windowToMergeResult.put(w, mergeResult); - } - } - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java new file mode 100644 index 000000000000..45026f9d8bda --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java @@ -0,0 +1,591 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mutablePairEncoder; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators.peekingIterator; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.BiFunction; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.PeekingIterator; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.util.MutablePair; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.joda.time.Instant; + +public class Aggregators { + + /** + * Creates simple value {@link Aggregator} that is not window aware. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + public static Aggregator value( + CombineFn fn, + Fun1 valueFn, + Encoder accEnc, + Encoder outEnc) { + return new ValueAggregator<>(fn, valueFn, accEnc, outEnc); + } + + /** + * Creates windowed Spark {@link Aggregator} depending on the provided Beam {@link WindowFn}s. + * + *

Specialised implementations are provided for: + *

  • {@link Sessions} + *
  • Non merging window functions + *
  • Merging window functions + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + public static + Aggregator, ?, Collection>> windowedValue( + CombineFn fn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + if (!windowing.needsMerge()) { + return new NonMergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc); + } else if (windowing.getWindowFn().getClass().equals(Sessions.class)) { + return new SessionsAggregator<>(fn, valueFn, windowing, (Encoder) windowEnc, accEnc, outEnc); + } + return new MergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc); + } + + /** + * Simple value {@link Aggregator} that is not window aware. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class ValueAggregator + extends CombineFnAggregator { + + public ValueAggregator( + CombineFn fn, + Fun1 valueFn, + Encoder accEnc, + Encoder outEnc) { + super(fn, valueFn, accEnc, outEnc); + } + + @Override + public AccT zero() { + return emptyAcc(); + } + + @Override + public AccT reduce(AccT buff, InT in) { + return addToAcc(buff, value(in)); + } + + @Override + public AccT merge(AccT b1, AccT b2) { + return mergeAccs(b1, b2); + } + + @Override + public ResT finish(AccT buff) { + return extract(buff); + } + } + + /** + * Specialized windowed Spark {@link Aggregator} for Beam {@link WindowFn}s of type {@link + * Sessions}. The aggregator uses a {@link TreeMap} as buffer to maintain ordering of the {@link + * IntervalWindow}s and merge these more efficiently. + * + *

    For efficiency, this aggregator re-implements {@link + * Sessions#mergeWindows(WindowFn.MergeContext)} to leverage the already sorted buffer. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class SessionsAggregator + extends WindowedAggregator< + ValT, + AccT, + ResT, + InT, + IntervalWindow, + TreeMap>> { + + SessionsAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) TreeMap.class); + checkArgument(windowing.getWindowFn().getClass().equals(Sessions.class)); + } + + @Override + public final TreeMap> zero() { + return new TreeMap<>(); + } + + @Override + @SuppressWarnings("keyfor") + public TreeMap> reduce( + TreeMap> buff, WindowedValue input) { + for (IntervalWindow window : (Collection) input.getWindows()) { + @MonotonicNonNull MutablePair acc = null; + @MonotonicNonNull IntervalWindow first = null, last = null; + // start with window before or equal to new window (if exists) + @Nullable Entry> lower = buff.floorEntry(window); + if (lower != null && window.intersects(lower.getKey())) { + // if intersecting, init accumulator and extend window to span both + acc = lower.getValue(); + window = window.span(lower.getKey()); + first = last = lower.getKey(); + } + // merge following windows in order if they intersect, then stop + for (Entry> entry : + buff.tailMap(window, false).entrySet()) { + MutablePair entryAcc = entry.getValue(); + IntervalWindow entryWindow = entry.getKey(); + if (window.intersects(entryWindow)) { + // extend window and merge accumulators + window = window.span(entryWindow); + acc = acc == null ? entryAcc : mergeAccs(window, acc, entryAcc); + if (first == null) { + // there was no previous (lower) window intersecting the input window + first = last = entryWindow; + } else { + last = entryWindow; + } + } else { + break; // stop, later windows won't intersect either + } + } + if (first != null && last != null) { + // remove entire subset from from first to last after it got merged into acc + buff.navigableKeySet().subSet(first, true, last, true).clear(); + } + // add input and get accumulator for new (potentially merged) window + buff.put(window, addToAcc(window, acc, value(input), input.getTimestamp())); + } + return buff; + } + + @Override + public TreeMap> merge( + TreeMap> b1, + TreeMap> b2) { + if (b1.isEmpty()) { + return b2; + } else if (b2.isEmpty()) { + return b1; + } + // Init new tree map to merge both buffers + TreeMap> res = zero(); + PeekingIterator>> it1 = + peekingIterator(b1.entrySet().iterator()); + PeekingIterator>> it2 = + peekingIterator(b2.entrySet().iterator()); + + @Nullable MutablePair acc = null; + @Nullable IntervalWindow window = null; + while (it1.hasNext() || it2.hasNext()) { + // pick iterator with the smallest window ahead and forward it + Entry> nextMin = + (it1.hasNext() && it2.hasNext()) + ? it1.peek().getKey().compareTo(it2.peek().getKey()) <= 0 ? it1.next() : it2.next() + : it1.hasNext() ? it1.next() : it2.next(); + if (window != null && window.intersects(nextMin.getKey())) { + // extend window and merge accumulators if intersecting + window = window.span(nextMin.getKey()); + acc = mergeAccs(window, acc, nextMin.getValue()); + } else { + // store window / accumulator if necessary and continue with next minimum + if (window != null && acc != null) { + res.put(window, acc); + } + acc = nextMin.getValue(); + window = nextMin.getKey(); + } + } + if (window != null && acc != null) { + res.put(window, acc); + } + return res; + } + } + + /** + * Merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as aggregation + * buffer. When reducing new input, a windowed accumulator is created for each new window of the + * input that doesn't overlap with existing windows. Otherwise, if the window is known or + * overlaps, the window is extended accordingly and accumulators are merged. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class MergingWindowedAggregator + extends NonMergingWindowedAggregator { + + private final WindowFn windowFn; + + public MergingWindowedAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc); + windowFn = (WindowFn) windowing.getWindowFn(); + } + + @Override + protected Map> reduce( + Map> buff, + Collection windows, + ValT value, + Instant timestamp) { + if (buff.isEmpty()) { + // no windows yet to be merged, use the non-merging behavior of super + return super.reduce(buff, windows, value, timestamp); + } + // Merge multiple windows into one target window using the reducer function if the window + // already exists. Otherwise, the input value is added to the accumulator. Merged windows are + // removed from the accumulator map. + Function> accFn = + target -> + (acc, w) -> { + MutablePair accW = buff.remove(w); + return (accW != null) + ? mergeAccs(w, acc, accW) + : addToAcc(w, acc, value, timestamp); + }; + Set unmerged = mergeWindows(buff, ImmutableSet.copyOf(windows), accFn); + if (!unmerged.isEmpty()) { + // remaining windows don't have to be merged + return super.reduce(buff, unmerged, value, timestamp); + } + return buff; + } + + @Override + public Map> merge( + Map> b1, + Map> b2) { + // Merge multiple windows into one target window using the reducer function. Merged windows + // are removed from both accumulator maps + Function> reduceFn = + target -> (acc, w) -> mergeAccs(w, mergeAccs(w, acc, b1.remove(w)), b2.remove(w)); + + Set unmerged = b2.keySet(); + unmerged = mergeWindows(b1, unmerged, reduceFn); + if (!unmerged.isEmpty()) { + // keep only unmerged windows in 2nd accumulator map, continue using "non-merging" merge + b2.keySet().retainAll(unmerged); + return super.merge(b1, b2); + } + return b1; + } + + /** Reduce function to merge multiple windowed accumulator values into one target window. */ + private interface ReduceFn + extends BiFunction, BoundedWindow, MutablePair> {} + + /** + * Attempt to merge windows of accumulator map with additional windows using the reducer + * function. The reducer function must support {@code null} as zero value. + * + * @return The subset of additional windows that don't require a merge. + */ + private Set mergeWindows( + Map> buff, + Set newWindows, + Function> reduceFn) { + try { + Set newUnmerged = new HashSet<>(newWindows); + windowFn.mergeWindows( + windowFn.new MergeContext() { + @Override + public Collection windows() { + return Sets.union(buff.keySet(), newWindows); + } + + @Override + public void merge(Collection merges, BoundedWindow target) { + buff.put( + target, merges.stream().reduce(null, reduceFn.apply(target), combiner(target))); + newUnmerged.removeAll(merges); + } + }); + return newUnmerged; + } catch (Exception e) { + throw new RuntimeException("Unable to merge accumulators windows", e); + } + } + } + + /** + * Non-merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as + * aggregation buffer. When reducing new input, a windowed accumulator is created for each new + * window of the input. Otherwise, if the window is known, the accumulators are merged. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + */ + private static class NonMergingWindowedAggregator + extends WindowedAggregator< + ValT, AccT, ResT, InT, BoundedWindow, Map>> { + + public NonMergingWindowedAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc) { + super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) Map.class); + } + + @Override + public Map> zero() { + return new HashMap<>(); + } + + @Override + public final Map> reduce( + Map> buff, WindowedValue input) { + Collection windows = (Collection) input.getWindows(); + return reduce(buff, windows, value(input), input.getTimestamp()); + } + + protected Map> reduce( + Map> buff, + Collection windows, + ValT value, + Instant timestamp) { + // for each window add the value to the accumulator + for (BoundedWindow window : windows) { + buff.compute(window, (w, acc) -> addToAcc(w, acc, value, timestamp)); + } + return buff; + } + + @Override + public Map> merge( + Map> b1, + Map> b2) { + if (b1.isEmpty()) { + return b2; + } else if (b2.isEmpty()) { + return b1; + } + if (b2.size() > b1.size()) { + return merge(b2, b1); + } + // merge entries of (smaller) 2nd agg buffer map into first by merging the accumulators + b2.forEach((w, acc) -> b1.merge(w, acc, combiner(w))); + return b1; + } + } + + /** + * Abstract base of a Spark {@link Aggregator} on {@link WindowedValue}s using a Map of {@link W} + * as aggregation buffer. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} / {@link Aggregator} result type + * @param {@link Aggregator} input type + * @param bounded window type + * @param aggregation buffer {@link W} + */ + private abstract static class WindowedAggregator< + ValT, + AccT, + ResT, + InT, + W extends @NonNull BoundedWindow, + MapT extends Map>> + extends CombineFnAggregator< + ValT, AccT, ResT, WindowedValue, MapT, Collection>> { + private final TimestampCombiner tsCombiner; + + public WindowedAggregator( + CombineFn combineFn, + Fun1, ValT> valueFn, + WindowingStrategy windowing, + Encoder windowEnc, + Encoder accEnc, + Encoder> outEnc, + Class clazz) { + super( + combineFn, + valueFn, + mapEncoder(windowEnc, mutablePairEncoder(encoderOf(Instant.class), accEnc), clazz), + collectionEncoder(outEnc)); + tsCombiner = windowing.getTimestampCombiner(); + } + + protected final Instant resolveTimestamp(BoundedWindow w, Instant t1, Instant t2) { + return tsCombiner.merge(w, t1, t2); + } + + /** Init accumulator with initial input value and timestamp. */ + protected final MutablePair initAcc(ValT value, Instant timestamp) { + return new MutablePair<>(timestamp, addToAcc(emptyAcc(), value)); + } + + /** Merge timestamped accumulators. */ + protected final > @PolyNull T mergeAccs( + W window, @PolyNull T a1, @PolyNull T a2) { + if (a1 == null || a2 == null) { + return a1 == null ? a2 : a1; + } + return (T) a1.update(resolveTimestamp(window, a1._1, a2._1), mergeAccs(a1._2, a2._2)); + } + + @SuppressWarnings("nullness") // may return null + protected BinaryOperator> combiner(W target) { + return (a1, a2) -> mergeAccs(target, a1, a2); + } + + /** Add an input value to a nullable accumulator. */ + protected final MutablePair addToAcc( + W window, @Nullable MutablePair acc, ValT val, Instant ts) { + if (acc == null) { + return initAcc(val, ts); + } + return acc.update(resolveTimestamp(window, acc._1, ts), addToAcc(acc._2, val)); + } + + @Override + @SuppressWarnings("nullness") // entries are non null + public final Collection> finish(MapT buffer) { + return Collections2.transform(buffer.entrySet(), this::windowedValue); + } + + private WindowedValue windowedValue(Entry> e) { + return WindowedValue.of(extract(e.getValue()._2), e.getValue()._1, e.getKey(), NO_FIRING); + } + } + + /** + * Abstract base of Spark {@link Aggregator}s using a Beam {@link CombineFn}. + * + * @param {@link CombineFn} input type + * @param {@link CombineFn} accumulator type + * @param {@link CombineFn} result type + * @param {@link Aggregator} input type + * @param {@link Aggregator} buffer type + * @param {@link Aggregator} output type + */ + private abstract static class CombineFnAggregator + extends Aggregator { + private final CombineFn fn; + private final Fun1 valueFn; + private final Encoder bufferEnc; + private final Encoder outputEnc; + + public CombineFnAggregator( + CombineFn fn, + Fun1 valueFn, + Encoder bufferEnc, + Encoder outputEnc) { + this.fn = fn; + this.valueFn = valueFn; + this.bufferEnc = bufferEnc; + this.outputEnc = outputEnc; + } + + protected final ValT value(InT in) { + return valueFn.apply(in); + } + + protected final AccT emptyAcc() { + return fn.createAccumulator(); + } + + protected final AccT mergeAccs(AccT a1, AccT a2) { + return fn.mergeAccumulators(ImmutableList.of(a1, a2)); + } + + protected final AccT addToAcc(AccT acc, ValT val) { + return fn.addInput(acc, val); + } + + protected final ResT extract(AccT acc) { + return fn.extractOutput(acc); + } + + @Override + public Encoder bufferEncoder() { + return bufferEnc; + } + + @Override + public Encoder outputEncoder() { + return outputEnc; + } + } +} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java new file mode 100644 index 000000000000..5bc017134e91 --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static scala.collection.Iterator.single; + +import java.util.Collection; +import java.util.Map; +import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import scala.collection.Iterator; + +/** + * Translator for {@link Combine.Globally} using a Spark {@link Aggregator}. + * + *

    To minimize the amount of data shuffled, this first reduces the data per partition using + * {@link Aggregator#reduce}, gathers the partial results (using {@code coalesce(1)}) and finally + * merges these using {@link Aggregator#merge}. + * + *

    TODOs: + *

  • any missing features? + */ +class CombineGloballyTranslatorBatch + extends TransformTranslator, PCollection, Combine.Globally> { + + @Override + public void translate(Combine.Globally transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + CombineFn combineFn = (CombineFn) transform.getFn(); + + Coder inputCoder = cxt.getInput().getCoder(); + Coder outputCoder = cxt.getOutput().getCoder(); + Coder accumCoder = accumulatorCoder(combineFn, inputCoder, cxt); + + Encoder outEnc = cxt.encoderOf(outputCoder); + Encoder accEnc = cxt.encoderOf(accumCoder); + Encoder> wvOutEnc = cxt.windowedEncoder(outEnc); + + Dataset> dataset = cxt.getDataset(cxt.getInput()); + + final Dataset> result; + if (GroupByKeyHelpers.eligibleForGlobalGroupBy(windowing, true)) { + Aggregator agg = Aggregators.value(combineFn, v -> v, accEnc, outEnc); + + // Drop window and restore afterwards, produces single global aggregation result + result = aggregate(dataset, agg, value(), windowedValue(), wvOutEnc); + } else { + Aggregator, ?, Collection>> agg = + Aggregators.windowedValue( + combineFn, value(), windowing, cxt.windowEncoder(), accEnc, wvOutEnc); + + // Produces aggregation result per window + result = + aggregate(dataset, agg, v -> v, fun1(out -> ScalaInterop.scalaIterator(out)), wvOutEnc); + } + cxt.putDataset(cxt.getOutput(), result); + } + + /** + * Aggregate dataset globally without using key. + * + *

    There is no global, typed version of {@link Dataset#agg(Map)} on datasets. This reduces all + * partitions first, and then merges them to receive the final result. + */ + private static Dataset> aggregate( + Dataset> ds, + Aggregator agg, + Fun1, AggInT> valueFn, + Fun1>> finishFn, + Encoder> enc) { + // reduce partition using aggregator + Fun1>, Iterator> reduce = + fun1(it -> single(it.map(valueFn).foldLeft(agg.zero(), agg::reduce))); + // merge reduced partitions using aggregator + Fun1, Iterator>> merge = + fun1(it -> finishFn.apply(agg.finish(it.hasNext() ? it.reduce(agg::merge) : agg.zero()))); + + return ds.mapPartitions(reduce, agg.bufferEncoder()).coalesce(1).mapPartitions(merge, enc); + } + + private Coder accumulatorCoder( + CombineFn fn, Coder valueCoder, Context cxt) { + try { + return fn.getAccumulatorCoder(cxt.getInput().getPipeline().getCoderRegistry(), valueCoder); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + } + + private static Fun1>> windowedValue() { + return v -> single(WindowedValue.valueInGlobalWindow(v)); + } +} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java index 2b0cf8be9955..f990cd114f98 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java @@ -17,98 +17,141 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; + +import java.util.Collection; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; import scala.Tuple2; +import scala.collection.TraversableOnce; -@SuppressWarnings({ - "rawtypes" // TODO(https://github.com/apache/beam/issues/20447) -}) -class CombinePerKeyTranslatorBatch - implements TransformTranslator< - PTransform>, PCollection>>> { +/** + * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link + * Aggregator}. + * + *

      + *
    • When using the default global window, window information is dropped and restored after the + * aggregation. + *
    • For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. After the aggregation, windowed values are restored from the composite key. + *
    • All other cases use an aggregator on windowed values that is optimized for the current + * windowing strategy. + *
    + * + * TODOs: + *
  • combine with context (CombineFnWithContext)? + *
  • combine with sideInputs? + *
  • other there other missing features? + */ +class CombinePerKeyTranslatorBatch + extends TransformTranslator< + PCollection>, PCollection>, Combine.PerKey> { @Override - public void translateTransform( - PTransform>, PCollection>> transform, - AbstractTranslationContext context) { + public void translate(Combine.PerKey transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + CombineFn combineFn = (CombineFn) transform.getFn(); + + KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder(); + KvCoder outputCoder = (KvCoder) cxt.getOutput().getCoder(); + + Encoder keyEnc = cxt.keyEncoderOf(inputCoder); + Encoder> inputEnc = cxt.encoderOf(inputCoder); + Encoder>> wvOutputEnc = cxt.windowedEncoder(outputCoder); + Encoder accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt); + + final Dataset>> result; + + boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true); + boolean groupByWindow = eligibleForGroupByWindow(windowing, true); - Combine.PerKey combineTransform = (Combine.PerKey) transform; - @SuppressWarnings("unchecked") - final PCollection> input = (PCollection>) context.getInput(); - @SuppressWarnings("unchecked") - final PCollection> output = (PCollection>) context.getOutput(); - @SuppressWarnings("unchecked") - final Combine.CombineFn combineFn = - (Combine.CombineFn) combineTransform.getFn(); - WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + if (globalGroupBy || groupByWindow) { + Aggregator, ?, OutT> valueAgg = + Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder)); - Dataset>> inputDataset = context.getDataset(input); + if (globalGroupBy) { + // Drop window and group by key globally to run the aggregation (combineFn), afterwards the + // global window is restored + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(value(), inputEnc) + .agg(valueAgg.toColumn()) + .map(globalKV(), wvOutputEnc); + } else { + Encoder> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); - KvCoder inputCoder = (KvCoder) input.getCoder(); - Coder keyCoder = inputCoder.getKeyCoder(); - KvCoder outputKVCoder = (KvCoder) output.getCoder(); - Coder outputCoder = outputKVCoder.getValueCoder(); + // Group by window and key to run the aggregation (combineFn) + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc)) + .groupByKey(fun1(Tuple2::_1), windowedKeyEnc) + .mapValues(fun1(Tuple2::_2), inputEnc) + .agg(valueAgg.toColumn()) + .map(windowedKV(), wvOutputEnc); + } + } else { + // Optimized aggregator for non-merging and session window functions, all others depend on + // windowFn.mergeWindows + Aggregator>, ?, Collection>> aggregator = + Aggregators.windowedValue( + combineFn, + valueValue(), + windowing, + cxt.windowEncoder(), + accumEnc, + cxt.windowedEncoder(outputCoder.getValueCoder())); + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .agg(aggregator.toColumn()) + .flatMap(explodeWindows(), wvOutputEnc); + } + + cxt.putDataset(cxt.getOutput(), result); + } + + private static + Fun1>>, TraversableOnce>>> + explodeWindows() { + return t -> + ScalaInterop.scalaIterator(t._2).map(wv -> wv.withValue(KV.of(t._1, wv.getValue()))); + } - KeyValueGroupedDataset>> groupedDataset = - inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder)); + private static Fun1, WindowedValue>> globalKV() { + return t -> WindowedValue.valueInGlobalWindow(KV.of(t._1, t._2)); + } - Coder accumulatorCoder = null; + private Encoder accumEncoder( + CombineFn fn, Coder valueCoder, Context cxt) { try { - accumulatorCoder = - combineFn.getAccumulatorCoder( - input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder()); + CoderRegistry registry = cxt.getInput().getPipeline().getCoderRegistry(); + return cxt.encoderOf(fn.getAccumulatorCoder(registry, valueCoder)); } catch (CannotProvideCoderException e) { throw new RuntimeException(e); } - - Dataset>>> combinedDataset = - groupedDataset.agg( - new AggregatorCombiner( - combineFn, windowingStrategy, accumulatorCoder, outputCoder) - .toColumn()); - - // expand the list into separate elements and put the key back into the elements - WindowedValue.WindowedValueCoder> wvCoder = - WindowedValue.FullWindowedValueCoder.of( - outputKVCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - Dataset>> outputDataset = - combinedDataset.flatMap( - (FlatMapFunction< - Tuple2>>, WindowedValue>>) - tuple2 -> { - K key = tuple2._1(); - Iterable> windowedValues = tuple2._2(); - List>> result = new ArrayList<>(); - for (WindowedValue windowedValue : windowedValues) { - KV kv = KV.of(key, windowedValue.getValue()); - result.add( - WindowedValue.of( - kv, - windowedValue.getTimestamp(), - windowedValue.getWindows(), - windowedValue.getPane())); - } - return result.iterator(); - }, - EncoderHelpers.fromBeamCoder(wvCoder)); - context.putDataset(output, outputDataset); } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java index ae1eeced3281..271512920239 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java @@ -19,39 +19,25 @@ import java.io.IOException; import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.spark.sql.Dataset; class CreatePCollectionViewTranslatorBatch - implements TransformTranslator, PCollection>> { + extends TransformTranslator< + PCollection, PCollection, View.CreatePCollectionView> { @Override - public void translateTransform( - PTransform, PCollection> transform, - AbstractTranslationContext context) { + public void translate(View.CreatePCollectionView transform, Context context) { Dataset> inputDataSet = context.getDataset(context.getInput()); - @SuppressWarnings("unchecked") - AppliedPTransform< - PCollection, - PCollection, - PTransform, PCollection>> - application = - (AppliedPTransform< - PCollection, - PCollection, - PTransform, PCollection>>) - context.getCurrentTransform(); PCollectionView input; try { - input = CreatePCollectionViewTranslation.getView(application); + input = CreatePCollectionViewTranslation.getView(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java deleted file mode 100644 index 46bde96c30cb..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.batch; - -import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS; -import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; - -import java.io.IOException; -import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.core.serialization.Base64Serializer; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SchemaHelpers; -import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.connector.catalog.SupportsRead; -import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.catalog.TableProvider; -import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.connector.read.Batch; -import org.apache.spark.sql.connector.read.InputPartition; -import org.apache.spark.sql.connector.read.PartitionReader; -import org.apache.spark.sql.connector.read.PartitionReaderFactory; -import org.apache.spark.sql.connector.read.Scan; -import org.apache.spark.sql.connector.read.ScanBuilder; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * Spark DataSourceV2 API was removed in Spark3. This is a Beam source wrapper using the new spark 3 - * source API. - */ -public class DatasetSourceBatch implements TableProvider { - - private static final StructType BINARY_SCHEMA = SchemaHelpers.binarySchema(); - - public DatasetSourceBatch() {} - - @Override - public StructType inferSchema(CaseInsensitiveStringMap options) { - return BINARY_SCHEMA; - } - - @Override - public boolean supportsExternalMetadata() { - return true; - } - - @Override - public Table getTable( - StructType schema, Transform[] partitioning, Map properties) { - return new DatasetSourceBatchTable(); - } - - private static class DatasetSourceBatchTable implements SupportsRead { - - @Override - public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { - return new ScanBuilder() { - - @Override - public Scan build() { - return new Scan() { // scan for Batch reading - - @Override - public StructType readSchema() { - return BINARY_SCHEMA; - } - - @Override - public Batch toBatch() { - return new BeamBatch<>(options); - } - }; - } - }; - } - - @Override - public String name() { - return "BeamSource"; - } - - @Override - public StructType schema() { - return BINARY_SCHEMA; - } - - @Override - public Set capabilities() { - final ImmutableSet capabilities = - ImmutableSet.of(TableCapability.BATCH_READ); - return capabilities; - } - - private static class BeamBatch implements Batch, Serializable { - - private final int numPartitions; - private final BoundedSource source; - private final SerializablePipelineOptions serializablePipelineOptions; - - private BeamBatch(CaseInsensitiveStringMap options) { - if (Strings.isNullOrEmpty(options.get(BEAM_SOURCE_OPTION))) { - throw new RuntimeException("Beam source was not set in DataSource options"); - } - this.source = - Base64Serializer.deserializeUnchecked( - options.get(BEAM_SOURCE_OPTION), BoundedSource.class); - - if (Strings.isNullOrEmpty(DEFAULT_PARALLELISM)) { - throw new RuntimeException("Spark default parallelism was not set in DataSource options"); - } - this.numPartitions = Integer.parseInt(options.get(DEFAULT_PARALLELISM)); - checkArgument(numPartitions > 0, "Number of partitions must be greater than zero."); - - if (Strings.isNullOrEmpty(options.get(PIPELINE_OPTIONS))) { - throw new RuntimeException("Beam pipelineOptions were not set in DataSource options"); - } - this.serializablePipelineOptions = - new SerializablePipelineOptions(options.get(PIPELINE_OPTIONS)); - } - - @Override - public InputPartition[] planInputPartitions() { - PipelineOptions options = serializablePipelineOptions.get(); - long desiredSizeBytes; - - try { - desiredSizeBytes = source.getEstimatedSizeBytes(options) / numPartitions; - List> splits = source.split(desiredSizeBytes, options); - InputPartition[] result = new InputPartition[splits.size()]; - int i = 0; - for (BoundedSource split : splits) { - result[i++] = new BeamInputPartition<>(split); - } - return result; - } catch (Exception e) { - throw new RuntimeException( - "Error in splitting BoundedSource " + source.getClass().getCanonicalName(), e); - } - } - - @Override - public PartitionReaderFactory createReaderFactory() { - return new PartitionReaderFactory() { - - @Override - public PartitionReader createReader(InputPartition partition) { - return new BeamPartitionReader( - ((BeamInputPartition) partition).getSource(), serializablePipelineOptions); - } - }; - } - - private static class BeamInputPartition implements InputPartition { - - private final BoundedSource source; - - private BeamInputPartition(BoundedSource source) { - this.source = source; - } - - public BoundedSource getSource() { - return source; - } - } - - private static class BeamPartitionReader implements PartitionReader { - - private final BoundedSource source; - private final BoundedSource.BoundedReader reader; - private boolean started; - private boolean closed; - - BeamPartitionReader( - BoundedSource source, SerializablePipelineOptions serializablePipelineOptions) { - this.started = false; - this.closed = false; - this.source = source; - // reader is not serializable so lazy initialize it - try { - reader = - source.createReader(serializablePipelineOptions.get().as(PipelineOptions.class)); - } catch (IOException e) { - throw new RuntimeException("Error creating BoundedReader ", e); - } - } - - @Override - public boolean next() throws IOException { - if (!started) { - started = true; - return reader.start(); - } else { - return !closed && reader.advance(); - } - } - - @Override - public InternalRow get() { - WindowedValue windowedValue = - WindowedValue.timestampedValueInGlobalWindow( - reader.getCurrent(), reader.getCurrentTimestamp()); - return RowHelpers.storeWindowedValueInRow(windowedValue, source.getOutputCoder()); - } - - @Override - public void close() throws IOException { - closed = true; - reader.close(); - } - } - } - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java deleted file mode 100644 index 42a809fdd970..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.batch; - -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners; -import org.apache.beam.runners.core.construction.SerializablePipelineOptions; -import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator; -import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast; -import org.apache.beam.runners.spark.structuredstreaming.translation.utils.CachedSideInputReader; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFnSchemaInformation; -import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.LinkedListMultimap; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; -import org.apache.spark.api.java.function.MapPartitionsFunction; -import scala.Tuple2; - -/** - * Encapsulates a {@link DoFn} inside a Spark {@link - * org.apache.spark.api.java.function.MapPartitionsFunction}. - * - *

    We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index and must tag - * all outputs with the output number. Afterwards a filter will filter out those elements that are - * not to be in a specific output. - */ -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class DoFnFunction - implements MapPartitionsFunction, Tuple2, WindowedValue>> { - - private final MetricsContainerStepMapAccumulator metricsAccum; - private final String stepName; - private final DoFn doFn; - private transient boolean wasSetupCalled; - private final WindowingStrategy windowingStrategy; - private final Map, WindowingStrategy> sideInputs; - private final SerializablePipelineOptions serializableOptions; - private final List> additionalOutputTags; - private final TupleTag mainOutputTag; - private final Coder inputCoder; - private final Map, Coder> outputCoderMap; - private final SideInputBroadcast broadcastStateData; - private DoFnSchemaInformation doFnSchemaInformation; - private Map> sideInputMapping; - - public DoFnFunction( - MetricsContainerStepMapAccumulator metricsAccum, - String stepName, - DoFn doFn, - WindowingStrategy windowingStrategy, - Map, WindowingStrategy> sideInputs, - SerializablePipelineOptions serializableOptions, - List> additionalOutputTags, - TupleTag mainOutputTag, - Coder inputCoder, - Map, Coder> outputCoderMap, - SideInputBroadcast broadcastStateData, - DoFnSchemaInformation doFnSchemaInformation, - Map> sideInputMapping) { - this.metricsAccum = metricsAccum; - this.stepName = stepName; - this.doFn = doFn; - this.windowingStrategy = windowingStrategy; - this.sideInputs = sideInputs; - this.serializableOptions = serializableOptions; - this.additionalOutputTags = additionalOutputTags; - this.mainOutputTag = mainOutputTag; - this.inputCoder = inputCoder; - this.outputCoderMap = outputCoderMap; - this.broadcastStateData = broadcastStateData; - this.doFnSchemaInformation = doFnSchemaInformation; - this.sideInputMapping = sideInputMapping; - } - - @Override - public Iterator, WindowedValue>> call(Iterator> iter) - throws Exception { - if (!wasSetupCalled && iter.hasNext()) { - DoFnInvokers.tryInvokeSetupFor(doFn, serializableOptions.get()); - wasSetupCalled = true; - } - - DoFnOutputManager outputManager = new DoFnOutputManager(); - - DoFnRunner doFnRunner = - DoFnRunners.simpleRunner( - serializableOptions.get(), - doFn, - CachedSideInputReader.of(new SparkSideInputReader(sideInputs, broadcastStateData)), - outputManager, - mainOutputTag, - additionalOutputTags, - new NoOpStepContext(), - inputCoder, - outputCoderMap, - windowingStrategy, - doFnSchemaInformation, - sideInputMapping); - - DoFnRunnerWithMetrics doFnRunnerWithMetrics = - new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); - - return new ProcessContext<>( - doFn, doFnRunnerWithMetrics, outputManager, Collections.emptyIterator()) - .processPartition(iter) - .iterator(); - } - - private class DoFnOutputManager - implements ProcessContext.ProcessOutputManager, WindowedValue>> { - - private final Multimap, WindowedValue> outputs = LinkedListMultimap.create(); - - @Override - public void clear() { - outputs.clear(); - } - - @Override - public Iterator, WindowedValue>> iterator() { - Iterator, WindowedValue>> entryIter = outputs.entries().iterator(); - return Iterators.transform(entryIter, this.entryToTupleFn()); - } - - private Function, Tuple2> entryToTupleFn() { - return en -> new Tuple2<>(en.getKey(), en.getValue()); - } - - @Override - public synchronized void output(TupleTag tag, WindowedValue output) { - outputs.put(tag, output); - } - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java new file mode 100644 index 000000000000..c02e07319af7 --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toCollection; +import static java.util.stream.Collectors.toMap; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.newArrayListWithCapacity; + +import java.io.Serializable; +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import org.apache.beam.runners.core.DoFnRunner; +import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.DoFnRunners.OutputManager; +import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext; +import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader; +import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.CachedSideInputReader; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun2; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; +import org.apache.spark.api.java.function.MapPartitionsFunction; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.collection.Iterator; + +/** + * Encapsulates a {@link DoFn} inside a Spark {@link + * org.apache.spark.api.java.function.MapPartitionsFunction}. + */ +class DoFnMapPartitionsFactory implements Serializable { + private final String stepName; + + private final DoFn doFn; + private final DoFnSchemaInformation doFnSchema; + private final SerializablePipelineOptions options; + + private final Coder coder; + private final WindowingStrategy windowingStrategy; + private final TupleTag mainOutput; + private final List> additionalOutputs; + private final Map, Coder> outputCoders; + + private final Map> sideInputs; + private final Map, WindowingStrategy> sideInputWindows; + private final SideInputBroadcast broadcastStateData; + + DoFnMapPartitionsFactory( + String stepName, + DoFn doFn, + DoFnSchemaInformation doFnSchema, + SerializablePipelineOptions options, + PCollection input, + TupleTag mainOutput, + Map, PCollection> outputs, + Map> sideInputs, + SideInputBroadcast broadcastStateData) { + this.stepName = stepName; + this.doFn = doFn; + this.doFnSchema = doFnSchema; + this.options = options; + this.coder = input.getCoder(); + this.windowingStrategy = input.getWindowingStrategy(); + this.mainOutput = mainOutput; + this.additionalOutputs = additionalOutputs(outputs, mainOutput); + this.outputCoders = outputCoders(outputs); + this.sideInputs = sideInputs; + this.sideInputWindows = sideInputWindows(sideInputs.values()); + this.broadcastStateData = broadcastStateData; + } + + /** Create the {@link MapPartitionsFunction} using the provided output function. */ + Fun1>, Iterator> create( + Fun2, WindowedValue, OutputT> outputFn) { + return it -> + it.hasNext() + ? scalaIterator(new DoFnPartitionIt<>(outputFn, it)) + : (Iterator) Iterator.empty(); + } + + // FIXME Add support for TimerInternals.TimerData + /** + * Partition iterator that lazily processes each element from the (input) iterator on demand + * producing zero, one or more output elements as output (via an internal buffer). + * + *

    When initializing the iterator for a partition {@code setup} followed by {@code startBundle} + * is called. + */ + private class DoFnPartitionIt extends AbstractIterator { + private final Deque buffer; + private final DoFnRunner doFnRunner; + private final Iterator> partitionIt; + + private boolean isBundleFinished; + + DoFnPartitionIt( + Fun2, WindowedValue, OutputT> outputFn, + Iterator> partitionIt) { + this.buffer = new ArrayDeque<>(); + this.doFnRunner = metricsRunner(simpleRunner(outputFn, buffer)); + this.partitionIt = partitionIt; + // Before starting to iterate over the partition, invoke setup and then startBundle + DoFnInvokers.tryInvokeSetupFor(doFn, options.get()); + try { + doFnRunner.startBundle(); + } catch (RuntimeException re) { + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + throw re; + } + } + + @Override + protected OutputT computeNext() { + try { + while (true) { + if (!buffer.isEmpty()) { + return buffer.remove(); + } + if (partitionIt.hasNext()) { + // grab the next element and process it. + doFnRunner.processElement((WindowedValue) partitionIt.next()); + } else { + if (!isBundleFinished) { + isBundleFinished = true; + doFnRunner.finishBundle(); + continue; // finishBundle can produce more output + } + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + return endOfData(); + } + } + } catch (RuntimeException re) { + DoFnInvokers.invokerFor(doFn).invokeTeardown(); + throw re; + } + } + } + + private DoFnRunner simpleRunner( + Fun2, WindowedValue, OutputT> outputFn, Deque buffer) { + OutputManager outputManager = + new OutputManager() { + @Override + public void output(TupleTag tag, WindowedValue output) { + buffer.add(outputFn.apply(tag, output)); + } + }; + SideInputReader sideInputReader = + CachedSideInputReader.of(new SparkSideInputReader(sideInputWindows, broadcastStateData)); + return DoFnRunners.simpleRunner( + options.get(), + doFn, + sideInputReader, + outputManager, + mainOutput, + additionalOutputs, + new NoOpStepContext(), + coder, + outputCoders, + windowingStrategy, + doFnSchema, + sideInputs); + } + + private DoFnRunner metricsRunner(DoFnRunner runner) { + return new DoFnRunnerWithMetrics<>(stepName, runner, MetricsAccumulator.getInstance()); + } + + private static Map, WindowingStrategy> sideInputWindows( + Collection> views) { + return views.stream().collect(toMap(identity(), DoFnMapPartitionsFactory::windowingStrategy)); + } + + private static WindowingStrategy windowingStrategy(PCollectionView view) { + PCollection pc = view.getPCollection(); + if (pc == null) { + throw new IllegalStateException("PCollection not available for " + view); + } + return pc.getWindowingStrategy(); + } + + private static List> additionalOutputs( + Map, PCollection> outputs, TupleTag mainOutput) { + return outputs.keySet().stream() + .filter(t -> !t.equals(mainOutput)) + .collect(toCollection(() -> newArrayListWithCapacity(outputs.size() - 1))); + } + + private static Map, Coder> outputCoders(Map, PCollection> outputs) { + return outputs.entrySet().stream() + .collect(toMap(Map.Entry::getKey, e -> e.getValue().getCoder())); + } +} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java index db361f7753e1..eb0713049984 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java @@ -17,49 +17,47 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; import java.util.Collection; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; +import java.util.Iterator; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; -import org.apache.beam.sdk.values.PValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class FlattenTranslatorBatch - implements TransformTranslator, PCollection>> { + extends TransformTranslator, PCollection, Flatten.PCollections> { @Override - public void translateTransform( - PTransform, PCollection> transform, - AbstractTranslationContext context) { - Collection> pcollectionList = context.getInputs().values(); - Dataset> result = null; - if (pcollectionList.isEmpty()) { - result = context.emptyDataset(); - } else { - for (PValue pValue : pcollectionList) { - checkArgument( - pValue instanceof PCollection, - "Got non-PCollection input to flatten: %s of type %s", - pValue, - pValue.getClass().getSimpleName()); - @SuppressWarnings("unchecked") - PCollection pCollection = (PCollection) pValue; - Dataset> current = context.getDataset(pCollection); - if (result == null) { - result = current; - } else { - result = result.union(current); - } + public void translate(Flatten.PCollections transform, Context cxt) { + Collection> pCollections = cxt.getInputs().values(); + Coder outputCoder = cxt.getOutput().getCoder(); + Encoder> outputEnc = + cxt.windowedEncoder(outputCoder, windowCoder(cxt.getOutput())); + + Dataset> result; + Iterator> pcIt = (Iterator) pCollections.iterator(); + if (pcIt.hasNext()) { + result = getDataset(pcIt.next(), outputCoder, outputEnc, cxt); + while (pcIt.hasNext()) { + result = result.union(getDataset(pcIt.next(), outputCoder, outputEnc, cxt)); } + } else { + result = cxt.createDataset(ImmutableList.of(), outputEnc); } - context.putDataset(context.getOutput(), result); + cxt.putDataset(cxt.getOutput(), result); + } + + private Dataset> getDataset( + PCollection pc, Coder coder, Encoder> enc, Context cxt) { + Dataset> current = cxt.getDataset(pc); + // if coders don't match, map using identity function to replace encoder + return pc.getCoder().equals(coder) ? current : current.map(fun1(v -> v), enc); } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java new file mode 100644 index 000000000000..28ab07114c6a --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW; + +import java.util.Collection; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop; +import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowingStrategy; +import scala.Tuple2; +import scala.collection.TraversableOnce; + +/** + * Package private helpers to support translating grouping transforms using `groupByKey` such as + * {@link GroupByKeyTranslatorBatch} or {@link CombinePerKeyTranslatorBatch}. + */ +class GroupByKeyHelpers { + + private GroupByKeyHelpers() {} + + /** + * Checks if it's possible to use an optimized `groupByKey` that also moves the window into the + * key. + * + * @param windowing The windowing strategy + * @param endOfWindowOnly Flag if to limit this optimization to {@link + * TimestampCombiner#END_OF_WINDOW}. + */ + static boolean eligibleForGroupByWindow( + WindowingStrategy windowing, boolean endOfWindowOnly) { + return !windowing.needsMerge() + && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW) + && windowing.getWindowFn().windowCoder().consistentWithEquals(); + } + + /** + * Checks if it's possible to use an optimized `groupByKey` for the global window. + * + * @param windowing The windowing strategy + * @param endOfWindowOnly Flag if to limit this optimization to {@link + * TimestampCombiner#END_OF_WINDOW}. + */ + static boolean eligibleForGlobalGroupBy( + WindowingStrategy windowing, boolean endOfWindowOnly) { + return windowing.getWindowFn() instanceof GlobalWindows + && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW); + } + + /** + * Explodes a windowed {@link KV} assigned to potentially multiple {@link BoundedWindow}s to a + * traversable of composite keys {@code (BoundedWindow, Key)} and value. + */ + static + Fun1>, TraversableOnce, T>>> + explodeWindowedKey(Fun1>, T> valueFn) { + return v -> { + T value = valueFn.apply(v); + K key = v.getValue().getKey(); + Collection windows = (Collection) v.getWindows(); + return ScalaInterop.scalaIterator(windows).map(w -> tuple(tuple(w, key), value)); + }; + } + + static Fun1, V>, WindowedValue>> windowedKV() { + return t -> windowedKV(t._1, t._2); + } + + static WindowedValue> windowedKV(Tuple2 key, V value) { + return WindowedValue.of(KV.of(key._2, value), key._1.maxTimestamp(), key._1, NO_FIRING); + } + + static Fun1, V> value() { + return v -> v.getValue(); + } + + static Fun1>, V> valueValue() { + return v -> v.getValue().getValue(); + } + + static Fun1>, K> valueKey() { + return v -> v.getValue().getKey(); + } +} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java index 6391ba4600cf..61306cb993c8 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java @@ -17,74 +17,274 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue; +import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.collect_list; +import static org.apache.spark.sql.functions.explode; +import static org.apache.spark.sql.functions.max; +import static org.apache.spark.sql.functions.min; +import static org.apache.spark.sql.functions.struct; + import java.io.Serializable; import org.apache.beam.runners.core.InMemoryStateInternals; -import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.ReduceFnRunner; import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.catalyst.expressions.CreateArray; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.Seq; +import scala.collection.immutable.List; +/** + * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation + * function {@code collect_list} when applicable. + * + *

    Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the + * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's + * a risk of OOM errors. When disabling {@link #useCollectList}, a more memory sensitive iterable is + * used that can be traversed just once. Attempting to traverse the iterable again will throw. + * + *

      + *
    • When using the default global window, window information is dropped and restored after the + * aggregation. + *
    • For non-merging windows, windows are exploded and moved into a composite key for better + * distribution. Though, to keep the amount of shuffled data low, this is only done if values + * are assigned to a single window or if there are only few keys and distributing data is + * important. After the aggregation, windowed values are restored from the composite key. + *
    • All other cases are implemented using the SDK {@link ReduceFnRunner}. + *
    + */ class GroupByKeyTranslatorBatch - implements TransformTranslator< - PTransform>, PCollection>>>> { + extends TransformTranslator< + PCollection>, PCollection>>, GroupByKey> { + + /** Literal of binary encoded Pane info. */ + private static final Expression PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of())); + + /** Defaults for value in single global window. */ + private static final List GLOBAL_WINDOW_DETAILS = + windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY})); + + private boolean useCollectList = true; + + public GroupByKeyTranslatorBatch() {} + + public GroupByKeyTranslatorBatch(boolean useCollectList) { + this.useCollectList = useCollectList; + } @Override - public void translateTransform( - PTransform>, PCollection>>> transform, - AbstractTranslationContext context) { - - @SuppressWarnings("unchecked") - final PCollection> inputPCollection = (PCollection>) context.getInput(); - Dataset>> input = context.getDataset(inputPCollection); - WindowingStrategy windowingStrategy = inputPCollection.getWindowingStrategy(); - KvCoder kvCoder = (KvCoder) inputPCollection.getCoder(); - Coder valueCoder = kvCoder.getValueCoder(); - - // group by key only - Coder keyCoder = kvCoder.getKeyCoder(); - KeyValueGroupedDataset>> groupByKeyOnly = - input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder)); - - // group also by windows - WindowedValue.FullWindowedValueCoder>> outputCoder = - WindowedValue.FullWindowedValueCoder.of( - KvCoder.of(keyCoder, IterableCoder.of(valueCoder)), - windowingStrategy.getWindowFn().windowCoder()); - Dataset>>> output = - groupByKeyOnly.flatMapGroups( - new GroupAlsoByWindowViaOutputBufferFn<>( - windowingStrategy, - new InMemoryStateInternalsFactory<>(), - SystemReduceFn.buffering(valueCoder), - context.getSerializableOptions()), - EncoderHelpers.fromBeamCoder(outputCoder)); - - context.putDataset(context.getOutput(), output); + public void translate(GroupByKey transform, Context cxt) { + WindowingStrategy windowing = cxt.getInput().getWindowingStrategy(); + TimestampCombiner tsCombiner = windowing.getTimestampCombiner(); + + Dataset>> input = cxt.getDataset(cxt.getInput()); + + KvCoder inputCoder = (KvCoder) cxt.getInput().getCoder(); + KvCoder> outputCoder = (KvCoder>) cxt.getOutput().getCoder(); + + Encoder valueEnc = cxt.valueEncoderOf(inputCoder); + Encoder keyEnc = cxt.keyEncoderOf(inputCoder); + + // In batch we can ignore triggering and allowed lateness parameters + final Dataset>>> result; + + if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) { + // Collects all values per key in memory. This might be problematic if there's few keys only + // or some highly skewed distribution. + result = + input + .groupBy(col("value.key").as("key")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inGlobalWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGlobalGroupBy(windowing, true)) { + // Produces an iterable that can be traversed exactly once. However, on the plus side, data is + // not collected in memory until serialized or done by the user. + result = + cxt.getDataset(cxt.getInput()) + .groupByKey(valueKey(), keyEnc) + .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder)) + .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder)) + .map(fun1(WindowedValue::valueInGlobalWindow), cxt.windowedEncoder(outputCoder)); + + } else if (useCollectList + && eligibleForGroupByWindow(windowing, false) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the data. However, if + // values are assigned to multiple windows, more data would be shuffled around. If there's few + // keys only, this is still valuable. + // Collects all values per key & window in memory. + result = + input + .select(explode(col("windows")).as("window"), col("value"), col("timestamp")) + .groupBy(col("value.key"), col("window")) + .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner)) + .select( + inSingleWindow( + keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))), + col("window").as(cxt.windowEncoder()), + windowTimestamp(tsCombiner))); + + } else if (eligibleForGroupByWindow(windowing, true) + && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) { + // Using the window as part of the key should help to better distribute the data. However, if + // values are assigned to multiple windows, more data would be shuffled around. If there's few + // keys only, this is still valuable. + // Produces an iterable that can be traversed exactly once. However, on the plus side, data is + // not collected in memory until serialized or done by the user. + Encoder> windowedKeyEnc = + cxt.tupleEncoder(cxt.windowEncoder(), keyEnc); + result = + cxt.getDataset(cxt.getInput()) + .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc)) + .groupByKey(fun1(Tuple2::_1), windowedKeyEnc) + .mapValues(fun1(Tuple2::_2), valueEnc) + .mapGroups( + fun2((wKey, it) -> windowedKV(wKey, iterableOnce(it))), + cxt.windowedEncoder(outputCoder)); + + } else { + // Collects all values per key in memory. This might be problematic if there's few keys only + // or some highly skewed distribution. + + // FIXME Revisit this case, implementation is far from ideal: + // - iterator traversed at least twice, forcing materialization in memory + + // group by key, then by windows + result = + input + .groupByKey(valueKey(), keyEnc) + .flatMapGroups( + new GroupAlsoByWindowViaOutputBufferFn<>( + windowing, + (SerStateInternalsFactory) key -> InMemoryStateInternals.forKey(key), + SystemReduceFn.buffering(inputCoder.getValueCoder()), + cxt.getSerializableOptions()), + cxt.windowedEncoder(outputCoder)); + } + + cxt.putDataset(cxt.getOutput(), result); + } + + /** Serializable In-memory state internals factory. */ + private interface SerStateInternalsFactory extends StateInternalsFactory, Serializable {} + + private Encoder> iterableEnc(Encoder enc) { + // safe to use list encoder with collect list + return (Encoder) collectionEncoder(enc); + } + + private static Column[] timestampAggregator(TimestampCombiner tsCombiner) { + if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) { + return new Column[0]; // no aggregation needed + } + Column agg = + tsCombiner.equals(TimestampCombiner.EARLIEST) + ? min(col("timestamp")) + : max(col("timestamp")); + return new Column[] {agg.as("timestamp")}; + } + + private static Expression windowTimestamp(TimestampCombiner tsCombiner) { + if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) { + // null will be set to END_OF_WINDOW by the respective deserializer + return litNull(DataTypes.LongType); + } + return col("timestamp").expr(); } /** - * In-memory state internals factory. - * - * @param State key type. + * Java {@link Iterable} from Scala {@link Iterator} that can be iterated just once so that we + * don't have to load all data into memory. */ - static class InMemoryStateInternalsFactory implements StateInternalsFactory, Serializable { - @Override - public StateInternals stateInternalsForKey(K key) { - return InMemoryStateInternals.forKey(key); - } + private static Iterable iterableOnce(Iterator it) { + return () -> { + checkState(!it.isEmpty(), "Iterator on values can only be consumed once!"); + return javaIterator(it); + }; + } + + private TypedColumn> keyValue(TypedColumn key, TypedColumn value) { + return struct(key.as("key"), value.as("value")).as(kvEncoder(key.encoder(), value.encoder())); + } + + private static TypedColumn> inGlobalWindow( + TypedColumn value, Expression ts) { + List fields = concat(timestampedValue(value, ts), GLOBAL_WINDOW_DETAILS); + Encoder> enc = + windowedValueEncoder(value.encoder(), encoderOf(GlobalWindow.class)); + return (TypedColumn>) new Column(new CreateNamedStruct(fields)).as(enc); + } + + public static TypedColumn> inSingleWindow( + TypedColumn value, TypedColumn window, Expression ts) { + Expression windows = new CreateArray(listOf(window.expr())); + Seq fields = concat(timestampedValue(value, ts), windowDetails(windows)); + Encoder> enc = windowedValueEncoder(value.encoder(), window.encoder()); + return (TypedColumn>) new Column(new CreateNamedStruct(fields)).as(enc); + } + + private static List timestampedValue(Column value, Expression ts) { + return seqOf(lit("value"), value.expr(), lit("timestamp"), ts).toList(); + } + + private static List windowDetails(Expression windows) { + return seqOf(lit("windows"), windows, lit("pane"), PANE_NO_FIRING).toList(); + } + + private static Expression lit(T t) { + return Literal$.MODULE$.apply(t); + } + + @SuppressWarnings("nullness") // NULL literal + private static Expression litNull(DataType dataType) { + return new Literal(null, dataType); } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java index 65f496c772ba..cf0d2e7ab093 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java @@ -17,33 +17,27 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import java.util.Collections; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; +import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; + import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; import org.apache.beam.sdk.coders.ByteArrayCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.spark.sql.Dataset; public class ImpulseTranslatorBatch - implements TransformTranslator>> { + extends TransformTranslator, Impulse> { @Override - public void translateTransform( - PTransform> transform, AbstractTranslationContext context) { - Coder> windowedValueCoder = - WindowedValue.FullWindowedValueCoder.of(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE); + public void translate(Impulse transform, Context cxt) { Dataset> dataset = - context - .getSparkSession() - .createDataset( - Collections.singletonList(WindowedValue.valueInGlobalWindow(new byte[0])), - EncoderHelpers.fromBeamCoder(windowedValueCoder)); - context.putDataset(context.getOutput(), dataset); + cxt.createDataset( + ImmutableList.of(WindowedValue.valueInGlobalWindow(EMPTY_BYTE_ARRAY)), + cxt.windowedEncoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE)); + cxt.putDataset(cxt.getOutput(), dataset); } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java index 52c2d5ae6420..131b285be138 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java @@ -17,64 +17,81 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static java.util.stream.Collectors.toList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY; import java.io.IOException; +import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import javax.annotation.Nullable; +import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.construction.ParDoTranslation; -import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator; -import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FilterFunction; -import org.apache.spark.api.java.function.MapFunction; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams; +import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.storage.StorageLevel; +import scala.Function1; import scala.Tuple2; +import scala.collection.Iterator; +import scala.reflect.ClassTag; /** - * TODO: Add support for state and timers. + * Translator for {@link ParDo.MultiOutput} based on {@link DoFnRunners#simpleRunner}. * - * @param - * @param + *

    Each tag is encoded as individual column with a respective schema & encoder each. + * + *

    TODO: + *

  • Add support for state and timers. + *
  • Add support for SplittableDoFn */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class ParDoTranslatorBatch - implements TransformTranslator, PCollectionTuple>> { + extends TransformTranslator< + PCollection, PCollectionTuple, ParDo.MultiOutput> { + + private static final ClassTag> WINDOWED_VALUE_CTAG = + ClassTag.apply(WindowedValue.class); + + private static final ClassTag>> TUPLE2_CTAG = + ClassTag.apply(Tuple2.class); @Override - public void translateTransform( - PTransform, PCollectionTuple> transform, - AbstractTranslationContext context) { - String stepName = context.getCurrentTransform().getFullName(); + public void translate(ParDo.MultiOutput transform, Context cxt) + throws IOException { + String stepName = cxt.getCurrentTransform().getFullName(); + + SparkCommonPipelineOptions opts = cxt.getOptions().as(SparkCommonPipelineOptions.class); + StorageLevel storageLevel = StorageLevel.fromString(opts.getStorageLevel()); // Check for not supported advanced features // TODO: add support of Splittable DoFn - DoFn doFn = getDoFn(context); + DoFn doFn = transform.getFn(); checkState( !DoFnSignatures.isSplittable(doFn), "Not expected to directly translate splittable DoFn, should have been overridden: %s", @@ -86,98 +103,124 @@ public void translateTransform( checkState( !DoFnSignatures.requiresTimeSortedInput(doFn), - "@RequiresTimeSortedInput is not " + "supported for the moment"); - - DoFnSchemaInformation doFnSchemaInformation = - ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); - - // Init main variables - PValue input = context.getInput(); - Dataset> inputDataSet = context.getDataset(input); - Map, PCollection> outputs = context.getOutputs(); - TupleTag mainOutputTag = getTupleTag(context); - List> outputTags = new ArrayList<>(outputs.keySet()); - WindowingStrategy windowingStrategy = - ((PCollection) input).getWindowingStrategy(); - Coder inputCoder = ((PCollection) input).getCoder(); - Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); - - // construct a map from side input to WindowingStrategy so that - // the DoFn runner can map main-input windows to side input windows - List> sideInputs = getSideInputs(context); - Map, WindowingStrategy> sideInputStrategies = new HashMap<>(); - for (PCollectionView sideInput : sideInputs) { - sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy()); - } - - SideInputBroadcast broadcastStateData = createBroadcastSideInputs(sideInputs, context); + "@RequiresTimeSortedInput is not supported for the moment"); - Map, Coder> outputCoderMap = context.getOutputCoders(); - MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance(); + TupleTag mainOutputTag = transform.getMainOutputTag(); - List> additionalOutputTags = new ArrayList<>(); - for (TupleTag tag : outputTags) { - if (!tag.equals(mainOutputTag)) { - additionalOutputTags.add(tag); - } - } + DoFnSchemaInformation doFnSchema = + ParDoTranslation.getSchemaInformation(cxt.getCurrentTransform()); - Map> sideInputMapping = - ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); - @SuppressWarnings("unchecked") - DoFnFunction doFnWrapper = - new DoFnFunction( - metricsAccum, + PCollection input = (PCollection) cxt.getInput(); + DoFnMapPartitionsFactory factory = + new DoFnMapPartitionsFactory<>( stepName, doFn, - windowingStrategy, - sideInputStrategies, - context.getSerializableOptions(), - additionalOutputTags, + doFnSchema, + cxt.getSerializableOptions(), + input, mainOutputTag, - inputCoder, - outputCoderMap, - broadcastStateData, - doFnSchemaInformation, - sideInputMapping); - - MultiOutputCoder multipleOutputCoder = - MultiOutputCoder.of(SerializableCoder.of(TupleTag.class), outputCoderMap, windowCoder); - Dataset, WindowedValue>> allOutputs = - inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.fromBeamCoder(multipleOutputCoder)); - if (outputs.entrySet().size() > 1) { - allOutputs.persist(); - for (Map.Entry, PCollection> output : outputs.entrySet()) { - pruneOutputFilteredByTag(context, allOutputs, output, windowCoder); + cxt.getOutputs(), + transform.getSideInputs(), + createBroadcastSideInputs(transform.getSideInputs().values(), cxt)); + + Dataset> inputDs = cxt.getDataset(input); + if (cxt.getOutputs().size() > 1) { + // In case of multiple outputs / tags, map each tag to a column by index. + // At the end split the result into multiple datasets selecting one column each. + Map, Integer> tags = ImmutableMap.copyOf(zipwithIndex(cxt.getOutputs().keySet())); + + List>> encoders = + createEncoders(cxt.getOutputs(), (Iterable>) tags.keySet(), cxt); + + Function1>, Iterator>>> + doFnMapper = factory.create((tag, v) -> tuple(tags.get(tag), (WindowedValue) v)); + + // FIXME What's the strategy to unpersist Datasets / RDDs? + + // If using storage level MEMORY_ONLY, it's best to persist the dataset as RDD to avoid any + // serialization / use of encoders. Persisting a Dataset, even if using a "deserialized" + // storage level, involves converting the data to the internal representation (InternalRow) + // by use of an encoder. + // For any other storage level, persist as Dataset, so we can select columns by TupleTag + // individually without restoring the entire row. + if (MEMORY_ONLY().equals(storageLevel)) { + + RDD>> allTagsRDD = + inputDs.rdd().mapPartitions(doFnMapper, false, TUPLE2_CTAG); + allTagsRDD.persist(); + + // divide into separate output datasets per tag + for (Entry, Integer> e : tags.entrySet()) { + TupleTag key = (TupleTag) e.getKey(); + Integer id = e.getValue(); + + RDD> rddByTag = + allTagsRDD + .filter(fun1(t -> t._1.equals(id))) + .map(fun1(Tuple2::_2), WINDOWED_VALUE_CTAG); + + cxt.putDataset( + cxt.getOutput(key), cxt.getSparkSession().createDataset(rddByTag, encoders.get(id))); + } + } else { + // Persist as wide rows with one column per TupleTag to support different schemas + Dataset>> allTagsDS = + inputDs.mapPartitions(doFnMapper, oneOfEncoder(encoders)); + allTagsDS.persist(storageLevel); + + // divide into separate output datasets per tag + for (Entry, Integer> e : tags.entrySet()) { + TupleTag key = (TupleTag) e.getKey(); + Integer id = e.getValue(); + + // Resolve specific column matching the tuple tag (by id) + TypedColumn>, WindowedValue> col = + (TypedColumn) col(id.toString()).as(encoders.get(id)); + + cxt.putDataset(cxt.getOutput(key), allTagsDS.filter(col.isNotNull()).select(col)); + } } } else { - Coder outputCoder = ((PCollection) outputs.get(mainOutputTag)).getCoder(); - Coder> windowedValueCoder = - (Coder>) (Coder) WindowedValue.getFullCoder(outputCoder, windowCoder); - Dataset> outputDataset = - allOutputs.map( - (MapFunction, WindowedValue>, WindowedValue>) - value -> value._2, - EncoderHelpers.fromBeamCoder(windowedValueCoder)); - context.putDatasetWildcard(outputs.entrySet().iterator().next().getValue(), outputDataset); + PCollection output = cxt.getOutput(mainOutputTag); + Dataset> mainDS = + inputDs.mapPartitions( + factory.create((tag, value) -> (WindowedValue) value), + cxt.windowedEncoder(output.getCoder())); + + cxt.putDataset(output, mainDS); + } + } + + private List>> createEncoders( + Map, PCollection> outputs, Iterable> columns, Context ctx) { + return Streams.stream(columns) + .map(tag -> ctx.windowedEncoder(getCoder(outputs.get(tag), tag))) + .collect(toList()); + } + + private Coder getCoder(@Nullable PCollection pc, TupleTag tag) { + if (pc == null) { + throw new NullPointerException("No PCollection for tag " + tag); } + return (Coder) pc.getCoder(); } - private static SideInputBroadcast createBroadcastSideInputs( - List> sideInputs, AbstractTranslationContext context) { - JavaSparkContext jsc = - JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext()); + // FIXME Better ways? + private SideInputBroadcast createBroadcastSideInputs( + Collection> sideInputs, Context context) { SideInputBroadcast sideInputBroadcast = new SideInputBroadcast(); for (PCollectionView sideInput : sideInputs) { + PCollection pc = sideInput.getPCollection(); + if (pc == null) { + throw new NullPointerException("PCollection for SideInput is null"); + } Coder windowCoder = - sideInput.getPCollection().getWindowingStrategy().getWindowFn().windowCoder(); - + pc.getWindowingStrategy().getWindowFn().windowCoder(); Coder> windowedValueCoder = (Coder>) - (Coder) - WindowedValue.getFullCoder(sideInput.getPCollection().getCoder(), windowCoder); - Dataset> broadcastSet = context.getSideInputDataSet(sideInput); + (Coder) WindowedValue.getFullCoder(pc.getCoder(), windowCoder); + Dataset> broadcastSet = context.getSideInputDataset(sideInput); List> valuesList = broadcastSet.collectAsList(); List codedValues = new ArrayList<>(); for (WindowedValue v : valuesList) { @@ -185,73 +228,17 @@ private static SideInputBroadcast createBroadcastSideInputs( } sideInputBroadcast.add( - sideInput.getTagInternal().getId(), jsc.broadcast(codedValues), windowedValueCoder); + sideInput.getTagInternal().getId(), context.broadcast(codedValues), windowedValueCoder); } return sideInputBroadcast; } - private List> getSideInputs(AbstractTranslationContext context) { - List> sideInputs; - try { - sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform()); - } catch (IOException e) { - throw new RuntimeException(e); - } - return sideInputs; - } - - private TupleTag getTupleTag(AbstractTranslationContext context) { - TupleTag mainOutputTag; - try { - mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform()); - } catch (IOException e) { - throw new RuntimeException(e); - } - return mainOutputTag; - } - - @SuppressWarnings("unchecked") - private DoFn getDoFn(AbstractTranslationContext context) { - DoFn doFn; - try { - doFn = (DoFn) ParDoTranslation.getDoFn(context.getCurrentTransform()); - } catch (IOException e) { - throw new RuntimeException(e); - } - return doFn; - } - - private void pruneOutputFilteredByTag( - AbstractTranslationContext context, - Dataset, WindowedValue>> allOutputs, - Map.Entry, PCollection> output, - Coder windowCoder) { - Dataset, WindowedValue>> filteredDataset = - allOutputs.filter(new DoFnFilterFunction(output.getKey())); - Coder> windowedValueCoder = - (Coder>) - (Coder) - WindowedValue.getFullCoder( - ((PCollection) output.getValue()).getCoder(), windowCoder); - Dataset> outputDataset = - filteredDataset.map( - (MapFunction, WindowedValue>, WindowedValue>) - value -> value._2, - EncoderHelpers.fromBeamCoder(windowedValueCoder)); - context.putDatasetWildcard(output.getValue(), outputDataset); - } - - static class DoFnFilterFunction implements FilterFunction, WindowedValue>> { - - private final TupleTag key; - - DoFnFilterFunction(TupleTag key) { - this.key = key; - } - - @Override - public boolean call(Tuple2, WindowedValue> value) { - return value._1.equals(key); + private static Collection> zipwithIndex(Collection col) { + ArrayList> zipped = new ArrayList<>(col.size()); + int i = 0; + for (T t : col) { + zipped.add(new SimpleImmutableEntry<>(t, i++)); } + return zipped; } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java index 5789db6cd304..62d79632cfad 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java @@ -25,7 +25,6 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; @@ -34,6 +33,8 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -41,10 +42,6 @@ * only the components specific to batch: registry of batch {@link TransformTranslator} and registry * lookup code. */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) public class PipelineTranslatorBatch extends PipelineTranslator { // -------------------------------------------------------------------------------------------- @@ -65,23 +62,24 @@ public class PipelineTranslatorBatch extends PipelineTranslator { static { TRANSFORM_TRANSLATORS.put(Impulse.class, new ImpulseTranslatorBatch()); - TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch()); - TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); + TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch<>()); + TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch<>()); + TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch<>()); // TODO: Do we need to have a dedicated translator for {@code Reshuffle} if it's deprecated? // TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch()); - TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch()); + TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch<>()); - TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch()); + TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch<>()); - TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch()); + TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch<>()); TRANSFORM_TRANSLATORS.put( - SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch()); + SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>()); TRANSFORM_TRANSLATORS.put( - View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch()); + View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch<>()); } public PipelineTranslatorBatch(SparkStructuredStreamingPipelineOptions options) { @@ -90,8 +88,10 @@ public PipelineTranslatorBatch(SparkStructuredStreamingPipelineOptions options) /** Returns a translator for the given node, if it is possible, otherwise null. */ @Override - protected TransformTranslator getTransformTranslator(TransformHierarchy.Node node) { - @Nullable PTransform transform = node.getTransform(); + @Nullable + protected > + TransformTranslator getTransformTranslator( + @Nullable TransformT transform) { // Root of the graph is null if (transform == null) { return null; diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java deleted file mode 100644 index db64bfd19f39..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.batch; - -import java.util.ArrayList; -import java.util.Iterator; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners.OutputManager; -import org.apache.beam.runners.core.TimerInternals; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; - -/** Spark runner process context processes Spark partitions using Beam's {@link DoFnRunner}. */ -class ProcessContext { - - private final DoFn doFn; - private final DoFnRunner doFnRunner; - private final ProcessOutputManager outputManager; - private final Iterator timerDataIterator; - - ProcessContext( - DoFn doFn, - DoFnRunner doFnRunner, - ProcessOutputManager outputManager, - Iterator timerDataIterator) { - - this.doFn = doFn; - this.doFnRunner = doFnRunner; - this.outputManager = outputManager; - this.timerDataIterator = timerDataIterator; - } - - Iterable processPartition(Iterator> partition) { - - // skip if partition is empty. - if (!partition.hasNext()) { - return new ArrayList<>(); - } - - // process the partition; finishBundle() is called from within the output iterator. - return this.getOutputIterable(partition, doFnRunner); - } - - private void clearOutput() { - outputManager.clear(); - } - - private Iterator getOutputIterator() { - return outputManager.iterator(); - } - - private Iterable getOutputIterable( - final Iterator> iter, - final DoFnRunner doFnRunner) { - return () -> new ProcCtxtIterator(iter, doFnRunner); - } - - interface ProcessOutputManager extends OutputManager, Iterable { - void clear(); - } - - private class ProcCtxtIterator extends AbstractIterator { - - private final Iterator> inputIterator; - private final DoFnRunner doFnRunner; - private Iterator outputIterator; - private boolean isBundleStarted; - private boolean isBundleFinished; - - ProcCtxtIterator( - Iterator> iterator, DoFnRunner doFnRunner) { - this.inputIterator = iterator; - this.doFnRunner = doFnRunner; - this.outputIterator = getOutputIterator(); - } - - @Override - protected OutputT computeNext() { - try { - // Process each element from the (input) iterator, which produces, zero, one or more - // output elements (of type V) in the output iterator. Note that the output - // collection (and iterator) is reset between each call to processElement, so the - // collection only holds the output values for each call to processElement, rather - // than for the whole partition (which would use too much memory). - if (!isBundleStarted) { - isBundleStarted = true; - // call startBundle() before beginning to process the partition. - doFnRunner.startBundle(); - } - - while (true) { - if (outputIterator.hasNext()) { - return outputIterator.next(); - } - - clearOutput(); - if (inputIterator.hasNext()) { - // grab the next element and process it. - doFnRunner.processElement(inputIterator.next()); - outputIterator = getOutputIterator(); - } else if (timerDataIterator.hasNext()) { - outputIterator = getOutputIterator(); - } else { - // no more input to consume, but finishBundle can produce more output - if (!isBundleFinished) { - isBundleFinished = true; - doFnRunner.finishBundle(); - outputIterator = getOutputIterator(); - continue; // try to consume outputIterator from start of loop - } - DoFnInvokers.invokerFor(doFn).invokeTeardown(); - return endOfData(); - } - } - } catch (final RuntimeException re) { - DoFnInvokers.invokerFor(doFn).invokeTeardown(); - throw re; - } - } - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java index ebeb8a96eda4..30b599c7e5ec 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java @@ -17,72 +17,38 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS; - import java.io.IOException; -import org.apache.beam.runners.core.construction.ReadTranslation; -import org.apache.beam.runners.core.serialization.Base64Serializer; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; +import org.apache.beam.runners.core.construction.SerializablePipelineOptions; +import org.apache.beam.runners.core.construction.SplittableParDo; +import org.apache.beam.runners.spark.structuredstreaming.io.BoundedDatasetFactory; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers; import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.Encoder; import org.apache.spark.sql.SparkSession; +/** + * Translator for a {@link SplittableParDo.PrimitiveBoundedRead} that creates a Dataset via an RDD + * to avoid an additional serialization roundtrip. + */ class ReadSourceTranslatorBatch - implements TransformTranslator>> { + extends TransformTranslator, SplittableParDo.PrimitiveBoundedRead> { - private static final String sourceProviderClass = DatasetSourceBatch.class.getCanonicalName(); - - @SuppressWarnings("unchecked") @Override - public void translateTransform( - PTransform> transform, AbstractTranslationContext context) { - AppliedPTransform, PTransform>> rootTransform = - (AppliedPTransform, PTransform>>) - context.getCurrentTransform(); - - BoundedSource source; - try { - source = ReadTranslation.boundedSourceFromTransform(rootTransform); - } catch (IOException e) { - throw new RuntimeException(e); - } - SparkSession sparkSession = context.getSparkSession(); - - String serializedSource = Base64Serializer.serializeUnchecked(source); - Dataset rowDataset = - sparkSession - .read() - .format(sourceProviderClass) - .option(BEAM_SOURCE_OPTION, serializedSource) - .option( - DEFAULT_PARALLELISM, - String.valueOf(context.getSparkSession().sparkContext().defaultParallelism())) - .option(PIPELINE_OPTIONS, context.getSerializableOptions().toString()) - .load(); - - // extract windowedValue from Row - WindowedValue.FullWindowedValueCoder windowedValueCoder = - WindowedValue.FullWindowedValueCoder.of( - source.getOutputCoder(), GlobalWindow.Coder.INSTANCE); - - Dataset> dataset = - rowDataset.map( - RowHelpers.extractWindowedValueFromRowMapFunction(windowedValueCoder), - EncoderHelpers.fromBeamCoder(windowedValueCoder)); - - PCollection output = (PCollection) context.getOutput(); - context.putDataset(output, dataset); + public void translate(SplittableParDo.PrimitiveBoundedRead transform, Context cxt) + throws IOException { + SparkSession session = cxt.getSparkSession(); + BoundedSource source = transform.getSource(); + SerializablePipelineOptions options = cxt.getSerializableOptions(); + + Encoder> encoder = + cxt.windowedEncoder(source.getOutputCoder(), GlobalWindow.Coder.INSTANCE); + + cxt.putDataset( + cxt.getOutput(), + BoundedDatasetFactory.createDatasetFromRDD(session, source, options, encoder)); } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java deleted file mode 100644 index a88d5454667f..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.batch; - -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.sdk.transforms.Reshuffle; - -/** TODO: Should be removed if {@link Reshuffle} won't be translated. */ -class ReshuffleTranslatorBatch implements TransformTranslator> { - - @Override - public void translateTransform( - Reshuffle transform, AbstractTranslationContext context) {} -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java index 875a983b4017..3b993a3ce193 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java @@ -17,45 +17,83 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; +import java.util.Collection; import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.WindowingHelpers; -import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.joda.time.Instant; -@SuppressWarnings({ - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) class WindowAssignTranslatorBatch - implements TransformTranslator, PCollection>> { + extends TransformTranslator, PCollection, Window.Assign> { @Override - public void translateTransform( - PTransform, PCollection> transform, AbstractTranslationContext context) { - - Window.Assign assignTransform = (Window.Assign) transform; - @SuppressWarnings("unchecked") - final PCollection input = (PCollection) context.getInput(); - @SuppressWarnings("unchecked") - final PCollection output = (PCollection) context.getOutput(); - - Dataset> inputDataset = context.getDataset(input); - if (WindowingHelpers.skipAssignWindows(assignTransform, context)) { - context.putDataset(output, inputDataset); + public void translate(Window.Assign transform, Context cxt) { + WindowFn windowFn = transform.getWindowFn(); + PCollection input = cxt.getInput(); + Dataset> inputDataset = cxt.getDataset(input); + + if (windowFn == null || skipAssignWindows(windowFn, input)) { + cxt.putDataset(cxt.getOutput(), inputDataset); } else { - WindowFn windowFn = assignTransform.getWindowFn(); - WindowedValue.FullWindowedValueCoder windowedValueCoder = - WindowedValue.FullWindowedValueCoder.of(input.getCoder(), windowFn.windowCoder()); Dataset> outputDataset = inputDataset.map( - WindowingHelpers.assignWindowsMapFunction(windowFn), - EncoderHelpers.fromBeamCoder(windowedValueCoder)); - context.putDataset(output, outputDataset); + assignWindows(windowFn), + cxt.windowedEncoder(input.getCoder(), windowFn.windowCoder())); + + cxt.putDataset(cxt.getOutput(), outputDataset); } } + + /** + * Checks if the window transformation should be applied or skipped. + * + *

    Avoid running assign windows if both source and destination are global window or if the user + * has not specified the WindowFn (meaning they are just messing with triggering or allowed + * lateness). + */ + private boolean skipAssignWindows(WindowFn newFn, PCollection input) { + WindowFn currentFn = input.getWindowingStrategy().getWindowFn(); + return currentFn instanceof GlobalWindows && newFn instanceof GlobalWindows; + } + + private static + MapFunction, WindowedValue> assignWindows(WindowFn windowFn) { + return value -> { + final BoundedWindow window = getOnlyWindow(value); + final T element = value.getValue(); + final Instant timestamp = value.getTimestamp(); + Collection windows = + windowFn.assignWindows( + windowFn.new AssignContext() { + + @Override + public T element() { + return element; + } + + @Override + public @NonNull Instant timestamp() { + return timestamp; + } + + @Override + public @NonNull BoundedWindow window() { + return window; + } + }); + return WindowedValue.of(element, timestamp, windows, value.getPane()); + }; + } + + private static BoundedWindow getOnlyWindow(WindowedValue wv) { + return Iterables.getOnlyElement((Iterable) wv.getWindows()); + } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java index fe3f39ef51e9..f8c63bc34f14 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java @@ -17,10 +17,9 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.util.CoderUtils; /** Serialization utility class. */ public final class CoderHelpers { @@ -35,13 +34,11 @@ private CoderHelpers() {} * @return Byte array representing serialized object. */ public static byte[] toByteArray(T value, Coder coder) { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); try { - coder.encode(value, baos); + return CoderUtils.encodeToByteArray(coder, value); } catch (IOException e) { throw new IllegalStateException("Error encoding value: " + value, e); } - return baos.toByteArray(); } /** @@ -53,9 +50,8 @@ public static byte[] toByteArray(T value, Coder coder) { * @return Deserialized object. */ public static T fromByteArray(byte[] serialized, Coder coder) { - ByteArrayInputStream bais = new ByteArrayInputStream(serialized); try { - return coder.decode(bais); + return CoderUtils.decodeFromByteArray(coder, serialized); } catch (IOException e) { throw new IllegalStateException("Error decoding bytes for coder: " + coder, e); } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java index c7d69c0b8ad8..e70cc7253f8d 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java @@ -17,13 +17,17 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; + import java.lang.reflect.Constructor; import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.objects.Invoke; +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance; import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke; import org.apache.spark.sql.types.DataType; -import scala.collection.immutable.Nil$; -import scala.collection.mutable.WrappedArray; +import scala.Option; import scala.reflect.ClassTag; public class EncoderFactory { @@ -31,6 +35,12 @@ public class EncoderFactory { private static final Constructor STATIC_INVOKE_CONSTRUCTOR = (Constructor) StaticInvoke.class.getConstructors()[0]; + private static final Constructor INVOKE_CONSTRUCTOR = + (Constructor) Invoke.class.getConstructors()[0]; + + private static final Constructor NEW_INSTANCE_CONSTRUCTOR = + (Constructor) NewInstance.class.getConstructors()[0]; + static ExpressionEncoder create( Expression serializer, Expression deserializer, Class clazz) { return new ExpressionEncoder<>(serializer, deserializer, ClassTag.apply(clazz)); @@ -39,21 +49,68 @@ static ExpressionEncoder create( /** * Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any * input arg is {@code null}. - * - *

    To address breaking interfaces between various version of Spark 3 these are created - * reflectively. This is fine as it's just needed once to create the query plan. */ static Expression invokeIfNotNull(Class cls, String fun, DataType type, Expression... args) { + return invoke(cls, fun, type, true, args); + } + + /** Invoke method {@code fun} on Class {@code cls}. */ + static Expression invoke(Class cls, String fun, DataType type, Expression... args) { + return invoke(cls, fun, type, false, args); + } + + private static Expression invoke( + Class cls, String fun, DataType type, boolean propagateNull, Expression... args) { try { + // To address breaking interfaces between various version of Spark 3, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) { case 6: // Spark 3.1.x return STATIC_INVOKE_CONSTRUCTOR.newInstance( - cls, type, fun, new WrappedArray.ofRef<>(args), true, true); + cls, type, fun, seqOf(args), propagateNull, true); case 8: // Spark 3.2.x, 3.3.x return STATIC_INVOKE_CONSTRUCTOR.newInstance( - cls, type, fun, new WrappedArray.ofRef<>(args), Nil$.MODULE$, true, true, true); + cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + /** Invoke method {@code fun} on {@code obj} with provided {@code args}. */ + static Expression invoke( + Expression obj, String fun, DataType type, boolean nullable, Expression... args) { + try { + // To address breaking interfaces between various version of Spark 3, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) { + case 6: + return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable); + case 8: + return INVOKE_CONSTRUCTOR.newInstance( + obj, fun, type, seqOf(args), emptyList(), false, nullable, true); + default: + throw new RuntimeException("Unsupported version of Spark"); + } + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + throw new RuntimeException(ex); + } + } + + static Expression newInstance(Class cls, DataType type, Expression... args) { + try { + // To address breaking interfaces between various version of Spark 3, expressions are + // created reflectively. This is fine as it's just needed once to create the query plan. + switch (NEW_INSTANCE_CONSTRUCTOR.getParameterCount()) { + case 5: + return NEW_INSTANCE_CONSTRUCTOR.newInstance(cls, seqOf(args), true, type, Option.empty()); + case 6: + return NEW_INSTANCE_CONSTRUCTOR.newInstance( + cls, seqOf(args), emptyList(), true, type, Option.empty()); default: throw new RuntimeException("Unsupported version of Spark"); } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index 68738cf03080..f89f6bdb9d30 100644 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -17,44 +17,488 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.newInstance; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; import static org.apache.spark.sql.types.DataTypes.BinaryType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import java.math.BigDecimal; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Function; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.SerializerBuildHelper; +import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation; import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Coalesce; +import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct; +import org.apache.spark.sql.catalyst.expressions.EqualTo; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.GetStructField; +import org.apache.spark.sql.catalyst.expressions.If; +import org.apache.spark.sql.catalyst.expressions.IsNotNull; +import org.apache.spark.sql.catalyst.expressions.IsNull; import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.catalyst.expressions.Literal$; +import org.apache.spark.sql.catalyst.expressions.MapKeys; +import org.apache.spark.sql.catalyst.expressions.MapValues; +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.ObjectType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.MutablePair; import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.IndexedSeq; +import scala.collection.JavaConverters; +import scala.collection.Seq; +/** {@link Encoders} utility class. */ public class EncoderHelpers { private static final DataType OBJECT_TYPE = new ObjectType(Object.class); + private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class); + private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class); + private static final DataType KV_TYPE = new ObjectType(KV.class); + private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class); + + // Collections / maps of these types can be (de)serialized without (de)serializing each member + private static final Set> PRIMITIV_TYPES = + ImmutableSet.of( + Boolean.class, + Byte.class, + Short.class, + Integer.class, + Long.class, + Float.class, + Double.class); + + // Default encoders by class + private static final Map, Encoder> DEFAULT_ENCODERS = new HashMap<>(); + + // Factory for default encoders by class + private static final Function, @Nullable Encoder> ENCODER_FACTORY = + cls -> { + if (cls.equals(PaneInfo.class)) { + return paneInfoEncoder(); + } else if (cls.equals(GlobalWindow.class)) { + return binaryEncoder(GlobalWindow.Coder.INSTANCE, false); + } else if (cls.equals(IntervalWindow.class)) { + return binaryEncoder(IntervalWindowCoder.of(), false); + } else if (cls.equals(Instant.class)) { + return instantEncoder(); + } else if (cls.equals(String.class)) { + return Encoders.STRING(); + } else if (cls.equals(Boolean.class)) { + return Encoders.BOOLEAN(); + } else if (cls.equals(Integer.class)) { + return Encoders.INT(); + } else if (cls.equals(Long.class)) { + return Encoders.LONG(); + } else if (cls.equals(Float.class)) { + return Encoders.FLOAT(); + } else if (cls.equals(Double.class)) { + return Encoders.DOUBLE(); + } else if (cls.equals(BigDecimal.class)) { + return Encoders.DECIMAL(); + } else if (cls.equals(byte[].class)) { + return Encoders.BINARY(); + } else if (cls.equals(Byte.class)) { + return Encoders.BYTE(); + } else if (cls.equals(Short.class)) { + return Encoders.SHORT(); + } + return null; + }; + + private static @Nullable Encoder getOrCreateDefaultEncoder(Class cls) { + return (Encoder) DEFAULT_ENCODERS.computeIfAbsent(cls, ENCODER_FACTORY); + } + + /** Gets or creates a default {@link Encoder} for {@link T}. */ + public static Encoder encoderOf(Class cls) { + Encoder enc = getOrCreateDefaultEncoder(cls); + if (enc == null) { + throw new IllegalArgumentException("No default coder available for class " + cls); + } + return enc; + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + *

    Note: For common types, if available, default Spark {@link Encoder}s are used instead. + * + * @param coder Beam {@link Coder} + */ + public static Encoder encoderFor(Coder coder) { + Encoder enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType()); + return enc != null ? enc : binaryEncoder(coder, true); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value}, + * {@code timestamp}, {@code windows} and {@code pane}. + * + * @param value {@link Encoder} to encode field `{@code value}`. + * @param window {@link Encoder} to encode individual windows in field `{@code windows}` + */ + public static Encoder> windowedValueEncoder( + Encoder value, Encoder window) { + Encoder timestamp = encoderOf(Instant.class); + Encoder pane = encoderOf(PaneInfo.class); + Encoder> windows = collectionEncoder(window); + Expression serializer = + serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, pane); + Expression deserializer = + deserializeWindowedValue(rootCol(serializer.dataType()), value, timestamp, windows, pane); + return EncoderFactory.create(serializer, deserializer, WindowedValue.class); + } + + /** + * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is + * represented as colum / field named by its index with a separate {@link Encoder} each. + * + *

    Externally this is represented as tuple {@code (index, data)} where an index corresponds to + * an {@link Encoder} in the provided list. + * + * @param encoders {@link Encoder}s for each alternative. + */ + public static Encoder> oneOfEncoder(List> encoders) { + Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders); + Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders); + return EncoderFactory.create(serializer, deserializer, Tuple2.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key} + * and {@code value}. + * + * @param key {@link Encoder} to encode field `{@code key}`. + * @param value {@link Encoder} to encode field `{@code value}` + */ + public static Encoder> kvEncoder(Encoder key, Encoder value) { + Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value); + Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value); + return EncoderFactory.create(serializer, deserializer, KV.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable + * elements. + * + * @param enc {@link Encoder} to encode collection elements + */ + public static Encoder> collectionEncoder(Encoder enc) { + return collectionEncoder(enc, true); + } + + /** + * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s. + * + * @param enc {@link Encoder} to encode collection elements + * @param nullable Allow nullable collection elements + */ + public static Encoder> collectionEncoder(Encoder enc, boolean nullable) { + DataType type = new ObjectType(Collection.class); + Expression serializer = serializeSeq(rootRef(type, true), enc, nullable); + Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true); + return EncoderFactory.create(serializer, deserializer, Collection.class); + } + + /** + * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}. + * + * @param key {@link Encoder} to encode keys + * @param value {@link Encoder} to encode values + * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap} + */ + public static , K, V> Encoder mapEncoder( + Encoder key, Encoder value, Class cls) { + Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value); + Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls); + return EncoderFactory.create(serializer, deserializer, cls); + } + + /** + * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with + * fields `{@code _1}` and `{@code _2}`. + * + *

    This is intended to be used in places such as aggregators. + * + * @param enc1 {@link Encoder} to encode `{@code _1}` + * @param enc2 {@link Encoder} to encode `{@code _2}` + */ + public static Encoder> mutablePairEncoder( + Encoder enc1, Encoder enc2) { + Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2); + Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2); + return EncoderFactory.create(serializer, deserializer, MutablePair.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType + * BinaryType}. + */ + private static Encoder paneInfoEncoder() { + DataType type = new ObjectType(PaneInfo.class); + return EncoderFactory.create( + invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)), + invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)), + PaneInfo.class); + } /** - * Wrap a Beam coder into a Spark Encoder using Catalyst Expression Encoders (which uses java code - * generation). + * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType + * LongType}. */ - public static Encoder fromBeamCoder(Coder coder) { - Class clazz = coder.getEncodedTypeDescriptor().getRawType(); - // Class T could be private, therefore use OBJECT_TYPE to not risk an IllegalAccessError + private static Encoder instantEncoder() { + DataType type = new ObjectType(Instant.class); + Expression instant = rootRef(type, true); + Expression millis = rootCol(LongType); return EncoderFactory.create( - beamSerializer(rootRef(OBJECT_TYPE, true), coder), - beamDeserializer(rootCol(BinaryType), coder), - clazz); + nullSafe(instant, invoke(instant, "getMillis", LongType, false)), + nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)), + Instant.class); + } + + /** + * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType} + * delegating to a Beam {@link Coder} underneath. + * + * @param coder Beam {@link Coder} + * @param nullable If to allow nullable items + */ + private static Encoder binaryEncoder(Coder coder, boolean nullable) { + Literal litCoder = lit(coder, Coder.class); + // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError + return EncoderFactory.create( + invokeIfNotNull( + CoderHelpers.class, + "toByteArray", + BinaryType, + rootRef(OBJECT_TYPE, nullable), + litCoder), + invokeIfNotNull( + CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder), + coder.getEncodedTypeDescriptor().getRawType()); + } + + private static Expression serializeWindowedValue( + Expression in, + Encoder valueEnc, + Encoder timestampEnc, + Encoder> windowsEnc, + Encoder paneEnc) { + return serializerObject( + in, + tuple("value", serializeField(in, valueEnc, "getValue")), + tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")), + tuple("windows", serializeField(in, windowsEnc, "getWindows")), + tuple("pane", serializeField(in, paneEnc, "getPane"))); + } + + private static Expression serializerObject(Expression in, Tuple2... fields) { + return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields)); + } + + private static Expression deserializeWindowedValue( + Expression in, + Encoder valueEnc, + Encoder timestampEnc, + Encoder> windowsEnc, + Encoder paneEnc) { + Expression value = deserializeField(in, valueEnc, 0, "value"); + Expression windows = deserializeField(in, windowsEnc, 2, "windows"); + Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp"); + Expression pane = deserializeField(in, paneEnc, 3, "pane"); + // set timestamp to end of window (maxTimestamp) if null + timestamp = + ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows)); + Expression[] fields = new Expression[] {value, timestamp, windows, pane}; + + return nullSafe(pane, invoke(WindowedValue.class, "of", WINDOWED_VALUE, fields)); + } + + private static Expression serializeMutablePair( + Expression in, Encoder enc1, Encoder enc2) { + return serializerObject( + in, + tuple("_1", serializeField(in, enc1, "_1")), + tuple("_2", serializeField(in, enc2, "_2"))); } - /** Catalyst Expression that serializes elements using Beam {@link Coder}. */ - private static Expression beamSerializer(Expression obj, Coder coder) { - Expression[] args = {obj, lit(coder, Coder.class)}; - return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "toByteArray", BinaryType, args); + private static Expression deserializeMutablePair( + Expression in, Encoder enc1, Encoder enc2) { + Expression field1 = deserializeField(in, enc1, 0, "_1"); + Expression field2 = deserializeField(in, enc2, 1, "_2"); + return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2); } - /** Catalyst Expression that deserializes elements using Beam {@link Coder}. */ - private static Expression beamDeserializer(Expression bytes, Coder coder) { - Expression[] args = {bytes, lit(coder, Coder.class)}; - return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "fromByteArray", OBJECT_TYPE, args); + private static Expression serializeKV( + Expression in, Encoder keyEnc, Encoder valueEnc) { + return serializerObject( + in, + tuple("key", serializeField(in, keyEnc, "getKey")), + tuple("value", serializeField(in, valueEnc, "getValue"))); + } + + private static Expression deserializeKV( + Expression in, Encoder keyEnc, Encoder valueEnc) { + Expression key = deserializeField(in, keyEnc, 0, "key"); + Expression value = deserializeField(in, valueEnc, 1, "value"); + return invoke(KV.class, "of", KV_TYPE, key, value); + } + + public static Expression serializeOneOf(Expression in, List> encoders) { + Expression type = invoke(in, "_1", IntegerType, false); + Expression[] args = new Expression[encoders.size() * 2]; + for (int i = 0; i < encoders.size(); i++) { + args[i * 2] = lit(String.valueOf(i)); + args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i); + } + return new CreateNamedStruct(seqOf(args)); + } + + public static Expression deserializeOneOf(Expression in, List> encoders) { + Expression[] args = new Expression[encoders.size()]; + for (int i = 0; i < encoders.size(); i++) { + args[i] = deserializeOneOfField(in, encoders.get(i), i); + } + return new Coalesce(seqOf(args)); + } + + private static Expression serializeOneOfField( + Expression in, Expression type, Encoder enc, int typeIdx) { + Expression litNull = lit(null, serializedType(enc)); + Expression value = invoke(in, "_2", deserializedType(enc), false); + return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), litNull); + } + + private static Expression deserializeOneOfField(Expression in, Encoder enc, int idx) { + GetStructField field = new GetStructField(in, idx, Option.empty()); + Expression litNull = lit(null, TUPLE2_TYPE); + Expression newTuple = newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), deserialize(field, enc)); + return new If(new IsNull(field), litNull, newTuple); + } + + private static Expression serializeField(Expression in, Encoder enc, String getterName) { + Expression ref = serializer(enc).collect(match(BoundReference.class)).head(); + return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), enc); + } + + private static Expression deserializeField( + Expression in, Encoder enc, int idx, String name) { + return deserialize(new GetStructField(in, idx, new Some<>(name)), enc); + } + + // Note: Currently this doesn't support nullable primitive values + private static Expression mapSerializer(Expression map, Encoder key, Encoder value) { + DataType keyType = deserializedType(key); + DataType valueType = deserializedType(value); + return SerializerBuildHelper.createSerializerForMap( + map, + new MapElementInformation(keyType, false, e -> serialize(e, key)), + new MapElementInformation(valueType, false, e -> serialize(e, value))); + } + + private static , K, V> Expression mapDeserializer( + Expression in, Encoder key, Encoder value, Class cls) { + Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class)); + Expression keys = deserializeSeq(new MapKeys(in), key, false, false); + Expression values = deserializeSeq(new MapValues(in), value, false, false); + String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap"; + return invoke( + Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), mapItemType(value)); + } + + // serialized type for primitive types (avoid boxing!), otherwise the deserialized type + private static Literal mapItemType(Encoder enc) { + return lit(isPrimitiveEnc(enc) ? serializedType(enc) : deserializedType(enc), DataType.class); + } + + private static Expression serializeSeq(Expression in, Encoder enc, boolean nullable) { + if (isPrimitiveEnc(enc)) { + Expression array = invoke(in, "toArray", new ObjectType(Object[].class), false); + return SerializerBuildHelper.createSerializerForGenericArray( + array, serializedType(enc), nullable); + } + Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), in); + return MapObjects$.MODULE$.apply( + exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, Option.empty()); + } + + private static Expression deserializeSeq( + Expression in, Encoder enc, boolean nullable, boolean asJava) { + DataType type = serializedType(enc); // input type is the serializer result type + if (isPrimitiveEnc(enc)) { + ObjectType listType = new ObjectType(List.class); + return asJava ? invoke(Utils.class, "toList", listType, in, lit(type, DataType.class)) : in; + } + Option> optCls = asJava ? Option.apply(List.class) : Option.empty(); + return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, nullable, optCls); + } + + private static boolean isPrimitiveEnc(Encoder enc) { + return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass()); + } + + private static Expression serialize(Expression input, Encoder enc) { + return serializer(enc).transformUp(replace(BoundReference.class, input)); + } + + private static Expression deserialize(Expression input, Encoder enc) { + return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, input)); + } + + private static Expression serializer(Encoder enc) { + return ((ExpressionEncoder) enc).objSerializer(); + } + + private static Expression deserializer(Encoder enc) { + return ((ExpressionEncoder) enc).objDeserializer(); + } + + private static DataType serializedType(Encoder enc) { + return ((ExpressionEncoder) enc).objSerializer().dataType(); + } + + private static DataType deserializedType(Encoder enc) { + return ((ExpressionEncoder) enc).objDeserializer().dataType(); } private static Expression rootRef(DataType dt, boolean nullable) { @@ -65,7 +509,77 @@ private static Expression rootCol(DataType dt) { return new GetColumnByOrdinal(0, dt); } + private static Expression nullSafe(Expression in, Expression out) { + return new If(new IsNull(in), lit(null, out.dataType()), out); + } + + private static Expression ifNotNull(Expression expr, Expression otherwise) { + return new If(new IsNotNull(expr), expr, otherwise); + } + + private static Expression lit(T t) { + return Literal$.MODULE$.apply(t); + } + + @SuppressWarnings("nullness") // literal NULL is allowed + private static Expression lit(@Nullable T t, DataType dataType) { + return new Literal(t, dataType); + } + private static Literal lit(T obj, Class cls) { return Literal.fromObject(obj, new ObjectType(cls)); } + + /** Encoder / expression utils that are called from generated code. */ + public static class Utils { + + public static PaneInfo paneInfoFromBytes(byte[] bytes) { + return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of()); + } + + public static byte[] paneInfoToBytes(PaneInfo pane) { + return CoderHelpers.toByteArray(pane, PaneInfoCoder.of()); + } + + /** The end of the only window (max timestamp). */ + public static Instant maxTimestamp(Iterable windows) { + return Iterables.getOnlyElement(windows).maxTimestamp(); + } + + public static List toList(ArrayData arrayData, DataType type) { + return JavaConverters.seqAsJavaList(arrayData.toSeq(type)); + } + + public static Seq toSeq(ArrayData arrayData) { + return arrayData.toSeq(OBJECT_TYPE); + } + + public static Seq toSeq(Collection col) { + if (col instanceof List) { + return JavaConverters.asScalaBuffer((List) col); + } + return JavaConverters.collectionAsScalaIterable(col).toSeq(); + } + + public static TreeMap toTreeMap( + ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + return toMap(new TreeMap<>(), keys, values, keyType, valueType); + } + + public static HashMap toMap( + ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + HashMap map = Maps.newHashMapWithExpectedSize(keys.numElements()); + return toMap(map, keys, values, keyType, valueType); + } + + private static > MapT toMap( + MapT map, ArrayData keys, ArrayData values, DataType keyType, DataType valueType) { + IndexedSeq keysSeq = keys.toSeq(keyType); + IndexedSeq valuesSeq = values.toSeq(valueType); + for (int i = 0; i < keysSeq.size(); i++) { + map.put(keysSeq.apply(i), valuesSeq.apply(i)); + } + return map; + } + } } diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java deleted file mode 100644 index f77fcea67960..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.Map; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.CustomCoder; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.TupleTag; -import scala.Tuple2; - -/** - * Coder to serialize and deserialize {@code}Tuple2, WindowedValue{/@code} to be used - * in spark encoders while applying {@link org.apache.beam.sdk.transforms.DoFn}. - * - * @param type of the elements in the collection - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class MultiOutputCoder extends CustomCoder, WindowedValue>> { - Coder tupleTagCoder; - Map, Coder> coderMap; - Coder windowCoder; - - public static MultiOutputCoder of( - Coder tupleTagCoder, - Map, Coder> coderMap, - Coder windowCoder) { - return new MultiOutputCoder(tupleTagCoder, coderMap, windowCoder); - } - - private MultiOutputCoder( - Coder tupleTagCoder, - Map, Coder> coderMap, - Coder windowCoder) { - this.tupleTagCoder = tupleTagCoder; - this.coderMap = coderMap; - this.windowCoder = windowCoder; - } - - @Override - public void encode(Tuple2, WindowedValue> tuple2, OutputStream outStream) - throws IOException { - TupleTag tupleTag = tuple2._1(); - tupleTagCoder.encode(tupleTag, outStream); - Coder valueCoder = (Coder) coderMap.get(tupleTag); - WindowedValue.FullWindowedValueCoder wvCoder = - WindowedValue.FullWindowedValueCoder.of(valueCoder, windowCoder); - wvCoder.encode(tuple2._2(), outStream); - } - - @Override - public Tuple2, WindowedValue> decode(InputStream inStream) - throws CoderException, IOException { - TupleTag tupleTag = (TupleTag) tupleTagCoder.decode(inStream); - Coder valueCoder = (Coder) coderMap.get(tupleTag); - WindowedValue.FullWindowedValueCoder wvCoder = - WindowedValue.FullWindowedValueCoder.of(valueCoder, windowCoder); - WindowedValue wv = wvCoder.decode(inStream); - return Tuple2.apply(tupleTag, wv); - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java deleted file mode 100644 index 9b5d5da2b2cd..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; - -import static scala.collection.JavaConversions.asScalaBuffer; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.InternalRow; - -/** Helper functions for working with {@link Row}. */ -public final class RowHelpers { - - /** - * A Spark {@link MapFunction} for extracting a {@link WindowedValue} from a Row in which the - * {@link WindowedValue} was serialized to bytes using its {@link - * WindowedValue.WindowedValueCoder}. - * - * @param The type of the object. - * @return A {@link MapFunction} that accepts a {@link Row} and returns its {@link WindowedValue}. - */ - public static MapFunction> extractWindowedValueFromRowMapFunction( - WindowedValue.WindowedValueCoder windowedValueCoder) { - return (MapFunction>) - value -> { - // there is only one value put in each Row by the InputPartitionReader - byte[] bytes = (byte[]) value.get(0); - return windowedValueCoder.decode(new ByteArrayInputStream(bytes)); - }; - } - - /** - * Serialize a windowedValue to bytes using windowedValueCoder {@link - * WindowedValue.FullWindowedValueCoder} and stores it an InternalRow. - */ - public static InternalRow storeWindowedValueInRow( - WindowedValue windowedValue, Coder coder) { - List list = new ArrayList<>(); - // serialize the windowedValue to bytes array to comply with dataset binary schema - WindowedValue.FullWindowedValueCoder windowedValueCoder = - WindowedValue.FullWindowedValueCoder.of(coder, GlobalWindow.Coder.INSTANCE); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); - try { - windowedValueCoder.encode(windowedValue, byteArrayOutputStream); - byte[] bytes = byteArrayOutputStream.toByteArray(); - list.add(bytes); - } catch (IOException e) { - throw new RuntimeException(e); - } - return InternalRow.apply(asScalaBuffer(list).toList()); - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java deleted file mode 100644 index 71dca5264dd8..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; - -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -/** A {@link SchemaHelpers} for the Spark Batch Runner. */ -public class SchemaHelpers { - private static final StructType BINARY_SCHEMA = - new StructType( - new StructField[] { - StructField.apply("binaryStructField", DataTypes.BinaryType, true, Metadata.empty()) - }); - - public static StructType binarySchema() { - // we use a binary schema for now because: - // using a empty schema raises a indexOutOfBoundsException - // using a NullType schema stores null in the elements - return BINARY_SCHEMA; - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java deleted file mode 100644 index 5085eb9f7964..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; - -import java.util.Collection; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; -import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.transforms.windowing.WindowFn; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.spark.api.java.function.MapFunction; -import org.joda.time.Instant; - -/** Helper functions for working with windows. */ -public final class WindowingHelpers { - - /** - * Checks if the window transformation should be applied or skipped. - * - *

    Avoid running assign windows if both source and destination are global window or if the user - * has not specified the WindowFn (meaning they are just messing with triggering or allowed - * lateness). - */ - @SuppressWarnings("unchecked") - public static boolean skipAssignWindows( - Window.Assign transform, AbstractTranslationContext context) { - WindowFn windowFnToApply = (WindowFn) transform.getWindowFn(); - PCollection input = (PCollection) context.getInput(); - WindowFn windowFnOfInput = input.getWindowingStrategy().getWindowFn(); - return windowFnToApply == null - || (windowFnOfInput instanceof GlobalWindows && windowFnToApply instanceof GlobalWindows); - } - - public static - MapFunction, WindowedValue> assignWindowsMapFunction( - WindowFn windowFn) { - return (MapFunction, WindowedValue>) - windowedValue -> { - final BoundedWindow boundedWindow = Iterables.getOnlyElement(windowedValue.getWindows()); - final T element = windowedValue.getValue(); - final Instant timestamp = windowedValue.getTimestamp(); - Collection windows = - windowFn.assignWindows( - windowFn.new AssignContext() { - - @Override - public T element() { - return element; - } - - @Override - public Instant timestamp() { - return timestamp; - } - - @Override - public BoundedWindow window() { - return boundedWindow; - } - }); - return WindowedValue.of(element, timestamp, windows, windowedValue.getPane()); - }; - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java deleted file mode 100644 index 5eb60f68cb34..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.streaming; - -/** - * Spark structured streaming framework does not support more than one aggregation in streaming mode - * because of watermark implementation. As a consequence, this runner, does not support streaming - * mode yet see https://github.com/apache/beam/issues/20241 - */ -class DatasetSourceStreaming {} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java deleted file mode 100644 index 73d99efa4630..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.streaming; - -import java.util.HashMap; -import java.util.Map; -import org.apache.beam.runners.core.construction.SplittableParDo; -import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; -import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.runners.TransformHierarchy; -import org.apache.beam.sdk.transforms.PTransform; -import org.checkerframework.checker.nullness.qual.Nullable; - -/** - * {@link PipelineTranslator} for executing a {@link Pipeline} in Spark in streaming mode. This - * contains only the components specific to streaming: registry of streaming {@link - * TransformTranslator} and registry lookup code. - */ -@SuppressWarnings({ - "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) - "nullness" // TODO(https://github.com/apache/beam/issues/20497) -}) -public class PipelineTranslatorStreaming extends PipelineTranslator { - // -------------------------------------------------------------------------------------------- - // Transform Translator Registry - // -------------------------------------------------------------------------------------------- - - @SuppressWarnings("rawtypes") - private static final Map, TransformTranslator> TRANSFORM_TRANSLATORS = - new HashMap<>(); - - // TODO the ability to have more than one TransformTranslator per URN - // that could be dynamically chosen by a predicated that evaluates based on PCollection - // obtainable though node.getInputs.getValue() - // See - // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L83 - // And - // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L106 - - static { - // TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch()); - // TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch()); - // TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); - - // TODO: Do we need to have a dedicated translator for {@code Reshuffle} if it's deprecated? - // TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch()); - - // TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch()); - // - // TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch()); - // - // TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch()); - - TRANSFORM_TRANSLATORS.put( - SplittableParDo.PrimitiveUnboundedRead.class, new ReadSourceTranslatorStreaming()); - - // TRANSFORM_TRANSLATORS - // .put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch()); - } - - public PipelineTranslatorStreaming(SparkStructuredStreamingPipelineOptions options) { - translationContext = new TranslationContext(options); - } - - /** Returns a translator for the given node, if it is possible, otherwise null. */ - @Override - protected TransformTranslator getTransformTranslator(TransformHierarchy.Node node) { - @Nullable PTransform transform = node.getTransform(); - // Root of the graph is null - if (transform == null) { - return null; - } - return TRANSFORM_TRANSLATORS.get(transform.getClass()); - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java deleted file mode 100644 index 8abc8771a4e8..000000000000 --- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.structuredstreaming.translation.streaming; - -import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM; -import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS; - -import java.io.IOException; -import org.apache.beam.runners.core.construction.ReadTranslation; -import org.apache.beam.runners.core.serialization.Base64Serializer; -import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext; -import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers; -import org.apache.beam.sdk.io.UnboundedSource; -import org.apache.beam.sdk.runners.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PBegin; -import org.apache.beam.sdk.values.PCollection; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; - -class ReadSourceTranslatorStreaming - implements TransformTranslator>> { - - private static final String sourceProviderClass = DatasetSourceStreaming.class.getCanonicalName(); - - @SuppressWarnings("unchecked") - @Override - public void translateTransform( - PTransform> transform, AbstractTranslationContext context) { - AppliedPTransform, PTransform>> rootTransform = - (AppliedPTransform, PTransform>>) - context.getCurrentTransform(); - - UnboundedSource source; - try { - source = ReadTranslation.unboundedSourceFromTransform(rootTransform); - } catch (IOException e) { - throw new RuntimeException(e); - } - SparkSession sparkSession = context.getSparkSession(); - - String serializedSource = Base64Serializer.serializeUnchecked(source); - Dataset rowDataset = - sparkSession - .readStream() - .format(sourceProviderClass) - .option(BEAM_SOURCE_OPTION, serializedSource) - .option( - DEFAULT_PARALLELISM, - String.valueOf(context.getSparkSession().sparkContext().defaultParallelism())) - .option(PIPELINE_OPTIONS, context.getSerializableOptions().toString()) - .load(); - - // extract windowedValue from Row - WindowedValue.FullWindowedValueCoder windowedValueCoder = - WindowedValue.FullWindowedValueCoder.of( - source.getOutputCoder(), GlobalWindow.Coder.INSTANCE); - Dataset> dataset = - rowDataset.map( - RowHelpers.extractWindowedValueFromRowMapFunction(windowedValueCoder), - EncoderHelpers.fromBeamCoder(windowedValueCoder)); - - PCollection output = (PCollection) context.getOutput(); - context.putDataset(output, dataset); - } -} diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java new file mode 100644 index 000000000000..1908cbf2bba9 --- /dev/null +++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.utils; + +import java.io.Serializable; +import org.checkerframework.checker.nullness.qual.NonNull; +import scala.Function1; +import scala.Function2; +import scala.PartialFunction; +import scala.Tuple2; +import scala.collection.Iterator; +import scala.collection.JavaConverters; +import scala.collection.Seq; +import scala.collection.immutable.List; +import scala.collection.immutable.Nil$; +import scala.collection.mutable.WrappedArray; + +/** Utilities for easier interoperability with the Spark Scala API. */ +public class ScalaInterop { + private ScalaInterop() {} + + public static Seq seqOf(T... t) { + return new WrappedArray.ofRef<>(t); + } + + public static List concat(List a, List b) { + return b.$colon$colon$colon(a); + } + + public static Seq listOf(T t) { + return emptyList().$colon$colon(t); + } + + public static List emptyList() { + return (List) Nil$.MODULE$; + } + + /** Scala {@link Iterator} of Java {@link Iterable}. */ + public static Iterator scalaIterator(Iterable iterable) { + return scalaIterator(iterable.iterator()); + } + + /** Scala {@link Iterator} of Java {@link java.util.Iterator}. */ + public static Iterator scalaIterator(java.util.Iterator it) { + return JavaConverters.asScalaIterator(it); + } + + /** Java {@link java.util.Iterator} of Scala {@link Iterator}. */ + public static java.util.Iterator javaIterator(Iterator it) { + return JavaConverters.asJavaIterator(it); + } + + public static Tuple2 tuple(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + public static PartialFunction replace( + Class clazz, T replace) { + return new PartialFunction() { + + @Override + public boolean isDefinedAt(T x) { + return clazz.isAssignableFrom(x.getClass()); + } + + @Override + public T apply(T x) { + return replace; + } + }; + } + + public static PartialFunction match(Class clazz) { + return new PartialFunction() { + + @Override + public boolean isDefinedAt(T x) { + return clazz.isAssignableFrom(x.getClass()); + } + + @Override + public V apply(T x) { + return (V) x; + } + }; + } + + public static Fun1 fun1(Fun1 fun) { + return fun; + } + + public static Fun2 fun2(Fun2 fun) { + return fun; + } + + public interface Fun1 extends Function1, Serializable {} + + public interface Fun2 extends Function2, Serializable {} +} diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java index f994f7712b32..7f2eaa10e809 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java @@ -49,7 +49,7 @@ public InMemoryMetrics(final Properties properties, final MetricRegistry metricR internalMetricRegistry = metricRegistry; } - @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"}) + @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"}) // because of getGauges public static T valueOf(final String name) { // this might fail in case we have multiple aggregators with the same suffix after // the last dot, but it should be good enough for tests. diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java new file mode 100644 index 000000000000..b5b07db0e38b --- /dev/null +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.joda.time.Duration.standardMinutes; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; +import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.transforms.windowing.WindowMappingFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.util.MutablePair; +import org.hamcrest.Matcher; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; + +@RunWith(Enclosed.class) +public class AggregatorsTest { + + // just something easy readable + private static final Instant NOW = Instant.parse("2000-01-01T00:00Z"); + + /** Tests for NonMergingWindowedAggregator in {@link Aggregators}. */ + public static class NonMergingWindowedAggregatorTest { + + private SlidingWindows sliding = + SlidingWindows.of(standardMinutes(15)).every(standardMinutes(5)); + + private Aggregator< + WindowedValue, + Map>, + Collection>> + agg = windowedAgg(sliding); + + @Test + public void testReduce() { + Map> acc; + + acc = agg.reduce(agg.zero(), windowedValue(1, at(10))); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(10), 1)), + KV.of(intervalWindow(5, 20), pair(at(10), 1)), + KV.of(intervalWindow(10, 25), pair(at(10), 1)))); + + acc = agg.reduce(acc, windowedValue(2, at(16))); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(10), 1)), + KV.of(intervalWindow(5, 20), pair(at(16), 3)), + KV.of(intervalWindow(10, 25), pair(at(16), 3)), + KV.of(intervalWindow(15, 30), pair(at(16), 2)))); + } + + @Test + public void testMerge() { + Map> acc; + + assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero())); + + acc = mapOf(KV.of(intervalWindow(0, 15), pair(at(0), 1))); + + assertThat(agg.merge(acc, agg.zero()), equalTo(acc)); + assertThat(agg.merge(agg.zero(), acc), equalTo(acc)); + + acc = agg.merge(acc, acc); + assertThat(acc, equalsToMap(KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)))); + + acc = agg.merge(acc, mapOf(KV.of(intervalWindow(5, 20), pair(at(5), 3)))); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)), + KV.of(intervalWindow(5, 20), pair(at(5), 3)))); + + acc = agg.merge(mapOf(KV.of(intervalWindow(10, 25), pair(at(10), 4))), acc); + assertThat( + acc, + equalsToMap( + KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)), + KV.of(intervalWindow(5, 20), pair(at(5), 3)), + KV.of(intervalWindow(10, 25), pair(at(10), 4)))); + } + + private WindowedValue windowedValue(Integer value, Instant ts) { + return WindowedValue.of(value, ts, sliding.assignWindows(ts), PaneInfo.NO_FIRING); + } + } + + /** + * Shared implementation of tests for SessionsAggregator and MergingWindowedAggregator in {@link + * Aggregators}. + */ + public abstract static class AbstractSessionsTest< + AccT extends Map>> { + + static final Duration SESSIONS_GAP = standardMinutes(15); + + final Aggregator, AccT, Collection>> agg; + + AbstractSessionsTest(WindowFn windowFn) { + agg = windowedAgg(windowFn); + } + + abstract AccT accOf(KV>... entries); + + @Test + public void testReduce() { + AccT acc; + + acc = agg.reduce(agg.zero(), sessionValue(10, at(0))); + assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 10)))); + + // 2nd session after 1st + acc = agg.reduce(acc, sessionValue(7, at(20))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), KV.of(sessionWindow(20), pair(at(20), 7)))); + + // merge into 2nd session + acc = agg.reduce(acc, sessionValue(6, at(18))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), + KV.of(sessionWindow(18, 35), pair(at(20), 7 + 6)))); + + // merge into 2nd session + acc = agg.reduce(acc, sessionValue(5, at(21))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), + KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5)))); + + // 3rd session after 2nd + acc = agg.reduce(acc, sessionValue(2, NOW.plus(standardMinutes(45)))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 10)), + KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5)), + KV.of(sessionWindow(45), pair(at(45), 2)))); + + // merge with 1st and 2nd + acc = agg.reduce(acc, sessionValue(1, at(10))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0, 36), pair(at(21), 10 + 7 + 6 + 5 + 1)), + KV.of(sessionWindow(45), pair(at(45), 2)))); + } + + @Test + public void testMerge() { + AccT acc; + + assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero())); + + acc = accOf(KV.of(sessionWindow(0), pair(at(0), 1))); + + assertThat(agg.merge(acc, agg.zero()), equalTo(acc)); + assertThat(agg.merge(agg.zero(), acc), equalTo(acc)); + + acc = agg.merge(acc, acc); + assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 1 + 1)))); + + acc = agg.merge(acc, accOf(KV.of(sessionWindow(20), pair(at(20), 2)))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 1 + 1)), + KV.of(sessionWindow(20), pair(at(20), 2)))); + + acc = agg.merge(accOf(KV.of(sessionWindow(40), pair(at(40), 3))), acc); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0), pair(at(0), 1 + 1)), + KV.of(sessionWindow(20), pair(at(20), 2)), + KV.of(sessionWindow(40), pair(at(40), 3)))); + + acc = agg.merge(acc, accOf(KV.of(sessionWindow(10), pair(at(10), 4)))); + assertThat( + acc, + equalsToMap( + KV.of(sessionWindow(0, 35), pair(at(20), 1 + 1 + 2 + 4)), + KV.of(sessionWindow(40), pair(at(40), 3)))); + + acc = agg.merge(accOf(KV.of(sessionWindow(5, 45), pair(at(30), 5))), acc); + assertThat( + acc, equalsToMap(KV.of(sessionWindow(0, 55), pair(at(40), 1 + 1 + 2 + 4 + 3 + 5)))); + } + + private WindowedValue sessionValue(Integer value, Instant ts) { + return WindowedValue.of(value, ts, new IntervalWindow(ts, SESSIONS_GAP), PaneInfo.NO_FIRING); + } + + private IntervalWindow sessionWindow(int fromMinutes) { + return new IntervalWindow(at(fromMinutes), SESSIONS_GAP); + } + + private static IntervalWindow sessionWindow(int fromMinutes, int toMinutes) { + return intervalWindow(fromMinutes, toMinutes); + } + } + + /** Tests for specialized SessionsAggregator in {@link Aggregators}. */ + public static class SessionsAggregatorTest + extends AbstractSessionsTest>> { + + public SessionsAggregatorTest() { + super(Sessions.withGapDuration(SESSIONS_GAP)); + } + + @Override + TreeMap> accOf( + KV>... entries) { + return new TreeMap<>(mapOf(entries)); + } + } + + /** Tests for MergingWindowedAggregator in {@link Aggregators}. */ + public static class MergingWindowedAggregatorTest + extends AbstractSessionsTest>> { + + public MergingWindowedAggregatorTest() { + super(new CustomSessions<>()); + } + + @Override + Map> accOf( + KV>... entries) { + return mapOf(entries); + } + + /** Wrapper around {@link Sessions} to test the MergingWindowedAggregator. */ + private static class CustomSessions extends WindowFn { + private final Sessions sessions = Sessions.withGapDuration(SESSIONS_GAP); + + @Override + public Collection assignWindows(WindowFn.AssignContext c) { + return sessions.assignWindows((WindowFn.AssignContext) c); + } + + @Override + public void mergeWindows(WindowFn.MergeContext c) throws Exception { + sessions.mergeWindows((WindowFn.MergeContext) c); + } + + @Override + public boolean isCompatible(WindowFn other) { + return sessions.isCompatible(other); + } + + @Override + public Coder windowCoder() { + return sessions.windowCoder(); + } + + @Override + public WindowMappingFn getDefaultWindowMappingFn() { + return sessions.getDefaultWindowMappingFn(); + } + } + } + + private static IntervalWindow intervalWindow(int fromMinutes, int toMinutes) { + return new IntervalWindow(at(fromMinutes), at(toMinutes)); + } + + private static Instant at(int minutes) { + return NOW.plus(standardMinutes(minutes)); + } + + private static Matcher>> equalsToMap( + KV>... entries) { + return equalTo(mapOf(entries)); + } + + private static Map> mapOf( + KV>... entries) { + return Arrays.asList(entries).stream().collect(Collectors.toMap(KV::getKey, KV::getValue)); + } + + private static MutablePair pair(Instant ts, int value) { + return new MutablePair<>(ts, value); + } + + private static + Aggregator, AccT, Collection>> windowedAgg( + WindowFn windowFn) { + Encoder intEnc = EncoderHelpers.encoderOf(Integer.class); + Encoder windowEnc = encoderFor((Coder) IntervalWindow.getCoder()); + Encoder> outputEnc = windowedValueEncoder(intEnc, windowEnc); + + WindowingStrategy windowing = + WindowingStrategy.of(windowFn).withTimestampCombiner(TimestampCombiner.LATEST); + + Aggregator, ?, Collection>> agg = + Aggregators.windowedValue( + new SimpleSum(), WindowedValue::getValue, windowing, windowEnc, intEnc, outputEnc); + return (Aggregator) agg; + } + + private static class SimpleSum extends Combine.CombineFn { + + @Override + public Integer createAccumulator() { + return 0; + } + + @Override + public Integer addInput(Integer acc, Integer input) { + return acc + input; + } + + @Override + public Integer mergeAccumulators(Iterable accs) { + return Streams.stream(accs.iterator()).reduce((a, b) -> a + b).orElseGet(() -> 0); + } + + @Override + public Integer extractOutput(Integer acc) { + return acc; + } + } +} diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java new file mode 100644 index 000000000000..dca8b664bd3d --- /dev/null +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.structuredstreaming.translation.batch; + +import java.io.Serializable; +import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; +import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Combine.BinaryCombineFn; +import org.apache.beam.sdk.transforms.CombineFnBase; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Test class for beam to spark {@link Combine#globally(CombineFnBase.GlobalCombineFn)} translation. + */ +@RunWith(JUnit4.class) +public class CombineGloballyTest implements Serializable { + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); + + private static PipelineOptions testOptions() { + SparkStructuredStreamingPipelineOptions options = + PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); + options.setRunner(SparkStructuredStreamingRunner.class); + options.setTestMode(true); + return options; + } + + @Test + public void testCombineGlobally() { + PCollection input = + pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(Sum.integersGlobally()); + PAssert.that(input).containsInAnyOrder(55); + // uses combine per key + pipeline.run(); + } + + @Test + public void testCombineGloballyPreservesWindowing() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(2, new Instant(2)), + TimestampedValue.of(3, new Instant(11)), + TimestampedValue.of(4, new Instant(3)), + TimestampedValue.of(5, new Instant(11)), + TimestampedValue.of(6, new Instant(12)))) + .apply(Window.into(FixedWindows.of(Duration.millis(10)))) + .apply(Sum.integersGlobally().withoutDefaults()); + PAssert.that(input).containsInAnyOrder(7, 14); + pipeline.run(); + } + + @Test + public void testCombineGloballyWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(3, new Instant(2)), + TimestampedValue.of(5, new Instant(3)), + TimestampedValue.of(2, new Instant(1)), + TimestampedValue.of(4, new Instant(2)), + TimestampedValue.of(6, new Instant(3)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1)))) + .apply(Sum.integersGlobally().withoutDefaults()); + PAssert.that(input) + .containsInAnyOrder(1 + 2, 1 + 2 + 3 + 4, 1 + 3 + 5 + 2 + 4 + 6, 3 + 4 + 5 + 6, 5 + 6); + pipeline.run(); + } + + @Test + public void testCombineGloballyWithMergingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(2, new Instant(5)), + TimestampedValue.of(4, new Instant(11)), + TimestampedValue.of(6, new Instant(12)))) + .apply(Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(Sum.integersGlobally().withoutDefaults()); + + PAssert.that(input).containsInAnyOrder(2 /*window [5-10)*/, 10 /*window [11-17)*/); + pipeline.run(); + } + + @Test + public void testCountGloballyWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of("a", new Instant(1)), + TimestampedValue.of("a", new Instant(2)), + TimestampedValue.of("a", new Instant(2)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1)))); + PCollection output = + input.apply(Combine.globally(Count.combineFn()).withoutDefaults()); + PAssert.that(output).containsInAnyOrder(1L, 3L, 2L); + pipeline.run(); + } + + @Test + public void testBinaryCombineWithSlidingWindows() { + PCollection input = + pipeline + .apply( + Create.timestamped( + TimestampedValue.of(1, new Instant(1)), + TimestampedValue.of(3, new Instant(2)), + TimestampedValue.of(5, new Instant(3)))) + .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1)))) + .apply( + Combine.globally(BinaryCombineFn.of((i1, i2) -> i1 > i2 ? i1 : i2)) + .withoutDefaults()); + PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5); + pipeline.run(); + } +} diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java similarity index 71% rename from runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java rename to runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java index 52e60a3db545..c8b25b3355dd 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java @@ -22,65 +22,44 @@ import java.util.List; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.CombineFnBase; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.SerializableBiFunction; +import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Test class for beam to spark {@link org.apache.beam.sdk.transforms.Combine} translation. */ +/** + * Test class for beam to spark {@link + * org.apache.beam.sdk.transforms.Combine#perKey(CombineFnBase.GlobalCombineFn)} translation. + */ @RunWith(JUnit4.class) -public class CombineTest implements Serializable { - private static Pipeline pipeline; +public class CombinePerKeyTest implements Serializable { + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); - } - - @Test - public void testCombineGlobally() { - PCollection input = - pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(Sum.integersGlobally()); - PAssert.that(input).containsInAnyOrder(55); - // uses combine per key - pipeline.run(); - } - - @Test - public void testCombineGloballyPreservesWindowing() { - PCollection input = - pipeline - .apply( - Create.timestamped( - TimestampedValue.of(1, new Instant(1)), - TimestampedValue.of(2, new Instant(2)), - TimestampedValue.of(3, new Instant(11)), - TimestampedValue.of(4, new Instant(3)), - TimestampedValue.of(5, new Instant(11)), - TimestampedValue.of(6, new Instant(12)))) - .apply(Window.into(FixedWindows.of(Duration.millis(10)))) - .apply(Combine.globally(Sum.ofIntegers()).withoutDefaults()); - PAssert.that(input).containsInAnyOrder(7, 14); + return options; } @Test @@ -99,6 +78,17 @@ public void testCombinePerKey() { pipeline.run(); } + @Test + public void testDistinctViaCombinePerKey() { + List elems = Lists.newArrayList(1, 2, 3, 3, 4, 4, 4, 4, 5, 5); + + // Distinct is implemented in terms of CombinePerKey + PCollection result = pipeline.apply(Create.of(elems)).apply(Distinct.create()); + + PAssert.that(result).containsInAnyOrder(1, 2, 3, 4, 5); + pipeline.run(); + } + @Test public void testCombinePerKeyPreservesWindowing() { PCollection> input = @@ -142,22 +132,26 @@ public void testCombinePerKeyWithSlidingWindows() { } @Test - public void testBinaryCombineWithSlidingWindows() { - PCollection input = + public void testCombineByKeyWithMergingWindows() { + PCollection> input = pipeline .apply( Create.timestamped( - TimestampedValue.of(1, new Instant(1)), - TimestampedValue.of(3, new Instant(2)), - TimestampedValue.of(5, new Instant(3)))) - .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1)))) - .apply( - Combine.globally( - Combine.BinaryCombineFn.of( - (SerializableBiFunction) - (integer1, integer2) -> integer1 > integer2 ? integer1 : integer2)) - .withoutDefaults()); - PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5); + TimestampedValue.of(KV.of(1, 1), new Instant(5)), + TimestampedValue.of(KV.of(1, 3), new Instant(7)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(5)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12)))) + .apply(Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(Sum.integersPerKey()); + + PAssert.that(input) + .containsInAnyOrder( + KV.of(1, 9), // window [5-16) + KV.of(2, 2), // window [5-10) + KV.of(2, 10) // window [11-17) + ); pipeline.run(); } diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java index 0175d03f8753..582a31a05a6a 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java @@ -27,13 +27,15 @@ import java.util.List; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.values.PCollection; import org.junit.BeforeClass; import org.junit.ClassRule; +import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; @@ -46,15 +48,18 @@ public class ComplexSourceTest implements Serializable { private static File file; private static List lines = createLines(30); - private static Pipeline pipeline; + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() throws IOException { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); + return options; + } + + @BeforeClass + public static void beforeClass() throws IOException { file = createFile(lines); } diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java index e126d06e6852..50b443da9ae6 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java @@ -20,14 +20,15 @@ import java.io.Serializable; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; -import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,15 +36,14 @@ /** Test class for beam to spark flatten translation. */ @RunWith(JUnit4.class) public class FlattenTest implements Serializable { - private static Pipeline pipeline; + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); + return options; } @Test diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java index 07850232853a..1a84466b319b 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java @@ -17,30 +17,41 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; +import static java.util.Arrays.stream; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.toList; import static org.apache.beam.sdk.testing.SerializableMatchers.containsInAnyOrder; import static org.hamcrest.MatcherAssert.assertThat; import java.io.Serializable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.SerializableMatcher; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Sessions; +import org.apache.beam.sdk.transforms.windowing.SlidingWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TimestampedValue; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -48,15 +59,14 @@ /** Test class for beam to spark {@link ParDo} translation. */ @RunWith(JUnit4.class) public class GroupByKeyTest implements Serializable { - private static Pipeline pipeline; + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); + return options; } @Test @@ -64,54 +74,89 @@ public void testGroupByKeyPreservesWindowing() { pipeline .apply( Create.timestamped( - TimestampedValue.of(KV.of(1, 1), new Instant(1)), - TimestampedValue.of(KV.of(1, 3), new Instant(2)), - TimestampedValue.of(KV.of(1, 5), new Instant(11)), - TimestampedValue.of(KV.of(2, 2), new Instant(3)), - TimestampedValue.of(KV.of(2, 4), new Instant(11)), - TimestampedValue.of(KV.of(2, 6), new Instant(12)))) + shuffleRandomly( + TimestampedValue.of(KV.of(1, 1), new Instant(1)), + TimestampedValue.of(KV.of(1, 3), new Instant(2)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(3)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12))))) .apply(Window.into(FixedWindows.of(Duration.millis(10)))) .apply(GroupByKey.create()) - // do manual assertion for windows because Passert do not support multiple kv with same key - // (because multiple windows) + // Passert do not support multiple kv with same key (because multiple windows) .apply( ParDo.of( - new DoFn>, KV>>() { + new AssertContains<>( + KV.of(1, containsInAnyOrder(1, 3)), // window [0-10) + KV.of(1, containsInAnyOrder(5)), // window [10-20) + KV.of(2, containsInAnyOrder(4, 6)), // window [10-20) + KV.of(2, containsInAnyOrder(2)) // window [0-10) + ))); + pipeline.run(); + } - @ProcessElement - public void processElement(ProcessContext context) { - KV> element = context.element(); - if (element.getKey() == 1) { - if (Iterables.size(element.getValue()) == 2) { - assertThat(element.getValue(), containsInAnyOrder(1, 3)); // window [0-10) - } else { - assertThat(element.getValue(), containsInAnyOrder(5)); // window [10-20) - } - } else { // key == 2 - if (Iterables.size(element.getValue()) == 2) { - assertThat(element.getValue(), containsInAnyOrder(4, 6)); // window [10-20) - } else { - assertThat(element.getValue(), containsInAnyOrder(2)); // window [0-10) - } - } - context.output(element); - } - })); + @Test + public void testGroupByKeyExplodesMultipleWindows() { + pipeline + .apply( + Create.timestamped( + shuffleRandomly( + TimestampedValue.of(KV.of(1, 1), new Instant(5)), + TimestampedValue.of(KV.of(1, 3), new Instant(7)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(5)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12))))) + .apply(Window.into(SlidingWindows.of(Duration.millis(10)).every(Duration.millis(5)))) + .apply(GroupByKey.create()) + // Passert do not support multiple kv with same key (because multiple windows) + .apply( + ParDo.of( + new AssertContains<>( + KV.of(1, containsInAnyOrder(1, 3)), // window [0-10) + KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-15) + KV.of(1, containsInAnyOrder(5)), // window [10-20) + KV.of(2, containsInAnyOrder(2)), // window [0-10) + KV.of(2, containsInAnyOrder(2, 4, 6)), // window [5-15) + KV.of(2, containsInAnyOrder(4, 6)) // window [10-20) + ))); + pipeline.run(); + } + + @Test + public void testGroupByKeyWithMergingWindows() { + pipeline + .apply( + Create.timestamped( + shuffleRandomly( + TimestampedValue.of(KV.of(1, 1), new Instant(5)), + TimestampedValue.of(KV.of(1, 3), new Instant(7)), + TimestampedValue.of(KV.of(1, 5), new Instant(11)), + TimestampedValue.of(KV.of(2, 2), new Instant(5)), + TimestampedValue.of(KV.of(2, 4), new Instant(11)), + TimestampedValue.of(KV.of(2, 6), new Instant(12))))) + .apply(Window.into(Sessions.withGapDuration(Duration.millis(5)))) + .apply(GroupByKey.create()) + // Passert do not support multiple kv with same key (because multiple windows) + .apply( + ParDo.of( + new AssertContains<>( + KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-16) + KV.of(2, containsInAnyOrder(2)), // window [5-10) + KV.of(2, containsInAnyOrder(4, 6)) // window [11-17) + ))); pipeline.run(); } @Test public void testGroupByKey() { - List> elems = new ArrayList<>(); - elems.add(KV.of(1, 1)); - elems.add(KV.of(1, 3)); - elems.add(KV.of(1, 5)); - elems.add(KV.of(2, 2)); - elems.add(KV.of(2, 4)); - elems.add(KV.of(2, 6)); + List> elems = + shuffleRandomly( + KV.of(1, 1), KV.of(1, 3), KV.of(1, 5), KV.of(2, 2), KV.of(2, 4), KV.of(2, 6)); PCollection>> input = pipeline.apply(Create.of(elems)).apply(GroupByKey.create()); + PAssert.thatMap(input) .satisfies( results -> { @@ -121,4 +166,27 @@ public void testGroupByKey() { }); pipeline.run(); } + + static class AssertContains extends DoFn>, Void> { + private final Map>>> byKey; + + public AssertContains(KV>>... matchers) { + byKey = stream(matchers).collect(groupingBy(KV::getKey, mapping(KV::getValue, toList()))); + } + + @ProcessElement + public void processElement(@Element KV> elem) { + assertThat("Unexpected key: " + elem.getKey(), byKey.containsKey(elem.getKey())); + List values = ImmutableList.copyOf(elem.getValue()); + assertThat( + "Unexpected values " + values + " for key " + elem.getKey(), + byKey.get(elem.getKey()).stream().anyMatch(m -> m.matches(values))); + } + } + + private List shuffleRandomly(T... elems) { + ArrayList list = Lists.newArrayList(elems); + Collections.shuffle(list); + return list; + } } diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java index 16d9a8b7fa87..f319173ed2bb 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java @@ -20,19 +20,24 @@ import java.io.Serializable; import java.util.List; import java.util.Map; +import org.apache.beam.runners.spark.SparkCommonPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; -import org.junit.BeforeClass; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -40,15 +45,14 @@ /** Test class for beam to spark {@link ParDo} translation. */ @RunWith(JUnit4.class) public class ParDoTest implements Serializable { - private static Pipeline pipeline; + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); + return options; } @Test @@ -59,6 +63,44 @@ public void testPardo() { pipeline.run(); } + @Test + public void testPardoWithOutputTagsCachedRDD() { + pardoWithOutputTags("MEMORY_ONLY"); + } + + @Test + public void testPardoWithOutputTagsCachedDataset() { + pardoWithOutputTags("MEMORY_AND_DISK"); + } + + private void pardoWithOutputTags(String storageLevel) { + pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel); + + TupleTag even = new TupleTag() {}; + TupleTag unevenAsString = new TupleTag() {}; + + DoFn doFn = + new DoFn() { + @ProcessElement + public void processElement(@Element Integer i, MultiOutputReceiver out) { + if (i % 2 == 0) { + out.get(even).output(i); + } else { + out.get(unevenAsString).output(i.toString()); + } + } + }; + + PCollectionTuple outputs = + pipeline + .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + .apply(ParDo.of(doFn).withOutputTags(even, TupleTagList.of(unevenAsString))); + + PAssert.that(outputs.get(even)).containsInAnyOrder(2, 4, 6, 8, 10); + PAssert.that(outputs.get(unevenAsString)).containsInAnyOrder("1", "3", "5", "7", "9"); + pipeline.run(); + } + @Test public void testTwoPardoInRow() { PCollection input = diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java index 70cdca630b9b..0f16b6442221 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java @@ -20,12 +20,13 @@ import java.io.Serializable; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; -import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -33,15 +34,14 @@ /** Test class for beam to spark source translation. */ @RunWith(JUnit4.class) public class SimpleSourceTest implements Serializable { - private static Pipeline pipeline; + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); + return options; } @Test diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java index b8b41010a24b..28efe754ddf6 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java @@ -20,9 +20,10 @@ import java.io.Serializable; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions; import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner; -import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.windowing.FixedWindows; @@ -31,7 +32,7 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,15 +40,14 @@ /** Test class for beam to spark window assign translation. */ @RunWith(JUnit4.class) public class WindowAssignTest implements Serializable { - private static Pipeline pipeline; + @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions()); - @BeforeClass - public static void beforeClass() { + private static PipelineOptions testOptions() { SparkStructuredStreamingPipelineOptions options = PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class); options.setRunner(SparkStructuredStreamingRunner.class); options.setTestMode(true); - pipeline = Pipeline.create(options); + return options; } @Test diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java index c8a8fba8d281..ab6e3083c54a 100644 --- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java +++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java @@ -18,32 +18,95 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; import static java.util.Arrays.asList; -import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.fromBeamCoder; +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder; +import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.notNull; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.apache.spark.sql.types.DataTypes.createStructField; +import static org.apache.spark.sql.types.DataTypes.createStructType; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; -import static org.junit.Assert.assertEquals; +import static org.hamcrest.Matchers.instanceOf; -import java.util.Arrays; +import java.math.BigDecimal; +import java.math.MathContext; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.TreeMap; +import java.util.function.Function; import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule; +import org.apache.beam.sdk.coders.BigDecimalCoder; +import org.apache.beam.sdk.coders.BigEndianIntegerCoder; +import org.apache.beam.sdk.coders.BigEndianLongCoder; +import org.apache.beam.sdk.coders.BigEndianShortCoder; +import org.apache.beam.sdk.coders.BooleanCoder; +import org.apache.beam.sdk.coders.ByteCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.DelegateCoder; +import org.apache.beam.sdk.coders.DoubleCoder; +import org.apache.beam.sdk.coders.FloatCoder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.joda.time.Instant; import org.junit.ClassRule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import scala.Tuple2; /** Test of the wrapping of Beam Coders as Spark ExpressionEncoders. */ @RunWith(JUnit4.class) public class EncoderHelpersTest { + @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule("local[1]"); - @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule(); + private static final Encoder windowEnc = + EncoderHelpers.encoderOf(GlobalWindow.class); + + private static final Map, List> BASIC_CASES = + ImmutableMap., List>builder() + .put(BooleanCoder.of(), asList(true, false, null)) + .put(ByteCoder.of(), asList((byte) 1, null)) + .put(BigEndianShortCoder.of(), asList((short) 1, null)) + .put(BigEndianIntegerCoder.of(), asList(1, 2, 3, null)) + .put(VarIntCoder.of(), asList(1, 2, 3, null)) + .put(BigEndianLongCoder.of(), asList(1L, 2L, 3L, null)) + .put(VarLongCoder.of(), asList(1L, 2L, 3L, null)) + .put(FloatCoder.of(), asList((float) 1.0, (float) 2.0, null)) + .put(DoubleCoder.of(), asList(1.0, 2.0, null)) + .put(StringUtf8Coder.of(), asList("1", "2", null)) + .put(BigDecimalCoder.of(), asList(bigDecimalOf(1L), bigDecimalOf(2L), null)) + .put(InstantCoder.of(), asList(Instant.ofEpochMilli(1), null)) + .build(); private Dataset createDataset(List data, Encoder encoder) { Dataset ds = sessionRule.getSession().createDataset(data, encoder); @@ -52,10 +115,14 @@ private Dataset createDataset(List data, Encoder encoder) { } @Test - public void beamCoderToSparkEncoderTest() { - List data = Arrays.asList(1, 2, 3); - Dataset dataset = createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of())); - assertEquals(data, dataset.collectAsList()); + public void testBeamEncoderMappings() { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder encoder = encoderFor(coder); + serializeAndDeserialize(data.get(0), (Encoder) encoder); + Dataset dataset = createDataset(data, (Encoder) encoder); + assertThat(dataset.collect(), equalTo(data.toArray())); + }); } @Test @@ -63,10 +130,135 @@ public void testBeamEncoderOfPrivateType() { // Verify concrete types are not used in coder generation. // In case of private types this would cause an IllegalAccessError. List data = asList(new PrivateString("1"), new PrivateString("2")); - Dataset dataset = createDataset(data, fromBeamCoder(PrivateString.CODER)); + Dataset dataset = createDataset(data, encoderFor(PrivateString.CODER)); + assertThat(dataset.collect(), equalTo(data.toArray())); + } + + @Test + public void testBeamWindowedValueEncoderMappings() { + BASIC_CASES.forEach( + (coder, data) -> { + List> windowed = + Lists.transform(data, WindowedValue::valueInGlobalWindow); + + Encoder encoder = windowedValueEncoder(encoderFor(coder), windowEnc); + serializeAndDeserialize(windowed.get(0), (Encoder) encoder); + + Dataset dataset = createDataset(windowed, (Encoder) encoder); + assertThat(dataset.collect(), equalTo(windowed.toArray())); + }); + } + + @Test + public void testCollectionEncoder() { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder> encoder = collectionEncoder(encoderFor(coder), true); + Collection collection = Collections.unmodifiableCollection(data); + + Dataset> dataset = createDataset(asList(collection), (Encoder) encoder); + assertThat(dataset.head(), equalTo(data)); + }); + } + + private void testMapEncoder(Class cls, Function, Map> decorator) { + BASIC_CASES.forEach( + (coder, data) -> { + Encoder enc = encoderFor(coder); + Encoder> mapEncoder = mapEncoder(enc, enc, (Class) cls); + Map map = + decorator.apply( + data.stream().filter(notNull()).collect(toMap(identity(), identity()))); + + Dataset> dataset = createDataset(asList(map), mapEncoder); + Map head = dataset.head(); + assertThat(head, equalTo(map)); + assertThat(head, instanceOf(cls)); + }); + } + + @Test + public void testMapEncoder() { + testMapEncoder(Map.class, identity()); + } + + @Test + public void testHashMapEncoder() { + testMapEncoder(HashMap.class, identity()); + } + + @Test + public void testTreeMapEncoder() { + testMapEncoder(TreeMap.class, TreeMap::new); + } + + @Test + public void testBeamBinaryEncoder() { + List> data = asList(asList("a1", "a2", "a3"), asList("b1", "b2"), asList("c1")); + + Encoder> encoder = encoderFor(ListCoder.of(StringUtf8Coder.of())); + serializeAndDeserialize(data.get(0), encoder); + + Dataset> dataset = createDataset(data, encoder); assertThat(dataset.collect(), equalTo(data.toArray())); } + @Test + public void testEncoderForKVCoder() { + List> data = + asList(KV.of(1, "value1"), KV.of(null, "value2"), KV.of(3, null)); + + Encoder> encoder = + kvEncoder(encoderFor(VarIntCoder.of()), encoderFor(StringUtf8Coder.of())); + serializeAndDeserialize(data.get(0), encoder); + + Dataset> dataset = createDataset(data, encoder); + + StructType kvSchema = + createStructType( + new StructField[] { + createStructField("key", IntegerType, true), + createStructField("value", StringType, true) + }); + + assertThat(dataset.schema(), equalTo(kvSchema)); + assertThat(dataset.collectAsList(), equalTo(data)); + } + + @Test + public void testOneOffEncoder() { + List> coders = ImmutableList.copyOf(BASIC_CASES.keySet()); + List> encoders = coders.stream().map(EncoderHelpers::encoderFor).collect(toList()); + + // build oneOf tuples of type index and corresponding value + List> data = + BASIC_CASES.entrySet().stream() + .map(e -> tuple(coders.indexOf(e.getKey()), (Object) e.getValue().get(0))) + .collect(toList()); + + // dataset is a sparse dataset with only one column set per row + Dataset> dataset = createDataset(data, oneOfEncoder((List) encoders)); + assertThat(dataset.collectAsList(), equalTo(data)); + } + + // fix scale/precision to system default to compare using equals + private static BigDecimal bigDecimalOf(long l) { + DecimalType type = DecimalType.SYSTEM_DEFAULT(); + return new BigDecimal(l, new MathContext(type.precision())).setScale(type.scale()); + } + + // test and explicit serialization roundtrip + private static void serializeAndDeserialize(T data, Encoder enc) { + ExpressionEncoder bound = (ExpressionEncoder) enc; + bound = + bound.resolveAndBind(bound.resolveAndBind$default$1(), bound.resolveAndBind$default$2()); + + InternalRow row = bound.createSerializer().apply(data); + T deserialized = bound.createDeserializer().apply(row); + + assertThat(deserialized, equalTo(data)); + } + private static class PrivateString { private static final Coder CODER = DelegateCoder.of( diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java index 95e7865e6f9e..461cfc80917e 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java @@ -50,6 +50,12 @@ public interface SparkCommonPipelineOptions void setCheckpointDir(String checkpointDir); + @Description("Batch default storage level") + @Default.String("MEMORY_ONLY") + String getStorageLevel(); + + void setStorageLevel(String storageLevel); + @Description("Enable/disable sending aggregator values to Spark's metric sinks") @Default.Boolean(true) Boolean getEnableSparkMetricSinks(); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java index ecf933336c97..9a5229f21ae1 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java @@ -34,12 +34,6 @@ public interface SparkPipelineOptions extends SparkCommonPipelineOptions { void setBatchIntervalMillis(Long batchInterval); - @Description("Batch default storage level") - @Default.String("MEMORY_ONLY") - String getStorageLevel(); - - void setStorageLevel(String storageLevel); - @Description("Minimum time to spend on read, for each micro-batch.") @Default.Long(200) Long getMinReadTimeMillis();