Skip to content

Commit

Permalink
[BEAM-13930] Address StateSpec consistency issue between Runner and F…
Browse files Browse the repository at this point in the history
…n API. (#16836)

The ability to mix and match runners and SDKs is accomplished through two portability layers:
1. The Runner API provides an SDK-and-runner-independent definition of a Beam pipeline
2. The Fn API allows a runner to invoke SDK-specific user-defined functions

Apache Beam pipelines support executing stateful DoFns[1]. To support this execution the Runner API defines multiple user state specifications:
* ReadModifyWriteStateSpec
* BagStateSpec
* OrderedListStateSpec
* CombiningStateSpec
* MapStateSpec
* SetStateSpec

The Fn API[2] defines APIs[3] to get, append and clear user state currently supporting a BagUserState and MultimapUserState protocol.

Since there is no clear mapping between the Runner API and Fn API state specifications, there is no way for a runner to know that it supports a given API necessary to support the execution of the pipeline. The Runner will also have to manage additional runtime metadata associated with which protocol was used for a type of state so that it can successfully manage the state’s lifetime once it can be garbage collected.

Please see the doc[4] for further details and a proposal on how to address this shortcoming.

1: https://beam.apache.org/blog/stateful-processing/
2: https://github.com/apache/beam/blob/3ad05523f4cdf5122fc319276fcb461f768af39d/model/fn-execution/src/main/proto/beam_fn_api.proto#L742
3: https://s.apache.org/beam-fn-state-api-and-bundle-processing
4: http://doc/1ELKTuRTV3C5jt_YoBBwPdsPa5eoXCCOSKQ3GPzZrK7Q
  • Loading branch information
lukecwik authored Feb 14, 2022
1 parent d62529b commit c558e85
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 10 deletions.
29 changes: 29 additions & 0 deletions model/pipeline/src/main/proto/beam_runner_api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,24 @@ message StandardSideInputTypes {
}
}

message StandardUserStateTypes {
enum Enum {
// Represents a user state specification that supports a bag.
//
// StateRequests performed on this user state must use
// StateKey.BagUserState.
BAG = 0 [(beam_urn) = "beam:user_state:bag:v1"];

// Represents a user state specification that supports a multimap.
//
// StateRequests performed on this user state must use
// StateKey.MultimapKeysUserState or StateKey.MultimapUserState.
MULTIMAP = 1 [(beam_urn) = "beam:user_state:multimap:v1"];

// TODO(BEAM-10650): Add protocol to support OrderedListState
}
}

// A PCollection!
message PCollection {

Expand Down Expand Up @@ -534,6 +552,7 @@ message ParDoPayload {
}

message StateSpec {
// TODO(BEAM-13930): Deprecate and remove these state specs
oneof spec {
ReadModifyWriteStateSpec read_modify_write_spec = 1;
BagStateSpec bag_spec = 2;
Expand All @@ -542,6 +561,16 @@ message StateSpec {
SetStateSpec set_spec = 5;
OrderedListStateSpec ordered_list_spec = 6;
}

// (Required) URN of the protocol required by this state specification to present
// the desired SDK-specific interface to a UDF.
//
// This protocol defines the SDK harness <-> Runner Harness RPC
// interface for accessing and mutating user state.
//
// See StandardUserStateTypes for an enumeration of all user state types
// defined.
FunctionSpec protocol = 7;
}

message ReadModifyWriteStateSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.model.pipeline.v1.RunnerApi.StandardRequirements;
import org.apache.beam.model.pipeline.v1.RunnerApi.StandardUserStateTypes;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
import org.apache.beam.runners.core.construction.PTransformTranslation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
Expand Down Expand Up @@ -122,6 +123,11 @@ public class ParDoTranslation {
public static final String REQUIRES_ON_WINDOW_EXPIRATION_URN =
"beam:requirement:pardo:on_window_expiration:v1";

/** Represents a user state specification that supports a bag. */
public static final String BAG_USER_STATE = "beam:user_state:bag:v1";
/** Represents a user state specification that supports a multimap. */
public static final String MULTIMAP_USER_STATE = "beam:user_state:multimap:v1";

static {
checkState(
REQUIRES_STATEFUL_PROCESSING_URN.equals(
Expand All @@ -140,6 +146,8 @@ public class ParDoTranslation {
checkState(
REQUIRES_ON_WINDOW_EXPIRATION_URN.equals(
getUrn(StandardRequirements.Enum.REQUIRES_ON_WINDOW_EXPIRATION)));
checkState(BAG_USER_STATE.equals(getUrn(StandardUserStateTypes.Enum.BAG)));
checkState(MULTIMAP_USER_STATE.equals(getUrn(StandardUserStateTypes.Enum.MULTIMAP)));
}

/** The URN for an unknown Java {@link DoFn}. */
Expand Down Expand Up @@ -571,6 +579,7 @@ public RunnerApi.StateSpec dispatchValue(Coder<?> valueCoder) {
.setReadModifyWriteSpec(
RunnerApi.ReadModifyWriteStateSpec.newBuilder()
.setCoderId(registerCoderOrThrow(components, valueCoder)))
.setProtocol(FunctionSpec.newBuilder().setUrn(BAG_USER_STATE))
.build();
}

Expand All @@ -580,6 +589,7 @@ public RunnerApi.StateSpec dispatchBag(Coder<?> elementCoder) {
.setBagSpec(
RunnerApi.BagStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.setProtocol(FunctionSpec.newBuilder().setUrn(BAG_USER_STATE))
.build();
}

Expand All @@ -589,6 +599,8 @@ public RunnerApi.StateSpec dispatchOrderedList(Coder<?> elementCoder) {
.setOrderedListSpec(
RunnerApi.OrderedListStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
// TODO(BEAM-10650): Update with correct protocol once the protocol is defined and
// the SDK harness uses it.
.build();
}

Expand All @@ -600,6 +612,7 @@ public RunnerApi.StateSpec dispatchCombining(
RunnerApi.CombiningStateSpec.newBuilder()
.setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder))
.setCombineFn(CombineTranslation.toProto(combineFn, components)))
.setProtocol(FunctionSpec.newBuilder().setUrn(BAG_USER_STATE))
.build();
}

Expand All @@ -610,6 +623,7 @@ public RunnerApi.StateSpec dispatchMap(Coder<?> keyCoder, Coder<?> valueCoder) {
RunnerApi.MapStateSpec.newBuilder()
.setKeyCoderId(registerCoderOrThrow(components, keyCoder))
.setValueCoderId(registerCoderOrThrow(components, valueCoder)))
.setProtocol(FunctionSpec.newBuilder().setUrn(MULTIMAP_USER_STATE))
.build();
}

Expand All @@ -619,6 +633,7 @@ public RunnerApi.StateSpec dispatchSet(Coder<?> elementCoder) {
.setSetSpec(
RunnerApi.SetStateSpec.newBuilder()
.setElementCoderId(registerCoderOrThrow(components, elementCoder)))
.setProtocol(FunctionSpec.newBuilder().setUrn(MULTIMAP_USER_STATE))
.build();
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
Expand Down Expand Up @@ -225,16 +227,33 @@ public void toTransformProto() throws Exception {
public static class TestStateAndTimerTranslation {

@Parameters(name = "{index}: {0}")
public static Iterable<StateSpec<?>> stateSpecs() {
return ImmutableList.of(
StateSpecs.value(VarIntCoder.of()),
StateSpecs.bag(VarIntCoder.of()),
StateSpecs.set(VarIntCoder.of()),
StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()));
public static Iterable<Object[]> stateSpecs() {
return Arrays.asList(
new Object[][] {
{
StateSpecs.value(VarIntCoder.of()),
FunctionSpec.newBuilder().setUrn(ParDoTranslation.BAG_USER_STATE).build()
},
{
StateSpecs.bag(VarIntCoder.of()),
FunctionSpec.newBuilder().setUrn(ParDoTranslation.BAG_USER_STATE).build()
},
{
StateSpecs.set(VarIntCoder.of()),
FunctionSpec.newBuilder().setUrn(ParDoTranslation.MULTIMAP_USER_STATE).build()
},
{
StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()),
FunctionSpec.newBuilder().setUrn(ParDoTranslation.MULTIMAP_USER_STATE).build()
}
});
}

@Parameter public StateSpec<?> stateSpec;

@Parameter(1)
public FunctionSpec protocol;

@Test
public void testStateSpecToFromProto() throws Exception {
// Encode
Expand All @@ -243,6 +262,8 @@ public void testStateSpecToFromProto() throws Exception {
RunnerApi.StateSpec stateSpecProto =
ParDoTranslation.translateStateSpec(stateSpec, sdkComponents);

assertEquals(stateSpecProto.getProtocol(), protocol);

// Decode
RehydratedComponents rehydratedComponents =
RehydratedComponents.forComponents(sdkComponents.toComponents());
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/portability/common_urns.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardRequirements
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardResourceHints
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardSideInputTypes
from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardUserStateTypes
from apache_beam.portability.api.external_transforms_pb2_urns import ExpansionMethods
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfo
from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfoSpecs
Expand All @@ -45,6 +46,7 @@
sdf_components = StandardPTransforms.SplittableParDoComponents
group_into_batches_components = StandardPTransforms.GroupIntoBatchesComponents

user_state = StandardUserStateTypes.Enum
side_inputs = StandardSideInputTypes.Enum
coders = StandardCoders.Enum
constants = BeamConstants.Constants
Expand Down
17 changes: 13 additions & 4 deletions sdks/python/apache_beam/transforms/userstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from apache_beam.coders import Coder
from apache_beam.coders import coders
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms.timeutil import TimeDomain

Expand Down Expand Up @@ -76,7 +77,9 @@ def to_runner_api(self, context):
# type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
return beam_runner_api_pb2.StateSpec(
read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec(
coder_id=context.coders.get_id(self.coder)))
coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))


class BagStateSpec(StateSpec):
Expand All @@ -85,7 +88,9 @@ def to_runner_api(self, context):
# type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
return beam_runner_api_pb2.StateSpec(
bag_spec=beam_runner_api_pb2.BagStateSpec(
element_coder_id=context.coders.get_id(self.coder)))
element_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))


class SetStateSpec(StateSpec):
Expand All @@ -94,7 +99,9 @@ def to_runner_api(self, context):
# type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
return beam_runner_api_pb2.StateSpec(
set_spec=beam_runner_api_pb2.SetStateSpec(
element_coder_id=context.coders.get_id(self.coder)))
element_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))


class CombiningValueStateSpec(StateSpec):
Expand Down Expand Up @@ -141,7 +148,9 @@ def to_runner_api(self, context):
return beam_runner_api_pb2.StateSpec(
combining_spec=beam_runner_api_pb2.CombiningStateSpec(
combine_fn=self.combine_fn.to_runner_api(context),
accumulator_coder_id=context.coders.get_id(self.coder)))
accumulator_coder_id=context.coders.get_id(self.coder)),
protocol=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.user_state.BAG.urn))


# TODO(BEAM-9562): Update Timer to have of() and clear() APIs.
Expand Down
33 changes: 33 additions & 0 deletions sdks/python/apache_beam/transforms/userstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from apache_beam.coders import StrUtf8Coder
from apache_beam.coders import VarIntCoder
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.common import DoFnSignature
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
Expand Down Expand Up @@ -157,6 +160,36 @@ def test_spec_construction(self):
with self.assertRaises(ValueError):
DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))

def test_state_spec_proto_conversion(self):
context = pipeline_context.PipelineContext()
state = BagStateSpec('statename', VarIntCoder())
state_proto = state.to_runner_api(context)
self.assertEquals(
beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
state_proto.protocol)

context = pipeline_context.PipelineContext()
state = CombiningValueStateSpec(
'statename', VarIntCoder(), TopCombineFn(10))
state_proto = state.to_runner_api(context)
self.assertEquals(
beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
state_proto.protocol)

context = pipeline_context.PipelineContext()
state = SetStateSpec('setstatename', VarIntCoder())
state_proto = state.to_runner_api(context)
self.assertEquals(
beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
state_proto.protocol)

context = pipeline_context.PipelineContext()
state = ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
state_proto = state.to_runner_api(context)
self.assertEquals(
beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
state_proto.protocol)

def test_param_construction(self):
with self.assertRaises(ValueError):
DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
Expand Down

0 comments on commit c558e85

Please sign in to comment.