-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* faster, less allocations * better error checking on malformed input, known panics are now errors * error messages show path to faulty field
- Loading branch information
Showing
4 changed files
with
236 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package tlv | ||
|
||
import ( | ||
"encoding/binary" | ||
"io" | ||
) | ||
|
||
// bytez is a zero-copy read-only byte buffer, | ||
// name selected to avoid clashing with the standard library's bytes package | ||
type bytez []byte | ||
|
||
func (b *bytez) ReadUint8() (v uint8, err error) { | ||
if len(*b) < 1 { | ||
return 0, io.EOF | ||
} | ||
v, *b = (*b)[0], (*b)[1:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadUint16() (v uint16, err error) { | ||
if len(*b) < 2 { | ||
return 0, io.EOF | ||
} | ||
v, *b = binary.BigEndian.Uint16(*b), (*b)[2:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadUint32() (v uint32, err error) { | ||
if len(*b) < 4 { | ||
return 0, io.EOF | ||
} | ||
v, *b = binary.BigEndian.Uint32(*b), (*b)[4:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadUint64() (v uint64, err error) { | ||
if len(*b) < 8 { | ||
return 0, io.EOF | ||
} | ||
v, *b = binary.BigEndian.Uint64(*b), (*b)[8:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadInt8() (v int8, err error) { | ||
if len(*b) < 1 { | ||
return 0, io.EOF | ||
} | ||
v, *b = int8((*b)[0]), (*b)[1:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadInt16() (v int16, err error) { | ||
if len(*b) < 2 { | ||
return 0, io.EOF | ||
} | ||
v, *b = int16(binary.BigEndian.Uint16(*b)), (*b)[2:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadInt32() (v int32, err error) { | ||
if len(*b) < 4 { | ||
return 0, io.EOF | ||
} | ||
v, *b = int32(binary.BigEndian.Uint32(*b)), (*b)[4:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadInt64() (v int64, err error) { | ||
if len(*b) < 8 { | ||
return 0, io.EOF | ||
} | ||
v, *b = int64(binary.BigEndian.Uint64(*b)), (*b)[8:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadSlice(n int) (v bytez, err error) { | ||
if len(*b) < n { | ||
return nil, io.EOF | ||
} | ||
v, *b = (*b)[:n], (*b)[n:] | ||
return | ||
} | ||
|
||
func (b *bytez) ReadString(n int) (v string, err error) { | ||
s, err := b.ReadSlice(n) | ||
v = string(s) | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,151 +1,180 @@ | ||
package tlv | ||
|
||
import ( | ||
"bytes" | ||
"encoding" | ||
"encoding/binary" | ||
"errors" | ||
"fmt" | ||
"maps" | ||
"math" | ||
"reflect" | ||
"slices" | ||
"strconv" | ||
"sync/atomic" | ||
) | ||
|
||
func Unmarshal(b []byte, v interface{}) error { | ||
func Unmarshal(b []byte, v any) error { | ||
// decodeValue shares the bytes of input slice whenever possible | ||
// clone protects the client code that is not ready for this | ||
b = slices.Clone(b) | ||
return decodeValue(b, v) | ||
} | ||
|
||
func decodeValue(b []byte, v interface{}) error { | ||
value := reflect.ValueOf(v) | ||
func UnmarshalZeroCopy(b []byte, v any) error { | ||
return decodeValue(b, v) | ||
} | ||
|
||
func decodeValue(b bytez, v any) (err error) { | ||
defer func() { | ||
if err != nil { | ||
err = fmt.Errorf("%v > %w", reflect.TypeOf(v), err) | ||
} | ||
}() | ||
|
||
if unmarshaler, ok := value.Interface().(encoding.BinaryUnmarshaler); ok { | ||
err := unmarshaler.UnmarshalBinary(b) | ||
return err | ||
if v == nil { | ||
return errors.New("nil") | ||
} | ||
|
||
value = reflect.Indirect(value) | ||
valueType := reflect.TypeOf(value.Interface()) | ||
switch value.Kind() { | ||
case reflect.Int8: | ||
tmp := int64(int8(b[0])) | ||
value.SetInt(tmp) | ||
case reflect.Int16: | ||
tmp := int64(int16(binary.BigEndian.Uint16(b))) | ||
value.SetInt(tmp) | ||
case reflect.Int32: | ||
tmp := int64(int32(binary.BigEndian.Uint32(b))) | ||
value.SetInt(tmp) | ||
case reflect.Int64: | ||
tmp := int64(binary.BigEndian.Uint64(b)) | ||
value.SetInt(tmp) | ||
case reflect.Int: | ||
tmp := int64(binary.BigEndian.Uint64(b)) | ||
value.SetInt(tmp) | ||
case reflect.Uint8: | ||
tmp := uint64(b[0]) | ||
value.SetUint(tmp) | ||
case reflect.Uint16: | ||
tmp := uint64(binary.BigEndian.Uint16(b)) | ||
value.SetUint(tmp) | ||
case reflect.Uint32: | ||
tmp := uint64(binary.BigEndian.Uint32(b)) | ||
value.SetUint(tmp) | ||
case reflect.Uint64: | ||
tmp := binary.BigEndian.Uint64(b) | ||
value.SetUint(tmp) | ||
case reflect.Uint: | ||
tmp := binary.BigEndian.Uint64(b) | ||
value.SetUint(tmp) | ||
case reflect.String: | ||
value.SetString(string(b)) | ||
case reflect.Ptr: | ||
if value.IsNil() { | ||
value.Set(reflect.New(value.Type().Elem())) | ||
} | ||
if err := decodeValue(b, value.Interface()); err != nil { | ||
return err | ||
} | ||
case reflect.Struct: | ||
var tlvFragment fragments | ||
if tlvFragmentTmp, err := parseTLV(b); err != nil { | ||
return err | ||
} else { | ||
tlvFragment = tlvFragmentTmp | ||
} | ||
for i := 0; i < value.NumField(); i++ { | ||
fieldValue := value.Field(i) | ||
fieldType := valueType.Field(i) | ||
if v, ok := v.(encoding.BinaryUnmarshaler); ok { | ||
return v.UnmarshalBinary(b) | ||
} | ||
|
||
tag, hasTLV := fieldType.Tag.Lookup("tlv") | ||
if !hasTLV { | ||
return errors.New("field " + fieldType.Name + " need tag `tlv`") | ||
switch v := v.(type) { | ||
case *int8: | ||
*v, err = b.ReadInt8() | ||
case *int16: | ||
*v, err = b.ReadInt16() | ||
case *int32: | ||
*v, err = b.ReadInt32() | ||
case *int64: | ||
*v, err = b.ReadInt64() | ||
case *uint8: | ||
*v, err = b.ReadUint8() | ||
case *uint16: | ||
*v, err = b.ReadUint16() | ||
case *uint32: | ||
*v, err = b.ReadUint32() | ||
case *uint64: | ||
*v, err = b.ReadUint64() | ||
case *string: | ||
*v, err = b.ReadString(len(b)) | ||
case *[]byte: | ||
*v, err = b.ReadSlice(len(b)) | ||
default: | ||
val := reflect.ValueOf(v) | ||
if val.Kind() != reflect.Ptr { | ||
return errors.New("not a pointer") | ||
} | ||
val = val.Elem() | ||
typ := val.Type() | ||
switch val.Kind() { | ||
case reflect.Ptr: | ||
if val.IsNil() { | ||
val.Set(reflect.New(typ.Elem())) | ||
} | ||
|
||
tagVal, err := strconv.Atoi(tag) | ||
if err = decodeValue(b, val.Interface()); err != nil { | ||
return err | ||
} | ||
case reflect.Struct: | ||
tagMap, err := tagMapFor(typ) | ||
if err != nil { | ||
return fmt.Errorf("invalid tlv tag \"%s\", need to be decimal number", tag) | ||
return err | ||
} | ||
|
||
if len(tlvFragment[tagVal]) == 0 { | ||
continue | ||
} | ||
for len(b) > 0 { | ||
tag, bytes, err := readTLV(&b) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { | ||
fieldValue.Set(reflect.New(fieldValue.Type().Elem())) | ||
} else if fieldValue.Kind() == reflect.Slice && fieldValue.IsNil() { | ||
fieldValue.Set(reflect.MakeSlice(fieldValue.Type(), 0, 1)) | ||
} | ||
idx, ok := tagMap[tag] | ||
if !ok { | ||
// return fmt.Errorf("no field has `tlv:%d` tag", tag) | ||
continue // skip unexpected TLVs | ||
} | ||
|
||
f := val.Field(idx) | ||
|
||
for _, buf := range tlvFragment[tagVal] { | ||
if fieldValue.Kind() != reflect.Ptr { | ||
fieldValue = fieldValue.Addr() | ||
if f.Kind() == reflect.Ptr && f.IsNil() { | ||
f.Set(reflect.New(f.Type().Elem())) | ||
} | ||
if f.Kind() != reflect.Ptr { | ||
f = f.Addr() | ||
} | ||
err = decodeValue(buf, fieldValue.Interface()) | ||
|
||
err = decodeValue(bytes, f.Interface()) | ||
if err != nil { | ||
return err | ||
return fmt.Errorf("field %s %w", val.Type().Field(idx).Name, err) | ||
} | ||
} | ||
} | ||
case reflect.Slice: | ||
if value.IsNil() { | ||
value.Set(reflect.MakeSlice(value.Type(), 0, 1)) | ||
} | ||
if valueType.Elem().Kind() == reflect.Uint8 { | ||
value.SetBytes(b) | ||
} else if valueType.Elem().Kind() == reflect.Ptr || valueType.Elem().Kind() == reflect.Struct || | ||
isNumber(valueType.Elem()) { | ||
elemValue := reflect.New(valueType.Elem()) | ||
if err := decodeValue(b, elemValue.Interface()); err != nil { | ||
return err | ||
case reflect.Slice: | ||
if typ.Elem().Kind() == reflect.Ptr || typ.Elem().Kind() == reflect.Struct || isNumber(typ.Elem()) { | ||
elemValue := reflect.New(typ.Elem()) | ||
if err := decodeValue(b, elemValue.Interface()); err != nil { | ||
return err | ||
} | ||
val.Set(reflect.Append(val, elemValue.Elem())) | ||
} else { | ||
return errors.New("unsupported slice type") | ||
} | ||
value.Set(reflect.Append(value, elemValue.Elem())) | ||
} else { | ||
return errors.New("value type `Slice of " + valueType.String() + "` is not support decode") | ||
default: | ||
return errors.New("unsupported type") | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func parseTLV(b []byte) (fragments, error) { | ||
tlvFragment := make(fragments) | ||
buffer := bytes.NewBuffer(b) | ||
|
||
var tag uint16 | ||
func readTLV(b *bytez) (tag uint16, value bytez, err error) { | ||
if tag, err = b.ReadUint16(); err != nil { | ||
err = fmt.Errorf("%w @ tag", err) | ||
return | ||
} | ||
var length uint16 | ||
for { | ||
if err := binary.Read(buffer, binary.BigEndian, &tag); err != nil { | ||
fmt.Printf("Binary Read error: %v", err) | ||
} | ||
if err := binary.Read(buffer, binary.BigEndian, &length); err != nil { | ||
fmt.Printf("Binary Read error: %v", err) | ||
} | ||
value := make([]byte, length) | ||
if _, err := buffer.Read(value); err != nil { | ||
return nil, err | ||
} | ||
tlvFragment.Add(int(tag), value) | ||
if buffer.Len() == 0 { | ||
break | ||
if length, err = b.ReadUint16(); err != nil { | ||
err = fmt.Errorf("%w @ len", err) | ||
return | ||
} | ||
if value, err = b.ReadSlice(int(length)); err != nil { | ||
err = fmt.Errorf("%w @ value", err) | ||
return | ||
} | ||
return | ||
} | ||
|
||
type tagMap map[uint16]int | ||
|
||
// type -> tag -> field index | ||
var typeCache atomic.Pointer[map[reflect.Type]tagMap] | ||
|
||
func tagMapFor(typ reflect.Type) (tagMap, error) { | ||
// Cache handling is intentionally racy to keep the cost of access to a single atomic read | ||
// since the contents are strictly determenistic, it's okay if we rebuild them a few times | ||
|
||
tc := typeCache.Load() | ||
if tc == nil { | ||
tc = &map[reflect.Type]tagMap{} | ||
typeCache.Store(tc) | ||
} | ||
|
||
t := (*tc)[typ] | ||
if t == nil { | ||
t = make(tagMap) | ||
for i := 0; i < typ.NumField(); i++ { | ||
field := typ.Field(i) | ||
if !field.IsExported() { | ||
return nil, fmt.Errorf("field %s is not exported", field.Name) | ||
} else if tlv, ok := field.Tag.Lookup("tlv"); !ok { | ||
return nil, fmt.Errorf("field %s has no `tlv:...` tag", field.Name) | ||
} else if tlvI, err := strconv.Atoi(tlv); err != nil || tlvI < 0 || tlvI > math.MaxUint16 { | ||
return nil, fmt.Errorf("field %s has invalid `tlv:%s` tag", field.Name, tlv) | ||
} else { | ||
t[uint16(tlvI)] = i | ||
} | ||
} | ||
tc = ref(maps.Clone(*tc)) | ||
(*tc)[typ] = t | ||
typeCache.Store(tc) | ||
} | ||
return tlvFragment, nil | ||
|
||
return t, nil | ||
} |
Oops, something went wrong.