Skip to content

Commit

Permalink
Decode/encode CBOR maps with int keys to Go struct
Browse files Browse the repository at this point in the history
This simplifies decoding COSE to struct since COSE uses map with
integer keys.  Decoding to struct is 42.125% faster than to
map[int]interface{} and 42.105% less B/op.  This also benefits
WebAuthn since it uses COSE.

Closes: #15

Struct field name is treated as integer if it has "keyasint" option in
its format string.  The format string must specify an integer as its
field name.

```
type T struct {
    F1 int `cbor:"1,keyasint"`
    F2 int `cbor:"2"`
    F3 int `cbor:"-1,keyasint"`
}
v := T{F1: 1, F2: 2, F3: 3}
```

Marshal encodes v as {1: 1, "2": 2, -1: 3}.  Note that field F2 is
encoded as "2", not 2, because "keyasint" is not specified.

Benchmarks of decoding COSE to map[int]interface{} vs. struct:

```
BenchmarkDecodeCOSE/COSE_to_Go_map[int]interface_{}-2                         	 1000000	      1600 ns/op	     456 B/op	      13 allocs/op
BenchmarkDecodeCOSE/COSE_to_Go_cbor_test.coseKey-2                            	 2000000	       926 ns/op	     264 B/op	      13 allocs/op
```
  • Loading branch information
fxamacker committed Nov 14, 2019
1 parent b198edc commit 3cbdc26
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 31 deletions.
18 changes: 16 additions & 2 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,27 @@ func BenchmarkEncodeStream(b *testing.B) {
}

func BenchmarkDecodeCOSE(b *testing.B) {
cborData := hexDecode("a50102032620012158205af8047e9085ef79ec321280c7b95844d707d7fe4d73cd648f044c619ee74f6b22582036bb8c00768e90858012dc3831e15a389072bbdbe7e2e19155db9e1197655edf")
type coseKey struct {
Kty int `cbor:"1,keyasint"`
Alg int `cbor:"3,keyasint"`
CrvOrN cbor.RawMessage `cbor:"-1,keyasint"`
XOrE cbor.RawMessage `cbor:"-2,keyasint"`
Y cbor.RawMessage `cbor:"-3,keyasint"`
}

// Data from https://www.w3.org/TR/webauthn/#sec-attested-credential-data example 2, COSE_Key-encoded ES256 (P-256 curve) in EC2 format.
cborData := hexDecode("a501020326200121582065eda5a12577c2bae829437fe338701a10aaa375e1bb5b5de108de439c08551d2258201e52ed75701163f7f9e40ddf9f341b3dc9ba860af7e0ca7ca7e9eecd0084d19c")
for _, ctor := range []func() interface{}{
func() interface{} { return new(interface{}) },
func() interface{} { return new(map[interface{}]interface{}) },
func() interface{} { return new(map[int]interface{}) },
func() interface{} { return new(coseKey) },
} {
name := "COSE to Go " + reflect.TypeOf(ctor()).Elem().String()
t := reflect.TypeOf(ctor()).Elem()
name := "COSE to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "COSE to Go " + t.Kind().String()
}
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
v := ctor()
Expand Down
48 changes: 36 additions & 12 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"reflect"
"sort"
"strconv"
"sync"
)

Expand Down Expand Up @@ -50,23 +51,46 @@ func getEncodingStructType(t reflect.Type) encodingStructType {
var err error
es := getEncodeState()
for i := 0; i < len(flds); i++ {
ef := getEncodeFunc(flds[i].typ)
if ef == nil {
if err == nil {
err = &UnsupportedTypeError{t}
}
// Get field's encodeFunc
flds[i].ef = getEncodeFunc(flds[i].typ)
if flds[i].ef == nil {
err = &UnsupportedTypeError{t}
break
}
flds[i].ef = ef

encodeTypeAndAdditionalValue(es, byte(cborTypeTextString), uint64(len(flds[i].name)))
flds[i].cborName = make([]byte, es.Len()+len(flds[i].name))
copy(flds[i].cborName, es.Bytes())
copy(flds[i].cborName[es.Len():], flds[i].name)

es.Reset()
// Encode field name
if flds[i].keyasint {
nameAsInt, numErr := strconv.Atoi(flds[i].name)
if numErr != nil {
err = numErr
break
}
if nameAsInt >= 0 {
encodeTypeAndAdditionalValue(es, byte(cborTypePositiveInt), uint64(nameAsInt))
} else {
n := nameAsInt*(-1) - 1
encodeTypeAndAdditionalValue(es, byte(cborTypeNegativeInt), uint64(n))
}
flds[i].cborName = make([]byte, es.Len())
copy(flds[i].cborName, es.Bytes())
es.Reset()
} else {
encodeTypeAndAdditionalValue(es, byte(cborTypeTextString), uint64(len(flds[i].name)))
flds[i].cborName = make([]byte, es.Len()+len(flds[i].name))
n := copy(flds[i].cborName, es.Bytes())
copy(flds[i].cborName[n:], flds[i].name)
es.Reset()
}
}
putEncodeState(es)

if err != nil {
structType := encodingStructType{err: err}
encodingStructTypeCache.Store(t, structType)
return structType
}

// Sort fields by canonical order
canonicalFields := make(fields, len(flds))
copy(canonicalFields, flds)
sort.Sort(byCanonicalRule{canonicalFields})
Expand Down
36 changes: 27 additions & 9 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,23 +583,41 @@ func (d *decodeState) parseStruct(t cborType, count int, v reflect.Value) error

hasSize := count >= 0
for i := 0; (hasSize && i < count) || (!hasSize && !d.foundBreak()); i++ {
var keyBytes []byte
var err error
t := cborType(d.data[d.offset] & 0xE0)
if t != cborTypeTextString {
if t == cborTypeTextString {
keyBytes, err = d.parseTextString()
if err != nil {
if d.err == nil {
d.err = err
}
d.skip() // skip value
continue
}
} else if t == cborTypePositiveInt || t == cborTypeNegativeInt {
v, err := d.parseInterface()
if err != nil {
if d.err == nil {
d.err = err
}
d.skip() // skip value
continue
}
switch n := v.(type) {
case int64:
keyBytes = []byte(strconv.Itoa(int(n)))
case uint64:
keyBytes = []byte(strconv.Itoa(int(n)))
}
} else {
if d.err == nil {
d.err = &UnmarshalTypeError{Value: t.String(), Type: reflect.TypeOf(""), errMsg: "map key is of type " + t.String() + " and cannot be used to match struct " + v.Type().String() + " field name"}
}
d.skip() // skip key
d.skip() // skip value
continue
}
keyBytes, err := d.parseTextString()
if err != nil {
if d.err == nil {
d.err = err
}
d.skip() // skip value
continue
}
keyLen := len(keyBytes)

var f *field
Expand Down
25 changes: 20 additions & 5 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1072,22 +1072,22 @@ func TestUnmarshalStructError2(t *testing.T) {
}

// Unmarshal returns first error encountered, which is *cbor.UnmarshalTypeError (failed to unmarshal int into Go string)
cborData := hexDecode("a301026161614161fe6142") // {1:2, "a":"A", 0xfe: B}
cborData := hexDecode("a3fa47c35000026161614161fe6142") // {100000.0:2, "a":"A", 0xfe: B}
v := strc{}
if err := cbor.Unmarshal(cborData, &v); err == nil {
t.Errorf("Unmarshal(0x%0x) doesn't return an error", cborData)
} else {
if typeError, ok := err.(*cbor.UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%0x) returns wrong type of error %T, want (*cbor.UnmarshalTypeError)", cborData, err)
} else {
if typeError.Value != "positive integer" {
t.Errorf("Unmarshal(0x%0x) (*cbor.UnmarshalTypeError).Value %s, want %s", cborData, typeError.Value, "positive integer")
if typeError.Value != "primitives" {
t.Errorf("Unmarshal(0x%0x) (*cbor.UnmarshalTypeError).Value %s, want %s", cborData, typeError.Value, "primitives")
}
if typeError.Type != typeString {
t.Errorf("Unmarshal(0x%0x) (*cbor.UnmarshalTypeError).Type %s, want %s", cborData, typeError.Type, typeString)
}
if !strings.Contains(err.Error(), "cannot unmarshal positive integer into Go value of type string") {
t.Errorf("Unmarshal(0x%0x) returns error %s, want error containing %q", cborData, err.Error(), "cannot unmarshal positive integer into Go value of type string")
if !strings.Contains(err.Error(), "cannot unmarshal primitives into Go value of type string") {
t.Errorf("Unmarshal(0x%0x) returns error %s, want error containing %q", cborData, err.Error(), "cannot unmarshal float into Go value of type string")
}
}
}
Expand Down Expand Up @@ -1787,3 +1787,18 @@ func TestCOSEExamples(t *testing.T) {
}
}
}

func TestUnmarshalStructKeyAsIntError(t *testing.T) {
type T1 struct {
F1 int `cbor:"1,keyasint"`
}
cborData := hexDecode("a13bffffffffffffffff01") // {1: -18446744073709551616}
var v T1
if err := cbor.Unmarshal(cborData, &v); err == nil {
t.Errorf("Unmarshal(0x%0x) returns no error, %v (%T), want (*cbor.UnmarshalTypeError)", cborData, v, v)
} else if _, ok := err.(*cbor.UnmarshalTypeError); !ok {
t.Errorf("Unmarshal(0x%0x) returns wrong error %T, want (*cbor.UnmarshalTypeError)", cborData, err)
} else if !strings.Contains(err.Error(), "cannot unmarshal") {
t.Errorf("Unmarshal(0x%0x) returns error %s, want error containing %q", cborData, err, "cannot unmarshal")
}
}
66 changes: 66 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"math"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -1390,3 +1391,68 @@ func TestCyclicDataStructure(t *testing.T) {
t.Errorf("Unmarshal(0x%0x) returns %+v, want %+v", cborData, v1, v)
}
}

func TestMarshalUnmarshalStructKeyAsInt(t *testing.T) {
type T struct {
F1 int `cbor:"1,keyasint"`
F2 int `cbor:"2"`
F3 int `cbor:"-3,keyasint"`
}
v1 := T{F1: 1, F2: 2, F3: 3}
want := hexDecode("a301012203613202") // {1: 1, -3: 3, "2": 2}

b, err := cbor.Marshal(v1, cbor.EncOptions{Canonical: true})
if err != nil {
t.Errorf("Marshal(%+v) returns error %s", v1, err)
}
if !bytes.Equal(b, want) {
t.Errorf("Marshal(%+v) = 0x%0x, want 0x%0x", v1, b, want)
}

var v2 T
if err := cbor.Unmarshal(b, &v2); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %s", b, err)
}
if !reflect.DeepEqual(v1, v2) {
t.Errorf("Unmarshal(0x%0x) returns %+v, want %+v", b, v2, v1)
}
}

func TestMarshalStructKeyAsIntNumError(t *testing.T) {
type T1 struct {
F1 int `cbor:"2.0,keyasint"`
}
type T2 struct {
F1 int `cbor:"-18446744073709551616,keyasint"`
}
testCases := []struct {
name string
obj interface{}
wantErrorMsg string
}{
{
name: "float as key",
obj: T1{},
wantErrorMsg: "parsing \"2.0\": invalid syntax",
},
{
name: "out of range int as key",
obj: T2{},
wantErrorMsg: "parsing \"-18446744073709551616\": value out of range",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
b, err := cbor.Marshal(tc.obj, cbor.EncOptions{})
if err == nil {
t.Errorf("Marshal(%+v, cbor.EncOptions{}) doesn't return an error, want error %q", tc.obj, tc.wantErrorMsg)
} else if _, ok := err.(*strconv.NumError); !ok {
t.Errorf("Marshal(%v, cbor.EncOptions{}) error type %T, want *strconv.NumError", tc.obj, err)
} else if !strings.Contains(err.Error(), tc.wantErrorMsg) {
t.Errorf("Marshal(%v, cbor.EncOptions{}) error %s, want %s", tc.obj, err, tc.wantErrorMsg)
} else if b != nil {
t.Errorf("Marshal(%v, cbor.EncOptions{}) = 0x%0x, want nil", tc.obj, b)
}
})
}
}
9 changes: 6 additions & 3 deletions structfields.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type field struct {
isUnmarshaler bool
tagged bool // used to choose dominant field (at the same level tagged fields dominate untagged fields)
omitempty bool // used to skip empty field
keyasint bool // used to encode/decode field name as int
}

type fields []field
Expand Down Expand Up @@ -91,7 +92,7 @@ func (s byNameLevelAndTag) Less(i, j int) bool {
return i < j // Field i and j have the same name, depth, and tagged status. Nothing else matters.
}

func getFieldNameAndOptionsFromTag(tag string) (name string, omitEmpty bool) {
func getFieldNameAndOptionsFromTag(tag string) (name string, omitEmpty bool, keyAsInt bool) {
if len(tag) == 0 {
return
}
Expand All @@ -100,6 +101,8 @@ func getFieldNameAndOptionsFromTag(tag string) (name string, omitEmpty bool) {
for _, s := range tokens[1:] {
if s == "omitempty" {
omitEmpty = true
} else if s == "keyasint" {
keyAsInt = true
}
}
return
Expand Down Expand Up @@ -171,15 +174,15 @@ func getFields(typ reflect.Type) fields {
idx[len(fieldIdx)] = i

tagged := len(tag) > 0
tagFieldName, omitempty := getFieldNameAndOptionsFromTag(tag)
tagFieldName, omitempty, keyasint := getFieldNameAndOptionsFromTag(tag)

fieldName := tagFieldName
if tagFieldName == "" {
fieldName = f.Name
}

if !f.Anonymous || ft.Kind() != reflect.Struct || len(tagFieldName) > 0 {
flds = append(flds, field{name: fieldName, idx: idx, typ: f.Type, tagged: tagged, omitempty: omitempty})
flds = append(flds, field{name: fieldName, idx: idx, typ: f.Type, tagged: tagged, omitempty: omitempty, keyasint: keyasint})
continue
}

Expand Down

0 comments on commit 3cbdc26

Please sign in to comment.