From 23acbb6c536b9c73d4a98c51651d94b716ab989d Mon Sep 17 00:00:00 2001 From: Linh Tran Tuan Date: Fri, 1 Dec 2023 00:37:41 +0900 Subject: [PATCH] =?UTF-8?q?Fix=20WriteBatchIterator:=20doesn=E2=80=99t=20c?= =?UTF-8?q?orrectly=20detect=20invalid=20sequences=20(#132)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- write_batch.go | 19 +++++----------- write_batch_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/write_batch.go b/write_batch.go index 7988628..56fa5cc 100644 --- a/write_batch.go +++ b/write_batch.go @@ -4,6 +4,7 @@ package grocksdb import "C" import ( + "encoding/binary" "errors" "io" ) @@ -342,21 +343,13 @@ func (iter *WriteBatchIterator) decodeRecType() WriteBatchRecordType { } func (iter *WriteBatchIterator) decodeVarint() uint64 { - var n int - var x uint64 - for shift := uint(0); shift < 64 && n < len(iter.data); shift += 7 { - b := uint64(iter.data[n]) - n++ - x |= (b & 0x7F) << shift - if (b & 0x80) == 0 { - iter.data = iter.data[n:] - return x - } - } - if n == len(iter.data) { + v, n := binary.Uvarint(iter.data) + if n > 0 { + iter.data = iter.data[n:] + } else if n == 0 { iter.err = io.ErrShortBuffer } else { iter.err = errors.New("malformed varint") } - return 0 + return v } diff --git a/write_batch_test.go b/write_batch_test.go index 49b3a8b..9384b6c 100644 --- a/write_batch_test.go +++ b/write_batch_test.go @@ -1,6 +1,7 @@ package grocksdb import ( + "math" "testing" "github.com/stretchr/testify/require" @@ -89,3 +90,56 @@ func TestWriteBatchIterator(t *testing.T) { // there shouldn't be any left require.False(t, iter.Next()) } + +func TestDecodeVarint_ISSUE131(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in []byte + wantValue uint64 + expectErr bool + }{ + { + name: "invalid: 10th byte", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, + wantValue: 0, + expectErr: true, + }, + { + name: "valid: math.MaxUint64-40", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01}, + wantValue: math.MaxUint64 - 40, + expectErr: false, + }, + { + name: "invalid: with more than MaxVarintLen64 bytes", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01}, + wantValue: 0, + expectErr: true, + }, + { + name: "invalid: 1000 bytes", + in: func() []byte { + b := make([]byte, 1000) + for i := range b { + b[i] = 0xff + } + b[999] = 0 + return b + }(), + wantValue: 0, + expectErr: true, + }, + } + + for _, test := range tests { + wbi := &WriteBatchIterator{data: test.in} + require.EqualValues(t, test.wantValue, wbi.decodeVarint(), test.name) + if test.expectErr { + require.Error(t, wbi.err, test.name) + } else { + require.NoError(t, wbi.err, test.name) + } + } +}