Skip to content

Commit

Permalink
Unmarshal rewrite
Browse files Browse the repository at this point in the history
* faster, less allocations
* better error checking on malformed input, known panics are now errors
* error messages show path to faulty field
  • Loading branch information
earwin committed Oct 16, 2024
1 parent 54f6adf commit 4250d2b
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 161 deletions.
88 changes: 88 additions & 0 deletions bytez.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package tlv

import (
"encoding/binary"
"io"
)

// bytez is a zero-copy read-only byte buffer,
// name selected to avoid clashing with the standard library's bytes package
type bytez []byte

func (b *bytez) ReadUint8() (v uint8, err error) {
if len(*b) < 1 {
return 0, io.EOF
}
v, *b = (*b)[0], (*b)[1:]
return
}

func (b *bytez) ReadUint16() (v uint16, err error) {
if len(*b) < 2 {
return 0, io.EOF
}
v, *b = binary.BigEndian.Uint16(*b), (*b)[2:]
return
}

func (b *bytez) ReadUint32() (v uint32, err error) {
if len(*b) < 4 {
return 0, io.EOF
}
v, *b = binary.BigEndian.Uint32(*b), (*b)[4:]
return
}

func (b *bytez) ReadUint64() (v uint64, err error) {
if len(*b) < 8 {
return 0, io.EOF
}
v, *b = binary.BigEndian.Uint64(*b), (*b)[8:]
return
}

func (b *bytez) ReadInt8() (v int8, err error) {
if len(*b) < 1 {
return 0, io.EOF
}
v, *b = int8((*b)[0]), (*b)[1:]
return
}

func (b *bytez) ReadInt16() (v int16, err error) {
if len(*b) < 2 {
return 0, io.EOF
}
v, *b = int16(binary.BigEndian.Uint16(*b)), (*b)[2:]
return
}

func (b *bytez) ReadInt32() (v int32, err error) {
if len(*b) < 4 {
return 0, io.EOF
}
v, *b = int32(binary.BigEndian.Uint32(*b)), (*b)[4:]
return
}

func (b *bytez) ReadInt64() (v int64, err error) {
if len(*b) < 8 {
return 0, io.EOF
}
v, *b = int64(binary.BigEndian.Uint64(*b)), (*b)[8:]
return
}

func (b *bytez) ReadSlice(n int) (v bytez, err error) {
if len(*b) < n {
return nil, io.EOF
}
v, *b = (*b)[:n], (*b)[n:]
return
}

func (b *bytez) ReadString(n int) (v string, err error) {
s, err := b.ReadSlice(n)
v = string(s)
return
}
257 changes: 143 additions & 114 deletions decode.go
Original file line number Diff line number Diff line change
@@ -1,151 +1,180 @@
package tlv

import (
"bytes"
"encoding"
"encoding/binary"
"errors"
"fmt"
"maps"
"math"
"reflect"
"slices"
"strconv"
"sync/atomic"
)

func Unmarshal(b []byte, v interface{}) error {
func Unmarshal(b []byte, v any) error {
// decodeValue shares the bytes of input slice whenever possible
// clone protects the client code that is not ready for this
b = slices.Clone(b)
return decodeValue(b, v)
}

func decodeValue(b []byte, v interface{}) error {
value := reflect.ValueOf(v)
func UnmarshalZeroCopy(b []byte, v any) error {
return decodeValue(b, v)
}

func decodeValue(b bytez, v any) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("%v > %w", reflect.TypeOf(v), err)
}
}()

if unmarshaler, ok := value.Interface().(encoding.BinaryUnmarshaler); ok {
err := unmarshaler.UnmarshalBinary(b)
return err
if v == nil {
return errors.New("nil")
}

value = reflect.Indirect(value)
valueType := reflect.TypeOf(value.Interface())
switch value.Kind() {
case reflect.Int8:
tmp := int64(int8(b[0]))
value.SetInt(tmp)
case reflect.Int16:
tmp := int64(int16(binary.BigEndian.Uint16(b)))
value.SetInt(tmp)
case reflect.Int32:
tmp := int64(int32(binary.BigEndian.Uint32(b)))
value.SetInt(tmp)
case reflect.Int64:
tmp := int64(binary.BigEndian.Uint64(b))
value.SetInt(tmp)
case reflect.Int:
tmp := int64(binary.BigEndian.Uint64(b))
value.SetInt(tmp)
case reflect.Uint8:
tmp := uint64(b[0])
value.SetUint(tmp)
case reflect.Uint16:
tmp := uint64(binary.BigEndian.Uint16(b))
value.SetUint(tmp)
case reflect.Uint32:
tmp := uint64(binary.BigEndian.Uint32(b))
value.SetUint(tmp)
case reflect.Uint64:
tmp := binary.BigEndian.Uint64(b)
value.SetUint(tmp)
case reflect.Uint:
tmp := binary.BigEndian.Uint64(b)
value.SetUint(tmp)
case reflect.String:
value.SetString(string(b))
case reflect.Ptr:
if value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
if err := decodeValue(b, value.Interface()); err != nil {
return err
}
case reflect.Struct:
var tlvFragment fragments
if tlvFragmentTmp, err := parseTLV(b); err != nil {
return err
} else {
tlvFragment = tlvFragmentTmp
}
for i := 0; i < value.NumField(); i++ {
fieldValue := value.Field(i)
fieldType := valueType.Field(i)
if v, ok := v.(encoding.BinaryUnmarshaler); ok {
return v.UnmarshalBinary(b)
}

tag, hasTLV := fieldType.Tag.Lookup("tlv")
if !hasTLV {
return errors.New("field " + fieldType.Name + " need tag `tlv`")
switch v := v.(type) {
case *int8:
*v, err = b.ReadInt8()
case *int16:
*v, err = b.ReadInt16()
case *int32:
*v, err = b.ReadInt32()
case *int64:
*v, err = b.ReadInt64()
case *uint8:
*v, err = b.ReadUint8()
case *uint16:
*v, err = b.ReadUint16()
case *uint32:
*v, err = b.ReadUint32()
case *uint64:
*v, err = b.ReadUint64()
case *string:
*v, err = b.ReadString(len(b))
case *[]byte:
*v, err = b.ReadSlice(len(b))
default:
val := reflect.ValueOf(v)
if val.Kind() != reflect.Ptr {
return errors.New("not a pointer")
}
val = val.Elem()
typ := val.Type()
switch val.Kind() {
case reflect.Ptr:
if val.IsNil() {
val.Set(reflect.New(typ.Elem()))
}

tagVal, err := strconv.Atoi(tag)
if err = decodeValue(b, val.Interface()); err != nil {
return err
}
case reflect.Struct:
tagMap, err := tagMapFor(typ)
if err != nil {
return fmt.Errorf("invalid tlv tag \"%s\", need to be decimal number", tag)
return err
}

if len(tlvFragment[tagVal]) == 0 {
continue
}
for len(b) > 0 {
tag, bytes, err := readTLV(&b)
if err != nil {
return err
}

if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
} else if fieldValue.Kind() == reflect.Slice && fieldValue.IsNil() {
fieldValue.Set(reflect.MakeSlice(fieldValue.Type(), 0, 1))
}
idx, ok := tagMap[tag]
if !ok {
// return fmt.Errorf("no field has `tlv:%d` tag", tag)
continue // skip unexpected TLVs
}

f := val.Field(idx)

for _, buf := range tlvFragment[tagVal] {
if fieldValue.Kind() != reflect.Ptr {
fieldValue = fieldValue.Addr()
if f.Kind() == reflect.Ptr && f.IsNil() {
f.Set(reflect.New(f.Type().Elem()))
}
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
err = decodeValue(buf, fieldValue.Interface())

err = decodeValue(bytes, f.Interface())
if err != nil {
return err
return fmt.Errorf("field %s %w", val.Type().Field(idx).Name, err)
}
}
}
case reflect.Slice:
if value.IsNil() {
value.Set(reflect.MakeSlice(value.Type(), 0, 1))
}
if valueType.Elem().Kind() == reflect.Uint8 {
value.SetBytes(b)
} else if valueType.Elem().Kind() == reflect.Ptr || valueType.Elem().Kind() == reflect.Struct ||
isNumber(valueType.Elem()) {
elemValue := reflect.New(valueType.Elem())
if err := decodeValue(b, elemValue.Interface()); err != nil {
return err
case reflect.Slice:
if typ.Elem().Kind() == reflect.Ptr || typ.Elem().Kind() == reflect.Struct || isNumber(typ.Elem()) {
elemValue := reflect.New(typ.Elem())
if err := decodeValue(b, elemValue.Interface()); err != nil {
return err
}
val.Set(reflect.Append(val, elemValue.Elem()))
} else {
return errors.New("unsupported slice type")
}
value.Set(reflect.Append(value, elemValue.Elem()))
} else {
return errors.New("value type `Slice of " + valueType.String() + "` is not support decode")
default:
return errors.New("unsupported type")
}
}

return nil
}

func parseTLV(b []byte) (fragments, error) {
tlvFragment := make(fragments)
buffer := bytes.NewBuffer(b)

var tag uint16
func readTLV(b *bytez) (tag uint16, value bytez, err error) {
if tag, err = b.ReadUint16(); err != nil {
err = fmt.Errorf("%w @ tag", err)
return
}
var length uint16
for {
if err := binary.Read(buffer, binary.BigEndian, &tag); err != nil {
fmt.Printf("Binary Read error: %v", err)
}
if err := binary.Read(buffer, binary.BigEndian, &length); err != nil {
fmt.Printf("Binary Read error: %v", err)
}
value := make([]byte, length)
if _, err := buffer.Read(value); err != nil {
return nil, err
}
tlvFragment.Add(int(tag), value)
if buffer.Len() == 0 {
break
if length, err = b.ReadUint16(); err != nil {
err = fmt.Errorf("%w @ len", err)
return
}
if value, err = b.ReadSlice(int(length)); err != nil {
err = fmt.Errorf("%w @ value", err)
return
}
return
}

type tagMap map[uint16]int

// type -> tag -> field index
var typeCache atomic.Pointer[map[reflect.Type]tagMap]

func tagMapFor(typ reflect.Type) (tagMap, error) {
// Cache handling is intentionally racy to keep the cost of access to a single atomic read
// since the contents are strictly determenistic, it's okay if we rebuild them a few times

tc := typeCache.Load()
if tc == nil {
tc = &map[reflect.Type]tagMap{}
typeCache.Store(tc)
}

t := (*tc)[typ]
if t == nil {
t = make(tagMap)
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
if !field.IsExported() {
return nil, fmt.Errorf("field %s is not exported", field.Name)
} else if tlv, ok := field.Tag.Lookup("tlv"); !ok {
return nil, fmt.Errorf("field %s has no `tlv:...` tag", field.Name)
} else if tlvI, err := strconv.Atoi(tlv); err != nil || tlvI < 0 || tlvI > math.MaxUint16 {
return nil, fmt.Errorf("field %s has invalid `tlv:%s` tag", field.Name, tlv)
} else {
t[uint16(tlvI)] = i
}
}
tc = ref(maps.Clone(*tc))
(*tc)[typ] = t
typeCache.Store(tc)
}
return tlvFragment, nil

return t, nil
}
Loading

0 comments on commit 4250d2b

Please sign in to comment.