Skip to content

Commit

Permalink
fix: fix several go-runtime JSON encoding issues (#1417)
Browse files Browse the repository at this point in the history
This PR:
* Rips out the existing textMarshaler and jsonMarshaler usage from
encoding.go. We may want to add those back for
#1296 down the road, but we
will need to be thoughtful about how we do that. Removing it for now
keeps the logic much more predictable.
* Moves the (un)marshaling logic for `ftl.Option` out of `option.go` and
into `encoding.go`.
* Special-cases both `time.Time` (the only stdlib type we currently
support) and `ftl.Option`. Also `json.RawMessage` for _just_ encoding to
preserve the existing `omitempty` behavior.
* Fixes some existing issues where the Pointer unmarshaling wasn't
actually working correctly
* [eww] Adds a rather grotesque alternative to `Peek()` in
`isNextTokenNull()` because json Decoder does not support Peek.
* [eww] Makes the ftl.Option struct fields public so that they are
settable by `encoding.go`.

Suggestions welcome for both counts of [eww] above :)

Fixes #1247.
Addresses most of #1262, except
`omitempty` is only working for json.RawMessage for now.
  • Loading branch information
deniseli authored May 7, 2024
1 parent e31e64d commit 053f922
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 137 deletions.
161 changes: 104 additions & 57 deletions go-runtime/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,18 @@ package encoding

import (
"bytes"
"encoding"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"reflect"
"strings"
"time"
"unicode"

"github.com/TBD54566975/ftl/backend/schema/strcase"
)

var (
textMarshaler = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
textUnmarshaler = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
jsonMarshaler = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
)

func Marshal(v any) ([]byte, error) {
w := &bytes.Buffer{}
err := encodeValue(reflect.ValueOf(v), w)
Expand All @@ -31,37 +27,29 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
w.WriteString("null")
return nil
}

t := v.Type()
switch {
case t.Kind() == reflect.Ptr && t.Elem().Implements(jsonMarshaler):
v = v.Elem()
fallthrough

case t.Implements(jsonMarshaler):
enc := v.Interface().(json.Marshaler) //nolint:forcetypeassert
data, err := enc.MarshalJSON()
// Special-cased types
switch {
case t == reflect.TypeFor[time.Time]():
data, err := json.Marshal(v.Interface().(time.Time))
if err != nil {
return err
}
w.Write(data)
return nil

case t.Kind() == reflect.Ptr && t.Elem().Implements(textMarshaler):
v = v.Elem()
fallthrough

case t.Implements(textMarshaler):
enc := v.Interface().(encoding.TextMarshaler) //nolint:forcetypeassert
data, err := enc.MarshalText()
if err != nil {
return err
}
data, err = json.Marshal(string(data))
case t == reflect.TypeFor[json.RawMessage]():
data, err := json.Marshal(v.Interface().(json.RawMessage))
if err != nil {
return err
}
w.Write(data)
return nil

case isOption(v.Type()):
return encodeOption(v, w)
}

switch v.Kind() {
Expand Down Expand Up @@ -107,6 +95,24 @@ func encodeValue(v reflect.Value, w *bytes.Buffer) error {
}
}

var ftlOptionTypePath = "github.com/TBD54566975/ftl/go-runtime/ftl.Option"

func isOption(t reflect.Type) bool {
return strings.HasPrefix(t.PkgPath()+"."+t.Name(), ftlOptionTypePath)
}

func encodeOption(v reflect.Value, w *bytes.Buffer) error {
if v.NumField() != 2 {
return fmt.Errorf("value cannot have type ftl.Option since it has %d fields rather than 2: %v", v.NumField(), v)
}
optionOk := v.Field(1).Bool()
if !optionOk {
w.WriteString("null")
return nil
}
return encodeValue(v.Field(0), w)
}

func encodeStruct(v reflect.Value, w *bytes.Buffer) error {
w.WriteRune('{')
afterFirst := false
Expand Down Expand Up @@ -213,50 +219,34 @@ func Unmarshal(data []byte, v any) error {

func decodeValue(d *json.Decoder, v reflect.Value) error {
if !v.CanSet() {
return fmt.Errorf("cannot set value")
allBytes, _ := io.ReadAll(d.Buffered())
return fmt.Errorf("cannot set value: %v", string(allBytes))
}

t := v.Type()
switch {
case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(jsonUnmarshaler):
v = v.Addr()
fallthrough

case t.Implements(jsonUnmarshaler):
if v.IsNil() {
v.Set(reflect.New(t.Elem()))
}
o := v.Interface()
return d.Decode(&o)

case v.Kind() != reflect.Ptr && v.CanAddr() && v.Addr().Type().Implements(textUnmarshaler):
v = v.Addr()
fallthrough

case t.Implements(textUnmarshaler):
if v.IsNil() {
v.Set(reflect.New(t.Elem()))
}
dec := v.Interface().(encoding.TextUnmarshaler) //nolint:forcetypeassert
var s string
if err := d.Decode(&s); err != nil {
return err
}
return dec.UnmarshalText([]byte(s))
// Special-case types
switch {
case t == reflect.TypeFor[time.Time]():
return d.Decode(v.Addr().Interface())
case isOption(v.Type()):
return decodeOption(d, v)
}

switch v.Kind() {
case reflect.Struct:
return decodeStruct(d, v)

case reflect.Ptr:
if token, err := d.Token(); err != nil {
return err
} else if token == nil {
return handleIfNextTokenIsNull(d, func(d *json.Decoder) error {
v.Set(reflect.Zero(v.Type()))
return nil
}
return decodeValue(d, v.Elem())
}, func(d *json.Decoder) error {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return decodeValue(d, v.Elem())
})

case reflect.Slice:
if v.Type().Elem().Kind() == reflect.Uint8 {
Expand All @@ -278,6 +268,63 @@ func decodeValue(d *json.Decoder, v reflect.Value) error {
}
}

func handleIfNextTokenIsNull(d *json.Decoder, ifNullFn func(*json.Decoder) error, elseFn func(*json.Decoder) error) error {
isNull, err := isNextTokenNull(d)
if err != nil {
return err
}
if isNull {
err = ifNullFn(d)
if err != nil {
return err
}
// Consume the null token
_, err := d.Token()
if err != nil {
return err
}
return nil
}
return elseFn(d)
}

// isNextTokenNull implements a cheap/dirty version of `Peek()`, which json.Decoder does
// not support.
func isNextTokenNull(d *json.Decoder) (bool, error) {
s, err := io.ReadAll(d.Buffered())
if err != nil {
return false, err
}
if len(s) == 0 {
return false, fmt.Errorf("cannot check emptystring for token \"null\"")
}
if s[0] != ':' {
return false, fmt.Errorf("cannot check emptystring for token \"null\"")
}
i := 1
for len(s) > i && unicode.IsSpace(rune(s[i])) {
i++
}
if len(s) < i+4 {
return false, nil
}
return string(s[i:i+4]) == "null", nil
}

func decodeOption(d *json.Decoder, v reflect.Value) error {
return handleIfNextTokenIsNull(d, func(d *json.Decoder) error {
v.FieldByName("Okay").SetBool(false)
return nil
}, func(d *json.Decoder) error {
err := decodeValue(d, v.FieldByName("Val"))
if err != nil {
return err
}
v.FieldByName("Okay").SetBool(true)
return nil
})
}

func decodeStruct(d *json.Decoder, v reflect.Value) error {
if err := expectDelim(d, '{'); err != nil {
return err
Expand Down
14 changes: 14 additions & 0 deletions go-runtime/encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package encoding_test
import (
"reflect"
"testing"
"time"

"github.com/alecthomas/assert/v2"

Expand Down Expand Up @@ -31,6 +32,8 @@ func TestMarshal(t *testing.T) {
{name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}, expected: `{"slice":["hello","world"]}`},
{name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}, expected: `{"map":{"foo":42}}`},
{name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}, expected: `{"option":42}`},
{name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}, expected: `{"option":null}`},
{name: "OptionZero", input: struct{ Option ftl.Option[int] }{ftl.Some(0)}, expected: `{"option":0}`},
{name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}, expected: `{"option":42}`},
{name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}, expected: `{"option":{"fooBar":"foo"}}`},
{name: "Unit", input: ftl.Unit{}, expected: `{}`},
Expand Down Expand Up @@ -69,6 +72,9 @@ func TestUnmarshal(t *testing.T) {
{name: "Slice", input: `{"slice":[1,2,3]}`, expected: struct{ Slice []int }{[]int{1, 2, 3}}},
{name: "SliceOfStrings", input: `{"slice":["hello","world"]}`, expected: struct{ Slice []string }{[]string{"hello", "world"}}},
{name: "Map", input: `{"map":{"foo":42}}`, expected: struct{ Map map[string]int }{map[string]int{"foo": 42}}},
{name: "OptionNull", input: `{"option":null}`, expected: struct{ Option ftl.Option[int] }{ftl.None[int]()}},
{name: "OptionNullWhitespace", input: `{"option": null}`, expected: struct{ Option ftl.Option[int] }{ftl.None[int]()}},
{name: "OptionZero", input: `{"option":0}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(0)}},
{name: "Option", input: `{"option":42}`, expected: struct{ Option ftl.Option[int] }{ftl.Some(42)}},
{name: "OptionPtr", input: `{"option":42}`, expected: struct{ Option *ftl.Option[int] }{&somePtr}},
{name: "OptionStruct", input: `{"option":{"fooBar":"foo"}}`, expected: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}},
Expand All @@ -77,6 +83,12 @@ func TestUnmarshal(t *testing.T) {
String string
Unit ftl.Unit
}{String: "something", Unit: ftl.Unit{}}},
// Whitespaces after each `:` and multiple fields to test handling of the
// two potential terminal delimiters: `}` and `,`
{name: "ComplexFormatting", input: `{"option": null, "bool": true}`, expected: struct {
Option ftl.Option[int]
Bool bool
}{ftl.None[int](), true}},
}

for _, tt := range tests {
Expand Down Expand Up @@ -111,7 +123,9 @@ func TestRoundTrip(t *testing.T) {
{name: "Slice", input: struct{ Slice []int }{[]int{1, 2, 3}}},
{name: "SliceOfStrings", input: struct{ Slice []string }{[]string{"hello", "world"}}},
{name: "Map", input: struct{ Map map[string]int }{map[string]int{"foo": 42}}},
{name: "Time", input: struct{ Time time.Time }{time.Date(2009, time.November, 29, 21, 33, 0, 0, time.UTC)}},
{name: "Option", input: struct{ Option ftl.Option[int] }{ftl.Some(42)}},
{name: "OptionNull", input: struct{ Option ftl.Option[int] }{ftl.None[int]()}},
{name: "OptionPtr", input: struct{ Option *ftl.Option[int] }{&somePtr}},
{name: "OptionStruct", input: struct{ Option ftl.Option[inner] }{ftl.Some(inner{"foo"})}},
{name: "Unit", input: ftl.Unit{}},
Expand Down
Loading

0 comments on commit 053f922

Please sign in to comment.