Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for SparkRunner streaming stateful processing #33267

Merged
merged 21 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6a02a0e
change SparkAndTimers as public class
twosom Nov 26, 2024
592e175
move getBatchDuration method from SparkGroupAlsoByWindowViaWindowSet …
twosom Nov 26, 2024
c5d453c
move checkpointIfNeeded from SparkGroupAlsoByWindowViaWindowSet to Tr…
twosom Nov 26, 2024
23e3e18
chore : add license header in StateAndTimers
twosom Nov 28, 2024
13fb728
refactor : add public modifier for using different packages
twosom Nov 28, 2024
b974707
chore : remove @Nullable annotation
twosom Dec 1, 2024
cc6e7f3
feat : change SparkStateInternals to public
twosom Dec 1, 2024
9713650
feat : implementation spark stateful ParDo
twosom Dec 1, 2024
23c492c
feat : change SparkStateInternals and SparkTimerInternals
twosom Dec 3, 2024
e7ff5d0
chore : spotless
twosom Dec 3, 2024
62f737b
feat : change StatefulStreamingParDoEvaluator to use mapPartitions (f…
twosom Dec 3, 2024
821de6e
feat : add test for StatefulStreamingParDoEvaluator
twosom Dec 3, 2024
ac8b0bf
feat : change spark_runner.gradle for test state and timers
twosom Dec 3, 2024
e0d3f8e
chore : remove Logger in StatefulStreamingParDoEvaluatorTest
twosom Dec 3, 2024
a9f482a
chore : spotlessApply
twosom Dec 3, 2024
fd883c7
chore : spotlessApply
twosom Dec 3, 2024
a8d7094
update javadoc
twosom Dec 3, 2024
f8b9916
refactor : replace flatMapToPair mapToPair chain to single flatMapToPair
twosom Dec 3, 2024
7493edc
touch trigger files
twosom Dec 6, 2024
51071ae
Merge branch 'master' into spark-streaming-stateful-pardo
twosom Dec 15, 2024
eda50b2
Merge branch 'master' into spark-streaming-stateful-pardo
twosom Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test"
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33267": "noting that PR #33267 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
}
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@

## New Features / Improvements

* Added support for stateful processing in Spark Runner for streaming pipelines. Timer functionality is not yet supported and will be implemented in a future release ([#33237](https://github.com/apache/beam/issues/33237)).
* Improved batch performance of SparkRunner's GroupByKey ([#20943](https://github.com/apache/beam/pull/20943)).
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
* Support OrderedListState in Prism ([#32929](https://github.com/apache/beam/issues/32929)).
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion runners/spark/spark_runner.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def validatesRunnerStreaming = tasks.register("validatesRunnerStreaming", Test)
excludeCategories 'org.apache.beam.sdk.testing.UsesSdkHarnessEnvironment'

// State and Timers
excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimersInParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap'
excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.util.ArrayList;
import java.util.LinkedHashMap;
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet.StateAndTimers;
import org.apache.beam.runners.spark.stateful.StateAndTimers;
import org.apache.beam.runners.spark.translation.ValueAndCoderKryoSerializer;
import org.apache.beam.runners.spark.translation.ValueAndCoderLazySerializable;
import org.apache.beam.runners.spark.util.ByteArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.beam.runners.spark.stateful;

import static org.apache.beam.runners.spark.translation.TranslationUtils.checkpointIfNeeded;
import static org.apache.beam.runners.spark.translation.TranslationUtils.getBatchDuration;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -35,7 +38,6 @@
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
import org.apache.beam.runners.core.triggers.TriggerStateMachines;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.translation.ReifyTimestampsAndWindowsFunction;
import org.apache.beam.runners.spark.translation.TranslationUtils;
Expand All @@ -60,10 +62,8 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.dstream.DStream;
Expand Down Expand Up @@ -100,27 +100,6 @@ public class SparkGroupAlsoByWindowViaWindowSet implements Serializable {
private static final Logger LOG =
LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class);

/** State and Timers wrapper. */
public static class StateAndTimers implements Serializable {
// Serializable state for internals (namespace to state tag to coded value).
private final Table<String, String, byte[]> state;
private final Collection<byte[]> serTimers;

private StateAndTimers(
final Table<String, String, byte[]> state, final Collection<byte[]> timers) {
this.state = state;
this.serTimers = timers;
}

Table<String, String, byte[]> getState() {
return state;
}

Collection<byte[]> getTimers() {
return serTimers;
}
}

private static class OutputWindowedValueHolder<K, V>
implements OutputWindowedValue<KV<K, Iterable<V>>> {
private final List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new ArrayList<>();
Expand Down Expand Up @@ -348,7 +327,7 @@ private Collection<TimerInternals.TimerData> filterTimersEligibleForProcessing(

// empty outputs are filtered later using DStream filtering
final StateAndTimers updated =
new StateAndTimers(
StateAndTimers.of(
stateInternals.getState(),
SparkTimerInternals.serializeTimers(
timerInternals.getTimers(), timerDataCoder));
Expand Down Expand Up @@ -466,21 +445,6 @@ private static <W extends BoundedWindow> TimerInternals.TimerDataCoderV2 timerDa
return TimerInternals.TimerDataCoderV2.of(windowingStrategy.getWindowFn().windowCoder());
}

private static void checkpointIfNeeded(
final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> firedStream,
final SerializablePipelineOptions options) {

final Long checkpointDurationMillis = getBatchDuration(options);

if (checkpointDurationMillis > 0) {
firedStream.checkpoint(new Duration(checkpointDurationMillis));
}
}

private static Long getBatchDuration(final SerializablePipelineOptions options) {
return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
}

private static <K, InputT> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> stripStateValues(
final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> firedStream,
final Coder<K> keyCoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
class SparkStateInternals<K> implements StateInternals {
public class SparkStateInternals<K> implements StateInternals {

private final K key;
// Serializable state for internals (namespace to state tag to coded value).
Expand All @@ -79,11 +79,11 @@ private SparkStateInternals(K key, Table<String, String, byte[]> stateTable) {
this.stateTable = stateTable;
}

static <K> SparkStateInternals<K> forKey(K key) {
public static <K> SparkStateInternals<K> forKey(K key) {
return new SparkStateInternals<>(key);
}

static <K> SparkStateInternals<K> forKeyAndState(
public static <K> SparkStateInternals<K> forKeyAndState(
K key, Table<String, String, byte[]> stateTable) {
return new SparkStateInternals<>(key, stateTable);
}
Expand Down Expand Up @@ -412,17 +412,25 @@ public void put(MapKeyT key, MapValueT value) {
@Override
public ReadableState<MapValueT> computeIfAbsent(
MapKeyT key, Function<? super MapKeyT, ? extends MapValueT> mappingFunction) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
Map<MapKeyT, MapValueT> sparkMapState = readAsMap();
MapValueT current = sparkMapState.get(key);
if (current == null) {
put(key, mappingFunction.apply(key));
}
return ReadableStates.immediate(current);
}

private Map<MapKeyT, MapValueT> readAsMap() {
Map<MapKeyT, MapValueT> mapState = readValue();
if (mapState == null) {
mapState = new HashMap<>();
}
return mapState;
}

@Override
public void remove(MapKeyT key) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
Map<MapKeyT, MapValueT> sparkMapState = readAsMap();
sparkMapState.remove(key);
writeValue(sparkMapState);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public Collection<TimerData> getTimers() {
return timers;
}

void addTimers(Iterator<TimerData> timers) {
public void addTimers(Iterator<TimerData> timers) {
while (timers.hasNext()) {
TimerData timer = timers.next();
this.timers.add(timer);
Expand Down Expand Up @@ -163,7 +163,8 @@ public void setTimer(
Instant target,
Instant outputTimestamp,
TimeDomain timeDomain) {
throw new UnsupportedOperationException("Setting a timer by ID not yet supported.");
this.setTimer(
TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.stateful;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.Collection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;

/** State and Timers wrapper. */
@AutoValue
public abstract class StateAndTimers implements Serializable {
public abstract Table<String, String, byte[]> getState();

public abstract Collection<byte[]> getTimers();

public static StateAndTimers of(
final Table<String, String, byte[]> state, final Collection<byte[]> timers) {
return new AutoValue_StateAndTimers.Builder().setState(state).setTimers(timers).build();
}

@AutoValue.Builder
abstract static class Builder {
abstract Builder setState(Table<String, String, byte[]> state);

abstract Builder setTimers(Collection<byte[]> timers);

abstract StateAndTimers build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
import org.joda.time.Instant;

/** DoFnRunner decorator which registers {@link MetricsContainerImpl}. */
class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
public class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, OutputT> {
private final DoFnRunner<InputT, OutputT> delegate;
private final String stepName;
private final MetricsContainerStepMapAccumulator metricsAccum;

DoFnRunnerWithMetrics(
public DoFnRunnerWithMetrics(
String stepName,
DoFnRunner<InputT, OutputT> delegate,
MetricsContainerStepMapAccumulator metricsAccum) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
* Processes Spark's input data iterators using Beam's {@link
* org.apache.beam.runners.core.DoFnRunner}.
*/
interface SparkInputDataProcessor<FnInputT, FnOutputT, OutputT> {
public interface SparkInputDataProcessor<FnInputT, FnOutputT, OutputT> {

/**
* @return {@link OutputManager} to be used by {@link org.apache.beam.runners.core.DoFnRunner} for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import org.apache.beam.sdk.transforms.DoFn;

/** Holds current processing context for {@link SparkInputDataProcessor}. */
class SparkProcessContext<K, InputT, OutputT> {
public class SparkProcessContext<K, InputT, OutputT> {
private final String stepName;
private final DoFn<InputT, OutputT> doFn;
private final DoFnRunner<InputT, OutputT> doFnRunner;
private final Iterator<TimerInternals.TimerData> timerDataIterator;
private final K key;

SparkProcessContext(
public SparkProcessContext(
String stepName,
DoFn<InputT, OutputT> doFn,
DoFnRunner<InputT, OutputT> doFnRunner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.apache.beam.runners.core.InMemoryStateInternals;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
Expand Down Expand Up @@ -54,8 +56,10 @@
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.dstream.DStream;
import scala.Tuple2;

/** A set of utilities to help translating Beam transformations into Spark transformations. */
Expand Down Expand Up @@ -258,6 +262,52 @@ public Boolean call(Tuple2<TupleTag<V>, WindowedValue<?>> input) {
}
}

/**
* Retrieves the batch duration in milliseconds from Spark pipeline options.
*
* @param options The serializable pipeline options containing Spark-specific settings
* @return The checkpoint duration in milliseconds as specified in SparkPipelineOptions
*/
public static Long getBatchDuration(final SerializablePipelineOptions options) {
return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
}

/**
* Reject timers {@link DoFn}.
*
* @param doFn the {@link DoFn} to possibly reject.
*/
public static void rejectTimers(DoFn<?, ?> doFn) {
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
if (signature.timerDeclarations().size() > 0
|| signature.timerFamilyDeclarations().size() > 0) {
throw new UnsupportedOperationException(
String.format(
"Found %s annotations on %s, but %s cannot yet be used with timers in the %s.",
DoFn.TimerId.class.getSimpleName(),
doFn.getClass().getName(),
DoFn.class.getSimpleName(),
SparkRunner.class.getSimpleName()));
}
}

/**
* Checkpoints the given DStream if checkpointing is enabled in the pipeline options.
*
* @param dStream The DStream to be checkpointed
* @param options The SerializablePipelineOptions containing configuration settings including
* batch duration
*/
public static void checkpointIfNeeded(
final DStream<?> dStream, final SerializablePipelineOptions options) {

final Long checkpointDurationMillis = getBatchDuration(options);

if (checkpointDurationMillis > 0) {
dStream.checkpoint(new Duration(checkpointDurationMillis));
}
}

/**
* Reject state and timers {@link DoFn}.
*
Expand Down
Loading
Loading