diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 5ebcefce1c..a2525c9675 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -169,10 +169,17 @@ func DetectCompression(source []byte) Compression { } // DecompressStream decompresses the archive and returns a ReaderCloser with the decompressed archive. -func DecompressStream(archive io.Reader) (io.ReadCloser, error) { +func DecompressStream(archive io.Reader) (_ io.ReadCloser, Err error) { p := pools.BufioReader32KPool buf := p.Get(archive) bs, err := buf.Peek(10) + + defer func() { + if Err != nil { + p.Put(buf) + } + }() + if err != nil && err != io.EOF { // Note: we'll ignore any io.EOF error because there are some odd // cases where the layer.tar file will be empty (zero bytes) and @@ -189,6 +196,12 @@ func DecompressStream(archive io.Reader) (io.ReadCloser, error) { readBufWrapper := p.NewReadCloserWrapper(buf, buf) return readBufWrapper, nil case Gzip: + cleanup := func() { + p.Put(buf) + } + if rc, canUse := tryProcFilter([]string{"pigz", "-d"}, buf, cleanup); canUse { + return rc, nil + } gzReader, err := gzip.NewReader(buf) if err != nil { return nil, err @@ -207,6 +220,12 @@ func DecompressStream(archive io.Reader) (io.ReadCloser, error) { readBufWrapper := p.NewReadCloserWrapper(buf, xzReader) return readBufWrapper, nil case Zstd: + cleanup := func() { + p.Put(buf) + } + if rc, canUse := tryProcFilter([]string{"zstd", "-d"}, buf, cleanup); canUse { + return rc, nil + } return zstdReader(buf) default: return nil, fmt.Errorf("unsupported compression format %s", (&compression).Extension()) @@ -214,9 +233,16 @@ func DecompressStream(archive io.Reader) (io.ReadCloser, error) { } // CompressStream compresses the dest with specified compression algorithm. -func CompressStream(dest io.Writer, compression Compression) (io.WriteCloser, error) { +func CompressStream(dest io.Writer, compression Compression) (_ io.WriteCloser, Err error) { p := pools.BufioWriter32KPool buf := p.Get(dest) + + defer func() { + if Err != nil { + p.Put(buf) + } + }() + switch compression { case Uncompressed: writeBufWrapper := p.NewWriteCloserWrapper(buf, buf) diff --git a/pkg/archive/filter.go b/pkg/archive/filter.go new file mode 100644 index 0000000000..e63d72e6aa --- /dev/null +++ b/pkg/archive/filter.go @@ -0,0 +1,55 @@ +package archive + +import ( + "bytes" + "fmt" + "io" + "os/exec" + "strings" + "sync" +) + +var filterPath sync.Map + +func getFilterPath(name string) string { + path, ok := filterPath.Load(name) + if ok { + return path.(string) + } + + path, err := exec.LookPath(name) + if err != nil { + path = "" + } + + filterPath.Store(name, path) + return path.(string) +} + +// tryProcFilter tries to run the command specified in args, passing input to its stdin and returning its stdout. +// cleanup() is a caller provided function that will be called when the command finishes running, regardless of +// whether it succeeds or fails. +// If the command is not found, it returns (nil, false) and the cleanup function is not called. +func tryProcFilter(args []string, input io.Reader, cleanup func()) (io.ReadCloser, bool) { + path := getFilterPath(args[0]) + if path == "" { + return nil, false + } + + var stderrBuf bytes.Buffer + + r, w := io.Pipe() + cmd := exec.Command(path, args[1:]...) + cmd.Stdin = input + cmd.Stdout = w + cmd.Stderr = &stderrBuf + go func() { + err := cmd.Run() + if err != nil && stderrBuf.Len() > 0 { + err = fmt.Errorf("%s: %w", strings.TrimRight(stderrBuf.String(), "\n"), err) + } + w.CloseWithError(err) // CloseWithErr(nil) == Close() + cleanup() + }() + return r, true +} diff --git a/pkg/archive/filter_test.go b/pkg/archive/filter_test.go new file mode 100644 index 0000000000..57524fef64 --- /dev/null +++ b/pkg/archive/filter_test.go @@ -0,0 +1,58 @@ +package archive + +import ( + "bufio" + "bytes" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTryProcFilter(t *testing.T) { + t.Run("Invalid filter path", func(t *testing.T) { + args := []string{"does-not-exist"} + input := bufio.NewReader(bytes.NewBufferString("foo")) + result, ok := tryProcFilter(args, input, func() {}) + assert.Nil(t, result) + assert.False(t, ok) + }) + + t.Run("Valid filter path", func(t *testing.T) { + inputData := "input data" + + args := []string{"cat", "-"} + input := bufio.NewReader(bytes.NewBufferString(inputData)) + + result, ok := tryProcFilter(args, input, func() {}) + assert.NotNil(t, result) + assert.True(t, ok) + + output, err := io.ReadAll(result) + require.NoError(t, err) + assert.Equal(t, inputData, string(output)) + }) + + t.Run("Filter fails with error", func(t *testing.T) { + inputData := "input data" + + var cleanedUp atomic.Bool + + args := []string{"sh", "-c", "echo 'oh no' 1>&2; exit 21"} + input := bufio.NewReader(bytes.NewBufferString(inputData)) + + result, ok := tryProcFilter(args, input, func() { cleanedUp.Store(true) }) + assert.NotNil(t, result) + assert.True(t, ok) + + _, err := io.ReadAll(result) + require.Error(t, err) + assert.Contains(t, err.Error(), "oh no: exit status 21") + assert.Eventually(t, func() bool { + return cleanedUp.Load() + }, 5*time.Second, 10*time.Millisecond, "clean up function was not called") + }) +}