diff --git a/vacation/archive.go b/vacation/archive.go index 58c651fa..06a7a886 100644 --- a/vacation/archive.go +++ b/vacation/archive.go @@ -64,7 +64,7 @@ func (a Archive) Decompress(destination string) error { case "application/x-bzip2": decompressor = NewTarBzip2Archive(bufferedReader).StripComponents(a.components) case "application/zip": - decompressor = NewZipArchive(bufferedReader) + decompressor = NewZipArchive(bufferedReader).StripComponents(a.components) case "text/plain; charset=utf-8", "application/jar": destination = filepath.Join(destination, a.name) decompressor = NewNopArchive(bufferedReader) diff --git a/vacation/archive_test.go b/vacation/archive_test.go index f6ea0a9e..eca0d5a3 100644 --- a/vacation/archive_test.go +++ b/vacation/archive_test.go @@ -301,6 +301,18 @@ func testArchive(t *testing.T, context spec.G, it spec.S) { _, err = f.Write([]byte("some-file")) Expect(err).NotTo(HaveOccurred()) + _, err = zw.Create("some-dir/") + Expect(err).NotTo(HaveOccurred()) + + header = &zip.FileHeader{Name: filepath.Join("some-dir", "some-nested-file")} + header.SetMode(0644) + + f, err = zw.CreateHeader(header) + Expect(err).NotTo(HaveOccurred()) + + _, err = f.Write([]byte("nested file")) + Expect(err).NotTo(HaveOccurred()) + Expect(zw.Close()).To(Succeed()) archive = vacation.NewArchive(buffer) @@ -318,6 +330,18 @@ func testArchive(t *testing.T, context spec.G, it spec.S) { Expect(err).NotTo(HaveOccurred()) Expect(files).To(ConsistOf([]string{ filepath.Join(tempDir, "some-file"), + filepath.Join(tempDir, "some-dir"), + })) + }) + + it("unpackages the archive into the path but also strips the first component", func() { + err := archive.StripComponents(1).Decompress(tempDir) + Expect(err).NotTo(HaveOccurred()) + + files, err := filepath.Glob(filepath.Join(tempDir, "*")) + Expect(err).NotTo(HaveOccurred()) + Expect(files).To(ConsistOf([]string{ + filepath.Join(tempDir, "some-nested-file"), })) }) }) diff --git a/vacation/example_test.go b/vacation/example_test.go index c75c7132..53b3b2e5 100644 --- a/vacation/example_test.go +++ b/vacation/example_test.go @@ -191,8 +191,7 @@ func ExampleArchive_StripComponents() { // Output: // some-tar-file - // some-zip-dir/some-zip-file - // zip-file + // some-zip-file } func ExampleTarArchive() { diff --git a/vacation/zip_archive.go b/vacation/zip_archive.go index cc504cfb..b2c1ea31 100644 --- a/vacation/zip_archive.go +++ b/vacation/zip_archive.go @@ -7,11 +7,13 @@ import ( "os" "path/filepath" "sort" + "strings" ) // A ZipArchive decompresses zip files from an input stream. type ZipArchive struct { - reader io.Reader + reader io.Reader + components int } // NewZipArchive returns a new ZipArchive that reads from inputReader. @@ -65,7 +67,15 @@ func (z ZipArchive) Decompress(destination string) error { return err } - path := filepath.Join(destination, name) + fileNames := strings.Split(name, "/") + + // Checks to see if file should be written when stripping components + if len(fileNames) <= z.components { + continue + } + + // Constructs the path that conforms to the stripped components. + path := filepath.Join(append([]string{destination}, fileNames[z.components:]...)...) switch { case f.FileInfo().IsDir(): @@ -158,3 +168,10 @@ func (z ZipArchive) Decompress(destination string) error { return nil } + +// StripComponents removes the first n levels from the final decompression +// destination. +func (z ZipArchive) StripComponents(components int) ZipArchive { + z.components = components + return z +} diff --git a/vacation/zip_archive_test.go b/vacation/zip_archive_test.go index 03a1784b..59f1f5d5 100644 --- a/vacation/zip_archive_test.go +++ b/vacation/zip_archive_test.go @@ -113,6 +113,21 @@ func testZipArchive(t *testing.T, context spec.G, it spec.S) { Expect(data).To(Equal([]byte("nested file"))) }) + it("unpackages the archive into the path but also strips the first component", func() { + var err error + err = zipArchive.StripComponents(1).Decompress(tempDir) + Expect(err).ToNot(HaveOccurred()) + + files, err := filepath.Glob(fmt.Sprintf("%s/*", tempDir)) + Expect(err).NotTo(HaveOccurred()) + Expect(files).To(ConsistOf([]string{ + filepath.Join(tempDir, "some-other-dir"), + })) + + Expect(filepath.Join(tempDir, "some-other-dir")).To(BeADirectory()) + Expect(filepath.Join(tempDir, "some-other-dir", "some-file")).To(BeARegularFile()) + }) + context("failure cases", func() { context("when it fails to create a zip reader", func() { it("returns an error", func() {