diff --git a/spanner/go.mod b/spanner/go.mod index e6a7ae3870e8..832e2d860095 100644 --- a/spanner/go.mod +++ b/spanner/go.mod @@ -31,3 +31,6 @@ require ( golang.org/x/text v0.3.7 // indirect google.golang.org/appengine v1.6.7 // indirect ) + +// To prevent failing builds as proto changes are not available in go-genproto until Public GA +replace google.golang.org/genproto v0.0.0-20221014173430-6e2ab493f96b => github.com/harshachinta/go-genproto v0.0.0-20221020104338-f731337b715d diff --git a/spanner/go.sum b/spanner/go.sum index 245a1473569a..6568691bda7d 100644 --- a/spanner/go.sum +++ b/spanner/go.sum @@ -298,6 +298,8 @@ github.com/googleapis/gax-go/v2 v2.5.1 h1:kBRZU0PSuI7PspsSb/ChWoVResUcwNVIdpB049 github.com/googleapis/gax-go/v2 v2.5.1/go.mod h1:h6B0KMMFNtI2ddbGJn3T3ZbwkeT6yqEF02fYlzkUCyo= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/harshachinta/go-genproto v0.0.0-20221020104338-f731337b715d h1:kgI+FZ58gWJsD64f1Cck9bcbQjDHlrC/VPUdVkU+VBI= +github.com/harshachinta/go-genproto v0.0.0-20221020104338-f731337b715d/go.mod h1:45EK0dUbEZ2NHjCeAd2LXmyjAgGUGrpGROgjhC3ADck= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= diff --git a/spanner/integration_test.go b/spanner/integration_test.go index dfa7a94c321b..2e594400a194 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -40,6 +40,9 @@ import ( database "cloud.google.com/go/spanner/admin/database/apiv1" instance "cloud.google.com/go/spanner/admin/instance/apiv1" "cloud.google.com/go/spanner/internal" + pb "cloud.google.com/go/spanner/testdata/protos" + "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" "go.opencensus.io/stats/view" "go.opencensus.io/tag" "google.golang.org/api/iterator" @@ -51,6 +54,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + "google.golang.org/protobuf/testing/protocmp" ) const ( @@ -411,8 +415,10 @@ func initIntegrationTests() (cleanup func()) { return func() { // Delete this test instance. - instanceName := fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID) - deleteInstanceAndBackups(ctx, instanceName) + if testInstanceID != "go-int-test-proto-column" { + instanceName := fmt.Sprintf("projects/%v/instances/%v", testProjectID, testInstanceID) + deleteInstanceAndBackups(ctx, instanceName) + } // Delete other test instances that may be lingering around. cleanupInstances() databaseAdmin.Close() @@ -2013,6 +2019,300 @@ func TestIntegration_BasicTypes(t *testing.T) { } } +// Test encoding/decoding non-struct Cloud Spanner Proto or Array of Proto Column types. +func TestIntegration_BasicTypes_ProtoColumns(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{} + client, _, cleanup := prepareIntegrationTestForProtoColumns(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + singerProtoEnum := pb.Genre_ROCK + singerProtoMessage := pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singerProtoEnum, + } + singer2ProtoEnum := pb.Genre_FOLK + singer2ProtoMessage := pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singer2ProtoEnum, + } + bytesSingerProtoMessage, _ := proto.Marshal(&singerProtoMessage) + + tests := []struct { + col string + val interface{} + want interface{} + }{ + // Proto Message + {col: "ProtoMessage", val: &singerProtoMessage, want: singerProtoMessage}, + {col: "ProtoMessage", val: &singerProtoMessage, want: NullProtoMessage{&singerProtoMessage, true}}, + {col: "ProtoMessage", val: NullProtoMessage{&singerProtoMessage, true}, want: singerProtoMessage}, + {col: "ProtoMessage", val: NullProtoMessage{&singerProtoMessage, true}, want: NullProtoMessage{&singerProtoMessage, true}}, + {col: "ProtoMessage", val: nil, want: NullProtoMessage{}}, + // Proto Enum + {col: "ProtoEnum", val: pb.Genre_ROCK, want: singerProtoEnum}, + {col: "ProtoEnum", val: pb.Genre_ROCK, want: NullProtoEnum{&singerProtoEnum, true}}, + {col: "ProtoEnum", val: NullProtoEnum{pb.Genre_ROCK, true}, want: singerProtoEnum}, + {col: "ProtoEnum", val: NullProtoEnum{pb.Genre_ROCK, true}, want: NullProtoEnum{&singerProtoEnum, true}}, + {col: "ProtoEnum", val: nil, want: NullProtoEnum{}}, + // Test Compatibility between Int64 and ProtoEnum + {col: "Int64a", val: pb.Genre_ROCK, want: int64(3)}, + {col: "Int64a", val: pb.Genre_ROCK, want: singerProtoEnum}, + {col: "Int64a", val: 3, want: singerProtoEnum}, + {col: "Int64a", val: int64(3)}, + {col: "ProtoEnum", val: int64(3), want: singerProtoEnum}, + {col: "ProtoEnum", val: int64(3), want: int64(3)}, + {col: "ProtoEnum", val: pb.Genre_ROCK, want: int64(3)}, + {col: "ProtoEnum", val: pb.Genre_ROCK, want: singerProtoEnum}, + // Test Compatibility between Bytes and ProtoMessage + {col: "Bytes", val: &singerProtoMessage, want: bytesSingerProtoMessage}, + {col: "Bytes", val: &singerProtoMessage, want: singerProtoMessage}, + {col: "Bytes", val: bytesSingerProtoMessage, want: singerProtoMessage}, + {col: "Bytes", val: bytesSingerProtoMessage}, + {col: "ProtoMessage", val: bytesSingerProtoMessage, want: singerProtoMessage}, + {col: "ProtoMessage", val: bytesSingerProtoMessage, want: bytesSingerProtoMessage}, + {col: "ProtoMessage", val: &singerProtoMessage, want: bytesSingerProtoMessage}, + {col: "ProtoMessage", val: &singerProtoMessage, want: singerProtoMessage}, + // Test Compatibility between NullInt64 and NullProtoEnum + {col: "Int64a", val: NullProtoEnum{pb.Genre_ROCK, true}, want: NullInt64{3, true}}, + {col: "Int64a", val: NullProtoEnum{pb.Genre_ROCK, true}, want: NullProtoEnum{&singerProtoEnum, true}}, + {col: "Int64a", val: NullInt64{3, true}, want: NullProtoEnum{&singerProtoEnum, true}}, + {col: "Int64a", val: NullInt64{3, false}, want: NullProtoEnum{}}, + {col: "ProtoEnum", val: NullInt64{3, true}, want: singerProtoEnum}, + {col: "ProtoEnum", val: NullInt64{3, true}, want: NullProtoEnum{&singerProtoEnum, true}}, + {col: "ProtoEnum", val: NullInt64{3, true}, want: NullInt64{3, true}}, + {col: "ProtoEnum", val: NullProtoEnum{pb.Genre_ROCK, true}, want: NullInt64{3, true}}, + // Test Compatibility between Bytes and NullProtoMessage + {col: "Bytes", val: NullProtoMessage{&singerProtoMessage, true}, want: bytesSingerProtoMessage}, + {col: "Bytes", val: NullProtoMessage{&singerProtoMessage, true}, want: NullProtoMessage{&singerProtoMessage, true}}, + {col: "Bytes", val: bytesSingerProtoMessage, want: NullProtoMessage{&singerProtoMessage, true}}, + {col: "Bytes", val: bytesSingerProtoMessage}, + {col: "ProtoMessage", val: bytesSingerProtoMessage, want: NullProtoMessage{&singerProtoMessage, true}}, + {col: "ProtoMessage", val: []byte(nil), want: NullProtoMessage{}}, + {col: "ProtoMessage", val: NullProtoMessage{&singerProtoMessage, true}, want: bytesSingerProtoMessage}, + {col: "ProtoMessage", val: NullProtoMessage{&singerProtoMessage, true}, want: NullProtoMessage{&singerProtoMessage, true}}, + // Array of Proto Messages : Tests insert and read operations on ARRAY type column + {col: "ProtoMessageArray", val: []*pb.SingerInfo{&singerProtoMessage, &singer2ProtoMessage}}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo(nil)}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo{}}, + // Array of Proto Enum : Tests insert and read operations on ARRAY type column + // 1. Insert and Read data using Enum value array (ex: [enum1, enum2]) + {col: "ProtoEnumArray", val: []pb.Genre{pb.Genre_ROCK, pb.Genre_FOLK}, want: []pb.Genre{singerProtoEnum, singer2ProtoEnum}}, + {col: "ProtoEnumArray", val: []pb.Genre{}}, + {col: "ProtoEnumArray", val: []pb.Genre(nil)}, + // 2. Insert and Read data using Enum pointer array (ex: [*enum1, *enum2]) + {col: "ProtoEnumArray", val: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}}, + {col: "ProtoEnumArray", val: []*pb.Genre{}}, + {col: "ProtoEnumArray", val: []*pb.Genre(nil)}, + // 3. Insert data using Enum value array (ex: [enum1, enum2]), Read data using Enum pointer array (ex: [*enum1, *enum2]) + {col: "ProtoEnumArray", val: []pb.Genre{pb.Genre_ROCK, pb.Genre_FOLK}, want: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}}, + {col: "ProtoEnumArray", val: []pb.Genre{}, want: []*pb.Genre{}}, + {col: "ProtoEnumArray", val: []pb.Genre(nil), want: []*pb.Genre(nil)}, + // 4. Insert data using Enum pointer array (ex: [*enum1, *enum2]), Read data using Enum value array (ex: [enum1, enum2]) + {col: "ProtoEnumArray", val: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}, want: []pb.Genre{singerProtoEnum, singer2ProtoEnum}}, + {col: "ProtoEnumArray", val: []*pb.Genre{}, want: []pb.Genre{}}, + {col: "ProtoEnumArray", val: []*pb.Genre(nil), want: []pb.Genre(nil)}, + // Tests Compatibility between Array of Int64 and Array of ProtoEnum type + {col: "ProtoEnumArray", val: []int64{3, 2}, want: []pb.Genre{singerProtoEnum, singer2ProtoEnum}}, + {col: "ProtoEnumArray", val: []int64{3, 2}, want: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}}, + {col: "ProtoEnumArray", val: []int64(nil), want: []pb.Genre(nil)}, + {col: "ProtoEnumArray", val: []int64{}, want: []pb.Genre{}}, + {col: "ProtoEnumArray", val: []pb.Genre{singerProtoEnum, singer2ProtoEnum}, want: []int64{3, 2}}, + {col: "ProtoEnumArray", val: []pb.Genre{}, want: []int64{}}, + {col: "ProtoEnumArray", val: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}, want: []int64{3, 2}}, + {col: "ProtoEnumArray", val: []*pb.Genre(nil), want: []int64(nil)}, + {col: "ProtoEnumArray", val: []*pb.Genre{}, want: []int64{}}, + {col: "ProtoEnumArray", val: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum, nil}, want: []NullInt64{{3, true}, {2, true}, {}}}, + {col: "Int64Array", val: []int64{3, 2}, want: []pb.Genre{singerProtoEnum, singer2ProtoEnum}}, + {col: "Int64Array", val: []int64{3, 2}, want: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}}, + {col: "Int64Array", val: []pb.Genre{singerProtoEnum, singer2ProtoEnum}, want: []int64{3, 2}}, + {col: "Int64Array", val: []*pb.Genre{&singerProtoEnum, &singer2ProtoEnum}, want: []int64{3, 2}}, + {col: "Int64Array", val: []pb.Genre(nil), want: []int64(nil)}, + // Tests Compatibility between Array of Bytes and Array of ProtoMessages type + {col: "ProtoMessageArray", val: []*pb.SingerInfo{&singerProtoMessage}, want: [][]byte{bytesSingerProtoMessage}}, + {col: "ProtoMessageArray", val: [][]byte{bytesSingerProtoMessage}}, + {col: "ProtoMessageArray", val: [][]byte{bytesSingerProtoMessage}, want: []*pb.SingerInfo{&singerProtoMessage}}, + {col: "ProtoMessageArray", val: [][]byte(nil), want: []*pb.SingerInfo(nil)}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo(nil), want: [][]byte(nil)}, + {col: "BytesArray", val: []*pb.SingerInfo{&singerProtoMessage}}, + {col: "BytesArray", val: []*pb.SingerInfo{&singerProtoMessage}, want: [][]byte{bytesSingerProtoMessage}}, + {col: "BytesArray", val: [][]byte{bytesSingerProtoMessage}, want: []*pb.SingerInfo{&singerProtoMessage}}, + {col: "BytesArray", val: [][]byte(nil), want: []*pb.SingerInfo(nil)}, + {col: "BytesArray", val: []*pb.SingerInfo(nil), want: [][]byte(nil)}, + } + + // Write rows into table first using DML. + statements := make([]Statement, 0) + for i, test := range tests { + stmt := NewStatement(fmt.Sprintf("INSERT INTO Types (RowId, `%s`) VALUES (@id, @value)", test.col)) + // Note: We are not setting the parameter type here to ensure that it + // can be automatically recognized when it is actually needed. + stmt.Params["id"] = i + stmt.Params["value"] = test.val + statements = append(statements, stmt) + } + + // Delete all the rows so we can insert them using mutations as well. + _, err := client.Apply(ctx, []*Mutation{Delete("Types", AllKeys())}) + if err != nil { + t.Fatalf("failed to delete all rows: %v", err) + } + + // Verify that we can insert the rows using mutations. + var muts []*Mutation + for i, test := range tests { + muts = append(muts, InsertOrUpdate("Types", []string{"RowID", test.col}, []interface{}{i, test.val})) + } + if _, err := client.Apply(ctx, muts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + for i, test := range tests { + row, err := client.Single().ReadRow(ctx, "Types", []interface{}{i}, []string{test.col}) + if err != nil { + t.Fatalf("Unable to fetch row %v: %v", i, err) + } + verifyDirectPathRemoteAddress(t) + // Create new instance of type of test.want. + want := test.want + if want == nil { + want = test.val + } + gotp := reflect.New(reflect.TypeOf(want)) + v := gotp.Interface() + + switch nullValue := v.(type) { + case *NullProtoMessage: + nullValue.ProtoMessageVal = &pb.SingerInfo{} + case *NullProtoEnum: + var singerProtoEnumDefault pb.Genre + nullValue.ProtoEnumVal = &singerProtoEnumDefault + default: + } + + if err := row.Column(0, v); err != nil { + t.Errorf("%d: col:%v val:%#v, %v", i, test.col, test.val, err) + continue + } + got := reflect.Indirect(gotp).Interface() + + // One of the test cases is checking NaN handling. Given + // NaN!=NaN, we can't use reflect to test for it. + if isNaN(got) && isNaN(want) { + continue + } + + switch v.(type) { + case proto.Message: + if diff := cmp.Diff(got, test.want, protocmp.Transform()); diff != "" { + t.Errorf("unexpected difference:\n%v", diff) + continue + } + default: + if !testEqual(got, want) { + t.Errorf("%d: col:%v val:%#v, got %#v, want %#v", i, test.col, test.val, got, want) + continue + } + } + } +} + +// Test errors during decoding non-struct Cloud Spanner Proto or Array of Proto Column types. +func TestIntegration_BasicTypes_ProtoColumns_Errors(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + stmts := []string{} + client, _, cleanup := prepareIntegrationTestForProtoColumns(ctx, t, DefaultSessionPoolConfig, stmts) + defer cleanup() + + singerProtoEnum := pb.Genre_ROCK + singerProtoMessage := pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singerProtoEnum, + } + protoMessageNilError := "*protos.SingerInfo cannot support NULL SQL values" + protoEnumNilError := "*protos.Genre cannot support NULL SQL values" + + tests := []struct { + col string + val interface{} + want interface{} + wantCode codes.Code + errString string + }{ + // Array of Proto Messages : Tests read operation on ARRAY type column that has untyped/typed nil elements in array + {col: "ProtoMessageArray", val: []*pb.SingerInfo{&singerProtoMessage, nil}, wantCode: codes.InvalidArgument, errString: protoMessageNilError}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo{nil, nil}, wantCode: codes.InvalidArgument, errString: protoMessageNilError}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo{&singerProtoMessage, (*pb.SingerInfo)(nil)}, wantCode: codes.InvalidArgument, errString: protoMessageNilError}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo{(*pb.SingerInfo)(nil), nil}, wantCode: codes.InvalidArgument, errString: protoMessageNilError}, + {col: "ProtoMessageArray", val: []*pb.SingerInfo{(*pb.SingerInfo)(nil), (*pb.SingerInfo)(nil)}, wantCode: codes.InvalidArgument, errString: protoMessageNilError}, + // Array of Proto Enum : Tests read operations on ARRAY type column that has untyped/typed nil elements in array + {col: "ProtoEnumArray", val: []*pb.Genre{&singerProtoEnum, nil}, wantCode: codes.InvalidArgument, errString: protoEnumNilError}, + {col: "ProtoEnumArray", val: []*pb.Genre{nil, nil}, wantCode: codes.InvalidArgument, errString: protoEnumNilError}, + {col: "ProtoEnumArray", val: []*pb.Genre{nil, (*pb.Genre)(nil)}, wantCode: codes.InvalidArgument, errString: protoEnumNilError}, + {col: "ProtoEnumArray", val: []*pb.Genre{&singerProtoEnum, (*pb.Genre)(nil)}, wantCode: codes.InvalidArgument, errString: protoEnumNilError}, + {col: "ProtoEnumArray", val: []*pb.Genre{(*pb.Genre)(nil), (*pb.Genre)(nil)}, wantCode: codes.InvalidArgument, errString: protoEnumNilError}, + } + + // Delete all the rows. + _, err := client.Apply(ctx, []*Mutation{Delete("Types", AllKeys())}) + if err != nil { + t.Fatalf("failed to delete all rows: %v", err) + } + + // Insert the rows using mutations. + var muts []*Mutation + for i, test := range tests { + muts = append(muts, InsertOrUpdate("Types", []string{"RowID", test.col}, []interface{}{i, test.val})) + } + if _, err := client.Apply(ctx, muts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + for i, test := range tests { + row, err := client.Single().ReadRow(ctx, "Types", []interface{}{i}, []string{test.col}) + if err != nil { + t.Fatalf("Unable to fetch row %v: %v", i, err) + } + verifyDirectPathRemoteAddress(t) + // Create new instance of type of test.want. + want := test.want + if want == nil { + want = test.val + } + gotp := reflect.New(reflect.TypeOf(want)) + v := gotp.Interface() + + switch nullValue := v.(type) { + case *NullProtoMessage: + nullValue.ProtoMessageVal = &pb.SingerInfo{} + case *NullProtoEnum: + var singerProtoEnumDefault pb.Genre + nullValue.ProtoEnumVal = &singerProtoEnumDefault + default: + } + err = row.Column(0, v) + if err == nil { + t.Errorf("Column: %s, expected error %v during decoding, got %v", test.col, test.errString, err) + continue + } + if msg, ok := matchError(err, test.wantCode, test.errString); !ok { + t.Fatal(msg) + } + } +} + // Test decoding Cloud Spanner STRUCT type. func TestIntegration_StructTypes(t *testing.T) { t.Parallel() @@ -3668,6 +3968,10 @@ func prepareIntegrationTestForPG(ctx context.Context, t *testing.T, spc SessionP return prepareDBAndClient(ctx, t, spc, statements, adminpb.DatabaseDialect_POSTGRESQL) } +func prepareIntegrationTestForProtoColumns(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string) (*Client, string, func()) { + return prepareDBAndClientForProtoColumns(ctx, t, spc, statements, testDialect) +} + func prepareDBAndClient(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string, dbDialect adminpb.DatabaseDialect) (*Client, string, func()) { if databaseAdmin == nil { t.Skip("Integration tests skipped") @@ -3714,6 +4018,26 @@ func prepareDBAndClient(ctx context.Context, t *testing.T, spc SessionPoolConfig } } +func prepareDBAndClientForProtoColumns(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string, dbDialect adminpb.DatabaseDialect) (*Client, string, func()) { + if databaseAdmin == nil { + t.Skip("Integration tests skipped") + } + // Construct a unique test DB name. + dbName := dbNameSpace.New() + // TODO: Remove this + testInstanceID = "go-int-test-proto-column" + dbName = "go_int_test_proto_column_db" + + dbPath := fmt.Sprintf("projects/%v/instances/%v/databases/%v", testProjectID, testInstanceID, dbName) + client, err := createClient(ctx, dbPath, spc) + if err != nil { + t.Fatalf("cannot create data client on DB %v: %v", dbPath, err) + } + return client, dbPath, func() { + client.Close() + } +} + func cleanupInstances() { if instanceAdmin == nil { // Integration tests skipped. diff --git a/spanner/protoutils.go b/spanner/protoutils.go index ebe03c1b0cf7..cbe40ad45941 100644 --- a/spanner/protoutils.go +++ b/spanner/protoutils.go @@ -23,8 +23,10 @@ import ( "time" "cloud.google.com/go/civil" + "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/protobuf/reflect/protoreflect" ) // Helpers to generate protobuf values and Cloud Spanner types. @@ -128,3 +130,20 @@ func structType(fields ...*sppb.StructType_Field) *sppb.Type { func nullProto() *proto3.Value { return &proto3.Value{Kind: &proto3.Value_NullValue{NullValue: proto3.NullValue_NULL_VALUE}} } + +func protoMessageType(fqn string) *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_PROTO, ProtoTypeFqn: fqn} +} + +func protoEnumType(fqn string) *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_ENUM, ProtoTypeFqn: fqn} +} + +func protoMessageProto(m proto.Message) *proto3.Value { + var b, _ = proto.Marshal(m) + return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: base64.StdEncoding.EncodeToString(b)}} +} + +func protoEnumProto(e protoreflect.Enum) *proto3.Value { + return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: strconv.FormatInt(int64(e.Number()), 10)}} +} diff --git a/spanner/testdata/protos/singer.pb.go b/spanner/testdata/protos/singer.pb.go new file mode 100644 index 000000000000..68648e0fb93e --- /dev/null +++ b/spanner/testdata/protos/singer.pb.go @@ -0,0 +1,243 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.5 +// source: protos/singer.proto + +package protos + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Genre int32 + +const ( + Genre_POP Genre = 0 + Genre_JAZZ Genre = 1 + Genre_FOLK Genre = 2 + Genre_ROCK Genre = 3 +) + +// Enum value maps for Genre. +var ( + Genre_name = map[int32]string{ + 0: "POP", + 1: "JAZZ", + 2: "FOLK", + 3: "ROCK", + } + Genre_value = map[string]int32{ + "POP": 0, + "JAZZ": 1, + "FOLK": 2, + "ROCK": 3, + } +) + +func (x Genre) Enum() *Genre { + p := new(Genre) + *p = x + return p +} + +func (x Genre) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Genre) Descriptor() protoreflect.EnumDescriptor { + return file_protos_singer_proto_enumTypes[0].Descriptor() +} + +func (Genre) Type() protoreflect.EnumType { + return &file_protos_singer_proto_enumTypes[0] +} + +func (x Genre) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *Genre) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = Genre(num) + return nil +} + +// Deprecated: Use Genre.Descriptor instead. +func (Genre) EnumDescriptor() ([]byte, []int) { + return file_protos_singer_proto_rawDescGZIP(), []int{0} +} + +type SingerInfo struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SingerId *int64 `protobuf:"varint,1,opt,name=singer_id,json=singerId" json:"singer_id,omitempty"` + BirthDate *string `protobuf:"bytes,2,opt,name=birth_date,json=birthDate" json:"birth_date,omitempty"` + Nationality *string `protobuf:"bytes,3,opt,name=nationality" json:"nationality,omitempty"` + Genre *Genre `protobuf:"varint,4,opt,name=genre,enum=spanner.examples.music.Genre" json:"genre,omitempty"` +} + +func (x *SingerInfo) Reset() { + *x = SingerInfo{} + if protoimpl.UnsafeEnabled { + mi := &file_protos_singer_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SingerInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SingerInfo) ProtoMessage() {} + +func (x *SingerInfo) ProtoReflect() protoreflect.Message { + mi := &file_protos_singer_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SingerInfo.ProtoReflect.Descriptor instead. +func (*SingerInfo) Descriptor() ([]byte, []int) { + return file_protos_singer_proto_rawDescGZIP(), []int{0} +} + +func (x *SingerInfo) GetSingerId() int64 { + if x != nil && x.SingerId != nil { + return *x.SingerId + } + return 0 +} + +func (x *SingerInfo) GetBirthDate() string { + if x != nil && x.BirthDate != nil { + return *x.BirthDate + } + return "" +} + +func (x *SingerInfo) GetNationality() string { + if x != nil && x.Nationality != nil { + return *x.Nationality + } + return "" +} + +func (x *SingerInfo) GetGenre() Genre { + if x != nil && x.Genre != nil { + return *x.Genre + } + return Genre_POP +} + +var File_protos_singer_proto protoreflect.FileDescriptor + +var file_protos_singer_proto_rawDesc = []byte{ + 0x0a, 0x13, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x2f, 0x73, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x16, 0x73, 0x70, 0x61, 0x6e, 0x6e, 0x65, 0x72, 0x2e, 0x65, + 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x6d, 0x75, 0x73, 0x69, 0x63, 0x22, 0x9f, 0x01, + 0x0a, 0x0a, 0x53, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09, + 0x73, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x08, 0x73, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x62, 0x69, 0x72, + 0x74, 0x68, 0x5f, 0x64, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x62, + 0x69, 0x72, 0x74, 0x68, 0x44, 0x61, 0x74, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x6e, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x69, 0x74, 0x79, 0x12, 0x33, 0x0a, 0x05, 0x67, 0x65, + 0x6e, 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x73, 0x70, 0x61, 0x6e, + 0x6e, 0x65, 0x72, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2e, 0x6d, 0x75, 0x73, + 0x69, 0x63, 0x2e, 0x47, 0x65, 0x6e, 0x72, 0x65, 0x52, 0x05, 0x67, 0x65, 0x6e, 0x72, 0x65, 0x2a, + 0x2e, 0x0a, 0x05, 0x47, 0x65, 0x6e, 0x72, 0x65, 0x12, 0x07, 0x0a, 0x03, 0x50, 0x4f, 0x50, 0x10, + 0x00, 0x12, 0x08, 0x0a, 0x04, 0x4a, 0x41, 0x5a, 0x5a, 0x10, 0x01, 0x12, 0x08, 0x0a, 0x04, 0x46, + 0x4f, 0x4c, 0x4b, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x52, 0x4f, 0x43, 0x4b, 0x10, 0x03, 0x42, + 0x0a, 0x5a, 0x08, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, +} + +var ( + file_protos_singer_proto_rawDescOnce sync.Once + file_protos_singer_proto_rawDescData = file_protos_singer_proto_rawDesc +) + +func file_protos_singer_proto_rawDescGZIP() []byte { + file_protos_singer_proto_rawDescOnce.Do(func() { + file_protos_singer_proto_rawDescData = protoimpl.X.CompressGZIP(file_protos_singer_proto_rawDescData) + }) + return file_protos_singer_proto_rawDescData +} + +var file_protos_singer_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_protos_singer_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_protos_singer_proto_goTypes = []interface{}{ + (Genre)(0), // 0: spanner.examples.music.Genre + (*SingerInfo)(nil), // 1: spanner.examples.music.SingerInfo +} +var file_protos_singer_proto_depIdxs = []int32{ + 0, // 0: spanner.examples.music.SingerInfo.genre:type_name -> spanner.examples.music.Genre + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_protos_singer_proto_init() } +func file_protos_singer_proto_init() { + if File_protos_singer_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_protos_singer_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SingerInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_protos_singer_proto_rawDesc, + NumEnums: 1, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_protos_singer_proto_goTypes, + DependencyIndexes: file_protos_singer_proto_depIdxs, + EnumInfos: file_protos_singer_proto_enumTypes, + MessageInfos: file_protos_singer_proto_msgTypes, + }.Build() + File_protos_singer_proto = out.File + file_protos_singer_proto_rawDesc = nil + file_protos_singer_proto_goTypes = nil + file_protos_singer_proto_depIdxs = nil +} diff --git a/spanner/testdata/protos/singer.proto b/spanner/testdata/protos/singer.proto new file mode 100644 index 000000000000..a7720a8a6887 --- /dev/null +++ b/spanner/testdata/protos/singer.proto @@ -0,0 +1,34 @@ +/* +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +syntax = "proto2"; + +package spanner.examples.music; +option go_package = "./protos"; + +message SingerInfo { + optional int64 singer_id = 1; + optional string birth_date = 2; + optional string nationality = 3; + optional Genre genre = 4; +} + +enum Genre { + POP = 0; + JAZZ = 1; + FOLK = 2; + ROCK = 3; +} diff --git a/spanner/value.go b/spanner/value.go index 1927c560e2e5..c5dadc85a1a6 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -36,6 +36,7 @@ import ( proto3 "github.com/golang/protobuf/ptypes/struct" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" + "google.golang.org/protobuf/reflect/protoreflect" ) const ( @@ -113,6 +114,9 @@ var ( commitTimestamp = time.Unix(0, 0).In(time.FixedZone("CommitTimestamp placeholder", 0xDB)) jsonNullBytes = []byte("null") + + protoMsgReflectType = reflect.TypeOf((*proto.Message)(nil)).Elem() + protoEnumReflectType = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem() ) // Encoder is the interface implemented by a custom type that can be encoded to @@ -876,6 +880,102 @@ func (n *PGNumeric) UnmarshalJSON(payload []byte) error { return nil } +// NullProtoMessage represents a Cloud Spanner PROTO that may be NULL. +// To write a NULL value using NullProtoMessage set ProtoMessageVal to typed nil and set Valid to true. +type NullProtoMessage struct { + ProtoMessageVal proto.Message // ProtoMessageVal contains the value when Valid is true, and nil when NULL. + Valid bool // Valid is true if ProtoMessageVal is not NULL. +} + +// IsNull implements NullableValue.IsNull for NullProtoMessage. +func (n NullProtoMessage) IsNull() bool { + return !n.Valid +} + +// String implements Stringer.String for NullProtoMessage. +func (n NullProtoMessage) String() string { + if !n.Valid { + return nullString + } + return n.ProtoMessageVal.String() +} + +// MarshalJSON implements json.Marshaler.MarshalJSON for NullProtoMessage. +func (n NullProtoMessage) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.ProtoMessageVal) + } + return jsonNullBytes, nil +} + +// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullProtoMessage. +func (n *NullProtoMessage) UnmarshalJSON(payload []byte) error { + if payload == nil { + return fmt.Errorf("payload should not be nil") + } + if bytes.Equal(payload, jsonNullBytes) { + n.ProtoMessageVal = nil + n.Valid = false + return nil + } + err := json.Unmarshal(payload, n.ProtoMessageVal) + if err != nil { + return fmt.Errorf("payload cannot be converted to a proto message: err: %s", err) + } + n.Valid = true + return nil +} + +// NullProtoEnum represents a Cloud Spanner ENUM that may be NULL. +// To write a NULL value using NullProtoEnum set ProtoEnumVal to typed nil and set Valid to true. +type NullProtoEnum struct { + ProtoEnumVal protoreflect.Enum // ProtoEnumVal contains the value when Valid is true, and nil when NULL. + Valid bool // Valid is true if ProtoEnumVal is not NULL. +} + +// IsNull implements NullableValue.IsNull for NullProtoEnum. +func (n NullProtoEnum) IsNull() bool { + return !n.Valid +} + +// String implements Stringer.String for NullProtoEnum. +func (n NullProtoEnum) String() string { + if !n.Valid { + return nullString + } + return fmt.Sprintf("%v", n.ProtoEnumVal) +} + +// MarshalJSON implements json.Marshaler.MarshalJSON for NullProtoEnum. +func (n NullProtoEnum) MarshalJSON() ([]byte, error) { + if n.Valid && n.ProtoEnumVal != nil { + return []byte(fmt.Sprintf("%v", n.ProtoEnumVal.Number())), nil + } + return jsonNullBytes, nil +} + +// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullProtoEnum. +func (n *NullProtoEnum) UnmarshalJSON(payload []byte) error { + if payload == nil { + return fmt.Errorf("payload should not be nil") + } + if bytes.Equal(payload, jsonNullBytes) { + n.ProtoEnumVal = nil + n.Valid = false + return nil + } + if reflect.ValueOf(n.ProtoEnumVal).Kind() != reflect.Ptr { + return errNotAPointerField(n, n.ProtoEnumVal) + } + num, err := strconv.ParseInt(string(payload), 10, 64) + if err != nil { + return fmt.Errorf("payload cannot be converted to Enum: got %v", string(payload)) + } + reflect.ValueOf(n.ProtoEnumVal).Elem().SetInt(num) + n.Valid = true + return nil +} + // NullRow represents a Cloud Spanner STRUCT that may be NULL. // See also the document for Row. // Note that NullRow is not a valid Cloud Spanner column Type. @@ -937,12 +1037,22 @@ func errNilDst(dst interface{}) error { return spannerErrorf(codes.InvalidArgument, "cannot decode into nil type %T", dst) } +// errNilDstField returns error for decoding into nil interface{} of Value field in NullProtoMessage or NullProtoEnum. +func errNilDstField(dst interface{}, field string) error { + return spannerErrorf(codes.InvalidArgument, "field %s in %T cannot be nil", field, dst) +} + // errNilArrElemType returns error for input Cloud Spanner data type being a array but without a // non-nil array element type. func errNilArrElemType(t *sppb.Type) error { return spannerErrorf(codes.FailedPrecondition, "array type %v is with nil array element type", t) } +// errNotValidSrc returns error if Valid field is false for NullProtoMessage and NullProtoEnum +func errNotValidSrc(dst interface{}) error { + return spannerErrorf(codes.InvalidArgument, "field \"Valid\" of %T cannot be set to false when writing data to Cloud Spanner. Use typed nil in %T to write null values to Cloud Spanner", dst, dst) +} + func errUnsupportedEmbeddedStructFields(fname string) error { return spannerErrorf(codes.InvalidArgument, "Embedded field: %s. Embedded and anonymous fields are not allowed "+ "when converting Go structs to Cloud Spanner STRUCT values. To create a STRUCT value with an "+ @@ -960,6 +1070,16 @@ func errBadEncoding(v *proto3.Value, err error) error { return spannerErrorf(codes.FailedPrecondition, "%v wasn't correctly encoded: <%v>", v, err) } +// errNotAPointer returns error for decoding a non pointer type. +func errNotAPointer(dst interface{}) error { + return spannerErrorf(codes.InvalidArgument, "destination %T must be a pointer", dst) +} + +// errNotAPointerField returns error for decoding a non pointer type. +func errNotAPointerField(dst interface{}, dstField interface{}) error { + return spannerErrorf(codes.InvalidArgument, "destination %T in %T must be a pointer", dstField, dst) +} + func parseNullTime(v *proto3.Value, p *NullTime, code sppb.TypeCode, isNull bool) error { if p == nil { return errNilDst(p) @@ -1120,7 +1240,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO if p == nil { return errNilDst(p) } - if code != sppb.TypeCode_BYTES { + if code != sppb.TypeCode_BYTES && code != sppb.TypeCode_PROTO { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -1140,7 +1260,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO if p == nil { return errNilDst(p) } - if acode != sppb.TypeCode_BYTES { + if acode != sppb.TypeCode_BYTES && acode != sppb.TypeCode_PROTO { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -1160,7 +1280,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO if p == nil { return errNilDst(p) } - if code != sppb.TypeCode_INT64 { + if code != sppb.TypeCode_INT64 && code != sppb.TypeCode_ENUM { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -1179,7 +1299,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO if p == nil { return errNilDst(p) } - if code != sppb.TypeCode_INT64 { + if code != sppb.TypeCode_INT64 && code != sppb.TypeCode_ENUM { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -1210,7 +1330,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO if p == nil { return errNilDst(p) } - if acode != sppb.TypeCode_INT64 { + if acode != sppb.TypeCode_INT64 && acode != sppb.TypeCode_ENUM { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -1244,7 +1364,7 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO if p == nil { return errNilDst(p) } - if acode != sppb.TypeCode_INT64 { + if acode != sppb.TypeCode_INT64 && acode != sppb.TypeCode_ENUM { return errTypeMismatch(code, acode, ptr) } if isNull { @@ -1844,6 +1964,94 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO *p = y case *GenericColumnValue: *p = GenericColumnValue{Type: t, Value: v} + case protoreflect.Enum: + if p == nil { + return errNilDst(p) + } + if reflect.ValueOf(p).Kind() != reflect.Ptr { + return errNotAPointer(p) + } + if code != sppb.TypeCode_ENUM && code != sppb.TypeCode_INT64 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + return errDstNotForNull(ptr) + } + y, err := getIntegerFromStringValue(v) + if err != nil { + return err + } + reflect.ValueOf(p).Elem().SetInt(y) + case *NullProtoEnum: + if p == nil { + return errNilDst(p) + } + if p.ProtoEnumVal == nil { + return errNilDstField(p, "ProtoEnumVal") + } + if reflect.ValueOf(p.ProtoEnumVal).Kind() != reflect.Ptr { + return errNotAPointer(p) + } + if code != sppb.TypeCode_ENUM && code != sppb.TypeCode_INT64 { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = NullProtoEnum{} + break + } + y, err := getIntegerFromStringValue(v) + if err != nil { + return err + } + reflect.ValueOf(p.ProtoEnumVal).Elem().SetInt(y) + p.Valid = true + case proto.Message: + if p == nil { + return errNilDst(p) + } + if reflect.ValueOf(p).Kind() != reflect.Ptr { + return errNotAPointer(p) + } + if code != sppb.TypeCode_PROTO && code != sppb.TypeCode_BYTES { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + return errDstNotForNull(ptr) + } + y, err := getBytesFromStringValue(v) + if err != nil { + return err + } + err = proto.Unmarshal(y, p) + if err != nil { + return err + } + case *NullProtoMessage: + if p == nil { + return errNilDst(p) + } + if p.ProtoMessageVal == nil { + return errNilDstField(p, "ProtoMessageVal") + } + if reflect.ValueOf(p.ProtoMessageVal).Kind() != reflect.Ptr { + return errNotAPointer(p.ProtoMessageVal) + } + if code != sppb.TypeCode_PROTO && code != sppb.TypeCode_BYTES { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = NullProtoMessage{} + break + } + y, err := getBytesFromStringValue(v) + if err != nil { + return err + } + err = proto.Unmarshal(y, p.ProtoMessageVal) + if err != nil { + return err + } + p.Valid = true default: // Check if the pointer is a custom type that implements spanner.Decoder // interface. @@ -1864,6 +2072,43 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO return decodableType.decodeValueToCustomType(v, t, acode, atypeAnnotation, ptr) } + rv := reflect.ValueOf(ptr) + typ := rv.Type() + // Check if the interface{} is a pointer and is of type array of proto columns + if typ.Kind() == reflect.Ptr && isAnArrayOfProtoColumn(ptr) && code == sppb.TypeCode_ARRAY { + if isNull { + break + } + // Get the user-defined type of the proto array + etyp := typ.Elem().Elem() + switch acode { + case sppb.TypeCode_PROTO, sppb.TypeCode_BYTES: + if etyp.Implements(protoMsgReflectType) { + if etyp.Kind() == reflect.Ptr { + x, err := getListValue(v) + if err != nil { + return err + } + return decodeProtoMessagePtrArray(x, t.ArrayElementType, rv) + } else { + return errTypeMismatch(code, acode, ptr) + } + } + case sppb.TypeCode_ENUM, sppb.TypeCode_INT64: + if etyp.Implements(protoEnumReflectType) { + x, err := getListValue(v) + if err != nil { + return err + } + if etyp.Kind() == reflect.Ptr { + return decodeProtoEnumPtrArray(x, t.ArrayElementType, rv) + } else { + return decodeProtoEnumArray(x, t.ArrayElementType, rv) + } + } + } + } + // Check if the proto encoding is for an array of structs. if !(code == sppb.TypeCode_ARRAY && acode == sppb.TypeCode_STRUCT) { return errTypeMismatch(code, acode, ptr) @@ -2485,6 +2730,32 @@ func errSrcVal(v *proto3.Value, want string) error { v, v.GetKind(), want) } +// getIntegerFromStringValue returns the integer value of the string value encoded in proto3.Value v +func getIntegerFromStringValue(v *proto3.Value) (int64, error) { + x, err := getStringValue(v) + if err != nil { + return 0, err + } + y, err := strconv.ParseInt(x, 10, 64) + if err != nil { + return 0, errBadEncoding(v, err) + } + return y, nil +} + +// getBytesFromStringValue returns the bytes value of the string value encoded in proto3.Value v +func getBytesFromStringValue(v *proto3.Value) ([]byte, error) { + x, err := getStringValue(v) + if err != nil { + return nil, err + } + y, err := base64.StdEncoding.DecodeString(x) + if err != nil { + return nil, errBadEncoding(v, err) + } + return y, nil +} + // getStringValue returns the string value encoded in proto3.Value v whose // kind is proto3.Value_StringValue. func getStringValue(v *proto3.Value) (string, error) { @@ -2890,6 +3161,69 @@ func decodeByteArray(pb *proto3.ListValue) ([][]byte, error) { return a, nil } +// decodeProtoMessagePtrArray decodes proto3.ListValue pb into a *proto.Message slice. +// The elements in the array implements proto.Message interface only if the element is a pointer (e.g. *ProtoMessage). +// However, if the element is a value (e.g. ProtoMessage), then it does not implement proto.Message. +// Therefore, decodeProtoMessagePtrArray allows decoding of proto message array if the array element is a pointer only. +func decodeProtoMessagePtrArray(pb *proto3.ListValue, t *sppb.Type, rv reflect.Value) error { + if pb == nil { + return errNilListValue("PROTO") + } + etyp := rv.Type().Elem().Elem().Elem() + a := reflect.MakeSlice(rv.Type().Elem(), len(pb.Values), len(pb.Values)) + for i, v := range pb.Values { + msg := reflect.New(etyp).Interface().(proto.Message) + if err := decodeValue(v, t, msg); err != nil { + return errDecodeArrayElement(i, v, "PROTO", err) + } + a.Index(i).Set(reflect.ValueOf(msg)) + } + rv.Elem().Set(a) + return nil +} + +// decodeProtoEnumPtrArray decodes proto3.ListValue pb into a *protoreflect.Enum slice. +func decodeProtoEnumPtrArray(pb *proto3.ListValue, t *sppb.Type, rv reflect.Value) error { + if pb == nil { + return errNilListValue("ENUM") + } + etyp := rv.Type().Elem().Elem().Elem() + a := reflect.MakeSlice(rv.Type().Elem(), len(pb.Values), len(pb.Values)) + for i, v := range pb.Values { + enum := reflect.New(etyp).Interface().(protoreflect.Enum) + if err := decodeValue(v, t, enum); err != nil { + return errDecodeArrayElement(i, v, "ENUM", err) + } + a.Index(i).Set(reflect.ValueOf(enum)) + } + rv.Elem().Set(a) + return nil +} + +// decodeProtoEnumArray decodes proto3.ListValue pb into a protoreflect.Enum slice. +func decodeProtoEnumArray(pb *proto3.ListValue, t *sppb.Type, rv reflect.Value) error { + if pb == nil { + return errNilListValue("ENUM") + } + a := reflect.MakeSlice(rv.Type().Elem(), len(pb.Values), len(pb.Values)) + // decodeValue method can decode only if ENUM is a pointer type. + // As the ENUM element in the Array is not a pointer type we cannot use decodeValue method + // and hence handle it separately. + for i, v := range pb.Values { + x, err := getStringValue(v) + if err != nil { + return err + } + y, err := strconv.ParseInt(x, 10, 64) + if err != nil { + return errBadEncoding(v, err) + } + a.Index(i).SetInt(y) + } + rv.Elem().Set(a) + return nil +} + // decodeNullTimeArray decodes proto3.ListValue pb into a NullTime slice. func decodeNullTimeArray(pb *proto3.ListValue) ([]NullTime, error) { if pb == nil { @@ -3564,6 +3898,43 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { pt = proto.Clone(v.Type).(*sppb.Type) case []GenericColumnValue: return nil, nil, errEncoderUnsupportedType(v) + case protoreflect.Enum: + if v != nil { + var protoEnumfqn string + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || !rv.IsNil() { + pb.Kind = stringKind(strconv.FormatInt(int64(v.Number()), 10)) + protoEnumfqn = string(v.Descriptor().FullName()) + } else { + defaultType := reflect.Zero(rv.Type().Elem()).Interface().(protoreflect.Enum) + protoEnumfqn = string(defaultType.Descriptor().FullName()) + } + pt = protoEnumType(protoEnumfqn) + } + case NullProtoEnum: + if v.Valid { + return encodeValue(v.ProtoEnumVal) + } else { + return nil, nil, errNotValidSrc(v) + } + case proto.Message: + if v != nil { + if proto.MessageReflect(v).IsValid() { + bytes, err := proto.Marshal(v) + if err != nil { + return nil, nil, err + } + pb.Kind = stringKind(base64.StdEncoding.EncodeToString(bytes)) + } + protoMessagefqn := string(proto.MessageReflect(v).Descriptor().FullName()) + pt = protoMessageType(protoMessagefqn) + } + case NullProtoMessage: + if v.Valid { + return encodeValue(v.ProtoMessageVal) + } else { + return nil, nil, errNotValidSrc(v) + } default: // Check if the value is a custom type that implements spanner.Encoder // interface. @@ -3585,7 +3956,7 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { return encodeValue(converted) } - if !isStructOrArrayOfStructValue(v) { + if !isStructOrArrayOfStructValue(v) && !isAnArrayOfProtoColumn(v) { return nil, nil, errEncoderUnsupportedType(v) } typ := reflect.TypeOf(v) @@ -3598,7 +3969,11 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { // Value is a slice of Go struct values/ptrs. if typ.Kind() == reflect.Slice { - return encodeStructArray(v) + if isAnArrayOfProtoColumn(v) { + return encodeProtoArray(v) + } else { + return encodeStructArray(v) + } } } return pb, pt, nil @@ -3851,6 +4226,42 @@ func encodeStructArray(v interface{}) (*proto3.Value, *sppb.Type, error) { return listProto(values...), listType(elemTyp), nil } +// Encodes a slice of proto messages or enum in v to the spanner Value and Type +// protos. +func encodeProtoArray(v interface{}) (*proto3.Value, *sppb.Type, error) { + pb := nullProto() + var pt *sppb.Type + var err error + sliceval := reflect.ValueOf(v) + etyp := reflect.TypeOf(v).Elem() + + if etyp.Implements(protoMsgReflectType) { + if !sliceval.IsNil() { + pb, err = encodeProtoMessageArray(sliceval.Len(), func(i int) reflect.Value { return sliceval.Index(i) }) + if err != nil { + return nil, nil, err + } + } + defaultInstance := reflect.Zero(etyp).Interface().(proto.Message) + protoMessagefqn := string(proto.MessageReflect(defaultInstance).Descriptor().FullName()) + pt = listType(protoMessageType(protoMessagefqn)) + } else if etyp.Implements(protoEnumReflectType) { + if !sliceval.IsNil() { + pb, err = encodeProtoEnumArray(sliceval.Len(), func(i int) reflect.Value { return sliceval.Index(i) }) + if err != nil { + return nil, nil, err + } + } + if etyp.Kind() == reflect.Ptr { + etyp = etyp.Elem() + } + defaultInstance := reflect.Zero(etyp).Interface().(protoreflect.Enum) + protoEnumfqn := string(defaultInstance.Descriptor().FullName()) + pt = listType(protoEnumType(protoEnumfqn)) + } + return pb, pt, nil +} + func isStructOrArrayOfStructValue(v interface{}) bool { typ := reflect.TypeOf(v) if typ.Kind() == reflect.Slice { @@ -3862,6 +4273,17 @@ func isStructOrArrayOfStructValue(v interface{}) bool { return typ.Kind() == reflect.Struct } +func isAnArrayOfProtoColumn(v interface{}) bool { + typ := reflect.TypeOf(v) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() == reflect.Slice { + typ = typ.Elem() + } + return typ.Implements(protoMsgReflectType) || typ.Implements(protoEnumReflectType) +} + func isSupportedMutationType(v interface{}) bool { switch v.(type) { case nil, string, *string, NullString, []string, []*string, []NullString, @@ -3872,7 +4294,7 @@ func isSupportedMutationType(v interface{}) bool { time.Time, *time.Time, []time.Time, []*time.Time, NullTime, []NullTime, civil.Date, *civil.Date, []civil.Date, []*civil.Date, NullDate, []NullDate, big.Rat, *big.Rat, []big.Rat, []*big.Rat, NullNumeric, []NullNumeric, - GenericColumnValue: + GenericColumnValue, proto.Message, protoreflect.Enum, NullProtoMessage, NullProtoEnum: return true default: // Check if the custom type implements spanner.Encoder interface. @@ -3880,6 +4302,10 @@ func isSupportedMutationType(v interface{}) bool { return true } + if isAnArrayOfProtoColumn(v) { + return true + } + decodableType := getDecodableSpannerType(v, false) return decodableType != spannerTypeUnknown && decodableType != spannerTypeInvalid } @@ -3916,6 +4342,32 @@ func encodeArray(len int, at func(int) interface{}) (*proto3.Value, error) { return listProto(vs...), nil } +func encodeProtoMessageArray(len int, at func(int) reflect.Value) (*proto3.Value, error) { + vs := make([]*proto3.Value, len) + var err error + for i := 0; i < len; i++ { + v := at(i).Interface().(proto.Message) + vs[i], _, err = encodeValue(v) + if err != nil { + return nil, err + } + } + return listProto(vs...), nil +} + +func encodeProtoEnumArray(len int, at func(int) reflect.Value) (*proto3.Value, error) { + vs := make([]*proto3.Value, len) + var err error + for i := 0; i < len; i++ { + v := at(i).Interface().(protoreflect.Enum) + vs[i], _, err = encodeValue(v) + if err != nil { + return nil, err + } + } + return listProto(vs...), nil +} + func spannerTagParser(t reflect.StructTag) (name string, keep bool, other interface{}, err error) { if s := t.Get("spanner"); s != "" { if s == "-" { diff --git a/spanner/value_test.go b/spanner/value_test.go index 01dd88eb3c0a..db64620c1717 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -27,10 +27,12 @@ import ( "cloud.google.com/go/civil" "cloud.google.com/go/internal/testutil" + pb "cloud.google.com/go/spanner/testdata/protos" "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" "github.com/google/go-cmp/cmp" sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/protobuf/testing/protocmp" ) var ( @@ -247,17 +249,37 @@ func TestEncodeValue(t *testing.T) { maxNumValuePtr, _ := (&big.Rat{}).SetString("99999999999999999999999999999.999999999") minNumValuePtr, _ := (&big.Rat{}).SetString("-99999999999999999999999999999.999999999") + singer1ProtoEnum := pb.Genre_ROCK + singer1ProtoMsg := &pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singer1ProtoEnum, + } + + singer2ProtoEnum := pb.Genre_FOLK + singer2ProtoMsg := &pb.SingerInfo{ + SingerId: proto.Int64(2), + BirthDate: proto.String("February"), + Nationality: proto.String("Country2"), + Genre: &singer2ProtoEnum, + } + protoMessagefqn := "spanner.examples.music.SingerInfo" + protoEnumfqn := "spanner.examples.music.Genre" + var ( - tString = stringType() - tInt = intType() - tBool = boolType() - tFloat = floatType() - tBytes = bytesType() - tTime = timeType() - tDate = dateType() - tNumeric = numericType() - tJSON = jsonType() - tPGNumeric = pgNumericType() + tString = stringType() + tInt = intType() + tBool = boolType() + tFloat = floatType() + tBytes = bytesType() + tTime = timeType() + tDate = dateType() + tNumeric = numericType() + tJSON = jsonType() + tPGNumeric = pgNumericType() + tProtoMessage = protoMessageType(protoMessagefqn) + tProtoEnum = protoEnumType(protoEnumfqn) ) for i, test := range []struct { in interface{} @@ -459,6 +481,27 @@ func TestEncodeValue(t *testing.T) { {CustomPGNumeric{Valid: false}, nullProto(), tPGNumeric, "PG Numeric with a null value"}, {[]CustomPGNumeric(nil), nullProto(), listType(tPGNumeric), "null []PGNumeric"}, {[]CustomPGNumeric{{"123.456", true}, {Valid: false}}, listProto(stringProto("123.456"), nullProto()), listType(tPGNumeric), "[]PGNumeric"}, + // PROTO MESSAGE AND PROTO ENUM + {singer1ProtoMsg, protoMessageProto(singer1ProtoMsg), tProtoMessage, "Proto Message"}, + {singer1ProtoEnum, protoEnumProto(singer1ProtoEnum), tProtoEnum, "Proto Enum"}, + {(*pb.SingerInfo)(nil), nullProto(), tProtoMessage, "Proto Message with nil"}, + {(*pb.Genre)(nil), nullProto(), tProtoEnum, "Proto Enum with nil"}, + {NullProtoMessage{singer1ProtoMsg, true}, protoMessageProto(singer1ProtoMsg), tProtoMessage, "NullProto with value"}, + {NullProtoEnum{singer1ProtoEnum, true}, protoEnumProto(singer1ProtoEnum), tProtoEnum, "NullEnum with value"}, + {NullProtoMessage{(*pb.SingerInfo)(nil), true}, nullProto(), tProtoMessage, "NullProto with value nil"}, + {NullProtoEnum{(*pb.Genre)(nil), true}, nullProto(), tProtoEnum, "NullEnum with value nil"}, + // ARRAY OF PROTO MESSAGES AND PROTO ENUM + {[]*pb.SingerInfo{singer1ProtoMsg, singer2ProtoMsg}, listProto(protoMessageProto(singer1ProtoMsg), protoMessageProto(singer2ProtoMsg)), listType(tProtoMessage), "Array of Proto Message"}, + {[]*pb.SingerInfo{singer1ProtoMsg, singer2ProtoMsg, nil, (*pb.SingerInfo)(nil)}, listProto(protoMessageProto(singer1ProtoMsg), protoMessageProto(singer2ProtoMsg), nullProto(), nullProto()), listType(tProtoMessage), "Array of Proto Message with nil values"}, + {[]*pb.Genre{&singer1ProtoEnum, &singer2ProtoEnum}, listProto(protoEnumProto(singer1ProtoEnum), protoEnumProto(singer2ProtoEnum)), listType(tProtoEnum), "Array of Proto Enum 1"}, + {[]pb.Genre{singer1ProtoEnum, singer2ProtoEnum}, listProto(protoEnumProto(singer1ProtoEnum), protoEnumProto(singer2ProtoEnum)), listType(tProtoEnum), "Array of Proto Enum 2"}, + {[]*pb.Genre{&singer1ProtoEnum, &singer2ProtoEnum, nil, (*pb.Genre)(nil)}, listProto(protoEnumProto(singer1ProtoEnum), protoEnumProto(singer2ProtoEnum), nullProto(), nullProto()), listType(tProtoEnum), "Array of Proto Enum with nil values"}, + {[]*pb.SingerInfo{}, listProto(), listType(tProtoMessage), "Empty Array of Proto Message"}, + {[]*pb.SingerInfo(nil), nullProto(), listType(tProtoMessage), "Nil array of Proto Message"}, + {[]*pb.Genre{}, listProto(), listType(tProtoEnum), "Empty Array of Proto Enum 1"}, + {[]*pb.Genre(nil), nullProto(), listType(tProtoEnum), "Nil Array of Proto Enum 1"}, + {[]pb.Genre{}, listProto(), listType(tProtoEnum), "Empty Array of Proto Enum 2"}, + {[]pb.Genre(nil), nullProto(), listType(tProtoEnum), "Nil Array of Proto Enum 2"}, } { got, gotType, err := encodeValue(test.in) if err != nil { @@ -501,6 +544,9 @@ func TestEncodeInvalidValues(t *testing.T) { // CUSTOM NUMERIC {desc: "custom numeric type with invalid scale component", in: CustomNumeric(*invalidNumPtr1), errMsg: "max scale for a numeric is 9. The requested numeric has more"}, {desc: "custom numeric type with invalid whole component", in: CustomNumeric(*invalidNumPtr2), errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"}, + // PROTO MESSAGE AND PROTO ENUM + {desc: "Invalid Null Proto", in: NullProtoMessage{}, errMsg: "spanner: code = \"InvalidArgument\", desc = \"field \\\"Valid\\\" of spanner.NullProtoMessage cannot be set to false when writing data to Cloud Spanner. Use typed nil in spanner.NullProtoMessage to write null values to Cloud Spanner\""}, + {desc: "Invalid Null Enum", in: NullProtoEnum{}, errMsg: "spanner: code = \"InvalidArgument\", desc = \"field \\\"Valid\\\" of spanner.NullProtoEnum cannot be set to false when writing data to Cloud Spanner. Use typed nil in spanner.NullProtoEnum to write null values to Cloud Spanner\""}, } { _, _, err := encodeValue(test.in) if err == nil { @@ -1393,6 +1439,23 @@ func TestDecodeValue(t *testing.T) { var dNilPtr *civil.Date d2Value := d2 + singerEnumValue := pb.Genre_ROCK + singerProtoMsg := pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singerEnumValue, + } + singer2ProtoEnum := pb.Genre_FOLK + singer2ProtoMsg := pb.SingerInfo{ + SingerId: proto.Int64(2), + BirthDate: proto.String("February"), + Nationality: proto.String("Country2"), + Genre: &singer2ProtoEnum, + } + protoMessagefqn := "spanner.examples.music.SingerInfo" + protoEnumfqn := "spanner.examples.music.Genre" + for _, test := range []struct { desc string proto *proto3.Value @@ -1787,6 +1850,23 @@ func TestDecodeValue(t *testing.T) { {desc: "decode NULL array of bool to CustomStructToNull", proto: nullProto(), protoType: listType(boolType()), want: customStructToNull{}}, {desc: "decode NULL array of float to CustomStructToNull", proto: nullProto(), protoType: listType(floatType()), want: customStructToNull{}}, {desc: "decode NULL array of string to CustomStructToNull", proto: nullProto(), protoType: listType(stringType()), want: customStructToNull{}}, + // PROTO MESSAGE AND PROTO ENUM + {desc: "decode PROTO to proto.Message", proto: protoMessageProto(&singerProtoMsg), protoType: protoMessageType(protoMessagefqn), want: singerProtoMsg}, + {desc: "decode ENUM to protoreflect.Enum", proto: protoEnumProto(pb.Genre_ROCK), protoType: protoEnumType(protoEnumfqn), want: singerEnumValue}, + {desc: "decode PROTO to NullProto", proto: protoMessageProto(&singerProtoMsg), protoType: protoMessageType(protoMessagefqn), want: NullProtoMessage{&singerProtoMsg, true}}, + {desc: "decode NULL to NullProto", proto: nullProto(), protoType: protoMessageType(protoMessagefqn), want: NullProtoMessage{}}, + {desc: "decode ENUM to NullEnum", proto: protoEnumProto(pb.Genre_ROCK), protoType: protoEnumType(protoEnumfqn), want: NullProtoEnum{&singerEnumValue, true}}, + {desc: "decode NULL to NullEnum", proto: nullProto(), protoType: protoEnumType(protoEnumfqn), want: NullProtoEnum{}}, + // ARRAY OF PROTO MESSAGES AND PROTO ENUM + {desc: "decode ARRAY> to []*pb.SingerInfo", proto: listProto(protoMessageProto(&singerProtoMsg), protoMessageProto(&singer2ProtoMsg)), protoType: listType(protoMessageType(protoMessagefqn)), want: []*pb.SingerInfo{&singerProtoMsg, &singer2ProtoMsg}}, + {desc: "decode ARRAY> to []*pb.Genre", proto: listProto(protoEnumProto(pb.Genre_ROCK), protoEnumProto(pb.Genre_FOLK)), protoType: listType(protoEnumType(protoEnumfqn)), want: []*pb.Genre{&singerEnumValue, &singer2ProtoEnum}}, + {desc: "decode ARRAY> to []pb.Genre", proto: listProto(protoEnumProto(pb.Genre_ROCK), protoEnumProto(pb.Genre_FOLK)), protoType: listType(protoEnumType(protoEnumfqn)), want: []pb.Genre{singerEnumValue, singer2ProtoEnum}}, + {desc: "decode NULL to []*pb.SingerInfo", proto: nullProto(), protoType: listType(protoMessageType(protoMessagefqn)), want: []*pb.SingerInfo(nil)}, + {desc: "decode NULL to []*pb.Genre", proto: nullProto(), protoType: listType(protoEnumType(protoEnumfqn)), want: []*pb.Genre(nil)}, + {desc: "decode NULL to []pb.Genre", proto: nullProto(), protoType: listType(protoEnumType(protoEnumfqn)), want: []pb.Genre(nil)}, + {desc: "decode empty array to []*pb.SingerInfo", proto: listProto(), protoType: listType(protoMessageType(protoMessagefqn)), want: []*pb.SingerInfo{}}, + {desc: "decode empty array to []*pb.Genre", proto: listProto(), protoType: listType(protoEnumType(protoEnumfqn)), want: []*pb.Genre{}}, + {desc: "decode empty array to []pb.Genre", proto: listProto(), protoType: listType(protoEnumType(protoEnumfqn)), want: []pb.Genre{}}, } { gotp := reflect.New(reflect.TypeOf(test.want)) v := gotp.Interface() @@ -1806,6 +1886,11 @@ func TestDecodeValue(t *testing.T) { nullValue.Time = time.Unix(100, 100) case *NullDate: nullValue.Date = civil.DateOf(time.Unix(100, 200)) + case *NullProtoMessage: + nullValue.ProtoMessageVal = &pb.SingerInfo{} + case *NullProtoEnum: + var singerProtoEnumDefault pb.Genre + nullValue.ProtoEnumVal = &singerProtoEnumDefault default: } err := decodeValue(test.proto, test.protoType, v) @@ -1820,8 +1905,15 @@ func TestDecodeValue(t *testing.T) { continue } got := reflect.Indirect(gotp).Interface() - if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{}, customStructToNull{})) { - t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want) + switch v.(type) { + case proto.Message: + if diff := cmp.Diff(got, test.want, protocmp.Transform()); diff != "" { + t.Errorf("unexpected difference in proto message :\n%v", diff) + } + default: + if !testutil.Equal(got, test.want, cmp.AllowUnexported(CustomNumeric{}, CustomTime{}, CustomDate{}, Row{}, big.Rat{}, big.Int{}, customStructToNull{})) { + t.Errorf("%s: unexpected decoding result - got %v (%T), want %v (%T)", test.desc, got, got, test.want, test.want) + } } } } @@ -2625,6 +2717,15 @@ func TestJSONMarshal_NullTypes(t *testing.T) { msg := Message{"Alice", "Hello", 1294706395881547000} jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` + singerProtoEnum := pb.Genre_ROCK + singerProtoMessage := pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singerProtoEnum, + } + singerProtoMessageJsonStr := `{"singer_id":1,"birth_date":"January","nationality":"Country1","genre":3}` + type testcase struct { input interface{} expect string @@ -2716,6 +2817,27 @@ func TestJSONMarshal_NullTypes(t *testing.T) { {input: PGNumeric{}, expect: "null"}, }, }, + { + "NullProtoMessage", + []testcase{ + {input: NullProtoMessage{&singerProtoMessage, true}, expect: singerProtoMessageJsonStr}, + {input: &NullProtoMessage{&singerProtoMessage, true}, expect: singerProtoMessageJsonStr}, + {input: &NullProtoMessage{&singerProtoMessage, false}, expect: "null"}, + {input: NullProtoMessage{}, expect: "null"}, + }, + }, + { + "NullProtoEnum", + []testcase{ + {input: NullProtoEnum{singerProtoEnum, true}, expect: "3"}, + {input: NullProtoEnum{&singerProtoEnum, true}, expect: "3"}, + {input: &NullProtoEnum{singerProtoEnum, true}, expect: "3"}, + {input: &NullProtoEnum{&singerProtoEnum, true}, expect: "3"}, + {input: NullProtoEnum{singerProtoEnum, false}, expect: "null"}, + {input: NullProtoEnum{nil, true}, expect: "null"}, + {input: NullProtoEnum{}, expect: "null"}, + }, + }, } { t.Run(test.name, func(t *testing.T) { for _, tc := range test.cases { @@ -2732,6 +2854,14 @@ func TestJSONMarshal_NullTypes(t *testing.T) { // Test converting json strings to nullable types. func TestJSONUnmarshal_NullTypes(t *testing.T) { jsonStr := `{"Body":"Hello","Name":"Alice","Time":1294706395881547000}` + singerProtoEnum := pb.Genre_ROCK + singerProtoMessage := pb.SingerInfo{ + SingerId: proto.Int64(1), + BirthDate: proto.String("January"), + Nationality: proto.String("Country1"), + Genre: &singerProtoEnum, + } + singerProtoMessageJsonStr := `{"singer_id":1,"birth_date":"January","nationality":"Country1","genre":3}` type testcase struct { input []byte @@ -2838,6 +2968,26 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { {input: []byte(`"123.456`), got: PGNumeric{}, isNull: true, expect: nullString, expectError: true}, }, }, + { + "NullProtoMessage", + []testcase{ + {input: []byte(singerProtoMessageJsonStr), got: NullProtoMessage{&pb.SingerInfo{}, true}, isNull: false, expect: singerProtoMessage.String(), expectError: false}, + {input: []byte("null"), got: NullProtoMessage{&pb.SingerInfo{}, true}, isNull: true, expect: nullString, expectError: false}, + {input: nil, got: NullProtoMessage{&pb.SingerInfo{}, true}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(""), got: NullProtoMessage{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(`{invalid_json_string}`), got: NullProtoMessage{}, isNull: true, expect: nullString, expectError: true}, + }, + }, + { + "NullProtoEnum", + []testcase{ + {input: []byte("3"), got: NullProtoEnum{&singerProtoEnum, true}, isNull: false, expect: singerProtoEnum.String(), expectError: false}, + {input: []byte("null"), got: NullProtoEnum{}, isNull: true, expect: nullString, expectError: false}, + {input: nil, got: NullProtoEnum{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(""), got: NullProtoEnum{}, isNull: true, expect: nullString, expectError: true}, + {input: []byte(`"hello`), got: NullProtoEnum{}, isNull: true, expect: nullString, expectError: true}, + }, + }, } { t.Run(test.name, func(t *testing.T) { for _, tc := range test.cases { @@ -2869,6 +3019,12 @@ func TestJSONUnmarshal_NullTypes(t *testing.T) { case PGNumeric: err := json.Unmarshal(tc.input, &v) expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) + case NullProtoMessage: + err := json.Unmarshal(tc.input, &v) + expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) + case NullProtoEnum: + err := json.Unmarshal(tc.input, &v) + expectUnmarshalNullableTypes(t, err, v, tc.isNull, tc.expect, tc.expectError) default: t.Fatalf("Unknown type: %T", v) }