Skip to content

Commit

Permalink
use buffer pool
Browse files Browse the repository at this point in the history
  • Loading branch information
ahrav committed Jan 18, 2024
1 parent dad92c5 commit fd41709
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 32 deletions.
4 changes: 2 additions & 2 deletions pkg/gitparse/gitparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ const (
defaultDateFormat = "Mon Jan 02 15:04:05 2006 -0700"

// defaultMaxDiffSize is the maximum size for a diff. Larger diffs will be cut off.
defaultMaxDiffSize = 1 * 1024 * 1024 * 1024 // 1GB
defaultMaxDiffSize = 2 * 1024 * 1024 * 1024 // 1GB

// defaultMaxCommitSize is the maximum size for a commit. Larger commits will be cut off.
defaultMaxCommitSize = 1 * 1024 * 1024 * 1024 // 1GB
defaultMaxCommitSize = 2 * 1024 * 1024 * 1024 // 1GB
)

// Commit contains commit header info and diffs.
Expand Down
5 changes: 3 additions & 2 deletions pkg/gitparse/gitparse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,9 @@ func TestStagedDiffParsing(t *testing.T) {
IsBinary: false,
},
{
PathB: "trufflehog_3.42.0_linux_arm64.tar.gz",
IsBinary: true,
PathB: "trufflehog_3.42.0_linux_arm64.tar.gz",
contentWriter: createBufferedFileWriterWithContent(nil),
IsBinary: true,
},
{
PathB: "tzu",
Expand Down
107 changes: 86 additions & 21 deletions pkg/writers/buffered_file_writer/bufferedfilewriter.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
// Package bufferedfilewriter provides a writer that buffers data in memory until a threshold is exceeded at
// which point it switches to writing to a temporary file.
package bufferedfilewriter

import (
"bytes"
"fmt"
"io"
"os"
"sync"

"github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

// bufferPool is used to store buffers for reuse.
var bufferPool = sync.Pool{
// TODO: Consider growing the buffer before returning it if we can find an optimal size.
// Ideally the size would cover the majority of cases without being too large.
// This would avoid the need to grow the buffer when writing to it, reducing allocations.
New: func() any { return new(bytes.Buffer) },
}

// BufferedFileWriter manages a buffer for writing data, flushing to a file when a threshold is exceeded.
type BufferedFileWriter struct {
threshold uint64
buf bytes.Buffer
file *os.File
threshold uint64 // Threshold for switching to file writing.
size uint64 // Total size of the data written.

buf bytes.Buffer // Buffer for storing data under the threshold in memory.
filename string // Name of the temporary file.
file io.WriteCloser // File for storing data over the threshold.
}

// Option is a function that modifies a BufferedFileWriter.
Expand All @@ -26,7 +41,7 @@ func WithThreshold(threshold uint64) Option {

// New creates a new BufferedFileWriter with the given options.
func New(opts ...Option) *BufferedFileWriter {
const defaultThreshold = 20 * 1024 * 1024 // 20MB
const defaultThreshold = 10 * 1024 * 1024 // 10MB
w := &BufferedFileWriter{threshold: defaultThreshold}
for _, opt := range opts {
opt(w)
Expand All @@ -41,40 +56,64 @@ func (w *BufferedFileWriter) Len() int { return w.buf.Len() }
func (w *BufferedFileWriter) String() string { return w.buf.String() }

// Write writes data to the buffer or a file, depending on the size.
func (w *BufferedFileWriter) Write(ctx context.Context, p []byte) (int, error) {
if uint64(w.buf.Len()+len(p)) <= w.threshold {
func (w *BufferedFileWriter) Write(ctx context.Context, data []byte) (int, error) {
size := uint64(len(data))
defer func() {
w.size += size
ctx.Logger().V(4).Info(
"write complete",
"data_size", size,
"content_size", w.buf.Len(),
"total_size", w.size,
)
}()

if w.buf.Len() == 0 {
bufPtr, ok := bufferPool.Get().(*bytes.Buffer)
if !ok {
ctx.Logger().Error(fmt.Errorf("buffer pool returned unexpected type"), "using new buffer")
bufPtr = new(bytes.Buffer)
}
bufPtr.Reset() // Reset the buffer to clear any existing data
w.buf = *bufPtr
}

if uint64(w.buf.Len())+size <= w.threshold {
// If the total size is within the threshold, write to the buffer.
ctx.Logger().V(4).Info(
"writing to buffer",
"data_size", len(p),
"data_size", size,
"content_size", w.buf.Len(),
)
return w.buf.Write(p)
return w.buf.Write(data)
}

// Switch to file writing if threshold is exceeded.
// This helps in managing memory efficiently for large diffs.
// This helps in managing memory efficiently for large content.
if w.file == nil {
var err error
w.file, err = os.CreateTemp(os.TempDir(), cleantemp.MkFilename())
file, err := os.CreateTemp(os.TempDir(), cleantemp.MkFilename())
if err != nil {
return 0, err
}

w.filename = file.Name()
w.file = file

// Transfer existing data in buffer to the file, then clear the buffer.
// This ensures all the diff data is in one place - either entirely in the buffer or the file.
// This ensures all the data is in one place - either entirely in the buffer or the file.
if w.buf.Len() > 0 {
ctx.Logger().V(4).Info("writing buffer to file", "content_size", w.buf.Len())
if _, err := w.file.Write(w.buf.Bytes()); err != nil {
return 0, err
}
// Replace the buffer with a new one to free up memory.
w.buf = bytes.Buffer{}
// Reset the buffer to clear any existing data and return it to the pool.
w.buf.Reset()
bufferPool.Put(&w.buf)
}
}
ctx.Logger().V(4).Info("writing to file", "data_size", len(p))
ctx.Logger().V(4).Info("writing to file", "data_size", size)

return w.file.Write(p)
return w.file.Write(data)
}

// Close flushes any remaining data in the buffer to the file and closes the file if it was created.
Expand All @@ -92,21 +131,27 @@ func (w *BufferedFileWriter) Close() error {
return w.file.Close()
}

// ReadCloser returns an io.ReadCloser to read the written content. If the total content size exceeds the
// predefined threshold, it is stored in a temporary file and a file reader is returned.
// For content under the threshold, it is kept in memory and a bytes reader on the buffer is returned.
// ReadCloser returns an io.ReadCloser to read the written content. It provides a reader
// based on the current storage medium of the data (in-memory buffer or file).
// If the total content size exceeds the predefined threshold, it is stored in a temporary file and a file
// reader is returned. For in-memory data, it returns a custom reader that handles returning
// the buffer to the pool.
// The caller should call Close() on the returned io.Reader when done to ensure files are cleaned up.
func (w *BufferedFileWriter) ReadCloser() (io.ReadCloser, error) {
if w.file != nil {
// Data is in a file, read from the file.
file, err := os.Open(w.file.Name())
file, err := os.Open(w.filename)
if err != nil {
return nil, err
}
return newAutoDeletingFileReader(file), nil
}

// Data is in memory.
return io.NopCloser(bytes.NewReader(w.buf.Bytes())), nil
return &bufferReadCloser{
Reader: bytes.NewReader(w.buf.Bytes()),
onClose: func() { bufferPool.Put(&w.buf) },
}, nil
}

// autoDeletingFileReader wraps an *os.File and deletes the file on Close.
Expand All @@ -122,3 +167,23 @@ func (r *autoDeletingFileReader) Close() error {
defer os.Remove(r.Name()) // Delete the file after closing
return r.File.Close()
}

// bufferReadCloser is a custom implementation of io.ReadCloser. It wraps a bytes.Reader
// for reading data from an in-memory buffer and includes an onClose callback.
// The onClose callback is used to return the buffer to the pool, ensuring buffer re-usability.
type bufferReadCloser struct {
*bytes.Reader
onClose func()
}

// Close implements the io.Closer interface. It calls the onClose callback to return the buffer
// to the pool, enabling buffer reuse. This method should be called by the consumers of ReadCloser
// once they have finished reading the data to ensure proper resource management.
func (brc *bufferReadCloser) Close() error {
if brc.onClose == nil {
return nil
}

brc.onClose() // Return the buffer to the pool
return nil
}
14 changes: 7 additions & 7 deletions pkg/writers/buffered_file_writer/bufferedfilewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ func TestBufferedFileWriterNewThreshold(t *testing.T) {
t.Parallel()

const (
defaultThreshold = 20 * 1024 * 1024 // 20MB
customThreshold = 10 * 1024 * 1024 // 10MB
defaultThreshold = 10 * 1024 * 1024 // 10MB
customThreshold = 20 * 1024 * 1024 // 20MB
)

tests := []struct {
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestBufferedFileWriterWriteExceedsThreshold(t *testing.T) {

assert.NotNil(t, writer.file)
assert.Len(t, writer.buf.Bytes(), 0)
fileContents, err := os.ReadFile(writer.file.Name())
fileContents, err := os.ReadFile(writer.filename)
assert.NoError(t, err)
assert.Equal(t, data, fileContents)
}
Expand All @@ -144,9 +144,9 @@ func TestBufferedFileWriterWriteAfterFlush(t *testing.T) {
defer writer.Close()

// Get the file modification time after the initial write.
initialModTime, err := getFileModTime(t, writer.file.Name())
initialModTime, err := getFileModTime(t, writer.filename)
assert.NoError(t, err)
fileContents, err := os.ReadFile(writer.file.Name())
fileContents, err := os.ReadFile(writer.filename)
assert.NoError(t, err)
assert.Equal(t, initialData, fileContents)

Expand All @@ -155,7 +155,7 @@ func TestBufferedFileWriterWriteAfterFlush(t *testing.T) {
assert.NoError(t, err)

assert.Equal(t, subsequentData, writer.buf.Bytes()) // Check buffer contents
finalModTime, err := getFileModTime(t, writer.file.Name())
finalModTime, err := getFileModTime(t, writer.filename)
assert.NoError(t, err)
assert.Equal(t, initialModTime, finalModTime) // File should not be modified again
}
Expand Down Expand Up @@ -229,7 +229,7 @@ func TestBufferedFileWriterClose(t *testing.T) {
assert.NoError(t, err)

if writer.file != nil {
fileContents, err := os.ReadFile(writer.file.Name())
fileContents, err := os.ReadFile(writer.filename)
assert.NoError(t, err)
assert.Equal(t, tc.expectFileContent, string(fileContents))
return
Expand Down

0 comments on commit fd41709

Please sign in to comment.