diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index ce99958c57fd..9f41ea138bd5 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -117,6 +117,7 @@ import org.apache.beam.sdk.transforms.Combine.GroupedValues; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.GroupIntoBatches; import org.apache.beam.sdk.transforms.Impulse; import org.apache.beam.sdk.transforms.PTransform; @@ -214,6 +215,8 @@ public class DataflowRunner extends PipelineRunner { "unsafely_attempt_to_process_unbounded_data_in_batch_mode"; private static final Logger LOG = LoggerFactory.getLogger(DataflowRunner.class); + private static final String EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING = + "enable_gbk_state_multiplexing"; /** Provided configuration options. */ private final DataflowPipelineOptions options; @@ -596,6 +599,7 @@ protected DataflowRunner(DataflowPipelineOptions options) { private static class AlwaysCreateViaRead implements PTransformOverrideFactory, Create.Values> { + @Override public PTransformOverrideFactory.PTransformReplacement> getReplacementTransform( @@ -797,6 +801,12 @@ private List getOverrides(boolean streaming) { new RedistributeByKeyOverrideFactory())); if (streaming) { + if (DataflowRunner.hasExperiment(options, EXPERIMENT_ENABLE_GBK_STATE_MULTIPLEXING)) { + overridesBuilder.add( + PTransformOverride.of( + PTransformMatchers.classEqualTo(GroupByKey.class), + new StateMultiplexingGroupByKeyOverrideFactory<>())); + } // For update compatibility, always use a Read for Create in streaming mode. overridesBuilder .add( @@ -1180,6 +1190,7 @@ private List getDefaultArtifacts() { @VisibleForTesting static boolean isMultiLanguagePipeline(Pipeline pipeline) { class IsMultiLanguageVisitor extends PipelineVisitor.Defaults { + private boolean isMultiLanguage = false; private void performMultiLanguageTest(Node node) { @@ -1648,6 +1659,7 @@ private static EnvironmentInfo getEnvironmentInfoFromEnvironmentId( @AutoValue abstract static class EnvironmentInfo { + static EnvironmentInfo create( String environmentId, String containerUrl, List capabilities) { return new AutoValue_DataflowRunner_EnvironmentInfo( @@ -2105,6 +2117,7 @@ protected String getKindString() { } private static class StreamingPubsubSinkTranslators { + /** Rewrite {@link StreamingPubsubIOWrite} to the appropriate internal node. */ static class StreamingPubsubIOWriteTranslator implements TransformTranslator { @@ -2161,6 +2174,7 @@ private static void translate( private static class SingleOutputExpandableTransformTranslator implements TransformTranslator { + @Override public void translate( External.SingleOutputExpandableTransform transform, TranslationContext context) { @@ -2178,6 +2192,7 @@ public void translate( private static class MultiOutputExpandableTransformTranslator implements TransformTranslator { + @Override public void translate( External.MultiOutputExpandableTransform transform, TranslationContext context) { @@ -2726,6 +2741,7 @@ static void verifyStateSupportForWindowingStrategy(WindowingStrategy strategy) { */ private static class DataflowPayloadTranslator implements TransformPayloadTranslator> { + @Override public String getUrn(PTransform transform) { return "dataflow_stub:" + transform.getClass().getName(); @@ -2750,6 +2766,7 @@ public RunnerApi.FunctionSpec translate( }) @AutoService(TransformPayloadTranslatorRegistrar.class) public static class DataflowTransformTranslator implements TransformPayloadTranslatorRegistrar { + @Override public Map, ? extends TransformPayloadTranslator> getTransformPayloadTranslators() { diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java new file mode 100644 index 000000000000..468a0a95d77c --- /dev/null +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StateMultiplexingGroupByKeyOverrideFactory.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow; + +import org.apache.beam.runners.dataflow.internal.StateMultiplexingGroupByKey; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.util.construction.PTransformReplacements; +import org.apache.beam.sdk.util.construction.SingleInputOutputOverrideFactory; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; + +class StateMultiplexingGroupByKeyOverrideFactory + extends SingleInputOutputOverrideFactory< + PCollection>, PCollection>>, GroupByKey> { + + @Override + public PTransformReplacement>, PCollection>>> + getReplacementTransform( + AppliedPTransform< + PCollection>, PCollection>>, GroupByKey> + transform) { + return PTransformReplacement.of( + PTransformReplacements.getSingletonMainInput(transform), + StateMultiplexingGroupByKey.create(transform.getTransform().fewKeys())); + } +} diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java new file mode 100644 index 000000000000..661652ce3453 --- /dev/null +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/internal/StateMultiplexingGroupByKey.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.internal; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.NonDeterministicException; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark.AfterWatermarkEarlyAndLate; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark.FromEndOfWindow; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; +import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.transforms.windowing.KeyedWindow; +import org.apache.beam.sdk.transforms.windowing.KeyedWindow.KeyedWindowFn; +import org.apache.beam.sdk.transforms.windowing.Never.NeverTrigger; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.Preconditions; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString.Output; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A GroupByKey implementation that multiplexes many small user keys over a fixed set of sharding + * keys for reducing per key overhead. + */ +public class StateMultiplexingGroupByKey + extends PTransform>, PCollection>>> { + + /* + * Keys larger than this threshold will not be multiplexed. + */ + private static final int SMALL_KEY_BYTES_THRESHOLD = 4096; + private final boolean fewKeys; + private final int numShardingKeys; + + private StateMultiplexingGroupByKey(boolean fewKeys) { + // :TODO plumb fewKeys to DataflowGroupByKey + this.fewKeys = fewKeys; + // :TODO Make this configurable + this.numShardingKeys = 32 << 10; + } + + /** + * Returns a {@code GroupByKey} {@code PTransform}. + * + * @param the type of the keys of the input and output {@code PCollection}s + * @param the type of the values of the input {@code PCollection} and the elements of the + * {@code Iterable}s in the output {@code PCollection} + */ + public static StateMultiplexingGroupByKey create(boolean fewKeys) { + return new StateMultiplexingGroupByKey<>(fewKeys); + } + + ///////////////////////////////////////////////////////////////////////////// + + public static void applicableTo(PCollection input) { + WindowingStrategy windowingStrategy = input.getWindowingStrategy(); + // Verify that the input PCollection is bounded, or that there is windowing/triggering being + // used. Without this, the watermark (at end of global window) will never be reached. + if (windowingStrategy.getWindowFn() instanceof GlobalWindows + && windowingStrategy.getTrigger() instanceof DefaultTrigger + && input.isBounded() != IsBounded.BOUNDED) { + throw new IllegalStateException( + "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without a" + + " trigger. Use a Window.into or Window.triggering transform prior to GroupByKey."); + } + + // Validate that the trigger does not finish before garbage collection time + if (!triggerIsSafe(windowingStrategy)) { + throw new IllegalArgumentException( + String.format( + "Unsafe trigger '%s' may lose data, did you mean to wrap it in" + + "`Repeatedly.forever(...)`?%nSee " + + "https://s.apache.org/finishing-triggers-drop-data " + + "for details.", + windowingStrategy.getTrigger())); + } + } + + @Override + public void validate( + @Nullable PipelineOptions options, + Map, PCollection> inputs, + Map, PCollection> outputs) { + PCollection input = Iterables.getOnlyElement(inputs.values()); + KvCoder inputCoder = getInputKvCoder(input.getCoder()); + + // Ensure that the output coder key and value types aren't different. + Coder outputCoder = Iterables.getOnlyElement(outputs.values()).getCoder(); + KvCoder expectedOutputCoder = getOutputKvCoder(inputCoder); + if (!expectedOutputCoder.equals(outputCoder)) { + throw new IllegalStateException( + String.format( + "the GroupByKey requires its output coder to be %s but found %s.", + expectedOutputCoder, outputCoder)); + } + } + + // Note that Never trigger finishes *at* GC time so it is OK, and + // AfterWatermark.fromEndOfWindow() finishes at end-of-window time so it is + // OK if there is no allowed lateness. + private static boolean triggerIsSafe(WindowingStrategy windowingStrategy) { + if (!windowingStrategy.getTrigger().mayFinish()) { + return true; + } + + if (windowingStrategy.getTrigger() instanceof NeverTrigger) { + return true; + } + + if (windowingStrategy.getTrigger() instanceof FromEndOfWindow + && windowingStrategy.getAllowedLateness().getMillis() == 0) { + return true; + } + + if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate + && windowingStrategy.getAllowedLateness().getMillis() == 0) { + return true; + } + + if (windowingStrategy.getTrigger() instanceof AfterWatermarkEarlyAndLate + && ((AfterWatermarkEarlyAndLate) windowingStrategy.getTrigger()).getLateTrigger() != null) { + return true; + } + + return false; + } + + @Override + public PCollection>> expand(PCollection> input) { + applicableTo(input); + // Verify that the input Coder> is a KvCoder, and that + // the key coder is deterministic. + Coder keyCoder = getKeyCoder(input.getCoder()); + Coder valueCoder = getInputValueCoder(input.getCoder()); + KvCoder> outputKvCoder = getOutputKvCoder(input.getCoder()); + + try { + keyCoder.verifyDeterministic(); + } catch (NonDeterministicException e) { + throw new IllegalStateException("the keyCoder of a GroupByKey must be deterministic", e); + } + Preconditions.checkArgument(numShardingKeys > 0); + final TupleTag> largeKeys = new TupleTag>() {}; + final TupleTag> smallKeys = new TupleTag>() {}; + WindowingStrategy originalWindowingStrategy = input.getWindowingStrategy(); + + PCollectionTuple mapKeysToBytes = + input.apply( + "MapKeysToBytes", + ParDo.of( + new DoFn, KV>() { + @ProcessElement + public void processElement(ProcessContext c) { + KV kv = c.element(); + Output output = ByteString.newOutput(); + try { + keyCoder.encode(kv.getKey(), output); + } catch (IOException e) { + throw new RuntimeException(e); + } + + KV outputKV = KV.of(output.toByteString(), kv.getValue()); + if (outputKV.getKey().size() <= SMALL_KEY_BYTES_THRESHOLD) { + c.output(smallKeys, outputKV); + } else { + c.output(largeKeys, outputKV); + } + } + }) + .withOutputTags(largeKeys, TupleTagList.of(smallKeys))); + + PCollection>> largeKeyBranch = + mapKeysToBytes + .get(largeKeys) + .setCoder(KvCoder.of(KeyedWindow.ByteStringCoder.of(), valueCoder)) + .apply(DataflowGroupByKey.create()) + .apply( + "DecodeKey", + MapElements.via( + new SimpleFunction>, KV>>() { + @Override + public KV> apply(KV> kv) { + try { + return KV.of(keyCoder.decode(kv.getKey().newInput()), kv.getValue()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + })) + .setCoder(outputKvCoder); + + WindowFn windowFn = originalWindowingStrategy.getWindowFn(); + PCollection>> smallKeyBranch = + mapKeysToBytes + .get(smallKeys) + .apply(Window.into(new KeyedWindowFn<>(windowFn))) + .apply( + "MapKeys", + MapElements.via( + new SimpleFunction, KV>() { + @Override + public KV apply(KV value) { + return KV.of(value.getKey().hashCode() % numShardingKeys, value.getValue()); + } + })) + .apply(DataflowGroupByKey.create()) + .apply( + "Restore Keys", + ParDo.of( + new DoFn>, KV>>() { + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow w, PaneInfo pane) { + ByteString key = ((KeyedWindow) w).getKey(); + try { + + // is it correct to use the pane from Keyed window here? + c.outputWindowedValue( + KV.of(keyCoder.decode(key.newInput()), c.element().getValue()), + c.timestamp(), + Collections.singleton(((KeyedWindow) w).getWindow()), + pane); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + })) + .setWindowingStrategyInternal(originalWindowingStrategy) + .setCoder(outputKvCoder); + return PCollectionList.of(Arrays.asList(smallKeyBranch, largeKeyBranch)) + .apply(Flatten.pCollections()); + } + + /** + * Returns the {@code Coder} of the input to this transform, which should be a {@code KvCoder}. + */ + @SuppressWarnings("unchecked") + static KvCoder getInputKvCoder(Coder inputCoder) { + if (!(inputCoder instanceof KvCoder)) { + throw new IllegalStateException("GroupByKey requires its input to use KvCoder"); + } + return (KvCoder) inputCoder; + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns the {@code Coder} of the keys of the input to this transform, which is also used as the + * {@code Coder} of the keys of the output of this transform. + */ + public static Coder getKeyCoder(Coder> inputCoder) { + return StateMultiplexingGroupByKey.getInputKvCoder(inputCoder).getKeyCoder(); + } + + /** Returns the {@code Coder} of the values of the input to this transform. */ + public static Coder getInputValueCoder(Coder> inputCoder) { + return StateMultiplexingGroupByKey.getInputKvCoder(inputCoder).getValueCoder(); + } + + /** Returns the {@code Coder} of the {@code Iterable} values of the output of this transform. */ + static Coder> getOutputValueCoder(Coder> inputCoder) { + return IterableCoder.of(getInputValueCoder(inputCoder)); + } + + /** Returns the {@code Coder} of the output of this transform. */ + public static KvCoder> getOutputKvCoder(Coder> inputCoder) { + return KvCoder.of(getKeyCoder(inputCoder), getOutputValueCoder(inputCoder)); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + super.populateDisplayData(builder); + if (fewKeys) { + builder.add(DisplayData.item("fewKeys", true).withLabel("Has Few Keys")); + } + } +} diff --git a/runners/prism/java/build.gradle b/runners/prism/java/build.gradle index 82eb62b9e207..f2dfa2bb1a28 100644 --- a/runners/prism/java/build.gradle +++ b/runners/prism/java/build.gradle @@ -178,7 +178,10 @@ def sickbayTests = [ 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testTwoRequiresTimeSortedInputWithLateData', 'org.apache.beam.sdk.transforms.ParDoTest$StateTests.testRequiresTimeSortedInputWithLateData', - // Missing output due to processing time timer skew. + // Timer race condition/ordering issue in Prism. + 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testTwoTimersSettingEachOtherWithCreateAsInputUnbounded', + + // Missing output due to timer skew. 'org.apache.beam.sdk.transforms.ParDoTest$TimestampTests.testProcessElementSkew', // TestStream + BundleFinalization. @@ -238,6 +241,10 @@ def createPrismValidatesRunnerTask = { name, environmentType -> excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState' } filter { + // Hangs forever with prism. Put here instead of sickbay to allow sickbay runs to terminate. + // https://github.com/apache/beam/issues/32222 + excludeTestsMatching 'org.apache.beam.sdk.transforms.ParDoTest$TimerTests.testEventTimeTimerOrderingWithCreate' + for (String test : sickbayTests) { excludeTestsMatching test } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/KeyedWindow.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/KeyedWindow.java new file mode 100644 index 000000000000..c0e9e513afda --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/windowing/KeyedWindow.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms.windowing; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.util.VarInt; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.Preconditions; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; + +public class KeyedWindow extends BoundedWindow { + + private final ByteString key; + private final W window; + + public KeyedWindow(ByteString name, W window) { + this.key = name; + this.window = window; + } + + public ByteString getKey() { + return key; + } + + public W getWindow() { + return window; + } + + @Override + public Instant maxTimestamp() { + return window.maxTimestamp(); + } + + @Override + public String toString() { + return "NamedWindow{" + "name='" + key + '\'' + ", window=" + window + '}'; + } + + @Override + public boolean equals(@Nullable Object o) { + if (o == null) { + return false; + } + if (this == o) { + return true; + } + if (!(o instanceof KeyedWindow)) { + return false; + } + KeyedWindow that = (KeyedWindow) o; + return Objects.equals(key, that.key) && Objects.equals(window, that.window); + } + + @Override + public int hashCode() { + return Objects.hash(key, window); + } + + public static class KeyedWindowFn + extends WindowFn, KeyedWindow> { + + private final WindowFn windowFn; + + public KeyedWindowFn(WindowFn windowFn) { + this.windowFn = (WindowFn) windowFn; + } + + @Override + public Collection> assignWindows( + WindowFn, KeyedWindow>.AssignContext c) throws Exception { + + return windowFn + .assignWindows( + new WindowFn.AssignContext() { + + @Override + public V element() { + return c.element().getValue(); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public BoundedWindow window() { + return c.window(); + } + }) + .stream() + .map(window -> new KeyedWindow<>(c.element().getKey(), window)) + .collect(Collectors.toList()); + } + + @Override + public void mergeWindows(WindowFn, KeyedWindow>.MergeContext c) throws Exception { + if (windowFn instanceof NonMergingWindowFn) { + return; + } + HashMap> keyToWindow = new HashMap<>(); + c.windows() + .forEach( + keyedWindow -> { + List windows = + keyToWindow.computeIfAbsent(keyedWindow.getKey(), k -> new ArrayList<>()); + windows.add(keyedWindow.getWindow()); + }); + for (Entry> entry : keyToWindow.entrySet()) { + ByteString key = entry.getKey(); + List windows = entry.getValue(); + windowFn.mergeWindows( + new WindowFn.MergeContext() { + @Override + public Collection windows() { + return windows; + } + + @Override + public void merge(Collection toBeMerged, W mergeResult) throws Exception { + List> toMergedKeyedWindows = + toBeMerged.stream() + .map(window -> new KeyedWindow<>(key, window)) + .collect(Collectors.toList()); + c.merge(toMergedKeyedWindows, new KeyedWindow<>(key, mergeResult)); + } + }); + } + } + + @Override + public boolean isCompatible(WindowFn other) { + return (other instanceof KeyedWindowFn) + && windowFn.isCompatible(((KeyedWindowFn) other).windowFn); + } + + @Override + public Coder> windowCoder() { + return new KeyedWindowCoder<>(windowFn.windowCoder()); + } + + @Override + public WindowMappingFn> getDefaultWindowMappingFn() { + return new WindowMappingFn>() { + @Override + public KeyedWindow getSideInputWindow(BoundedWindow mainWindow) { + Preconditions.checkArgument(mainWindow instanceof KeyedWindow); + KeyedWindow mainKeyedWindow = (KeyedWindow) mainWindow; + return new KeyedWindow<>( + mainKeyedWindow.getKey(), + windowFn.getDefaultWindowMappingFn().getSideInputWindow(mainKeyedWindow.getWindow())); + } + }; + } + + @Override + public boolean isNonMerging() { + return windowFn.isNonMerging(); + } + + @Override + public boolean assignsToOneWindow() { + return windowFn.assignsToOneWindow(); + } + + @Override + public void verifyCompatibility(WindowFn other) throws IncompatibleWindowException { + if (other instanceof KeyedWindowFn) { + windowFn.verifyCompatibility(((KeyedWindowFn) other).windowFn); + } + ; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + windowFn.populateDisplayData(builder); + } + } + + public static class KeyedWindowCoder extends Coder> { + + private final KvCoder coder; + + public KeyedWindowCoder(Coder windowCoder) { + //:TODO consider swapping the order for improved state locality + this.coder = KvCoder.of(ByteStringCoder.of(), windowCoder); + } + + @Override + public void encode(KeyedWindow value, OutputStream outStream) throws IOException { + coder.encode(KV.of(value.getKey(), value.getWindow()), outStream); + } + + @Override + public KeyedWindow decode(InputStream inStream) throws IOException { + KV decode = coder.decode(inStream); + return new KeyedWindow<>(decode.getKey(), decode.getValue()); + } + + @Override + public List> getCoderArguments() { + return coder.getCoderArguments(); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + coder.verifyDeterministic(); + } + + @Override + public boolean consistentWithEquals() { + return coder.getValueCoder().consistentWithEquals(); + } + } + + public static class ByteStringCoder extends AtomicCoder { + public static ByteStringCoder of() { + return INSTANCE; + } + + private static final ByteStringCoder INSTANCE = new ByteStringCoder(); + + private ByteStringCoder() {} + + @Override + public void encode(ByteString value, OutputStream os) throws IOException { + VarInt.encode(value.size(), os); + value.writeTo(os); + } + + @Override + public ByteString decode(InputStream is) throws IOException { + int size = VarInt.decodeInt(is); + return ByteString.readFrom(ByteStreams.limit(is, size), size); + } + } +}