diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index f952519d57b6..47b8aec45749 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -63,7 +63,6 @@ import org.apache.beam.runners.core.construction.graph.ExecutableStage; import org.apache.beam.runners.core.construction.graph.UserStateReference; import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory; -import org.apache.beam.runners.flink.translation.functions.FlinkStreamingSideInputHandlerFactory; import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler; @@ -82,6 +81,7 @@ import org.apache.beam.runners.fnexecution.provisioning.JobInfo; import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.runners.fnexecution.state.StateRequestHandlers; +import org.apache.beam.runners.fnexecution.translation.StreamingSideInputHandlerFactory; import org.apache.beam.runners.fnexecution.wire.ByteStringCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.VoidCoder; @@ -314,7 +314,7 @@ private StateRequestHandler getStateRequestHandler(ExecutableStage executableSta checkNotNull(super.sideInputHandler); StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = Preconditions.checkNotNull( - FlinkStreamingSideInputHandlerFactory.forStage( + StreamingSideInputHandlerFactory.forStage( executableStage, sideInputIds, super.sideInputHandler)); try { sideInputStateHandler = diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStreamingSideInputHandlerFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/StreamingSideInputHandlerFactory.java similarity index 93% rename from runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStreamingSideInputHandlerFactory.java rename to runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/StreamingSideInputHandlerFactory.java index f6f6d55d1ed3..1dfcc4a8d7bf 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStreamingSideInputHandlerFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/translation/StreamingSideInputHandlerFactory.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.runners.flink.translation.functions; +package org.apache.beam.runners.fnexecution.translation; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; @@ -40,13 +40,13 @@ /** * {@link StateRequestHandler} that uses {@link org.apache.beam.runners.core.SideInputHandler} to - * access the Flink broadcast state that represents side inputs. + * access the broadcast state that represents side inputs. */ @SuppressWarnings({ "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556) "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) }) -public class FlinkStreamingSideInputHandlerFactory implements SideInputHandlerFactory { +public class StreamingSideInputHandlerFactory implements SideInputHandlerFactory { // Map from side input id to global PCollection id. private final Map> sideInputToCollection; @@ -56,7 +56,7 @@ public class FlinkStreamingSideInputHandlerFactory implements SideInputHandlerFa * Creates a new state handler for the given stage. Note that this requires a traversal of the * stage itself, so this should only be called once per stage rather than once per bundle. */ - public static FlinkStreamingSideInputHandlerFactory forStage( + public static StreamingSideInputHandlerFactory forStage( ExecutableStage stage, Map> viewMapping, org.apache.beam.runners.core.SideInputHandler runnerHandler) { @@ -76,12 +76,12 @@ public static FlinkStreamingSideInputHandlerFactory forStage( sideInputId.getLocalName())); } - FlinkStreamingSideInputHandlerFactory factory = - new FlinkStreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler); + StreamingSideInputHandlerFactory factory = + new StreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler); return factory; } - private FlinkStreamingSideInputHandlerFactory( + private StreamingSideInputHandlerFactory( Map> sideInputToCollection, org.apache.beam.runners.core.SideInputHandler runnerHandler) { this.sideInputToCollection = sideInputToCollection; diff --git a/runners/samza/build.gradle b/runners/samza/build.gradle index 8a24de9a6ff9..3edf73abb962 100644 --- a/runners/samza/build.gradle +++ b/runners/samza/build.gradle @@ -65,6 +65,7 @@ dependencies { compile "org.apache.samza:samza-yarn_2.11:$samza_version" runtimeOnly "org.apache.kafka:kafka-clients:2.0.1" compile library.java.vendored_grpc_1_36_0 + compile project(path: ":model:fn-execution", configuration: "shadow") compile project(path: ":model:job-management", configuration: "shadow") compile project(path: ":model:pipeline", configuration: "shadow") compile project(":sdks:java:fn-execution") diff --git a/runners/samza/job-server/build.gradle b/runners/samza/job-server/build.gradle index 5a71d5e4b3a8..899b44393610 100644 --- a/runners/samza/job-server/build.gradle +++ b/runners/samza/job-server/build.gradle @@ -70,8 +70,6 @@ createPortableValidatesRunnerTask( environment: BeamModulePlugin.PortableValidatesRunnerConfiguration.Environment.EMBEDDED, testCategories: { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' - // TODO: BEAM-12347 - excludeCategories 'org.apache.beam.sdk.testing.UsesSideInputs' // TODO: BEAM-12348 excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo' // TODO: BEAM-12349 @@ -102,6 +100,7 @@ createPortableValidatesRunnerTask( excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs' }, testFilter: { + excludeTestsMatching "org.apache.beam.sdk.transforms.FlattenTest.testEmptyFlattenAsSideInput" excludeTestsMatching "org.apache.beam.sdk.transforms.FlattenTest.testFlattenPCollectionsEmptyThenParDo" excludeTestsMatching "org.apache.beam.sdk.transforms.FlattenTest.testFlattenPCollectionsEmpty" } diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java index 062e1a92310f..856ea85766e2 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java @@ -31,6 +31,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.PushbackSideInputDoFnRunner; @@ -44,6 +45,7 @@ import org.apache.beam.runners.fnexecution.control.ExecutableStageContext; import org.apache.beam.runners.fnexecution.control.StageBundleFactory; import org.apache.beam.runners.fnexecution.provisioning.JobInfo; +import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.runners.samza.SamzaExecutionContext; import org.apache.beam.runners.samza.SamzaPipelineOptions; import org.apache.beam.runners.samza.util.FutureUtils; @@ -86,6 +88,7 @@ public class DoFnOp implements Op { private final WindowingStrategy windowingStrategy; private final OutputManagerFactory outputManagerFactory; // NOTE: we use HashMap here to guarantee Serializability + // Mapping from view id to a view private final HashMap> idToViewMap; private final String transformFullName; private final String transformId; @@ -107,6 +110,7 @@ public class DoFnOp implements Op { private transient PushbackSideInputDoFnRunner pushbackFnRunner; private transient SideInputHandler sideInputHandler; private transient DoFnInvoker doFnInvoker; + // Mapping from side input id to a view private transient SamzaPipelineOptions samzaPipelineOptions; // This is derivable from pushbackValues which is persisted to a store. @@ -125,7 +129,7 @@ public class DoFnOp implements Op { private transient StageBundleFactory stageBundleFactory; private DoFnSchemaInformation doFnSchemaInformation; private transient boolean bundleDisabled; - private Map> sideInputMapping; + private Map> sideInputMapping; public DoFnOp( TupleTag mainOutputTag, @@ -147,7 +151,7 @@ public DoFnOp( JobInfo jobInfo, Map> idToTupleTagMap, DoFnSchemaInformation doFnSchemaInformation, - Map> sideInputMapping) { + Map> sideInputMapping) { this.mainOutputTag = mainOutputTag; this.doFn = doFn; this.sideInputs = sideInputs; @@ -225,12 +229,18 @@ public void open( final ExecutableStage executableStage = ExecutableStage.fromPayload(stagePayload); stageContext = SamzaExecutableStageContextFactory.getInstance().get(jobInfo); stageBundleFactory = stageContext.getStageBundleFactory(executableStage); + final StateRequestHandler stateRequestHandler = + SamzaStateRequestHandlers.of( + executableStage, + (Map>) sideInputMapping, + sideInputHandler); this.fnRunner = SamzaDoFnRunners.createPortable( samzaPipelineOptions, bundledEventsBagState, outputManagerFactory.create(emitter, outputFutureCollector), stageBundleFactory, + stateRequestHandler, mainOutputTag, idToTupleTagMap, context, @@ -253,7 +263,7 @@ public void open( sideOutputTags, outputCoders, doFnSchemaInformation, - sideInputMapping); + (Map>) sideInputMapping); } this.pushbackFnRunner = diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java index 2232736fa12b..c8f5f850476b 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java @@ -175,6 +175,7 @@ public static DoFnRunner createPortable( BagState> bundledEventsBag, DoFnRunners.OutputManager outputManager, StageBundleFactory stageBundleFactory, + StateRequestHandler stateRequestHandler, TupleTag mainOutputTag, Map> idToTupleTagMap, Context context, @@ -183,7 +184,12 @@ public static DoFnRunner createPortable( (SamzaExecutionContext) context.getApplicationContainerContext(); final DoFnRunner sdkHarnessDoFnRunner = new SdkHarnessDoFnRunner<>( - outputManager, stageBundleFactory, mainOutputTag, idToTupleTagMap, bundledEventsBag); + outputManager, + stageBundleFactory, + mainOutputTag, + idToTupleTagMap, + bundledEventsBag, + stateRequestHandler); return DoFnRunnerWithMetrics.wrap( sdkHarnessDoFnRunner, executionContext.getMetricsContainer(), transformFullName); } @@ -197,18 +203,21 @@ private static class SdkHarnessDoFnRunner implements DoFnRunner> bundledEventsBag; private RemoteBundle remoteBundle; private FnDataReceiver> inputReceiver; + private StateRequestHandler stateRequestHandler; private SdkHarnessDoFnRunner( DoFnRunners.OutputManager outputManager, StageBundleFactory stageBundleFactory, TupleTag mainOutputTag, Map> idToTupleTagMap, - BagState> bundledEventsBag) { + BagState> bundledEventsBag, + StateRequestHandler stateRequestHandler) { this.outputManager = outputManager; this.stageBundleFactory = stageBundleFactory; this.mainOutputTag = mainOutputTag; this.idToTupleTagMap = idToTupleTagMap; this.bundledEventsBag = bundledEventsBag; + this.stateRequestHandler = stateRequestHandler; } @Override @@ -227,9 +236,7 @@ public FnDataReceiver create(String pCollectionId) { remoteBundle = stageBundleFactory.getBundle( - receiverFactory, - StateRequestHandler.unsupported(), - BundleProgressHandler.ignored()); + receiverFactory, stateRequestHandler, BundleProgressHandler.ignored()); // TODO: side input support needs to implement to handle this properly inputReceiver = Iterables.getOnlyElement(remoteBundle.getInputReceivers().values()); diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStateRequestHandlers.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStateRequestHandlers.java new file mode 100644 index 000000000000..65749a7daefd --- /dev/null +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaStateRequestHandlers.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.samza.runtime; + +import java.io.IOException; +import java.util.EnumMap; +import java.util.Map; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.SideInputHandler; +import org.apache.beam.runners.core.construction.graph.ExecutableStage; +import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors; +import org.apache.beam.runners.fnexecution.state.StateRequestHandler; +import org.apache.beam.runners.fnexecution.state.StateRequestHandlers; +import org.apache.beam.runners.fnexecution.translation.StreamingSideInputHandlerFactory; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; + +/** + * This class creates {@link StateRequestHandler} for side inputs and states of the Samza portable + * runner. + */ +public class SamzaStateRequestHandlers { + + // TODO: [BEAM-12403] support state handlers + public static StateRequestHandler of( + ExecutableStage executableStage, + Map> sideInputIds, + SideInputHandler sideInputHandler) { + final StateRequestHandler sideInputStateHandler; + if (executableStage.getSideInputs().size() > 0) { + final StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = + Preconditions.checkNotNull( + StreamingSideInputHandlerFactory.forStage( + executableStage, sideInputIds, sideInputHandler)); + try { + sideInputStateHandler = + StateRequestHandlers.forSideInputHandlerFactory( + ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory); + } catch (IOException e) { + throw new RuntimeException("Failed to initialize SideInputHandler", e); + } + } else { + sideInputStateHandler = StateRequestHandler.unsupported(); + } + + final EnumMap handlerMap = + new EnumMap<>(BeamFnApi.StateKey.TypeCase.class); + handlerMap.put(BeamFnApi.StateKey.TypeCase.ITERABLE_SIDE_INPUT, sideInputStateHandler); + handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputStateHandler); + handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_KEYS_SIDE_INPUT, sideInputStateHandler); + return StateRequestHandlers.delegateBasedUponType(handlerMap); + } +} diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java index 4f3c40928b5e..2818b6ab9e38 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/GroupByKeyTranslator.java @@ -112,7 +112,21 @@ public void translatePortable( PipelineNode.PTransformNode transform, QueryablePipeline pipeline, PortableTranslationContext ctx) { - doTranslatePortable(transform, pipeline, ctx); + final String inputId = ctx.getInputId(transform); + final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId); + final MessageStream>> inputStream = ctx.getMessageStreamById(inputId); + final WindowingStrategy windowingStrategy = + ctx.getPortableWindowStrategy(inputId, pipeline.getComponents()); + final WindowedValue.WindowedValueCoder> windowedInputCoder = + ctx.instantiateCoder(inputId, pipeline.getComponents()); + final TupleTag> outputTag = + new TupleTag<>(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().keySet())); + + final MessageStream>> outputStream = + doTranslatePortable( + input, inputStream, windowingStrategy, windowedInputCoder, outputTag, ctx); + + ctx.registerMessageStream(ctx.getOutputId(transform), outputStream); } @Override @@ -129,48 +143,41 @@ public Map createPortableConfig( return ConfigBuilder.createRocksDBStoreConfig(options); } - private static void doTranslatePortable( - PipelineNode.PTransformNode transform, - QueryablePipeline pipeline, + /** + * The method is used to translate both portable GBK transform as well as grouping side inputs + * into Samza. + */ + static MessageStream>> doTranslatePortable( + RunnerApi.PCollection input, + MessageStream>> inputStream, + WindowingStrategy windowingStrategy, + WindowedValue.WindowedValueCoder> windowedInputCoder, + TupleTag> outputTag, PortableTranslationContext ctx) { - final MessageStream>> inputStream = - ctx.getOneInputMessageStream(transform); final boolean needRepartition = ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1; - final WindowingStrategy windowingStrategy = - ctx.getPortableWindowStrategy(transform, pipeline); final Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); - - final String inputId = ctx.getInputId(transform); - final WindowedValue.WindowedValueCoder> windowedInputCoder = - ctx.instantiateCoder(inputId, pipeline.getComponents()); final KvCoder kvInputCoder = (KvCoder) windowedInputCoder.getValueCoder(); final Coder>> elementCoder = WindowedValue.FullWindowedValueCoder.of(kvInputCoder, windowCoder); - final TupleTag> outputTag = - new TupleTag<>(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().keySet())); - @SuppressWarnings("unchecked") final SystemReduceFn reduceFn = (SystemReduceFn) SystemReduceFn.buffering(kvInputCoder.getValueCoder()); - final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId); final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input); - final MessageStream>> outputStream = - doTranslateGBK( - inputStream, - needRepartition, - reduceFn, - windowingStrategy, - kvInputCoder, - elementCoder, - ctx.getTransformFullName(), - ctx.getTransformId(), - outputTag, - isBounded); - ctx.registerMessageStream(ctx.getOutputId(transform), outputStream); + return doTranslateGBK( + inputStream, + needRepartition, + reduceFn, + windowingStrategy, + kvInputCoder, + elementCoder, + ctx.getTransformFullName(), + ctx.getTransformId(), + outputTag, + isBounded); } private static MessageStream>> doTranslateGBK( diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java index 01be930855e6..0e3e11deebda 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.samza.translation; +import static org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder; + import java.io.IOException; import java.util.ArrayList; import java.util.Collection; @@ -29,7 +31,9 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId; import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.core.construction.RunnerPCollectionView; import org.apache.beam.runners.core.construction.graph.PipelineNode; import org.apache.beam.runners.core.construction.graph.QueryablePipeline; import org.apache.beam.runners.samza.SamzaPipelineOptions; @@ -41,18 +45,27 @@ import org.apache.beam.runners.samza.runtime.SamzaDoFnInvokerRegistrar; import org.apache.beam.runners.samza.util.SamzaPipelineTranslatorUtils; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ViewFn; import org.apache.beam.sdk.transforms.join.RawUnionValue; import org.apache.beam.sdk.transforms.reflect.DoFnSignature; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.functions.FlatMapFunction; @@ -223,15 +236,39 @@ private static void doTranslatePortable( String inputId = stagePayload.getInput(); final MessageStream> inputStream = ctx.getMessageStreamById(inputId); - // TODO: support side input - if (!stagePayload.getSideInputsList().isEmpty()) { - throw new UnsupportedOperationException( - "Side inputs in portable pipelines are not supported in samza"); + // Analyze side inputs + final List>>> sideInputStreams = new ArrayList<>(); + final Map> sideInputMapping = new HashMap<>(); + final Map> idToViewMapping = new HashMap<>(); + final RunnerApi.Components components = stagePayload.getComponents(); + for (SideInputId sideInputId : stagePayload.getSideInputsList()) { + final String sideInputCollectionId = + components + .getTransformsOrThrow(sideInputId.getTransformId()) + .getInputsOrThrow(sideInputId.getLocalName()); + final WindowingStrategy windowingStrategy = + ctx.getPortableWindowStrategy(sideInputCollectionId, components); + final WindowedValue.WindowedValueCoder coder = + (WindowedValue.WindowedValueCoder) instantiateCoder(sideInputCollectionId, components); + + // Create a runner-side view + final PCollectionView view = createPCollectionView(sideInputId, coder, windowingStrategy); + + // Use GBK to aggregate the side inputs and then broadcast it out + final MessageStream>> broadcastSideInput = + groupAndBroadcastSideInput( + sideInputId, + sideInputCollectionId, + components.getPcollectionsOrThrow(sideInputCollectionId), + (WindowingStrategy) windowingStrategy, + coder, + ctx); + + sideInputStreams.add(broadcastSideInput); + sideInputMapping.put(sideInputId, view); + idToViewMapping.put(getSideInputUniqueId(sideInputId), view); } - // set side inputs to empty until it's supported - final List>> sideInputStreams = Collections.emptyList(); - final Map, Integer> tagToIndexMap = new HashMap<>(); final Map indexToIdMap = new HashMap<>(); final Map> idToTupleTagMap = new HashMap<>(); @@ -261,7 +298,6 @@ private static void doTranslatePortable( // Note: transform.getTransform() is an ExecutableStage, not ParDo, so we need to extract // these info from its components. final DoFnSchemaInformation doFnSchemaInformation = null; - final Map> sideInputMapping = Collections.emptyMap(); final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId); final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input); @@ -274,10 +310,10 @@ private static void doTranslatePortable( windowedInputCoder.getValueCoder(), // input coder not in use windowedInputCoder, Collections.emptyMap(), // output coders not in use - Collections.emptyList(), // sideInputs not in use until side input support + new ArrayList<>(sideInputMapping.values()), new ArrayList<>(idToTupleTagMap.values()), // used by java runner only - SamzaPipelineTranslatorUtils.getPortableWindowStrategy(transform, pipeline), - Collections.emptyMap(), // idToViewMap not in use until side input support + ctx.getPortableWindowStrategy(inputId, stagePayload.getComponents()), + idToViewMapping, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), @@ -368,6 +404,80 @@ public Map createPortableConfig( return Collections.emptyMap(); } + @SuppressWarnings("unchecked") + private static final ViewFn>, ?> VIEW_FN = + (ViewFn) + new PCollectionViews.MultimapViewFn<>( + (PCollectionViews.TypeDescriptorSupplier>>) + () -> TypeDescriptors.iterables(new TypeDescriptor>() {}), + (PCollectionViews.TypeDescriptorSupplier) TypeDescriptors::voids); + + // This method follows the same way in Flink to create a runner-side Java + // PCollectionView to represent a portable side input. + private static PCollectionView createPCollectionView( + SideInputId sideInputId, + WindowedValue.WindowedValueCoder coder, + WindowingStrategy windowingStrategy) { + + return new RunnerPCollectionView<>( + null, + new TupleTag<>(sideInputId.getLocalName()), + VIEW_FN, + // TODO: support custom mapping fn + windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), + windowingStrategy, + coder.getValueCoder()); + } + + // Group the side input globally with a null key and then broadcast it + // to all tasks. + private static + MessageStream>> groupAndBroadcastSideInput( + SideInputId sideInputId, + String sideInputCollectionId, + RunnerApi.PCollection sideInputPCollection, + WindowingStrategy windowingStrategy, + WindowedValue.WindowedValueCoder coder, + PortableTranslationContext ctx) { + final MessageStream> sideInput = + ctx.getMessageStreamById(sideInputCollectionId); + final MessageStream>> keyedSideInput = + sideInput.map( + opMessage -> { + WindowedValue wv = opMessage.getElement(); + return OpMessage.ofElement(wv.withValue(KV.of(null, wv.getValue()))); + }); + final WindowedValue.WindowedValueCoder> kvCoder = + coder.withValueCoder(KvCoder.of(VoidCoder.of(), coder.getValueCoder())); + final MessageStream>>> groupedSideInput = + GroupByKeyTranslator.doTranslatePortable( + sideInputPCollection, + keyedSideInput, + windowingStrategy, + kvCoder, + new TupleTag<>("main output"), + ctx); + final MessageStream>> nonkeyGroupedSideInput = + groupedSideInput.map( + opMessage -> { + WindowedValue>> wv = opMessage.getElement(); + return OpMessage.ofElement(wv.withValue(wv.getValue().getValue())); + }); + final MessageStream>> broadcastSideInput = + SamzaPublishViewTranslator.doTranslate( + nonkeyGroupedSideInput, + coder.withValueCoder(IterableCoder.of(coder.getValueCoder())), + ctx.getTransformId(), + getSideInputUniqueId(sideInputId), + ctx.getSamzaPipelineOptions()); + + return broadcastSideInput; + } + + private static String getSideInputUniqueId(SideInputId sideInputId) { + return sideInputId.getTransformId() + "-" + sideInputId.getLocalName(); + } + static class SideInputWatermarkFn implements FlatMapFunction, OpMessage>, WatermarkFunction> { diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java index d0cccd5afead..aab744e6ed68 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/PortableTranslationContext.java @@ -29,7 +29,6 @@ import org.apache.beam.runners.core.construction.RehydratedComponents; import org.apache.beam.runners.core.construction.WindowingStrategyTranslation; import org.apache.beam.runners.core.construction.graph.PipelineNode; -import org.apache.beam.runners.core.construction.graph.QueryablePipeline; import org.apache.beam.runners.fnexecution.provisioning.JobInfo; import org.apache.beam.runners.fnexecution.wire.WireCoders; import org.apache.beam.runners.samza.SamzaPipelineOptions; @@ -148,16 +147,12 @@ public WindowedValue.WindowedValueCoder instantiateCoder( } public WindowingStrategy getPortableWindowStrategy( - PipelineNode.PTransformNode transform, QueryablePipeline pipeline) { - String inputId = Iterables.getOnlyElement(transform.getTransform().getInputsMap().values()); - RehydratedComponents rehydratedComponents = - RehydratedComponents.forComponents(pipeline.getComponents()); + String collectionId, RunnerApi.Components components) { + RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components); RunnerApi.WindowingStrategy windowingStrategyProto = - pipeline - .getComponents() - .getWindowingStrategiesOrThrow( - pipeline.getComponents().getPcollectionsOrThrow(inputId).getWindowingStrategyId()); + components.getWindowingStrategiesOrThrow( + components.getPcollectionsOrThrow(collectionId).getWindowingStrategyId()); WindowingStrategy windowingStrategy; try { diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTranslator.java index 308be267a3aa..08b6196b6c43 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTranslator.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/SamzaPublishViewTranslator.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.samza.translation; import java.util.List; +import org.apache.beam.runners.samza.SamzaPipelineOptions; import org.apache.beam.runners.samza.runtime.OpMessage; import org.apache.beam.runners.samza.util.SamzaCoders; import org.apache.beam.sdk.coders.Coder; @@ -35,18 +36,29 @@ public void translate( SamzaPublishView transform, TransformHierarchy.Node node, TranslationContext ctx) { - doTranslate(transform, node, ctx); - } - - private static void doTranslate( - SamzaPublishView transform, - TransformHierarchy.Node node, - TranslationContext ctx) { - final PCollection> input = ctx.getInput(transform); final MessageStream>> inputStream = ctx.getMessageStream(input); @SuppressWarnings("unchecked") final Coder>> elementCoder = (Coder) SamzaCoders.of(input); + final String viewId = ctx.getViewId(transform.getView()); + + final MessageStream>> outputStream = + doTranslate( + inputStream, elementCoder, ctx.getTransformId(), viewId, ctx.getPipelineOptions()); + + ctx.registerViewStream(transform.getView(), outputStream); + } + + /** + * This method is used to translate both native Java PublishView transform as well as portable + * side input broadcasting into Samza. + */ + static MessageStream>> doTranslate( + MessageStream>> inputStream, + Coder>> coder, + String transformId, + String viewId, + SamzaPipelineOptions options) { final MessageStream>> elementStream = inputStream @@ -55,15 +67,10 @@ private static void doTranslate( // TODO: once SAMZA-1580 is resolved, this optimization will go directly inside Samza final MessageStream>> broadcastStream = - ctx.getPipelineOptions().getMaxSourceParallelism() == 1 + options.getMaxSourceParallelism() == 1 ? elementStream - : elementStream.broadcast( - SamzaCoders.toSerde(elementCoder), "view-" + ctx.getTransformId()); + : elementStream.broadcast(SamzaCoders.toSerde(coder), "view-" + transformId); - final String viewId = ctx.getViewId(transform.getView()); - final MessageStream>> outputStream = - broadcastStream.map(element -> OpMessage.ofSideInput(viewId, element)); - - ctx.registerViewStream(transform.getView(), outputStream); + return broadcastStream.map(element -> OpMessage.ofSideInput(viewId, element)); } } diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/util/SamzaPipelineTranslatorUtils.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/util/SamzaPipelineTranslatorUtils.java index 55e7877b3a20..e265ab02e41a 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/util/SamzaPipelineTranslatorUtils.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/util/SamzaPipelineTranslatorUtils.java @@ -19,17 +19,10 @@ import java.io.IOException; import org.apache.beam.model.pipeline.v1.RunnerApi; -import org.apache.beam.runners.core.construction.RehydratedComponents; -import org.apache.beam.runners.core.construction.WindowingStrategyTranslation; import org.apache.beam.runners.core.construction.graph.PipelineNode; -import org.apache.beam.runners.core.construction.graph.QueryablePipeline; import org.apache.beam.runners.fnexecution.wire.WireCoders; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; /** Utilities for pipeline translation. */ @SuppressWarnings({ @@ -50,35 +43,6 @@ public static WindowedValue.WindowedValueCoder instantiateCoder( } } - public static WindowingStrategy getPortableWindowStrategy( - PipelineNode.PTransformNode transform, QueryablePipeline pipeline) { - String inputId = Iterables.getOnlyElement(transform.getTransform().getInputsMap().values()); - RehydratedComponents rehydratedComponents = - RehydratedComponents.forComponents(pipeline.getComponents()); - - RunnerApi.WindowingStrategy windowingStrategyProto = - pipeline - .getComponents() - .getWindowingStrategiesOrThrow( - pipeline.getComponents().getPcollectionsOrThrow(inputId).getWindowingStrategyId()); - - WindowingStrategy windowingStrategy; - try { - windowingStrategy = - WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException( - String.format( - "Unable to hydrate GroupByKey windowing strategy %s.", windowingStrategyProto), - e); - } - - @SuppressWarnings("unchecked") - WindowingStrategy ret = - (WindowingStrategy) windowingStrategy; - return ret; - } - /** * Escape the non-alphabet chars in the name so we can create a physical stream out of it. *