diff --git a/buffer.go b/buffer.go index 0012655..6527114 100644 --- a/buffer.go +++ b/buffer.go @@ -1,6 +1,7 @@ package parquet import ( + "math/bits" "sort" "sync" "sync/atomic" @@ -306,45 +307,47 @@ func (b *buffer) unref() { } } -func (b *buffer) reset() { - b.data = b.data[:0] -} - -func (b *buffer) resize(size int) { - if cap(b.data) < size { - const pageSize = 4096 - minSize := 2 * cap(b.data) - bufferSize := ((size + (pageSize - 1)) / pageSize) * pageSize - if bufferSize < minSize { - bufferSize = minSize - } - b.data = make([]byte, size, bufferSize) - } else { - b.data = b.data[:size] - } -} - -func (b *buffer) clone() (clone *buffer) { - if b.pool != nil { - clone = b.pool.get() - } else { - clone = &buffer{refc: 1} - } - clone.data = append(clone.data, b.data...) - return clone -} +// bufferPool holds a slice of sync.pools used for levelled buffering. +// the table below shows the pools used for different buffer sizes when both getting +// and putting a buffer. when allocating a new buffer from a given pool we always choose the +// min of the put range to guarantee that all gets will have an adequately sized buffer. +// +// [pool] : : : +// [0] : 0 -> 1023 : 1024 -> 2047 : 1024 +// [1] : 1024 -> 2047 : 2048 -> 4095 : 2048 +// [2] : 2048 -> 4095 : 4096 -> 8191 : 4096 +// ... +const numPoolBuckets = 16 +const basePoolIncrement = 1024 type bufferPool struct { - pool sync.Pool + pool [numPoolBuckets]sync.Pool } -func (p *bufferPool) get() *buffer { - b, _ := p.pool.Get().(*buffer) +// get returns a buffer from the levelled buffer pool. sz is used to choose the appropriate pool +func (p *bufferPool) get(sz int) *buffer { + i := levelledPoolIndex(sz) + b, _ := p.pool[i].Get().(*buffer) if b == nil { - b = &buffer{pool: p} - } else { - b.reset() + // align size to the pool + poolSize := basePoolIncrement << i + if sz > poolSize { // this can occur when the buffer requested is larger than the largest pool + poolSize = sz + } + b = &buffer{ + data: make([]byte, 0, poolSize), + pool: p, + } } + // if the buffer comes from the largest pool it may not be big enough + if cap(b.data) < sz { + p.pool[i].Put(b) + b = &buffer{ + data: make([]byte, 0, sz), + pool: p, + } + } + b.data = b.data[:sz] b.ref() return b } @@ -353,7 +356,27 @@ func (p *bufferPool) put(b *buffer) { if b.pool != p { panic("BUG: buffer returned to a different pool than the one it was allocated from") } - p.pool.Put(b) + // if this slice is somehow less then our min pool size, just drop it + sz := cap(b.data) + if sz < basePoolIncrement { + return + } + i := levelledPoolIndex(sz / 2) // divide by 2 to put the buffer in the level below so it will always be large enough + p.pool[i].Put(b) +} + +// levelledPoolIndex returns the index of the pool to use for a buffer of size sz. it never returns +// an index that will panic +func levelledPoolIndex(sz int) int { + i := sz / basePoolIncrement + i = 32 - bits.LeadingZeros32(uint32(i)) // log2 + if i >= numPoolBuckets { + i = numPoolBuckets - 1 + } + if i < 0 { + i = 0 + } + return i } var ( diff --git a/buffer_internal_test.go b/buffer_internal_test.go new file mode 100644 index 0000000..e2a0f42 --- /dev/null +++ b/buffer_internal_test.go @@ -0,0 +1,57 @@ +package parquet + +import ( + "math" + "math/rand" + "testing" +) + +func TestBufferAlwaysCorrectSize(t *testing.T) { + var p bufferPool + for i := 0; i < 1000; i++ { + sz := rand.Intn(1024 * 1024) + buff := p.get(sz) + if len(buff.data) != sz { + t.Errorf("Expected buffer of size %d, got %d", sz, len(buff.data)) + } + p.put(buff) + } +} + +func TestLevelledPoolIndex(t *testing.T) { + tcs := []struct { + sz int + expected int + }{ + { + sz: 1023, + expected: 0, + }, + { + sz: 1024, + expected: 1, + }, + { + sz: -1, + expected: 0, + }, + { + sz: 16*1024*1024 - 1, + expected: 14, + }, + { + sz: 16 * 1024 * 1024, + expected: 15, + }, + { + sz: math.MaxInt, + expected: 15, + }, + } + + for _, tc := range tcs { + if actual := levelledPoolIndex(tc.sz); actual != tc.expected { + t.Errorf("Expected index %d for size %d, got %d", tc.expected, tc.sz, actual) + } + } +} diff --git a/column.go b/column.go index 41b1fd0..74c09ea 100644 --- a/column.go +++ b/column.go @@ -493,10 +493,7 @@ func schemaRepetitionTypeOf(s *format.SchemaElement) format.FieldRepetitionType } func (c *Column) decompress(compressedPageData []byte, uncompressedPageSize int32) (page *buffer, err error) { - page = uncompressedPageBufferPool.get() - if uncompressedPageSize > 0 { - page.resize(int(uncompressedPageSize)) - } + page = uncompressedPageBufferPool.get(int(uncompressedPageSize)) page.data, err = c.compression.Decode(page.data, compressedPageData) if err != nil { page.unref() @@ -633,15 +630,13 @@ func (c *Column) decodeDataPage(header DataPageHeader, numValues int, repetition pageKind := pageType.Kind() if pageKind >= 0 && int(pageKind) < len(pageValuesBufferPool) { - vbuf = pageValuesBufferPool[pageKind].get() + vbuf = pageValuesBufferPool[pageKind].get(int(pageType.EstimateSize(numValues))) defer vbuf.unref() - vbuf.resize(int(pageType.EstimateSize(numValues))) pageValues = vbuf.data } if pageKind == ByteArray { - obuf = pageOffsetsBufferPool.get() + obuf = pageOffsetsBufferPool.get(4 * (numValues + 1)) defer obuf.unref() - obuf.resize(4 * (numValues + 1)) pageOffsets = unsafecast.BytesToUint32(obuf.data) } @@ -712,7 +707,7 @@ func decodeLevelsV2(enc encoding.Encoding, numValues int, data []byte, length in } func decodeLevels(enc encoding.Encoding, numValues int, data []byte) (levels *buffer, err error) { - levels = levelsBufferPool.get() + levels = levelsBufferPool.get(numValues) levels.data, err = enc.DecodeLevels(levels.data, data) if err != nil { levels.unref() diff --git a/file.go b/file.go index f19a814..7a47570 100644 --- a/file.go +++ b/file.go @@ -567,11 +567,9 @@ func (f *filePages) readDictionary() error { return err } - page := compressedPageBufferPool.get() + page := compressedPageBufferPool.get(int(header.CompressedPageSize)) defer page.unref() - page.resize(int(header.CompressedPageSize)) - if _, err := io.ReadFull(rbuf, page.data); err != nil { return err } @@ -619,11 +617,9 @@ func (f *filePages) readDataPageV2(header *format.PageHeader, page *buffer) (Pag } func (f *filePages) readPage(header *format.PageHeader, reader *bufio.Reader) (*buffer, error) { - page := compressedPageBufferPool.get() + page := compressedPageBufferPool.get(int(header.CompressedPageSize)) defer page.unref() - page.resize(int(header.CompressedPageSize)) - if _, err := io.ReadFull(reader, page.data); err != nil { return nil, err }