Skip to content

Commit

Permalink
[BEAM-12513] Schemas and Coders (apache#15632)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostluck authored and dmitriikuzinepam committed Nov 2, 2021
1 parent 7d1e459 commit d12d0a2
Show file tree
Hide file tree
Showing 5 changed files with 535 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sdks/go/examples/snippets/04transforms.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func filterWordsAbove(word string, lengthCutOffIter func(*float64) bool, emitAbo
var cutOff float64
ok := lengthCutOffIter(&cutOff)
if !ok {
return fmt.Errorf("No length cutoff provided.")
return fmt.Errorf("no length cutoff provided")
}
if float64(len(word)) > cutOff {
emitAboveCutoff(word)
Expand Down
143 changes: 143 additions & 0 deletions sdks/go/examples/snippets/06schemas.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

import (
"fmt"
"io"
"reflect"
"time"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
)

// [START schema_define]

type Purchase struct {
// ID of the user who made the purchase.
UserID string `beam:"userId"`
// Identifier of the item that was purchased.
ItemID int64 `beam:"itemId"`
// The shipping address, a nested type.
ShippingAddress ShippingAddress `beam:"shippingAddress"`
// The cost of the item in cents.
Cost int64 `beam:"cost"`
// The transactions that paid for this purchase.
// A slice since the purchase might be spread out over multiple
// credit cards.
Transactions []Transaction `beam:"transactions"`
}

type ShippingAddress struct {
StreetAddress string `beam:"streetAddress"`
City string `beam:"city"`
State *string `beam:"state"`
Country string `beam:"country"`
PostCode string `beam:"postCode"`
}

type Transaction struct {
Bank string `beam:"bank"`
PurchaseAmount float64 `beam:"purchaseAmount"`
}

// [END schema_define]

// Validate that the interface is being implemented.
var _ beam.SchemaProvider = &TimestampNanosProvider{}

// [START schema_logical_provider]

// TimestampNanos is a logical type using time.Time, but
// encodes as a schema type.
type TimestampNanos time.Time

func (tn TimestampNanos) Seconds() int64 {
return time.Time(tn).Unix()
}
func (tn TimestampNanos) Nanos() int32 {
return int32(time.Time(tn).UnixNano() % 1000000000)
}

// tnStorage is the storage schema for TimestampNanos.
type tnStorage struct {
Seconds int64 `beam:"seconds"`
Nanos int32 `beam:"nanos"`
}

var (
// reflect.Type of the Value type of TimestampNanos
tnType = reflect.TypeOf((*TimestampNanos)(nil)).Elem()
tnStorageType = reflect.TypeOf((*tnStorage)(nil)).Elem()
)

// TimestampNanosProvider implements the beam.SchemaProvider interface.
type TimestampNanosProvider struct{}

// FromLogicalType converts checks if the given type is TimestampNanos, and if so
// returns the storage type.
func (p *TimestampNanosProvider) FromLogicalType(rt reflect.Type) (reflect.Type, error) {
if rt != tnType {
return nil, fmt.Errorf("unable to provide schema.LogicalType for type %v, want %v", rt, tnType)
}
return tnStorageType, nil
}

// BuildEncoder builds a Beam schema encoder for the TimestampNanos type.
func (p *TimestampNanosProvider) BuildEncoder(rt reflect.Type) (func(interface{}, io.Writer) error, error) {
if _, err := p.FromLogicalType(rt); err != nil {
return nil, err
}
enc, err := coder.RowEncoderForStruct(tnStorageType)
if err != nil {
return nil, err
}
return func(iface interface{}, w io.Writer) error {
v := iface.(TimestampNanos)
return enc(tnStorage{
Seconds: v.Seconds(),
Nanos: v.Nanos(),
}, w)
}, nil
}

// BuildDecoder builds a Beam schema decoder for the TimestampNanos type.
func (p *TimestampNanosProvider) BuildDecoder(rt reflect.Type) (func(io.Reader) (interface{}, error), error) {
if _, err := p.FromLogicalType(rt); err != nil {
return nil, err
}
dec, err := coder.RowDecoderForStruct(tnStorageType)
if err != nil {
return nil, err
}
return func(r io.Reader) (interface{}, error) {
s, err := dec(r)
if err != nil {
return nil, err
}
tn := s.(tnStorage)
return TimestampNanos(time.Unix(tn.Seconds, int64(tn.Nanos))), nil
}, nil
}

// [END schema_logical_provider]

func LogicalTypeExample() {
// [START schema_logical_register]
beam.RegisterSchemaProvider(tnType, &TimestampNanosProvider{})
// [END schema_logical_register]
}
170 changes: 170 additions & 0 deletions sdks/go/examples/snippets/06schemas_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package snippets

import (
"fmt"
"reflect"
"testing"
"time"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder/testutil"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/schema"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)

func atomicSchemaField(name string, typ pipepb.AtomicType) *pipepb.Field {
return &pipepb.Field{
Name: name,
Type: &pipepb.FieldType{
TypeInfo: &pipepb.FieldType_AtomicType{
AtomicType: typ,
},
},
}
}

func rowSchemaField(name string, typ *pipepb.Schema) *pipepb.Field {
return &pipepb.Field{
Name: name,
Type: &pipepb.FieldType{
TypeInfo: &pipepb.FieldType_RowType{
RowType: &pipepb.RowType{
Schema: typ,
},
},
},
}
}

func listSchemaField(name string, typ *pipepb.Field) *pipepb.Field {
return &pipepb.Field{
Name: name,
Type: &pipepb.FieldType{
TypeInfo: &pipepb.FieldType_ArrayType{
ArrayType: &pipepb.ArrayType{
ElementType: typ.GetType(),
},
},
},
}
}

func nillable(f *pipepb.Field) *pipepb.Field {
f.Type.Nullable = true
return f
}

func TestSchemaTypes(t *testing.T) {
transactionSchema := &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("bank", pipepb.AtomicType_STRING),
atomicSchemaField("purchaseAmount", pipepb.AtomicType_DOUBLE),
},
}
shippingAddressSchema := &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("streetAddress", pipepb.AtomicType_STRING),
atomicSchemaField("city", pipepb.AtomicType_STRING),
nillable(atomicSchemaField("state", pipepb.AtomicType_STRING)),
atomicSchemaField("country", pipepb.AtomicType_STRING),
atomicSchemaField("postCode", pipepb.AtomicType_STRING),
},
}

tests := []struct {
rt reflect.Type
st *pipepb.Schema
preReg func(reg *schema.Registry)
}{{
rt: reflect.TypeOf(Transaction{}),
st: transactionSchema,
}, {
rt: reflect.TypeOf(ShippingAddress{}),
st: shippingAddressSchema,
}, {
rt: reflect.TypeOf(Purchase{}),
st: &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("userId", pipepb.AtomicType_STRING),
atomicSchemaField("itemId", pipepb.AtomicType_INT64),
rowSchemaField("shippingAddress", shippingAddressSchema),
atomicSchemaField("cost", pipepb.AtomicType_INT64),
listSchemaField("transactions",
rowSchemaField("n/a", transactionSchema)),
},
},
}, {
rt: tnType,
st: &pipepb.Schema{
Fields: []*pipepb.Field{
atomicSchemaField("seconds", pipepb.AtomicType_INT64),
atomicSchemaField("nanos", pipepb.AtomicType_INT32),
},
},
preReg: func(reg *schema.Registry) {
reg.RegisterLogicalType(schema.ToLogicalType(tnType.Name(), tnType, tnStorageType))
},
}}
for _, test := range tests {
t.Run(fmt.Sprintf("%v", test.rt), func(t *testing.T) {
reg := schema.NewRegistry()
if test.preReg != nil {
test.preReg(reg)
}
{
got, err := reg.FromType(test.rt)
if err != nil {
t.Fatalf("error FromType(%v) = %v", test.rt, err)
}
if d := cmp.Diff(test.st, got,
protocmp.Transform(),
protocmp.IgnoreFields(proto.Message(&pipepb.Schema{}), "id", "options"),
); d != "" {
t.Errorf("diff (-want, +got): %v", d)
}
}
})
}
}

func TestSchema_validate(t *testing.T) {
tests := []struct {
rt reflect.Type
p beam.SchemaProvider
logical, storage interface{}
}{
{
rt: tnType,
p: &TimestampNanosProvider{},
logical: TimestampNanos(time.Unix(2300003, 456789)),
storage: tnStorage{},
},
}
for _, test := range tests {
sc := &testutil.SchemaCoder{
CmpOptions: cmp.Options{
cmp.Comparer(func(a, b TimestampNanos) bool {
return a.Seconds() == b.Seconds() && a.Nanos() == b.Nanos()
})},
}
sc.Validate(t, test.rt, test.p.BuildEncoder, test.p.BuildDecoder, test.storage, test.logical)
}
}
20 changes: 19 additions & 1 deletion sdks/go/pkg/beam/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package beam

import (
"fmt"
"io"
"reflect"

Expand Down Expand Up @@ -55,7 +56,24 @@ import (
// is called in a package init() function.
func RegisterSchemaProvider(rt reflect.Type, provider interface{}) {
p := provider.(SchemaProvider)
schema.RegisterLogicalTypeProvider(rt, p.FromLogicalType)
switch rt.Kind() {
case reflect.Interface:
schema.RegisterLogicalTypeProvider(rt, p.FromLogicalType)
case reflect.Ptr:
if rt.Elem().Kind() != reflect.Struct {
panic(fmt.Sprintf("beam.RegisterSchemaProvider: unsupported type kind for schema provider %v is a %v, must be interface, struct or *struct.", rt, rt.Kind()))
}
fallthrough
case reflect.Struct:
st, err := p.FromLogicalType(rt)
if err != nil {
panic(fmt.Sprintf("beam.RegisterSchemaProvider: schema type provider for %v, doesn't support that type", rt))
}
schema.RegisterLogicalType(schema.ToLogicalType(rt.Name(), rt, st))
default:
panic(fmt.Sprintf("beam.RegisterSchemaProvider: unsupported type kind for schema provider %v is a %v, must be interface, struct or *struct.", rt, rt.Kind()))
}

coder.RegisterSchemaProviders(rt, p.BuildEncoder, p.BuildDecoder)
}

Expand Down
Loading

0 comments on commit d12d0a2

Please sign in to comment.