Skip to content

Commit

Permalink
Batch optimized SparkRunner groupByKey (#33322)
Browse files Browse the repository at this point in the history
* feat : optimized SparkRunner batch groupByKey

* update CHANGES.md

* touch trigger files

* remove unused test
  • Loading branch information
twosom authored Dec 13, 2024
1 parent 994d2f0 commit a6061fe
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 5
"modification": 6
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test"
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"comment": "Modify this file in a trivial way to cause this test suite to run",
"https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test",
"https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test",
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test"
"https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test",
"https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test"
}
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

## New Features / Improvements

* Improved batch performance of SparkRunner's GroupByKey ([#20943](https://github.com/apache/beam/pull/20943)).
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Support OnWindowExpiration in Prism ([#32211](https://github.com/apache/beam/issues/32211)).
* This enables initial Java GroupIntoBatches support.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.beam.runners.spark.translation;

import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.coders.Coder;
Expand All @@ -27,6 +30,7 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
Expand All @@ -49,18 +53,36 @@ public static <K, V> JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupByKeyOnly(
@Nullable Partitioner partitioner) {
// we use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle.
JavaPairRDD<ByteArray, byte[]> pairRDD =
rdd.map(new ReifyTimestampsAndWindowsFunction<>())
.mapToPair(TranslationUtils.toPairFunction())
.mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder));

// If no partitioner is passed, the default group by key operation is called
JavaPairRDD<ByteArray, Iterable<byte[]>> groupedRDD =
(partitioner != null) ? pairRDD.groupByKey(partitioner) : pairRDD.groupByKey();

return groupedRDD
.mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder))
.map(new TranslationUtils.FromPairFunction<>());
final JavaPairRDD<ByteArray, byte[]> pairRDD =
rdd.mapPartitionsToPair(
(Iterator<WindowedValue<KV<K, V>>> iter) ->
Iterators.transform(
iter,
(WindowedValue<KV<K, V>> wv) -> {
final K key = wv.getValue().getKey();
final WindowedValue<V> windowedValue = wv.withValue(wv.getValue().getValue());
final ByteArray keyBytes =
new ByteArray(CoderHelpers.toByteArray(key, keyCoder));
final byte[] windowedValueBytes =
CoderHelpers.toByteArray(windowedValue, wvCoder);
return Tuple2.apply(keyBytes, windowedValueBytes);
}));

final JavaPairRDD<ByteArray, List<byte[]>> combined =
GroupNonMergingWindowsFunctions.combineByKey(pairRDD, partitioner).cache();

return combined.mapPartitions(
(Iterator<Tuple2<ByteArray, List<byte[]>>> iter) ->
Iterators.transform(
iter,
(Tuple2<ByteArray, List<byte[]>> tuple) -> {
final K key = CoderHelpers.fromByteArray(tuple._1().getValue(), keyCoder);
final List<WindowedValue<V>> windowedValues =
tuple._2().stream()
.map(bytes -> CoderHelpers.fromByteArray(bytes, wvCoder))
.collect(Collectors.toList());
return KV.of(key, windowedValues);
}));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
*/
package org.apache.beam.runners.spark.translation;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.runners.spark.util.ByteArray;
Expand All @@ -41,6 +43,9 @@
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -259,34 +264,83 @@ private WindowedValue<KV<K, V>> decodeItem(Tuple2<ByteArray, byte[]> item) {
}

/**
* Group all values with a given key for that composite key with Spark's groupByKey, dropping the
* Window (which must be GlobalWindow) and returning the grouped result in the appropriate global
* window.
* Groups values with a given key using Spark's combineByKey operation in the Global Window
* context. The window information (which must be GlobalWindow) is dropped during processing, and
* the grouped results are returned in the appropriate global window with the maximum timestamp.
*
* <p>This implementation uses {@link JavaPairRDD#combineByKey} for better performance compared to
* {@link JavaPairRDD#groupByKey}, as it allows for local aggregation before shuffle operations.
*/
static <K, V, W extends BoundedWindow>
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupByKeyInGlobalWindow(
JavaRDD<WindowedValue<KV<K, V>>> rdd,
Coder<K> keyCoder,
Coder<V> valueCoder,
Partitioner partitioner) {
JavaPairRDD<ByteArray, byte[]> rawKeyValues =
rdd.mapToPair(
wv ->
new Tuple2<>(
new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)),
CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder)));

JavaPairRDD<ByteArray, Iterable<byte[]>> grouped =
(partitioner == null) ? rawKeyValues.groupByKey() : rawKeyValues.groupByKey(partitioner);
return grouped.map(
kvs ->
WindowedValue.timestampedValueInGlobalWindow(
KV.of(
CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder),
Iterables.transform(
kvs._2,
encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))),
GlobalWindow.INSTANCE.maxTimestamp(),
PaneInfo.ON_TIME_AND_ONLY_FIRING));
final JavaPairRDD<ByteArray, byte[]> rawKeyValues =
rdd.mapPartitionsToPair(
(Iterator<WindowedValue<KV<K, V>>> iter) ->
Iterators.transform(
iter,
(WindowedValue<KV<K, V>> wv) -> {
final ByteArray keyBytes =
new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder));
final byte[] valueBytes =
CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder);
return Tuple2.apply(keyBytes, valueBytes);
}));

JavaPairRDD<ByteArray, List<byte[]>> combined = combineByKey(rawKeyValues, partitioner).cache();

return combined.mapPartitions(
(Iterator<Tuple2<ByteArray, List<byte[]>>> iter) ->
Iterators.transform(
iter,
kvs ->
WindowedValue.timestampedValueInGlobalWindow(
KV.of(
CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder),
Iterables.transform(
kvs._2(),
encodedValue ->
CoderHelpers.fromByteArray(encodedValue, valueCoder))),
GlobalWindow.INSTANCE.maxTimestamp(),
PaneInfo.ON_TIME_AND_ONLY_FIRING)));
}

/**
* Combines values by key using Spark's {@link JavaPairRDD#combineByKey} operation.
*
* @param rawKeyValues Input RDD of key-value pairs
* @param partitioner Optional custom partitioner for data distribution
* @return RDD with values combined into Lists per key
*/
static JavaPairRDD<ByteArray, List<byte[]>> combineByKey(
JavaPairRDD<ByteArray, byte[]> rawKeyValues, @Nullable Partitioner partitioner) {

final Function<byte[], List<byte[]>> createCombiner =
value -> {
List<byte[]> list = new ArrayList<>();
list.add(value);
return list;
};

final Function2<List<byte[]>, byte[], List<byte[]>> mergeValues =
(list, value) -> {
list.add(value);
return list;
};

final Function2<List<byte[]>, List<byte[]>, List<byte[]>> mergeCombiners =
(list1, list2) -> {
list1.addAll(list2);
return list1;
};

if (partitioner == null) {
return rawKeyValues.combineByKey(createCombiner, mergeValues, mergeCombiners);
}

return rawKeyValues.combineByKey(createCombiner, mergeValues, mergeCombiners, partitioner);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
package org.apache.beam.runners.spark.translation;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Iterator;
Expand All @@ -45,9 +39,6 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Assert;
Expand Down Expand Up @@ -121,54 +112,6 @@ public void testGbkIteratorValuesCannotBeReiterated() throws Coder.NonDeterminis
}
}

@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void testGroupByKeyInGlobalWindowWithPartitioner() {
// mocking
Partitioner mockPartitioner = mock(Partitioner.class);
JavaRDD mockRdd = mock(JavaRDD.class);
Coder mockKeyCoder = mock(Coder.class);
Coder mockValueCoder = mock(Coder.class);
JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class);
JavaPairRDD mockGrouped = mock(JavaPairRDD.class);

when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues);
when(mockRawKeyValues.groupByKey(any(Partitioner.class)))
.thenAnswer(
invocation -> {
Partitioner partitioner = invocation.getArgument(0);
assertEquals(partitioner, mockPartitioner);
return mockGrouped;
});
when(mockGrouped.map(any())).thenReturn(mock(JavaRDD.class));

GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
mockRdd, mockKeyCoder, mockValueCoder, mockPartitioner);

verify(mockRawKeyValues, never()).groupByKey();
verify(mockRawKeyValues, times(1)).groupByKey(any(Partitioner.class));
}

@Test
@SuppressWarnings({"rawtypes", "unchecked"})
public void testGroupByKeyInGlobalWindowWithoutPartitioner() {
// mocking
JavaRDD mockRdd = mock(JavaRDD.class);
Coder mockKeyCoder = mock(Coder.class);
Coder mockValueCoder = mock(Coder.class);
JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class);
JavaPairRDD mockGrouped = mock(JavaPairRDD.class);

when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues);
when(mockRawKeyValues.groupByKey()).thenReturn(mockGrouped);

GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
mockRdd, mockKeyCoder, mockValueCoder, null);

verify(mockRawKeyValues, times(1)).groupByKey();
verify(mockRawKeyValues, never()).groupByKey(any(Partitioner.class));
}

private GroupByKeyIterator<String, Integer, GlobalWindow> createGbkIterator()
throws Coder.NonDeterministicException {
return createGbkIterator(
Expand Down

0 comments on commit a6061fe

Please sign in to comment.