Skip to content

Commit

Permalink
Improve encoding speed by 14%-50%
Browse files Browse the repository at this point in the history
Cached encoding functions and struct field names to boost speed:
- map 14%
- struct 30%
- slice 50%
  • Loading branch information
Faye Amacker committed Oct 18, 2019
1 parent 6939d12 commit 29e05d7
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 60 deletions.
4 changes: 2 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ func (d *decodeState) parseMap(t cborType, count int, v reflect.Value) error {
}

func (d *decodeState) parseStruct(t cborType, count int, v reflect.Value) error {
flds := getStructFields(v.Type(), false)
flds := getStructFields(v.Type())

hasSize := count >= 0
for i := 0; (hasSize && i < count) || (!hasSize && !d.foundBreak()); i++ {
Expand Down Expand Up @@ -602,7 +602,7 @@ func (d *decodeState) parseStruct(t cborType, count int, v reflect.Value) error
if d.err == nil {
if typeError, ok := err.(*UnmarshalTypeError); ok {
typeError.Struct = v.Type().String()
typeError.Field = string(keyBytes)
typeError.Field = f.name
d.err = typeError
} else {
d.err = err
Expand Down
98 changes: 67 additions & 31 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ var (
cborNan = []byte{0xf9, 0x7e, 0x00}
cborPositiveInfinity = []byte{0xf9, 0x7c, 0x00}
cborNegativeInfinity = []byte{0xf9, 0xfc, 0x00}
typeBinaryMarshaler = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem()
)
var typeBinaryMarshaler = reflect.TypeOf((*encoding.BinaryMarshaler)(nil)).Elem()

// EncOptions specifies encoding options.
type EncOptions struct {
Expand Down Expand Up @@ -86,8 +86,7 @@ func encode(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
e.Write(cborNil)
return 1, nil
}

f := getEncodeFunc(v)
f := getEncodeFunc(v.Type())
if f == nil {
return 0, &UnsupportedTypeError{v.Type()}
}
Expand Down Expand Up @@ -199,14 +198,16 @@ func encodeArray(e *encodeState, v reflect.Value, opts EncOptions) (int, error)
if v.Type().Elem().Kind() == reflect.Uint8 {
return encodeByteString(e, v, opts)
}

f := getEncodeFunc(v.Type().Elem())
if f == nil {
return 0, &UnsupportedTypeError{v.Type()}
}
if v.Len() == 0 {
return encodeTypeAndAdditionalValue(e, byte(cborTypeArray), uint64(0)), nil
}

n := encodeTypeAndAdditionalValue(e, byte(cborTypeArray), uint64(v.Len()))
for i := 0; i < v.Len(); i++ {
n1, err := encode(e, v.Index(i), opts)
n1, err := f(e, v.Index(i), opts)
if err != nil {
return 0, err
}
Expand All @@ -216,27 +217,32 @@ func encodeArray(e *encodeState, v reflect.Value, opts EncOptions) (int, error)
}

func encodeMap(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
if opts.Canonical {
return encodeMapCanonical(e, v, opts)
}
kf := getEncodeFunc(v.Type().Key())
ef := getEncodeFunc(v.Type().Elem())
if kf == nil || ef == nil {
return 0, &UnsupportedTypeError{v.Type()}
}
if v.IsNil() {
return e.Write(cborNil)
}
if v.Len() == 0 {
return encodeTypeAndAdditionalValue(e, byte(cborTypeMap), uint64(0)), nil
}
if opts.Canonical {
return encodeMapCanonical(e, v, opts)
}
n := encodeTypeAndAdditionalValue(e, byte(cborTypeMap), uint64(v.Len()))
iter := v.MapRange()
for iter.Next() {
kn, err := encode(e, iter.Key(), opts)
n1, err := kf(e, iter.Key(), opts)
if err != nil {
return 0, err
}
en, err := encode(e, iter.Value(), opts)
n2, err := ef(e, iter.Value(), opts)
if err != nil {
return 0, err
}
n += kn + en
n += n1 + n2
}
return n, nil
}
Expand Down Expand Up @@ -286,18 +292,29 @@ func returnByCanonical(s *byCanonical) {
}

func encodeMapCanonical(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
kf := getEncodeFunc(v.Type().Key())
ef := getEncodeFunc(v.Type().Elem())
if kf == nil || ef == nil {
return 0, &UnsupportedTypeError{v.Type()}
}
if v.IsNil() {
return e.Write(cborNil)
}
if v.Len() == 0 {
return encodeTypeAndAdditionalValue(e, byte(cborTypeMap), uint64(0)), nil
}
pairEncodeState := newEncodeState() // accumulated cbor encoded map key-value pairs
pairs := newByCanonical(v.Len()) // for sorting keys

iter := v.MapRange()
for iter.Next() {
n1, err := encode(pairEncodeState, iter.Key(), opts)
n1, err := kf(pairEncodeState, iter.Key(), opts)
if err != nil {
returnEncodeState(pairEncodeState)
returnByCanonical(pairs)
return 0, err
}
n2, err := encode(pairEncodeState, iter.Value(), opts)
n2, err := ef(pairEncodeState, iter.Value(), opts)
if err != nil {
returnEncodeState(pairEncodeState)
returnByCanonical(pairs)
Expand Down Expand Up @@ -326,11 +343,14 @@ func encodeMapCanonical(e *encodeState, v reflect.Value, opts EncOptions) (int,
}

func encodeStruct(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
flds := getStructFields(v.Type(), opts.Canonical)
flds := getEncodingStructFields(v.Type(), opts.Canonical)

kve := newEncodeState() // encode key-value pairs based on struct field tag options
kvcount := 0
for _, f := range flds {
if f.ef == nil {
return 0, &UnsupportedTypeError{v.Type()}
}
fv, err := fieldByIndex(v, f.idx)
if err != nil {
returnEncodeState(kve)
Expand All @@ -339,8 +359,12 @@ func encodeStruct(e *encodeState, v reflect.Value, opts EncOptions) (int, error)
if !fv.IsValid() || (f.omitempty && isEmptyValue(fv)) {
continue
}
encodeStringInternal(kve, f.name, opts)
_, err = encode(kve, fv, opts)
if f.cborNameLen > 0 {
kve.Write(f.cborName[:f.cborNameLen])
} else {
encodeStringInternal(kve, f.name, opts)
}
_, err = f.ef(kve, fv, opts)
if err != nil {
returnEncodeState(kve)
return 0, err
Expand All @@ -349,17 +373,10 @@ func encodeStruct(e *encodeState, v reflect.Value, opts EncOptions) (int, error)
}

n := encodeTypeAndAdditionalValue(e, byte(cborTypeMap), uint64(kvcount))
n1, _ := e.Write(kve.Bytes())
n1, err := e.Write(kve.Bytes())

returnEncodeState(kve)
return n + n1, nil
}

func encodePtr(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
if v.IsNil() {
return e.Write(cborNil)
}
return encode(e, v.Elem(), opts)
return n + n1, err
}

func encodeIntf(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
Expand Down Expand Up @@ -402,14 +419,26 @@ func encodeBinaryMarshalerType(e *encodeState, v reflect.Value, opts EncOptions)
return n1 + n2, nil
}

func getEncodeFunc(v reflect.Value) encodeFunc {
if reflect.PtrTo(v.Type()).Implements(typeBinaryMarshaler) {
if v.Type() == typeTime {
func getEncodeIndirectValueFunc(f encodeFunc) encodeFunc {
return func(e *encodeState, v reflect.Value, opts EncOptions) (int, error) {
for v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
if v.Kind() == reflect.Ptr && v.IsNil() {
return e.Write(cborNil)
}
return f(e, v, opts)
}
}

func getEncodeFunc(t reflect.Type) encodeFunc {
if reflect.PtrTo(t).Implements(typeBinaryMarshaler) {
if t == typeTime {
return encodeTime
}
return encodeBinaryMarshalerType
}
switch v.Kind() {
switch t.Kind() {
case reflect.Bool:
return encodeBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
Expand All @@ -429,7 +458,14 @@ func getEncodeFunc(v reflect.Value) encodeFunc {
case reflect.Struct:
return encodeStruct
case reflect.Ptr:
return encodePtr
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
f := getEncodeFunc(t)
if f == nil {
return f
}
return getEncodeIndirectValueFunc(f)
case reflect.Interface:
return encodeIntf
default:
Expand Down
50 changes: 44 additions & 6 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ var exMarshalTests = []marshalTest{

var marshalErrorTests = []marshalErrorTest{
{"channel can't be marshalled", make(chan bool), "cbor: unsupported type: chan bool"},
{"slice of channel can't be marshalled", make([]chan bool, 10), "cbor: unsupported type: []chan bool"},
{"slice of pointer to channel can't be marshalled", make([]*chan bool, 10), "cbor: unsupported type: []*chan bool"},
{"map of channel can't be marshalled", make(map[string]chan bool), "cbor: unsupported type: map[string]chan bool"},
{"struct of channel can't be marshalled", struct{ Chan chan bool }{}, "cbor: unsupported type: struct { Chan chan bool }"},
{"function can't be marshalled", func(i int) int { return i * i }, "cbor: unsupported type: func(int) int"},
{"complex can't be marshalled", complex(100, 8), "cbor: unsupported type: complex128"},
}
Expand All @@ -268,13 +272,24 @@ func TestInvalidTypeMarshal(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
b, err := cbor.Marshal(&tc.value, cbor.EncOptions{})
if err == nil {
t.Errorf("Marshal(%v) doesn't return an error, want error %q", tc.value, tc.wantErrorMsg)
t.Errorf("Marshal(%v, cbor.EncOptions{}) doesn't return an error, want error %q", tc.value, tc.wantErrorMsg)
} else if _, ok := err.(*cbor.UnsupportedTypeError); !ok {
t.Errorf("Marshal(%v) error type %T, want *cbor.UnsupportedTypeError", tc.value, err)
t.Errorf("Marshal(%v, cbor.EncOptions{}) error type %T, want *cbor.UnsupportedTypeError", tc.value, err)
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("Marshal(%v) error %s, want %s", tc.value, err, tc.wantErrorMsg)
t.Errorf("Marshal(%v, cbor.EncOptions{}) error %s, want %s", tc.value, err, tc.wantErrorMsg)
} else if b != nil {
t.Errorf("Marshal(%v) = 0x%0x, want nil", tc.value, b)
t.Errorf("Marshal(%v, cbor.EncOptions{}) = 0x%0x, want nil", tc.value, b)
}

b, err = cbor.Marshal(&tc.value, cbor.EncOptions{Canonical: true})
if err == nil {
t.Errorf("Marshal(%v, cbor.EncOptions{Canonical: true}) doesn't return an error, want error %q", tc.value, tc.wantErrorMsg)
} else if _, ok := err.(*cbor.UnsupportedTypeError); !ok {
t.Errorf("Marshal(%v, cbor.EncOptions{Canonical: true}) error type %T, want *cbor.UnsupportedTypeError", tc.value, err)
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("Marshal(%v, cbor.EncOptions{Canonical: true}) error %s, want %s", tc.value, err, tc.wantErrorMsg)
} else if b != nil {
t.Errorf("Marshal(%v, cbor.EncOptions{Canonical: true}) = 0x%0x, want nil", tc.value, b)
}
})
}
Expand Down Expand Up @@ -522,10 +537,13 @@ func TestMarshalStructCanonical(t *testing.T) {
func testMarshal(t *testing.T, testCases []marshalTest) {
for _, tc := range testCases {
for _, value := range tc.values {
if _, err := cbor.Marshal(value, cbor.EncOptions{}); err != nil {
t.Errorf("Marshal(%v, cbor.EncOptions{}) returns error %v", value, err)
}
if b, err := cbor.Marshal(value, cbor.EncOptions{Canonical: true}); err != nil {
t.Errorf("Marshal(%v) returns error %v", value, err)
t.Errorf("Marshal(%v, cbor.EncOptions{Canonical: true}) returns error %v", value, err)
} else if !bytes.Equal(b, tc.cborData) {
t.Errorf("Marshal(%v) = 0x%0x, want 0x%0x", value, b, tc.cborData)
t.Errorf("Marshal(%v, cbor.EncOptions{Canonical: true}) = 0x%0x, want 0x%0x", value, b, tc.cborData)
}
}
}
Expand Down Expand Up @@ -1191,3 +1209,23 @@ func TestMarshalStructTag4(t *testing.T) {
t.Errorf("Marshal(%+v) = %v, want %v", v, b, want)
}
}

func TestMarshalStructLongFieldName(t *testing.T) {
type strc struct {
A string `cbor:"a"`
B string `cbor:"abcdefghijklmnopqrstuvwxyz"`
C string `cbor:"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmn"`
}
v := strc{
A: "A",
B: "B",
C: "C",
}
want := hexDecode("a361616141781a6162636465666768696a6b6c6d6e6f707172737475767778797a614278426162636465666768696a6b6c6d6e6f707172737475767778797a6162636465666768696a6b6c6d6e6f707172737475767778797a6162636465666768696a6b6c6d6e6143") // {"a":"A", "abcdefghijklmnopqrstuvwxyz":"B", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmn":"C"}

if b, err := cbor.Marshal(v, cbor.EncOptions{Canonical: true}); err != nil {
t.Errorf("Marshal(%+v) returns error %v", v, err)
} else if !bytes.Equal(b, want) {
t.Errorf("Marshal(%+v) = %v, want %v", v, b, want)
}
}
Loading

0 comments on commit 29e05d7

Please sign in to comment.