From 7d5f73b8c589110108c06b677582337ad4fc7b72 Mon Sep 17 00:00:00 2001 From: linxGnu Date: Fri, 1 Dec 2023 00:05:35 +0900 Subject: [PATCH 1/3] =?UTF-8?q?Fix=20(*WriteBatchIterator).decodeVarint:?= =?UTF-8?q?=20doesn=E2=80=99t=20correctly=20detect=20invalid=20varint=20se?= =?UTF-8?q?quences?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- write_batch.go | 19 +++++-------------- write_batch_test.go | 7 +++++++ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/write_batch.go b/write_batch.go index 7988628..d7f0a2a 100644 --- a/write_batch.go +++ b/write_batch.go @@ -4,6 +4,7 @@ package grocksdb import "C" import ( + "encoding/binary" "errors" "io" ) @@ -342,21 +343,11 @@ func (iter *WriteBatchIterator) decodeRecType() WriteBatchRecordType { } func (iter *WriteBatchIterator) decodeVarint() uint64 { - var n int - var x uint64 - for shift := uint(0); shift < 64 && n < len(iter.data); shift += 7 { - b := uint64(iter.data[n]) - n++ - x |= (b & 0x7F) << shift - if (b & 0x80) == 0 { - iter.data = iter.data[n:] - return x - } - } - if n == len(iter.data) { + v, n := binary.Uvarint(iter.data) + if n == 0 { iter.err = io.ErrShortBuffer - } else { + } else if n < 0 { iter.err = errors.New("malformed varint") } - return 0 + return v } diff --git a/write_batch_test.go b/write_batch_test.go index 49b3a8b..f2f7ad8 100644 --- a/write_batch_test.go +++ b/write_batch_test.go @@ -89,3 +89,10 @@ func TestWriteBatchIterator(t *testing.T) { // there shouldn't be any left require.False(t, iter.Next()) } + +func TestDecodeVarint(t *testing.T) { + t.Parallel() + + wbi := &WriteBatchIterator{data: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}} + require.EqualValues(t, 0, wbi.decodeVarint()) +} From c5cb56a13d43ea7e6efec72449b35c46e5b38464 Mon Sep 17 00:00:00 2001 From: linxGnu Date: Fri, 1 Dec 2023 00:16:54 +0900 Subject: [PATCH 2/3] Fix --- write_batch.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/write_batch.go b/write_batch.go index d7f0a2a..56fa5cc 100644 --- a/write_batch.go +++ b/write_batch.go @@ -344,9 +344,11 @@ func (iter *WriteBatchIterator) decodeRecType() WriteBatchRecordType { func (iter *WriteBatchIterator) decodeVarint() uint64 { v, n := binary.Uvarint(iter.data) - if n == 0 { + if n > 0 { + iter.data = iter.data[n:] + } else if n == 0 { iter.err = io.ErrShortBuffer - } else if n < 0 { + } else { iter.err = errors.New("malformed varint") } return v From 7432cdeba3816a88082a27337d5fcde478f5fad2 Mon Sep 17 00:00:00 2001 From: linxGnu Date: Fri, 1 Dec 2023 00:33:23 +0900 Subject: [PATCH 3/3] Update tests --- write_batch_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/write_batch_test.go b/write_batch_test.go index f2f7ad8..9384b6c 100644 --- a/write_batch_test.go +++ b/write_batch_test.go @@ -1,6 +1,7 @@ package grocksdb import ( + "math" "testing" "github.com/stretchr/testify/require" @@ -90,9 +91,55 @@ func TestWriteBatchIterator(t *testing.T) { require.False(t, iter.Next()) } -func TestDecodeVarint(t *testing.T) { +func TestDecodeVarint_ISSUE131(t *testing.T) { t.Parallel() - wbi := &WriteBatchIterator{data: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}} - require.EqualValues(t, 0, wbi.decodeVarint()) + tests := []struct { + name string + in []byte + wantValue uint64 + expectErr bool + }{ + { + name: "invalid: 10th byte", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, + wantValue: 0, + expectErr: true, + }, + { + name: "valid: math.MaxUint64-40", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01}, + wantValue: math.MaxUint64 - 40, + expectErr: false, + }, + { + name: "invalid: with more than MaxVarintLen64 bytes", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01}, + wantValue: 0, + expectErr: true, + }, + { + name: "invalid: 1000 bytes", + in: func() []byte { + b := make([]byte, 1000) + for i := range b { + b[i] = 0xff + } + b[999] = 0 + return b + }(), + wantValue: 0, + expectErr: true, + }, + } + + for _, test := range tests { + wbi := &WriteBatchIterator{data: test.in} + require.EqualValues(t, test.wantValue, wbi.decodeVarint(), test.name) + if test.expectErr { + require.Error(t, wbi.err, test.name) + } else { + require.NoError(t, wbi.err, test.name) + } + } }