diff --git a/pkg/images/pull.go b/pkg/images/pull.go index 55a6cdac..7440bcb2 100644 --- a/pkg/images/pull.go +++ b/pkg/images/pull.go @@ -5,6 +5,7 @@ package images import ( "archive/tar" + "bufio" "context" "errors" "fmt" @@ -121,41 +122,66 @@ func handleTarObject(ctx context.Context, tr *tar.Reader, hdr *tar.Header, conf } case tar.TypeReg: compressed := strings.HasSuffix(dstPath, ".zst") - if compressed { - image = strings.TrimSuffix(dstPath, ".zst") - if _, err := os.Stat(image); err == nil { - return image, os.ErrExist - } - - cmd := exec.CommandContext(ctx, "zstd", "-d", "-", "-o", image) - cmd.Stdin = tr - if _, err := cmd.Output(); err != nil { - var e *exec.ExitError - if errors.As(err, &e) { - fmt.Fprintf(os.Stderr, string(e.Stderr)) - } - return image, fmt.Errorf("failed during zst decompression of %s: %w", hdr.Name, err) - } + // Copy the target file out to the host (compressed or not). + // On failure, clean up the temporary file; otherwise move it + // to the target destination. + tmpFile, err := os.CreateTemp("", filepath.Base(hdr.Name)) + if err != nil { + return image, fmt.Errorf("failed to open temporary file for %s: %w", hdr.Name, err) } - if conf.Cache || !compressed { - dst, err := os.Create(dstPath) - if err != nil { - return image, fmt.Errorf("failed to create file %s: %w", dstPath, err) + defer func() { + tmpPath := tmpFile.Name() + tmpFile.Close() + if err != nil || (!conf.Cache && compressed) { + os.Remove(tmpPath) + return } - defer dst.Close() - - n, err := io.CopyN(dst, tr, hdr.Size) - if err != nil { - return image, fmt.Errorf("failed to copy %s from container %s: %w", dstPath, containerID, err) + if err = os.Rename(tmpFile.Name(), dstPath); err != nil { + fmt.Fprintf(os.Stderr, "Failed to move %s to %s: %s", tmpPath, dstPath, err) } - if n != hdr.Size { - return image, fmt.Errorf("tar header reports file %s size %d, but only %d bytes were pulled", hdr.Name, hdr.Size, n) + }() + + n, err := io.CopyN(tmpFile, tr, hdr.Size) + if err != nil { + return image, fmt.Errorf("failed to copy %s from container %s: %w", dstPath, containerID, err) + } + if n != hdr.Size { + return image, fmt.Errorf("tar header reports file %s size %d, but only %d bytes were pulled", hdr.Name, hdr.Size, n) + } + + if compressed { + if _, err = tmpFile.Seek(0, 0); err != nil { + return image, fmt.Errorf("cannot seek to the start of the compressed target file %s: %w", dstPath, err) } + compressedTarget := bufio.NewReader(tmpFile) + dstImagePath := strings.TrimSuffix(dstPath, ".zst") + + return extractZst(ctx, compressedTarget, dstImagePath) } + default: return image, fmt.Errorf("unexpected tar header type %d", hdr.Typeflag) } return image, nil } + +func extractZst(ctx context.Context, reader io.Reader, dstPath string) (image string, err error) { + if _, err := os.Stat(dstPath); err == nil { + return dstPath, os.ErrExist + } + + cmd := exec.CommandContext(ctx, "zstd", "-d", "-", "-o", dstPath) + cmd.Stdin = reader + + if _, err := cmd.Output(); err != nil { + var e *exec.ExitError + if errors.As(err, &e) { + fmt.Fprintf(os.Stderr, string(e.Stderr)) + } + return dstPath, fmt.Errorf("failed during zst decompression to %s: %w", dstPath, err) + } + + return dstPath, nil +}