Skip to content

Commit

Permalink
fix: allow mapfs to open dirs (#3867)
Browse files Browse the repository at this point in the history
  • Loading branch information
knqyf263 authored Mar 19, 2023
1 parent 09fd299 commit 9e4b57f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 26 deletions.
76 changes: 67 additions & 9 deletions pkg/mapfs/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ func (f *file) Open(name string) (fs.File, error) {
return f.open()
}

// TODO: support directory
if sub, err := f.getFile(name); err == nil && !sub.stat.IsDir() {
if sub, err := f.getFile(name); err == nil {
return sub.open()
}

Expand All @@ -49,16 +48,26 @@ func (f *file) Open(name string) (fs.File, error) {
}

func (f *file) open() (fs.File, error) {
// virtual file
if len(f.data) != 0 {
switch {
case f.stat.IsDir(): // Directory
entries, err := f.ReadDir(".")
if err != nil {
return nil, xerrors.Errorf("read dir error: %w", err)
}
return &mapDir{
path: f.path,
fileStat: f.stat,
entry: entries,
}, nil
case len(f.data) != 0: // Virtual file
return &openMapFile{
path: f.stat.name,
file: f,
offset: 0,
}, nil
default: // Real file
return os.Open(f.path)
}
// real file
return os.Open(f.path)
}

func (f *file) Remove(name string) error {
Expand Down Expand Up @@ -213,6 +222,9 @@ func (f *file) WriteFile(path, underlyingPath string) error {
}

func (f *file) WriteVirtualFile(path string, data []byte, mode fs.FileMode) error {
if mode&fs.ModeDir != 0 {
return xerrors.Errorf("invalid perm: %v", mode)
}
parts := strings.Split(path, separator)

if len(parts) == 1 {
Expand Down Expand Up @@ -287,7 +299,11 @@ func (f *openMapFile) Read(b []byte) (int, error) {
return 0, io.EOF
}
if f.offset < 0 {
return 0, &fs.PathError{Op: "read", Path: f.path, Err: fs.ErrInvalid}
return 0, &fs.PathError{
Op: "read",
Path: f.path,
Err: fs.ErrInvalid,
}
}
n := copy(b, f.file.data[f.offset:])
f.offset += int64(n)
Expand All @@ -304,19 +320,61 @@ func (f *openMapFile) Seek(offset int64, whence int) (int64, error) {
offset += int64(len(f.file.data))
}
if offset < 0 || offset > int64(len(f.file.data)) {
return 0, &fs.PathError{Op: "seek", Path: f.path, Err: fs.ErrInvalid}
return 0, &fs.PathError{
Op: "seek",
Path: f.path,
Err: fs.ErrInvalid,
}
}
f.offset = offset
return offset, nil
}

func (f *openMapFile) ReadAt(b []byte, offset int64) (int, error) {
if offset < 0 || offset > int64(len(f.file.data)) {
return 0, &fs.PathError{Op: "read", Path: f.path, Err: fs.ErrInvalid}
return 0, &fs.PathError{
Op: "read",
Path: f.path,
Err: fs.ErrInvalid,
}
}
n := copy(b, f.file.data[offset:])
if n < len(b) {
return n, io.EOF
}
return n, nil
}

// A mapDir is a directory fs.File (so also an fs.ReadDirFile) open for reading.
type mapDir struct {
path string
fileStat
entry []fs.DirEntry
offset int
}

func (d *mapDir) Stat() (fs.FileInfo, error) { return &d.fileStat, nil }
func (d *mapDir) Close() error { return nil }
func (d *mapDir) Read(_ []byte) (int, error) {
return 0, &fs.PathError{
Op: "read",
Path: d.path,
Err: fs.ErrInvalid,
}
}

func (d *mapDir) ReadDir(count int) ([]fs.DirEntry, error) {
n := len(d.entry) - d.offset
if n == 0 && count > 0 {
return nil, io.EOF
}
if count > 0 && n > count {
n = count
}
list := make([]fs.DirEntry, n)
for i := range list {
list[i] = d.entry[d.offset+i]
}
d.offset += n
return list, nil
}
2 changes: 2 additions & 0 deletions pkg/mapfs/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io/fs"
"os"
"path/filepath"
"strings"
"time"

"golang.org/x/exp/slices"
Expand Down Expand Up @@ -188,5 +189,6 @@ func (m *FS) RemoveAll(path string) error {
func cleanPath(path string) string {
path = filepath.Clean(path)
path = filepath.ToSlash(path)
path = strings.TrimLeft(path, "/") // Remove the leading slash
return path
}
13 changes: 9 additions & 4 deletions pkg/mapfs/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ func TestFS_Open(t *testing.T) {
{
name: "dir",
filePath: "a/b/c",
wantErr: assert.Error,
want: file{
fileInfo: cdirFileInfo,
},
wantErr: assert.NoError,
},
{
name: "no such file",
Expand All @@ -306,9 +309,11 @@ func TestFS_Open(t *testing.T) {
require.NoError(t, err)
assertFileInfo(t, tt.want.fileInfo, fi)

b, err := io.ReadAll(f)
require.NoError(t, err)
assert.Equal(t, tt.want.body, string(b))
if tt.want.body != "" {
b, err := io.ReadAll(f)
require.NoError(t, err)
assert.Equal(t, tt.want.body, string(b))
}
})
}
}
Expand Down
28 changes: 15 additions & 13 deletions pkg/module/memfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import (
"golang.org/x/xerrors"

dio "github.com/aquasecurity/go-dep-parser/pkg/io"
"github.com/aquasecurity/memoryfs"
"github.com/aquasecurity/trivy/pkg/mapfs"
)

// memFS is a wrapper of memoryfs.FS and can change its underlying file system
// memFS is a wrapper of mapfs.FS and can change its underlying file system
// at runtime. This implements fs.FS.
type memFS struct {
current *memoryfs.FS
current *mapfs.FS
}

// Open implements fs.FS.
Expand All @@ -29,18 +29,20 @@ func (m *memFS) Open(name string) (fs.File, error) {
//
// Note: it is always to safe swap the underlying FS with this API since this is called only at the beginning of
// Analyze interface call, which is not concurrently called per module instance.
func (m *memFS) initialize(filePath string, content dio.ReadSeekerAt) (err error) {
memfs := memoryfs.New()
if err = memfs.MkdirAll(filepath.Dir(filePath), fs.ModePerm); err != nil {
return xerrors.Errorf("memory fs mkdir error: %w", err)
func (m *memFS) initialize(filePath string, content dio.ReadSeekerAt) error {
mfs := mapfs.New()
if err := mfs.MkdirAll(filepath.Dir(filePath), fs.ModePerm); err != nil {
return xerrors.Errorf("mapfs mkdir error: %w", err)
}
err = memfs.WriteLazyFile(filePath, func() (io.Reader, error) {
return content, nil
}, fs.ModePerm)
b, err := io.ReadAll(content)
if err != nil {
return xerrors.Errorf("memory fs write error: %w", err)
return xerrors.Errorf("read error: %w", err)
}
err = mfs.WriteVirtualFile(filePath, b, fs.ModePerm)
if err != nil {
return xerrors.Errorf("mapfs write error: %w", err)
}

m.current = memfs
return
m.current = mfs
return nil
}

0 comments on commit 9e4b57f

Please sign in to comment.