diff --git a/pkg/mapfs/file.go b/pkg/mapfs/file.go index d0406372a3df..b19cee041162 100644 --- a/pkg/mapfs/file.go +++ b/pkg/mapfs/file.go @@ -36,8 +36,7 @@ func (f *file) Open(name string) (fs.File, error) { return f.open() } - // TODO: support directory - if sub, err := f.getFile(name); err == nil && !sub.stat.IsDir() { + if sub, err := f.getFile(name); err == nil { return sub.open() } @@ -49,16 +48,26 @@ func (f *file) Open(name string) (fs.File, error) { } func (f *file) open() (fs.File, error) { - // virtual file - if len(f.data) != 0 { + switch { + case f.stat.IsDir(): // Directory + entries, err := f.ReadDir(".") + if err != nil { + return nil, xerrors.Errorf("read dir error: %w", err) + } + return &mapDir{ + path: f.path, + fileStat: f.stat, + entry: entries, + }, nil + case len(f.data) != 0: // Virtual file return &openMapFile{ path: f.stat.name, file: f, offset: 0, }, nil + default: // Real file + return os.Open(f.path) } - // real file - return os.Open(f.path) } func (f *file) Remove(name string) error { @@ -213,6 +222,9 @@ func (f *file) WriteFile(path, underlyingPath string) error { } func (f *file) WriteVirtualFile(path string, data []byte, mode fs.FileMode) error { + if mode&fs.ModeDir != 0 { + return xerrors.Errorf("invalid perm: %v", mode) + } parts := strings.Split(path, separator) if len(parts) == 1 { @@ -287,7 +299,11 @@ func (f *openMapFile) Read(b []byte) (int, error) { return 0, io.EOF } if f.offset < 0 { - return 0, &fs.PathError{Op: "read", Path: f.path, Err: fs.ErrInvalid} + return 0, &fs.PathError{ + Op: "read", + Path: f.path, + Err: fs.ErrInvalid, + } } n := copy(b, f.file.data[f.offset:]) f.offset += int64(n) @@ -304,7 +320,11 @@ func (f *openMapFile) Seek(offset int64, whence int) (int64, error) { offset += int64(len(f.file.data)) } if offset < 0 || offset > int64(len(f.file.data)) { - return 0, &fs.PathError{Op: "seek", Path: f.path, Err: fs.ErrInvalid} + return 0, &fs.PathError{ + Op: "seek", + Path: f.path, + Err: fs.ErrInvalid, + } } f.offset = offset return offset, nil @@ -312,7 +332,11 @@ func (f *openMapFile) Seek(offset int64, whence int) (int64, error) { func (f *openMapFile) ReadAt(b []byte, offset int64) (int, error) { if offset < 0 || offset > int64(len(f.file.data)) { - return 0, &fs.PathError{Op: "read", Path: f.path, Err: fs.ErrInvalid} + return 0, &fs.PathError{ + Op: "read", + Path: f.path, + Err: fs.ErrInvalid, + } } n := copy(b, f.file.data[offset:]) if n < len(b) { @@ -320,3 +344,37 @@ func (f *openMapFile) ReadAt(b []byte, offset int64) (int, error) { } return n, nil } + +// A mapDir is a directory fs.File (so also an fs.ReadDirFile) open for reading. +type mapDir struct { + path string + fileStat + entry []fs.DirEntry + offset int +} + +func (d *mapDir) Stat() (fs.FileInfo, error) { return &d.fileStat, nil } +func (d *mapDir) Close() error { return nil } +func (d *mapDir) Read(_ []byte) (int, error) { + return 0, &fs.PathError{ + Op: "read", + Path: d.path, + Err: fs.ErrInvalid, + } +} + +func (d *mapDir) ReadDir(count int) ([]fs.DirEntry, error) { + n := len(d.entry) - d.offset + if n == 0 && count > 0 { + return nil, io.EOF + } + if count > 0 && n > count { + n = count + } + list := make([]fs.DirEntry, n) + for i := range list { + list[i] = d.entry[d.offset+i] + } + d.offset += n + return list, nil +} diff --git a/pkg/mapfs/fs.go b/pkg/mapfs/fs.go index f997984cbd4e..eea07c162030 100644 --- a/pkg/mapfs/fs.go +++ b/pkg/mapfs/fs.go @@ -5,6 +5,7 @@ import ( "io/fs" "os" "path/filepath" + "strings" "time" "golang.org/x/exp/slices" @@ -188,5 +189,6 @@ func (m *FS) RemoveAll(path string) error { func cleanPath(path string) string { path = filepath.Clean(path) path = filepath.ToSlash(path) + path = strings.TrimLeft(path, "/") // Remove the leading slash return path } diff --git a/pkg/mapfs/fs_test.go b/pkg/mapfs/fs_test.go index 638feed88dc0..22b659d7f387 100644 --- a/pkg/mapfs/fs_test.go +++ b/pkg/mapfs/fs_test.go @@ -281,7 +281,10 @@ func TestFS_Open(t *testing.T) { { name: "dir", filePath: "a/b/c", - wantErr: assert.Error, + want: file{ + fileInfo: cdirFileInfo, + }, + wantErr: assert.NoError, }, { name: "no such file", @@ -306,9 +309,11 @@ func TestFS_Open(t *testing.T) { require.NoError(t, err) assertFileInfo(t, tt.want.fileInfo, fi) - b, err := io.ReadAll(f) - require.NoError(t, err) - assert.Equal(t, tt.want.body, string(b)) + if tt.want.body != "" { + b, err := io.ReadAll(f) + require.NoError(t, err) + assert.Equal(t, tt.want.body, string(b)) + } }) } } diff --git a/pkg/module/memfs.go b/pkg/module/memfs.go index 6dceb4ce65fa..f8cfa1801c06 100644 --- a/pkg/module/memfs.go +++ b/pkg/module/memfs.go @@ -8,13 +8,13 @@ import ( "golang.org/x/xerrors" dio "github.com/aquasecurity/go-dep-parser/pkg/io" - "github.com/aquasecurity/memoryfs" + "github.com/aquasecurity/trivy/pkg/mapfs" ) -// memFS is a wrapper of memoryfs.FS and can change its underlying file system +// memFS is a wrapper of mapfs.FS and can change its underlying file system // at runtime. This implements fs.FS. type memFS struct { - current *memoryfs.FS + current *mapfs.FS } // Open implements fs.FS. @@ -29,18 +29,20 @@ func (m *memFS) Open(name string) (fs.File, error) { // // Note: it is always to safe swap the underlying FS with this API since this is called only at the beginning of // Analyze interface call, which is not concurrently called per module instance. -func (m *memFS) initialize(filePath string, content dio.ReadSeekerAt) (err error) { - memfs := memoryfs.New() - if err = memfs.MkdirAll(filepath.Dir(filePath), fs.ModePerm); err != nil { - return xerrors.Errorf("memory fs mkdir error: %w", err) +func (m *memFS) initialize(filePath string, content dio.ReadSeekerAt) error { + mfs := mapfs.New() + if err := mfs.MkdirAll(filepath.Dir(filePath), fs.ModePerm); err != nil { + return xerrors.Errorf("mapfs mkdir error: %w", err) } - err = memfs.WriteLazyFile(filePath, func() (io.Reader, error) { - return content, nil - }, fs.ModePerm) + b, err := io.ReadAll(content) if err != nil { - return xerrors.Errorf("memory fs write error: %w", err) + return xerrors.Errorf("read error: %w", err) + } + err = mfs.WriteVirtualFile(filePath, b, fs.ModePerm) + if err != nil { + return xerrors.Errorf("mapfs write error: %w", err) } - m.current = memfs - return + m.current = mfs + return nil }