Skip to content

Commit

Permalink
[BEAM-12370] Support side input in Samza portable runner (#14883)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyuiscool authored May 27, 2021
1 parent 99aa83d commit 85b85a5
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.UserStateReference;
import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory;
import org.apache.beam.runners.flink.translation.functions.FlinkStreamingSideInputHandlerFactory;
import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
Expand All @@ -82,6 +81,7 @@
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.runners.fnexecution.translation.StreamingSideInputHandlerFactory;
import org.apache.beam.runners.fnexecution.wire.ByteStringCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VoidCoder;
Expand Down Expand Up @@ -314,7 +314,7 @@ private StateRequestHandler getStateRequestHandler(ExecutableStage executableSta
checkNotNull(super.sideInputHandler);
StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
Preconditions.checkNotNull(
FlinkStreamingSideInputHandlerFactory.forStage(
StreamingSideInputHandlerFactory.forStage(
executableStage, sideInputIds, super.sideInputHandler));
try {
sideInputStateHandler =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.flink.translation.functions;
package org.apache.beam.runners.fnexecution.translation;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
Expand All @@ -40,13 +40,13 @@

/**
* {@link StateRequestHandler} that uses {@link org.apache.beam.runners.core.SideInputHandler} to
* access the Flink broadcast state that represents side inputs.
* access the broadcast state that represents side inputs.
*/
@SuppressWarnings({
"rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
"nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
})
public class FlinkStreamingSideInputHandlerFactory implements SideInputHandlerFactory {
public class StreamingSideInputHandlerFactory implements SideInputHandlerFactory {

// Map from side input id to global PCollection id.
private final Map<SideInputId, PCollectionView<?>> sideInputToCollection;
Expand All @@ -56,7 +56,7 @@ public class FlinkStreamingSideInputHandlerFactory implements SideInputHandlerFa
* Creates a new state handler for the given stage. Note that this requires a traversal of the
* stage itself, so this should only be called once per stage rather than once per bundle.
*/
public static FlinkStreamingSideInputHandlerFactory forStage(
public static StreamingSideInputHandlerFactory forStage(
ExecutableStage stage,
Map<SideInputId, PCollectionView<?>> viewMapping,
org.apache.beam.runners.core.SideInputHandler runnerHandler) {
Expand All @@ -76,12 +76,12 @@ public static FlinkStreamingSideInputHandlerFactory forStage(
sideInputId.getLocalName()));
}

FlinkStreamingSideInputHandlerFactory factory =
new FlinkStreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler);
StreamingSideInputHandlerFactory factory =
new StreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler);
return factory;
}

private FlinkStreamingSideInputHandlerFactory(
private StreamingSideInputHandlerFactory(
Map<SideInputId, PCollectionView<?>> sideInputToCollection,
org.apache.beam.runners.core.SideInputHandler runnerHandler) {
this.sideInputToCollection = sideInputToCollection;
Expand Down
1 change: 1 addition & 0 deletions runners/samza/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ dependencies {
compile "org.apache.samza:samza-yarn_2.11:$samza_version"
runtimeOnly "org.apache.kafka:kafka-clients:2.0.1"
compile library.java.vendored_grpc_1_36_0
compile project(path: ":model:fn-execution", configuration: "shadow")
compile project(path: ":model:job-management", configuration: "shadow")
compile project(path: ":model:pipeline", configuration: "shadow")
compile project(":sdks:java:fn-execution")
Expand Down
3 changes: 1 addition & 2 deletions runners/samza/job-server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ createPortableValidatesRunnerTask(
environment: BeamModulePlugin.PortableValidatesRunnerConfiguration.Environment.EMBEDDED,
testCategories: {
includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner'
// TODO: BEAM-12347
excludeCategories 'org.apache.beam.sdk.testing.UsesSideInputs'
// TODO: BEAM-12348
excludeCategories 'org.apache.beam.sdk.testing.UsesStatefulParDo'
// TODO: BEAM-12349
Expand Down Expand Up @@ -102,6 +100,7 @@ createPortableValidatesRunnerTask(
excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
},
testFilter: {
excludeTestsMatching "org.apache.beam.sdk.transforms.FlattenTest.testEmptyFlattenAsSideInput"
excludeTestsMatching "org.apache.beam.sdk.transforms.FlattenTest.testFlattenPCollectionsEmptyThenParDo"
excludeTestsMatching "org.apache.beam.sdk.transforms.FlattenTest.testFlattenPCollectionsEmpty"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.PushbackSideInputDoFnRunner;
Expand All @@ -44,6 +45,7 @@
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.samza.SamzaExecutionContext;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.runners.samza.util.FutureUtils;
Expand Down Expand Up @@ -86,6 +88,7 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
private final WindowingStrategy windowingStrategy;
private final OutputManagerFactory<OutT> outputManagerFactory;
// NOTE: we use HashMap here to guarantee Serializability
// Mapping from view id to a view
private final HashMap<String, PCollectionView<?>> idToViewMap;
private final String transformFullName;
private final String transformId;
Expand All @@ -107,6 +110,7 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
private transient PushbackSideInputDoFnRunner<InT, FnOutT> pushbackFnRunner;
private transient SideInputHandler sideInputHandler;
private transient DoFnInvoker<InT, FnOutT> doFnInvoker;
// Mapping from side input id to a view
private transient SamzaPipelineOptions samzaPipelineOptions;

// This is derivable from pushbackValues which is persisted to a store.
Expand All @@ -125,7 +129,7 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {
private transient StageBundleFactory stageBundleFactory;
private DoFnSchemaInformation doFnSchemaInformation;
private transient boolean bundleDisabled;
private Map<String, PCollectionView<?>> sideInputMapping;
private Map<?, PCollectionView<?>> sideInputMapping;

public DoFnOp(
TupleTag<FnOutT> mainOutputTag,
Expand All @@ -147,7 +151,7 @@ public DoFnOp(
JobInfo jobInfo,
Map<String, TupleTag<?>> idToTupleTagMap,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping) {
Map<?, PCollectionView<?>> sideInputMapping) {
this.mainOutputTag = mainOutputTag;
this.doFn = doFn;
this.sideInputs = sideInputs;
Expand Down Expand Up @@ -225,12 +229,18 @@ public void open(
final ExecutableStage executableStage = ExecutableStage.fromPayload(stagePayload);
stageContext = SamzaExecutableStageContextFactory.getInstance().get(jobInfo);
stageBundleFactory = stageContext.getStageBundleFactory(executableStage);
final StateRequestHandler stateRequestHandler =
SamzaStateRequestHandlers.of(
executableStage,
(Map<SideInputId, PCollectionView<?>>) sideInputMapping,
sideInputHandler);
this.fnRunner =
SamzaDoFnRunners.createPortable(
samzaPipelineOptions,
bundledEventsBagState,
outputManagerFactory.create(emitter, outputFutureCollector),
stageBundleFactory,
stateRequestHandler,
mainOutputTag,
idToTupleTagMap,
context,
Expand All @@ -253,7 +263,7 @@ public void open(
sideOutputTags,
outputCoders,
doFnSchemaInformation,
sideInputMapping);
(Map<String, PCollectionView<?>>) sideInputMapping);
}

this.pushbackFnRunner =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public static <InT, FnOutT> DoFnRunner<InT, FnOutT> createPortable(
BagState<WindowedValue<InT>> bundledEventsBag,
DoFnRunners.OutputManager outputManager,
StageBundleFactory stageBundleFactory,
StateRequestHandler stateRequestHandler,
TupleTag<FnOutT> mainOutputTag,
Map<String, TupleTag<?>> idToTupleTagMap,
Context context,
Expand All @@ -183,7 +184,12 @@ public static <InT, FnOutT> DoFnRunner<InT, FnOutT> createPortable(
(SamzaExecutionContext) context.getApplicationContainerContext();
final DoFnRunner<InT, FnOutT> sdkHarnessDoFnRunner =
new SdkHarnessDoFnRunner<>(
outputManager, stageBundleFactory, mainOutputTag, idToTupleTagMap, bundledEventsBag);
outputManager,
stageBundleFactory,
mainOutputTag,
idToTupleTagMap,
bundledEventsBag,
stateRequestHandler);
return DoFnRunnerWithMetrics.wrap(
sdkHarnessDoFnRunner, executionContext.getMetricsContainer(), transformFullName);
}
Expand All @@ -197,18 +203,21 @@ private static class SdkHarnessDoFnRunner<InT, FnOutT> implements DoFnRunner<InT
private final BagState<WindowedValue<InT>> bundledEventsBag;
private RemoteBundle remoteBundle;
private FnDataReceiver<WindowedValue<?>> inputReceiver;
private StateRequestHandler stateRequestHandler;

private SdkHarnessDoFnRunner(
DoFnRunners.OutputManager outputManager,
StageBundleFactory stageBundleFactory,
TupleTag<FnOutT> mainOutputTag,
Map<String, TupleTag<?>> idToTupleTagMap,
BagState<WindowedValue<InT>> bundledEventsBag) {
BagState<WindowedValue<InT>> bundledEventsBag,
StateRequestHandler stateRequestHandler) {
this.outputManager = outputManager;
this.stageBundleFactory = stageBundleFactory;
this.mainOutputTag = mainOutputTag;
this.idToTupleTagMap = idToTupleTagMap;
this.bundledEventsBag = bundledEventsBag;
this.stateRequestHandler = stateRequestHandler;
}

@Override
Expand All @@ -227,9 +236,7 @@ public FnDataReceiver<FnOutT> create(String pCollectionId) {

remoteBundle =
stageBundleFactory.getBundle(
receiverFactory,
StateRequestHandler.unsupported(),
BundleProgressHandler.ignored());
receiverFactory, stateRequestHandler, BundleProgressHandler.ignored());

// TODO: side input support needs to implement to handle this properly
inputReceiver = Iterables.getOnlyElement(remoteBundle.getInputReceivers().values());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.samza.runtime;

import java.io.IOException;
import java.util.EnumMap;
import java.util.Map;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.SideInputHandler;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.runners.fnexecution.state.StateRequestHandlers;
import org.apache.beam.runners.fnexecution.translation.StreamingSideInputHandlerFactory;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

/**
* This class creates {@link StateRequestHandler} for side inputs and states of the Samza portable
* runner.
*/
public class SamzaStateRequestHandlers {

// TODO: [BEAM-12403] support state handlers
public static StateRequestHandler of(
ExecutableStage executableStage,
Map<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInputIds,
SideInputHandler sideInputHandler) {
final StateRequestHandler sideInputStateHandler;
if (executableStage.getSideInputs().size() > 0) {
final StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory =
Preconditions.checkNotNull(
StreamingSideInputHandlerFactory.forStage(
executableStage, sideInputIds, sideInputHandler));
try {
sideInputStateHandler =
StateRequestHandlers.forSideInputHandlerFactory(
ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory);
} catch (IOException e) {
throw new RuntimeException("Failed to initialize SideInputHandler", e);
}
} else {
sideInputStateHandler = StateRequestHandler.unsupported();
}

final EnumMap<BeamFnApi.StateKey.TypeCase, StateRequestHandler> handlerMap =
new EnumMap<>(BeamFnApi.StateKey.TypeCase.class);
handlerMap.put(BeamFnApi.StateKey.TypeCase.ITERABLE_SIDE_INPUT, sideInputStateHandler);
handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputStateHandler);
handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_KEYS_SIDE_INPUT, sideInputStateHandler);
return StateRequestHandlers.delegateBasedUponType(handlerMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,21 @@ public void translatePortable(
PipelineNode.PTransformNode transform,
QueryablePipeline pipeline,
PortableTranslationContext ctx) {
doTranslatePortable(transform, pipeline, ctx);
final String inputId = ctx.getInputId(transform);
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final MessageStream<OpMessage<KV<K, InputT>>> inputStream = ctx.getMessageStreamById(inputId);
final WindowingStrategy<?, BoundedWindow> windowingStrategy =
ctx.getPortableWindowStrategy(inputId, pipeline.getComponents());
final WindowedValue.WindowedValueCoder<KV<K, InputT>> windowedInputCoder =
ctx.instantiateCoder(inputId, pipeline.getComponents());
final TupleTag<KV<K, OutputT>> outputTag =
new TupleTag<>(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().keySet()));

final MessageStream<OpMessage<KV<K, OutputT>>> outputStream =
doTranslatePortable(
input, inputStream, windowingStrategy, windowedInputCoder, outputTag, ctx);

ctx.registerMessageStream(ctx.getOutputId(transform), outputStream);
}

@Override
Expand All @@ -129,48 +143,41 @@ public Map<String, String> createPortableConfig(
return ConfigBuilder.createRocksDBStoreConfig(options);
}

private static <K, InputT, OutputT> void doTranslatePortable(
PipelineNode.PTransformNode transform,
QueryablePipeline pipeline,
/**
* The method is used to translate both portable GBK transform as well as grouping side inputs
* into Samza.
*/
static <K, InputT, OutputT> MessageStream<OpMessage<KV<K, OutputT>>> doTranslatePortable(
RunnerApi.PCollection input,
MessageStream<OpMessage<KV<K, InputT>>> inputStream,
WindowingStrategy<?, BoundedWindow> windowingStrategy,
WindowedValue.WindowedValueCoder<KV<K, InputT>> windowedInputCoder,
TupleTag<KV<K, OutputT>> outputTag,
PortableTranslationContext ctx) {
final MessageStream<OpMessage<KV<K, InputT>>> inputStream =
ctx.getOneInputMessageStream(transform);
final boolean needRepartition = ctx.getSamzaPipelineOptions().getMaxSourceParallelism() > 1;
final WindowingStrategy<?, BoundedWindow> windowingStrategy =
ctx.getPortableWindowStrategy(transform, pipeline);
final Coder<BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();

final String inputId = ctx.getInputId(transform);
final WindowedValue.WindowedValueCoder<KV<K, InputT>> windowedInputCoder =
ctx.instantiateCoder(inputId, pipeline.getComponents());
final KvCoder<K, InputT> kvInputCoder = (KvCoder<K, InputT>) windowedInputCoder.getValueCoder();
final Coder<WindowedValue<KV<K, InputT>>> elementCoder =
WindowedValue.FullWindowedValueCoder.of(kvInputCoder, windowCoder);

final TupleTag<KV<K, OutputT>> outputTag =
new TupleTag<>(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().keySet()));

@SuppressWarnings("unchecked")
final SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow> reduceFn =
(SystemReduceFn<K, InputT, ?, OutputT, BoundedWindow>)
SystemReduceFn.buffering(kvInputCoder.getValueCoder());

final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);

final MessageStream<OpMessage<KV<K, OutputT>>> outputStream =
doTranslateGBK(
inputStream,
needRepartition,
reduceFn,
windowingStrategy,
kvInputCoder,
elementCoder,
ctx.getTransformFullName(),
ctx.getTransformId(),
outputTag,
isBounded);
ctx.registerMessageStream(ctx.getOutputId(transform), outputStream);
return doTranslateGBK(
inputStream,
needRepartition,
reduceFn,
windowingStrategy,
kvInputCoder,
elementCoder,
ctx.getTransformFullName(),
ctx.getTransformId(),
outputTag,
isBounded);
}

private static <K, InputT, OutputT> MessageStream<OpMessage<KV<K, OutputT>>> doTranslateGBK(
Expand Down
Loading

0 comments on commit 85b85a5

Please sign in to comment.