From fced4f616be9057a68c29119bec96ed2106271cf Mon Sep 17 00:00:00 2001 From: Sylvain Cleymans Date: Fri, 15 Jan 2021 10:55:50 +1300 Subject: [PATCH] Add support for nullable types This allows to differentiate between an omitted value and a null value in an input struct. --- graphql_test.go | 184 ++++++++++++++++++++++++++++++++- internal/exec/packer/packer.go | 62 +++++++---- nullable_types.go | 150 +++++++++++++++++++++++++++ 3 files changed, 376 insertions(+), 20 deletions(-) create mode 100644 nullable_types.go diff --git a/graphql_test.go b/graphql_test.go index fe8fedd9..a2106358 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -2997,7 +2997,7 @@ func TestInput(t *testing.T) { }) } -type inputArgumentsHello struct {} +type inputArgumentsHello struct{} type inputArgumentsScalarMismatch1 struct{} @@ -3755,3 +3755,185 @@ func TestPointerReturnForNonNull(t *testing.T) { }, }) } + +type nullableInput struct { + String graphql.NullString + Int graphql.NullInt + Bool graphql.NullBool + Time graphql.NullTime + Float graphql.NullFloat +} + +type nullableResult struct { + String string + Int string + Bool string + Time string + Float string +} + +type nullableResolver struct { +} + +func (r *nullableResolver) TestNullables(args struct { + Input *nullableInput +}) nullableResult { + var res nullableResult + if args.Input.String.Set { + if args.Input.String.Value == nil { + res.String = "" + } else { + res.String = *args.Input.String.Value + } + } + + if args.Input.Int.Set { + if args.Input.Int.Value == nil { + res.Int = "" + } else { + res.Int = fmt.Sprintf("%d", *args.Input.Int.Value) + } + } + + if args.Input.Float.Set { + if args.Input.Float.Value == nil { + res.Float = "" + } else { + res.Float = fmt.Sprintf("%.2f", *args.Input.Float.Value) + } + } + + if args.Input.Bool.Set { + if args.Input.Bool.Value == nil { + res.Bool = "" + } else { + res.Bool = fmt.Sprintf("%t", *args.Input.Bool.Value) + } + } + + if args.Input.Time.Set { + if args.Input.Time.Value == nil { + res.Time = "" + } else { + res.Time = args.Input.Time.Value.Format(time.RFC3339) + } + } + + return res +} + +func TestNullable(t *testing.T) { + schema := ` + scalar Time + + input MyInput { + string: String + int: Int + float: Float + bool: Boolean + time: Time + } + + type Result { + string: String! + int: String! + float: String! + bool: String! + time: String! + } + + type Query { + testNullables(input: MyInput): Result! + } + ` + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(schema, &nullableResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + testNullables(input: { + string: "test" + int: 1234 + float: 42.42 + bool: true + time: "2021-01-02T15:04:05Z" + }) { + string + int + float + bool + time + } + } + `, + ExpectedResult: ` + { + "testNullables": { + "string": "test", + "int": "1234", + "float": "42.42", + "bool": "true", + "time": "2021-01-02T15:04:05Z" + } + } + `, + }, + { + Schema: graphql.MustParseSchema(schema, &nullableResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + testNullables(input: { + string: null + int: null + float: null + bool: null + time: null + }) { + string + int + float + bool + time + } + } + `, + ExpectedResult: ` + { + "testNullables": { + "string": "", + "int": "", + "float": "", + "bool": "", + "time": "" + } + } + `, + }, + { + Schema: graphql.MustParseSchema(schema, &nullableResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + testNullables(input: {}) { + string + int + float + bool + time + } + } + `, + ExpectedResult: ` + { + "testNullables": { + "string": "", + "int": "", + "float": "", + "bool": "", + "time": "" + } + } + `, + }, + }) +} diff --git a/internal/exec/packer/packer.go b/internal/exec/packer/packer.go index fca88da3..deadacb8 100644 --- a/internal/exec/packer/packer.go +++ b/internal/exec/packer/packer.go @@ -78,24 +78,37 @@ func (b *Builder) assignPacker(target *packer, schemaType common.Type, reflectTy func (b *Builder) makePacker(schemaType common.Type, reflectType reflect.Type) (packer, error) { t, nonNull := unwrapNonNull(schemaType) if !nonNull { - if reflectType.Kind() != reflect.Ptr { - return nil, fmt.Errorf("%s is not a pointer", reflectType) - } - elemType := reflectType.Elem() - addPtr := true - if _, ok := t.(*schema.InputObject); ok { - elemType = reflectType // keep pointer for input objects - addPtr = false - } - elem, err := b.makeNonNullPacker(t, elemType) - if err != nil { - return nil, err + if reflectType.Kind() == reflect.Ptr { + elemType := reflectType.Elem() + addPtr := true + if _, ok := t.(*schema.InputObject); ok { + elemType = reflectType // keep pointer for input objects + addPtr = false + } + elem, err := b.makeNonNullPacker(t, elemType) + if err != nil { + return nil, err + } + return &nullPacker{ + elemPacker: elem, + valueType: reflectType, + addPtr: addPtr, + }, nil + } else if isNullable(reflectType) { + elemType := reflectType + addPtr := false + elem, err := b.makeNonNullPacker(t, elemType) + if err != nil { + return nil, err + } + return &nullPacker{ + elemPacker: elem, + valueType: reflectType, + addPtr: addPtr, + }, nil + } else { + return nil, fmt.Errorf("%s is not a pointer or a nullable type", reflectType) } - return &nullPacker{ - elemPacker: elem, - valueType: reflectType, - addPtr: addPtr, - }, nil } return b.makeNonNullPacker(t, reflectType) @@ -266,7 +279,7 @@ type nullPacker struct { } func (p *nullPacker) Pack(value interface{}) (reflect.Value, error) { - if value == nil { + if value == nil && !isNullable(p.valueType) { return reflect.Zero(p.valueType), nil } @@ -305,7 +318,7 @@ type unmarshalerPacker struct { } func (p *unmarshalerPacker) Pack(value interface{}) (reflect.Value, error) { - if value == nil { + if value == nil && !isNullable(p.ValueType) { return reflect.Value{}, errors.Errorf("got null for non-null") } @@ -369,3 +382,14 @@ func unwrapNonNull(t common.Type) (common.Type, bool) { func stripUnderscore(s string) string { return strings.Replace(s, "_", "", -1) } + +// NullUnmarshaller is an unmarshaller that can handle a nil input +type NullUnmarshaller interface { + Unmarshaler + Nullable() +} + +func isNullable(t reflect.Type) bool { + _, ok := reflect.New(t).Interface().(NullUnmarshaller) + return ok +} diff --git a/nullable_types.go b/nullable_types.go new file mode 100644 index 00000000..531863ae --- /dev/null +++ b/nullable_types.go @@ -0,0 +1,150 @@ +package graphql + +import ( + "fmt" +) + +// NullString is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullString struct { + Value *string + Set bool +} + +func (NullString) ImplementsGraphQLType(name string) bool { + return name == "String" +} + +func (s *NullString) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case string: + s.Value = &v + return nil + default: + return fmt.Errorf("wrong type for String: %T", v) + } +} + +func (s *NullString) Nullable() {} + +// NullBool is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullBool struct { + Value *bool + Set bool +} + +func (NullBool) ImplementsGraphQLType(name string) bool { + return name == "Boolean" +} + +func (s *NullBool) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case bool: + s.Value = &v + return nil + default: + return fmt.Errorf("wrong type for Boolean: %T", v) + } +} + +func (s *NullBool) Nullable() {} + +// NullInt is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullInt struct { + Value *int32 + Set bool +} + +func (NullInt) ImplementsGraphQLType(name string) bool { + return name == "Int" +} + +func (s *NullInt) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case int32: + s.Value = &v + return nil + default: + return fmt.Errorf("wrong type for Int: %T", v) + } +} + +func (s *NullInt) Nullable() {} + +// NullFloat is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullFloat struct { + Value *float64 + Set bool +} + +func (NullFloat) ImplementsGraphQLType(name string) bool { + return name == "Float" +} + +func (s *NullFloat) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case float64: + s.Value = &v + return nil + default: + return fmt.Errorf("wrong type for Float: %T", v) + } +} + +func (s *NullFloat) Nullable() {} + +// NullTime is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullTime struct { + Value *Time + Set bool +} + +func (NullTime) ImplementsGraphQLType(name string) bool { + return name == "Time" +} + +func (s *NullTime) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + s.Value = new(Time) + return s.Value.UnmarshalGraphQL(input) +} + +func (s *NullTime) Nullable() {}