-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
extract the buffer logic into it's own package
- Loading branch information
Showing
11 changed files
with
763 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
package buffer | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"sync" | ||
"time" | ||
|
||
"github.com/trufflesecurity/trufflehog/v3/pkg/context" | ||
) | ||
|
||
type poolMetrics struct{} | ||
|
||
func (poolMetrics) recordShrink(amount int) { | ||
shrinkCount.Inc() | ||
shrinkAmount.Add(float64(amount)) | ||
} | ||
|
||
func (poolMetrics) recordBufferRetrival() { | ||
activeBufferCount.Inc() | ||
checkoutCount.Inc() | ||
bufferCount.Inc() | ||
} | ||
|
||
func (poolMetrics) recordBufferReturn(bufCap, bufLen int64) { | ||
activeBufferCount.Dec() | ||
totalBufferSize.Add(float64(bufCap)) | ||
totalBufferLength.Add(float64(bufLen)) | ||
} | ||
|
||
// PoolOpts is a function that configures a BufferPool. | ||
type PoolOpts func(pool *Pool) | ||
|
||
// Pool of buffers. | ||
type Pool struct { | ||
*sync.Pool | ||
bufferSize uint32 | ||
|
||
metrics poolMetrics | ||
} | ||
|
||
const defaultBufferSize = 1 << 12 // 4KB | ||
// NewBufferPool creates a new instance of BufferPool. | ||
func NewBufferPool(opts ...PoolOpts) *Pool { | ||
pool := &Pool{bufferSize: defaultBufferSize} | ||
|
||
for _, opt := range opts { | ||
opt(pool) | ||
} | ||
pool.Pool = &sync.Pool{ | ||
New: func() any { | ||
return &Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, pool.bufferSize))} | ||
}, | ||
} | ||
|
||
return pool | ||
} | ||
|
||
// Get returns a Buffer from the pool. | ||
func (p *Pool) Get(ctx context.Context) *Buffer { | ||
buf, ok := p.Pool.Get().(*Buffer) | ||
if !ok { | ||
ctx.Logger().Error(fmt.Errorf("Buffer pool returned unexpected type"), "using new Buffer") | ||
buf = &Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, p.bufferSize))} | ||
} | ||
p.metrics.recordBufferRetrival() | ||
buf.resetMetric() | ||
|
||
return buf | ||
} | ||
|
||
// Put returns a Buffer to the pool. | ||
func (p *Pool) Put(buf *Buffer) { | ||
p.metrics.recordBufferReturn(int64(buf.Cap()), int64(buf.Len())) | ||
|
||
// If the Buffer is more than twice the default size, replace it with a new Buffer. | ||
// This prevents us from returning very large buffers to the pool. | ||
const maxAllowedCapacity = 2 * defaultBufferSize | ||
if buf.Cap() > maxAllowedCapacity { | ||
p.metrics.recordShrink(buf.Cap() - defaultBufferSize) | ||
buf = &Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, p.bufferSize))} | ||
} else { | ||
// Reset the Buffer to clear any existing data. | ||
buf.Reset() | ||
} | ||
buf.recordMetric() | ||
|
||
p.Pool.Put(buf) | ||
} | ||
|
||
// Buffer is a wrapper around bytes.Buffer that includes a timestamp for tracking Buffer checkout duration. | ||
type Buffer struct { | ||
*bytes.Buffer | ||
checkedOutAt time.Time | ||
} | ||
|
||
// NewBuffer creates a new instance of Buffer. | ||
func NewBuffer() *Buffer { return &Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, defaultBufferSize))} } | ||
|
||
func (b *Buffer) Grow(size int) { | ||
b.Buffer.Grow(size) | ||
b.recordGrowth(size) | ||
} | ||
|
||
func (b *Buffer) resetMetric() { b.checkedOutAt = time.Now() } | ||
|
||
func (b *Buffer) recordMetric() { | ||
dur := time.Since(b.checkedOutAt) | ||
checkoutDuration.Observe(float64(dur.Microseconds())) | ||
checkoutDurationTotal.Add(float64(dur.Microseconds())) | ||
} | ||
|
||
func (b *Buffer) recordGrowth(size int) { | ||
growCount.Inc() | ||
growAmount.Add(float64(size)) | ||
} | ||
|
||
// Write date to the buffer. | ||
func (b *Buffer) Write(ctx context.Context, data []byte) (int, error) { | ||
if b.Buffer == nil { | ||
// This case should ideally never occur if buffers are properly managed. | ||
ctx.Logger().Error(fmt.Errorf("buffer is nil, initializing a new buffer"), "action", "initializing_new_buffer") | ||
b.Buffer = bytes.NewBuffer(make([]byte, 0, defaultBufferSize)) | ||
b.resetMetric() | ||
} | ||
|
||
size := len(data) | ||
bufferLength := b.Buffer.Len() | ||
totalSizeNeeded := bufferLength + size | ||
// If the total size is within the threshold, write to the buffer. | ||
ctx.Logger().V(4).Info( | ||
"writing to buffer", | ||
"data_size", size, | ||
"content_size", bufferLength, | ||
) | ||
|
||
availableSpace := b.Buffer.Cap() - bufferLength | ||
growSize := totalSizeNeeded - bufferLength | ||
if growSize > availableSpace { | ||
ctx.Logger().V(4).Info( | ||
"buffer size exceeded, growing buffer", | ||
"current_size", bufferLength, | ||
"new_size", totalSizeNeeded, | ||
"available_space", availableSpace, | ||
"grow_size", growSize, | ||
) | ||
// We are manually growing the buffer so we can track the growth via metrics. | ||
// Knowing the exact data size, we directly resize to fit it, rather than exponential growth | ||
// which may require multiple allocations and copies if the size required is much larger | ||
// than double the capacity. Our approach aligns with default behavior when growth sizes | ||
// happen to match current capacity, retaining asymptotic efficiency benefits. | ||
b.Buffer.Grow(growSize) | ||
} | ||
|
||
return b.Buffer.Write(data) | ||
} | ||
|
||
// readCloser is a custom implementation of io.ReadCloser. It wraps a bytes.Reader | ||
// for reading data from an in-memory buffer and includes an onClose callback. | ||
// The onClose callback is used to return the buffer to the pool, ensuring buffer re-usability. | ||
type readCloser struct { | ||
*bytes.Reader | ||
onClose func() | ||
} | ||
|
||
// ReadCloser creates a new instance of readCloser. | ||
func ReadCloser(data []byte, onClose func()) *readCloser { | ||
return &readCloser{Reader: bytes.NewReader(data), onClose: onClose} | ||
} | ||
|
||
// Close implements the io.Closer interface. It calls the onClose callback to return the buffer | ||
// to the pool, enabling buffer reuse. This method should be called by the consumers of ReadCloser | ||
// once they have finished reading the data to ensure proper resource management. | ||
func (brc *readCloser) Close() error { | ||
if brc.onClose == nil { | ||
return nil | ||
} | ||
|
||
brc.onClose() // Return the buffer to the pool | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
package buffer | ||
|
||
import ( | ||
"bytes" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
|
||
"github.com/trufflesecurity/trufflehog/v3/pkg/context" | ||
) | ||
|
||
func TestNewBufferPool(t *testing.T) { | ||
t.Parallel() | ||
tests := []struct { | ||
name string | ||
opts []PoolOpts | ||
expectedBuffSize uint32 | ||
}{ | ||
{name: "Default pool size", expectedBuffSize: defaultBufferSize}, | ||
{ | ||
name: "Custom pool size", | ||
opts: []PoolOpts{func(p *Pool) { p.bufferSize = 8 * 1024 }}, // 8KB | ||
expectedBuffSize: 8 * 1024, | ||
}, | ||
} | ||
|
||
for _, tc := range tests { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
pool := NewBufferPool(tc.opts...) | ||
assert.Equal(t, tc.expectedBuffSize, pool.bufferSize) | ||
}) | ||
} | ||
} | ||
|
||
func TestBufferPoolGetPut(t *testing.T) { | ||
t.Parallel() | ||
tests := []struct { | ||
name string | ||
preparePool func(p *Pool) *Buffer // Prepare the pool and return an initial buffer to put if needed | ||
expectedCapBefore int // Expected capacity before putting it back | ||
expectedCapAfter int // Expected capacity after retrieving it again | ||
}{ | ||
{ | ||
name: "Get new buffer and put back without modification", | ||
preparePool: func(_ *Pool) *Buffer { | ||
return nil // No initial buffer to put | ||
}, | ||
expectedCapBefore: int(defaultBufferSize), | ||
expectedCapAfter: int(defaultBufferSize), | ||
}, | ||
{ | ||
name: "Put oversized buffer, expect shrink", | ||
preparePool: func(p *Pool) *Buffer { | ||
buf := &Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, 3*defaultBufferSize))} | ||
return buf | ||
}, | ||
expectedCapBefore: int(defaultBufferSize), | ||
expectedCapAfter: int(defaultBufferSize), // Should shrink back to default | ||
}, | ||
} | ||
|
||
for _, tc := range tests { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
pool := NewBufferPool() | ||
initialBuf := tc.preparePool(pool) | ||
if initialBuf != nil { | ||
pool.Put(initialBuf) | ||
} | ||
|
||
buf := pool.Get(context.Background()) | ||
assert.Equal(t, tc.expectedCapBefore, buf.Cap()) | ||
|
||
pool.Put(buf) | ||
|
||
bufAfter := pool.Get(context.Background()) | ||
assert.Equal(t, tc.expectedCapAfter, bufAfter.Cap()) | ||
}) | ||
} | ||
} | ||
|
||
func TestBufferWrite(t *testing.T) { | ||
t.Parallel() | ||
tests := []struct { | ||
name string | ||
initialCapacity int | ||
writeDataSequence [][]byte // Sequence of writes to simulate multiple writes | ||
expectedSize int | ||
expectedCap int | ||
}{ | ||
{ | ||
name: "Write to empty buffer", | ||
initialCapacity: defaultBufferSize, | ||
writeDataSequence: [][]byte{ | ||
[]byte("hello"), | ||
}, | ||
expectedSize: 5, | ||
expectedCap: defaultBufferSize, // No growth for small data | ||
}, | ||
{ | ||
name: "Write causing growth", | ||
initialCapacity: 10, // Small initial capacity to force growth | ||
writeDataSequence: [][]byte{ | ||
[]byte("this is a longer string exceeding initial capacity"), | ||
}, | ||
expectedSize: 50, | ||
expectedCap: 50, | ||
}, | ||
{ | ||
name: "Write nil data", | ||
initialCapacity: defaultBufferSize, | ||
writeDataSequence: [][]byte{nil}, | ||
expectedCap: defaultBufferSize, | ||
}, | ||
{ | ||
name: "Repeated writes, cumulative growth", | ||
initialCapacity: 20, // Set an initial capacity to test growth over multiple writes | ||
writeDataSequence: [][]byte{ | ||
[]byte("first write, "), | ||
[]byte("second write, "), | ||
[]byte("third write exceeding the initial capacity."), | ||
}, | ||
expectedSize: 70, | ||
expectedCap: 70, // Expect capacity to grow to accommodate all writes | ||
}, | ||
{ | ||
name: "Write large single data to test significant growth", | ||
initialCapacity: 50, // Set an initial capacity smaller than the data to be written | ||
writeDataSequence: [][]byte{ | ||
bytes.Repeat([]byte("a"), 1024), // 1KB data to significantly exceed initial capacity | ||
}, | ||
expectedSize: 1024, | ||
expectedCap: 1024, | ||
}, | ||
} | ||
|
||
for _, tc := range tests { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
buf := &Buffer{Buffer: bytes.NewBuffer(make([]byte, 0, tc.initialCapacity))} | ||
totalWritten := 0 | ||
for _, data := range tc.writeDataSequence { | ||
n, err := buf.Write(context.Background(), data) | ||
assert.NoError(t, err) | ||
|
||
totalWritten += n | ||
} | ||
assert.Equal(t, tc.expectedSize, totalWritten) | ||
assert.Equal(t, tc.expectedSize, buf.Len()) | ||
assert.GreaterOrEqual(t, buf.Cap(), tc.expectedCap) | ||
}) | ||
} | ||
} | ||
|
||
func TestReadCloserClose(t *testing.T) { | ||
t.Parallel() | ||
onCloseCalled := false | ||
rc := ReadCloser([]byte("data"), func() { onCloseCalled = true }) | ||
|
||
err := rc.Close() | ||
assert.NoError(t, err) | ||
assert.True(t, onCloseCalled, "onClose callback should be called upon Close") | ||
} |
Oops, something went wrong.