diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go index edc686b1ff3f..4350ff971172 100644 --- a/sdks/go/pkg/beam/core/graph/fn.go +++ b/sdks/go/pkg/beam/core/graph/fn.go @@ -475,6 +475,21 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) { return nil, addContext(err, fn) } + // Make sure that all state entries have keys. If they don't set them to the struct field name. + if fn.Recv != nil { + v := reflect.Indirect(reflect.ValueOf(fn.Recv)) + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if f.CanInterface() { + if ps, ok := f.Interface().(state.PipelineState); ok { + if ps.StateKey() == "" { + f.FieldByName("Key").SetString(v.Type().Field(i).Name) + } + } + } + } + } + // Validate ProcessElement has correct number of main inputs (as indicated by // numMainIn), and that main inputs are before side inputs. processFn := fn.methods[processElementName] diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go index 224b93fedef0..c98a390cabf8 100644 --- a/sdks/go/pkg/beam/core/graph/fn_test.go +++ b/sdks/go/pkg/beam/core/graph/fn_test.go @@ -53,8 +53,8 @@ func TestNewDoFn(t *testing.T) { {dfn: &GoodDoFnCoGbk2{}, opt: CoGBKMainInput(3)}, {dfn: &GoodDoFnCoGbk7{}, opt: CoGBKMainInput(8)}, {dfn: &GoodDoFnCoGbk1wSide{}, opt: NumMainInputs(MainKv)}, - {dfn: &GoodStatefulDoFn{State1: state.MakeValueState[int]("state1")}, opt: NumMainInputs(MainKv)}, - {dfn: &GoodStatefulDoFn2{State1: state.MakeBagState[int]("state1")}, opt: NumMainInputs(MainKv)}, + {dfn: &GoodStatefulDoFn{}, opt: NumMainInputs(MainKv)}, + {dfn: &GoodStatefulDoFn2{}, opt: NumMainInputs(MainKv)}, {dfn: &GoodStatefulDoFn3{State1: state.MakeCombiningState[int, int, int]("state1", func(a, b int) int { return a * b })}, opt: NumMainInputs(MainKv)}, diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go index 44b7a193b756..1a208e476933 100644 --- a/sdks/go/pkg/beam/core/state/state.go +++ b/sdks/go/pkg/beam/core/state/state.go @@ -149,10 +149,6 @@ func (s *Value[T]) Clear(p Provider) error { // StateKey returns the key for this pipeline state entry. func (s Value[T]) StateKey() string { - if s.Key == "" { - // TODO(#22736) - infer the state from the member variable name during pipeline construction. - panic("Value state exists on struct but has not been initialized with a key.") - } return s.Key } @@ -232,10 +228,6 @@ func (s *Bag[T]) Clear(p Provider) error { // StateKey returns the key for this pipeline state entry. func (s Bag[T]) StateKey() string { - if s.Key == "" { - // TODO(#22736) - infer the state from the member variable name during pipeline construction. - panic("Value state exists on struct but has not been initialized with a key.") - } return s.Key } @@ -381,10 +373,6 @@ func (s *Combining[T1, T2, T3]) readAccumulator(p Provider) (interface{}, bool, // StateKey returns the key for this pipeline state entry. func (s Combining[T1, T2, T3]) StateKey() string { - if s.Key == "" { - // TODO(#22736) - infer the state from the member variable name during pipeline construction. - panic("Value state exists on struct but has not been initialized with a key.") - } return s.Key } @@ -515,10 +503,6 @@ func (s *Map[K, V]) Get(p Provider, key K) (V, bool, error) { // StateKey returns the key for this pipeline state entry. func (s Map[K, V]) StateKey() string { - if s.Key == "" { - // TODO(#22736) - infer the state from the member variable name during pipeline construction. - panic("Value state exists on struct but has not been initialized with a key.") - } return s.Key } @@ -638,10 +622,6 @@ func (s *Set[K]) Contains(p Provider, key K) (bool, error) { // StateKey returns the key for this pipeline state entry. func (s Set[K]) StateKey() string { - if s.Key == "" { - // TODO(#22736) - infer the state from the member variable name during pipeline construction. - panic("Value state exists on struct but has not been initialized with a key.") - } return s.Key } diff --git a/sdks/go/test/integration/primitives/state.go b/sdks/go/test/integration/primitives/state.go index ce79be03758a..e422be7eea1e 100644 --- a/sdks/go/test/integration/primitives/state.go +++ b/sdks/go/test/integration/primitives/state.go @@ -83,7 +83,7 @@ func ValueStateParDo() *beam.Pipeline { keyed := beam.ParDo(s, func(w string, emit func(string, int)) { emit(w, 1) }, in) - counts := beam.ParDo(s, &valueStateFn{State1: state.MakeValueState[int]("key1"), State2: state.MakeValueState[string]("key2")}, keyed) + counts := beam.ParDo(s, &valueStateFn{}, keyed) passert.Equals(s, counts, "apple: 1, I", "pear: 1, I", "peach: 1, I", "apple: 2, II", "apple: 3, III", "pear: 2, II") return p @@ -184,7 +184,7 @@ func BagStateParDo() *beam.Pipeline { keyed := beam.ParDo(s, func(w string, emit func(string, int)) { emit(w, 1) }, in) - counts := beam.ParDo(s, &bagStateFn{State1: state.MakeBagState[int]("key1"), State2: state.MakeBagState[string]("key2")}, keyed) + counts := beam.ParDo(s, &bagStateFn{}, keyed) passert.Equals(s, counts, "apple: 0, ", "pear: 0, ", "peach: 0, ", "apple: 1, I", "apple: 2, I,I", "pear: 1, I") return p