Skip to content

Commit

Permalink
types, export major type, write null
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-richards committed Oct 23, 2024
1 parent 8d14d64 commit 1594b7e
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 47 deletions.
46 changes: 23 additions & 23 deletions cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,45 @@ type MajorType byte

const (
MajorTypeUInt MajorType = 0 << 5
MajorTypeNInt = 1 << 5
MajorTypeBstr = 2 << 5
MajorTypeTstr = 3 << 5
MajorTypeArray = 4 << 5
MajorTypeMap = 5 << 5
MajorTypeTagged = 6 << 5
MajorTypeSimpleFloat = 7 << 5
MajorTypeNInt MajorType = 1 << 5
MajorTypeBstr MajorType = 2 << 5
MajorTypeTstr MajorType = 3 << 5
MajorTypeArray MajorType = 4 << 5
MajorTypeMap MajorType = 5 << 5
MajorTypeTagged MajorType = 6 << 5
MajorTypeSimpleFloat MajorType = 7 << 5
)

const (
majorTypeMask = 0b111_00000
MajorTypeMask byte = 0b111_00000
)

type Arg byte

const (
Arg8 Arg = 24
Arg16 = 25
Arg32 = 26
Arg64 = 27
Arg16 Arg = 25
Arg32 Arg = 26
Arg64 Arg = 27
// 28..30 reserved
ArgIndefinite = 31
ArgIndefinite Arg = 31
)

const (
argMask = 0b000_11111
argMask byte = 0b000_11111
)

const (
SimpleFalse byte = 20
SimpleTrue = 21
SimpleNull = 22
SimpleUndefined = 23
SimpleUint8 = Arg8 // 24
SimpleFloat16 = Arg16 // 25
SimpleFloat32 = Arg32 // 26
SimpleFloat64 = Arg64 // 27
SimpleFalse Arg = 20
SimpleTrue Arg = 21
SimpleNull Arg = 22
SimpleUndefined Arg = 23
SimpleUint8 = Arg8 // 24
SimpleFloat16 = Arg16 // 25
SimpleFloat32 = Arg32 // 26
SimpleFloat64 = Arg64 // 27
// 28..30 reserved
SimpleBreak = 31
SimpleBreak Arg = 31
)

var (
Expand All @@ -56,7 +56,7 @@ var (
)

const (
valueBreak = MajorTypeSimpleFloat | SimpleBreak
valueBreak = byte(MajorTypeSimpleFloat) | byte(SimpleBreak)
)

const lenSharedBuffer = 64
Expand Down
2 changes: 1 addition & 1 deletion io.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "io"
type peekReader struct {
r io.Reader // wrapped reader
p byte // peeked byte
pv bool // peeded valid
pv bool // peeked valid
}

func (r *peekReader) Read(out []byte) (int, error) {
Expand Down
37 changes: 37 additions & 0 deletions malicious_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package cbor

import (
"bytes"
"io"
"testing"
)

var malicious = []byte{0x9B, 0x00, 0x00, 0x42, 0xFA, 0x42, 0xFA, 0x42, 0xFA, 0x42}

func Test_Malicious(t *testing.T) {
t.Run("ReadAny", func(t *testing.T) {
in := bytes.NewBuffer(malicious)
v, err := ReadAny(in)
if v != nil {
t.Fatal("unexpected value")
}
if err != io.EOF {
t.Fatal("expected EOF")
}
})
t.Run("ReadOver", func(t *testing.T) {
in := bytes.NewBuffer(malicious)
err := ReadOver(in)
if err != io.EOF {
t.Fatal("expected EOF")
}
})
t.Run("ReadRaw", func(t *testing.T) {
in := bytes.NewBuffer(malicious)
out := bytes.NewBuffer(nil)
err := ReadRaw(in, out)
if err != io.EOF {
t.Fatal("expected EOF")
}
})
}
16 changes: 8 additions & 8 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func readMajorType(in io.Reader) (MajorType, Arg, uint64, error) {
// decodePrefix returns the major type, argument, and the length, in bytes, of
// the remaining part of the header if any.
func decodePrefix(p byte) (MajorType, Arg) {
majorType := MajorType(p & majorTypeMask)
majorType := MajorType(p & MajorTypeMask)
arg := Arg(p & argMask)
return majorType, arg
}
Expand Down Expand Up @@ -219,11 +219,11 @@ func ReadBool(in io.Reader) (bool, error) {

func readBool(majorType MajorType, value uint64) (bool, error) {
if majorType == MajorTypeSimpleFloat {
if byte(value) == SimpleFalse {
if Arg(value) == SimpleFalse {
return false, nil
}

if byte(value) == SimpleTrue {
if Arg(value) == SimpleTrue {
return true, nil
}

Expand Down Expand Up @@ -329,7 +329,7 @@ func readByteChunks(
func ReadArray(
in io.Reader,
readLength func(indefinite bool, length uint64) error,
readItem func(in io.Reader) error,
readItem func(i uint64, in io.Reader) error,
) error {
majorType, arg, value, err := readMajorType(in)
if err != nil {
Expand All @@ -345,7 +345,7 @@ func readArray(
arg Arg,
value uint64,
readLength func(indefinite bool, length uint64) error,
readItem func(in io.Reader) error,
readItem func(i uint64, in io.Reader) error,
) error {
if majorType != MajorTypeArray {
return ErrUnsupportedMajorType
Expand All @@ -360,7 +360,7 @@ func readArray(

if indefinite {
pin := &peekReader{r: in}
for {
for i := uint64(0); ; i++ {
r, err := pin.PeekByte()
if err != nil {
return err
Expand All @@ -369,14 +369,14 @@ func readArray(
break
}

err = readItem(pin)
err = readItem(i, pin)
if err != nil {
return err
}
}
} else {
for i := uint64(0); i < value; i++ {
err = readItem(in)
err = readItem(i, in)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions read_any.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func readAny(in io.Reader, majorType MajorType, arg Arg, value uint64) (any, err
return string(b.Bytes()), nil

case MajorTypeArray:
a := make([]any, value)
a := make([]any, 0, min(value, 16)) // limit pre allocate
if arg == ArgIndefinite {
for {
majorType, arg, value, err := readMajorType(in)
Expand All @@ -103,13 +103,13 @@ func readAny(in io.Reader, majorType MajorType, arg Arg, value uint64) (any, err
if err != nil {
return nil, err
}
a[i] = v
a = append(a, v)
}
}
return a, nil

case MajorTypeMap:
m := make(map[any]any, value)
m := make(map[any]any, min(value, 16)) // limit pre allocate
if arg == ArgIndefinite {
for {
majorType, arg, value, err := readMajorType(in)
Expand Down Expand Up @@ -151,11 +151,11 @@ func readAny(in io.Reader, majorType MajorType, arg Arg, value uint64) (any, err

default: // MajorTypeSimpleFloat:
switch {
case arg == 0 && value < uint64(SimpleFalse):
case arg == 0 && Arg(value) < SimpleFalse:
return readUnsigned[uint8](majorType, arg, value)
case arg == 0 && (value == uint64(SimpleFalse) || value == SimpleTrue):
case arg == 0 && (Arg(value) == SimpleFalse || Arg(value) == SimpleTrue):
return readBool(majorType, value)
case arg == 0 && (value == SimpleNull || value == SimpleUndefined):
case arg == 0 && (Arg(value) == SimpleNull || Arg(value) == SimpleUndefined):
return nil, nil
case arg == SimpleUint8:
return readUnsigned[uint8](majorType, arg, value)
Expand Down
15 changes: 12 additions & 3 deletions read_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ func ReadRaw(
}

if b == valueBreak {
out.Write([]byte{b})
_, err = out.Write([]byte{b})
if err != nil {
return err
}
break
}

Expand All @@ -74,7 +77,10 @@ func ReadRaw(
}

if b == valueBreak {
out.Write([]byte{b})
_, err = out.Write([]byte{b})
if err != nil {
return err
}
break
}

Expand Down Expand Up @@ -104,7 +110,10 @@ func ReadRaw(
}

if b == valueBreak {
out.Write([]byte{b})
_, err = out.Write([]byte{b})
if err != nil {
return err
}
break
}

Expand Down
4 changes: 2 additions & 2 deletions read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ func Test_ReadBool(t *testing.T) {
want = false
case encoded[0] == 0xf5:
want = true
case encoded[0]&majorTypeMask == MajorTypeSimpleFloat:
case encoded[0]&MajorTypeMask == byte(MajorTypeSimpleFloat):
wantErr = ErrUnsupportedValue
default:
wantErr = ErrUnsupportedMajorType
Expand Down Expand Up @@ -633,7 +633,7 @@ func Test_ReadArray(t *testing.T) {
out = make([]int32, 0, length)
return nil
},
func(in io.Reader) error {
func(i uint64, in io.Reader) error {
v, err := ReadSigned[int32](in)
if err != nil {
return err
Expand Down
16 changes: 12 additions & 4 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func WriteSigned[T int8 | int16 | int32 | int64](out io.Writer, value T) (int, e
func WriteFloat[T float16.Float16 | float32 | float64](out io.Writer, value T) (int, error) {
switch v := any(value).(type) {
case float16.Float16:
sharedBuffer[0] = MajorTypeSimpleFloat | SimpleFloat16
sharedBuffer[0] = byte(MajorTypeSimpleFloat) | byte(SimpleFloat16)
shiftBytesFrom(uint16(v), sharedBuffer[1:3])
return out.Write(sharedBuffer[0:3])

Expand All @@ -57,7 +57,7 @@ func WriteFloat[T float16.Float16 | float32 | float64](out io.Writer, value T) (
return WriteFloat(out, float16.Fromfloat32(v))
}

sharedBuffer[0] = MajorTypeSimpleFloat | SimpleFloat32
sharedBuffer[0] = byte(MajorTypeSimpleFloat) | byte(SimpleFloat32)
shiftBytesFrom(math.Float32bits(v), sharedBuffer[1:5])
return out.Write(sharedBuffer[0:5])

Expand All @@ -68,7 +68,7 @@ func WriteFloat[T float16.Float16 | float32 | float64](out io.Writer, value T) (
return WriteFloat(out, v32)
}

sharedBuffer[0] = MajorTypeSimpleFloat | SimpleFloat64
sharedBuffer[0] = byte(MajorTypeSimpleFloat) | byte(SimpleFloat64)
shiftBytesFrom(math.Float64bits(v), sharedBuffer[1:9])
return out.Write(sharedBuffer[0:9])

Expand All @@ -79,11 +79,19 @@ func WriteFloat[T float16.Float16 | float32 | float64](out io.Writer, value T) (

func WriteBool(out io.Writer, value bool) (int, error) {
if value {
return writeMajorType(out, MajorTypeSimpleFloat, SimpleTrue)
return writeMajorType(out, MajorTypeSimpleFloat, uint64(SimpleTrue))
}
return writeMajorType(out, MajorTypeSimpleFloat, uint64(SimpleFalse))
}

func WriteNull(out io.Writer) (int, error) {
return writeMajorType(out, MajorTypeSimpleFloat, uint64(SimpleNull))
}

func WriteBreak(out io.Writer) (int, error) {
return writeMajorType(out, MajorTypeSimpleFloat, uint64(SimpleBreak))
}

func WriteTag(out io.Writer, value uint64) (int, error) {
return writeMajorType(out, MajorTypeTagged, value)
}
Expand Down

0 comments on commit 1594b7e

Please sign in to comment.