Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursion limit for dynamic code #358

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions msgp/defs.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ const (
last5 = 0x1f
first3 = 0xe0
last7 = 0x7f

// recursionLimit is the limit of recursive calls.
// This limits the call depth of dynamic code, like Skip and interface conversions.
recursionLimit = 100000
)

func isfixint(b byte) bool {
Expand Down
9 changes: 9 additions & 0 deletions msgp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ var (
// contain the contents of the message
ErrShortBytes error = errShort{}

// ErrRecursion is returned when the maximum recursion limit is reached for an operation.
// This should only realistically be seen on adversarial data trying to exhaust the stack.
ErrRecursion error = errRecursion{}

// this error is only returned
// if we reach code that should
// be unreachable
Expand Down Expand Up @@ -134,6 +138,11 @@ func (f errFatal) Resumable() bool { return false }

func (f errFatal) withContext(ctx string) error { f.ctx = addCtx(f.ctx, ctx); return f }

type errRecursion struct{}

func (e errRecursion) Error() string { return "msgp: recursion limit reached" }
func (e errRecursion) Resumable() bool { return false }

// ArrayError is an error returned
// when decoding a fix-sized array
// of the wrong size
Expand Down
14 changes: 14 additions & 0 deletions msgp/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ func rwMap(dst jsWriter, src *Reader) (n int, err error) {
return dst.WriteString("{}")
}

// This is potentially a recursive call.
if done, err := src.recursiveCall(); err != nil {
return 0, err
} else {
defer done()
}

err = dst.WriteByte('{')
if err != nil {
return
Expand Down Expand Up @@ -162,6 +169,13 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) {
if err != nil {
return
}
// This is potentially a recursive call.
if done, err := src.recursiveCall(); err != nil {
return 0, err
} else {
defer done()
}

var sz uint32
var nn int
sz, err = src.ReadArrayHeader()
Expand Down
52 changes: 29 additions & 23 deletions msgp/json_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"time"
)

var unfuns [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error)
var unfuns [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error)

func init() {
// NOTE(pmh): this is best expressed as a jump table,
// but gc doesn't do that yet. revisit post-go1.5.
unfuns = [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error){
unfuns = [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error){
StrType: rwStringBytes,
BinType: rwBytesBytes,
MapType: rwMapBytes,
Expand Down Expand Up @@ -51,15 +51,15 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) {
dst = bufio.NewWriterSize(w, 512)
}
for len(msg) > 0 && err == nil {
msg, scratch, err = writeNext(dst, msg, scratch)
msg, scratch, err = writeNext(dst, msg, scratch, 0)
}
if !cast && err == nil {
err = dst.(*bufio.Writer).Flush()
}
return msg, err
}

func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
if len(msg) < 1 {
return msg, scratch, ErrShortBytes
}
Expand All @@ -76,10 +76,13 @@ func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
t = TimeType
}
}
return unfuns[t](w, msg, scratch)
return unfuns[t](w, msg, scratch, depth)
}

func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwArrayBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
if depth >= recursionLimit {
return msg, scratch, ErrRecursion
}
sz, msg, err := ReadArrayHeaderBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -95,7 +98,7 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
return msg, scratch, err
}
}
msg, scratch, err = writeNext(w, msg, scratch)
msg, scratch, err = writeNext(w, msg, scratch, depth+1)
if err != nil {
return msg, scratch, err
}
Expand All @@ -104,7 +107,10 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
return msg, scratch, err
}

func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwMapBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
if depth >= recursionLimit {
return msg, scratch, ErrRecursion
}
sz, msg, err := ReadMapHeaderBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -120,15 +126,15 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}
}
msg, scratch, err = rwMapKeyBytes(w, msg, scratch)
msg, scratch, err = rwMapKeyBytes(w, msg, scratch, depth)
if err != nil {
return msg, scratch, err
}
err = w.WriteByte(':')
if err != nil {
return msg, scratch, err
}
msg, scratch, err = writeNext(w, msg, scratch)
msg, scratch, err = writeNext(w, msg, scratch, depth+1)
if err != nil {
return msg, scratch, err
}
Expand All @@ -137,17 +143,17 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
msg, scratch, err := rwStringBytes(w, msg, scratch)
func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
msg, scratch, err := rwStringBytes(w, msg, scratch, depth)
if err != nil {
if tperr, ok := err.(TypeError); ok && tperr.Encoded == BinType {
return rwBytesBytes(w, msg, scratch)
return rwBytesBytes(w, msg, scratch, depth)
}
}
return msg, scratch, err
}

func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwStringBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
str, msg, err := ReadStringZC(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -156,7 +162,7 @@ func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, erro
return msg, scratch, err
}

func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwBytesBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
bts, msg, err := ReadBytesZC(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -180,7 +186,7 @@ func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error
return msg, scratch, err
}

func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwNullBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
msg, err := ReadNilBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -189,7 +195,7 @@ func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwBoolBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
b, msg, err := ReadBoolBytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -202,7 +208,7 @@ func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwIntBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
i, msg, err := ReadInt64Bytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -212,7 +218,7 @@ func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwUintBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
u, msg, err := ReadUint64Bytes(msg)
if err != nil {
return msg, scratch, err
Expand All @@ -222,7 +228,7 @@ func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var f float32
var err error
f, msg, err = ReadFloat32Bytes(msg)
Expand All @@ -234,7 +240,7 @@ func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err
return msg, scratch, err
}

func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var f float64
var err error
f, msg, err = ReadFloat64Bytes(msg)
Expand All @@ -246,7 +252,7 @@ func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err
return msg, scratch, err
}

func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwTimeBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var t time.Time
var err error
t, msg, err = ReadTimeBytes(msg)
Expand All @@ -261,7 +267,7 @@ func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error)
return msg, scratch, err
}

func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) {
func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) {
var err error
var et int8
et, err = peekExtension(msg)
Expand Down
43 changes: 40 additions & 3 deletions msgp/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ type Reader struct {
// is stateless; all the
// buffering is done
// within R.
R *fwd.Reader
scratch []byte
R *fwd.Reader
scratch []byte
recursionDepth int
}

// Read implements `io.Reader`
Expand Down Expand Up @@ -190,6 +191,11 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) {
return n, io.ErrShortWrite
}

if done, err := m.recursiveCall(); err != nil {
return n, err
} else {
defer done()
}
// for maps and slices, read elements
for x := uintptr(0); x < o; x++ {
var n2 int64
Expand All @@ -202,6 +208,18 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) {
return n, nil
}

// recursiveCall will increment the recursion depth and return an error if it is exceeded.
// If a nil error is returned, done must be called to decrement the counter.
func (m *Reader) recursiveCall() (done func(), err error) {
if m.recursionDepth >= recursionLimit {
return func() {}, ErrRecursion
}
m.recursionDepth++
return func() {
m.recursionDepth--
}, nil
}

// ReadFull implements `io.ReadFull`
func (m *Reader) ReadFull(p []byte) (int, error) {
return m.R.ReadFull(p)
Expand Down Expand Up @@ -332,7 +350,12 @@ func (m *Reader) Skip() error {
return err
}

// for maps and slices, skip elements
// for maps and slices, skip elements with recursive call
if done, err := m.recursiveCall(); err != nil {
return err
} else {
defer done()
}
for x := uintptr(0); x < o; x++ {
err = m.Skip()
if err != nil {
Expand Down Expand Up @@ -1333,6 +1356,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) {
return

case MapType:
// This can call back here, so treat as recursive call.
if done, err := m.recursiveCall(); err != nil {
return nil, err
} else {
defer done()
}

mp := make(map[string]interface{})
err = m.ReadMapStrIntf(mp)
i = mp
Expand All @@ -1358,6 +1388,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) {
if err != nil {
return
}

if done, err := m.recursiveCall(); err != nil {
return nil, err
} else {
defer done()
}

out := make([]interface{}, int(sz))
for j := range out {
out[j], err = m.ReadIntf()
Expand Down
Loading
Loading