diff --git a/book.go b/book.go index e2fec03a..4bc2552e 100644 --- a/book.go +++ b/book.go @@ -300,24 +300,39 @@ func (bk *book) parseHTTPRunnerWithDetailed(name string, b []byte) (bool, error) } r.multipartBoundary = c.MultipartBoundary if c.OpenAPI3DocLocation != "" && !strings.HasPrefix(c.OpenAPI3DocLocation, "https://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "http://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "/") { - c.OpenAPI3DocLocation = fp(c.OpenAPI3DocLocation, root) + c.OpenAPI3DocLocation, err = fp(c.OpenAPI3DocLocation, root) + if err != nil { + return false, err + } } if c.CACert != "" { - b, err := readFile(fp(c.CACert, root)) + p, err := fp(c.CACert, root) + if err != nil { + return false, err + } + b, err := readFile(p) if err != nil { return false, err } r.cacert = b } if c.Cert != "" { - b, err := readFile(fp(c.Cert, root)) + p, err := fp(c.Cert, root) + if err != nil { + return false, err + } + b, err := readFile(p) if err != nil { return false, err } r.cert = b } if c.Key != "" { - b, err := readFile(fp(c.Key, root)) + p, err := fp(c.Key, root) + if err != nil { + return false, err + } + b, err := readFile(p) if err != nil { return false, err } @@ -361,7 +376,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error) if len(c.cacert) != 0 { r.cacert = c.cacert } else if c.CACert != "" { - b, err := readFile(fp(c.CACert, root)) + p, err := fp(c.CACert, root) + if err != nil { + return false, err + } + b, err := readFile(p) if err != nil { return false, err } @@ -370,7 +389,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error) if len(c.cert) != 0 { r.cert = c.cert } else if c.Cert != "" { - b, err := readFile(fp(c.Cert, root)) + p, err := fp(c.Cert, root) + if err != nil { + return false, err + } + b, err := readFile(p) if err != nil { return false, err } @@ -379,7 +402,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error) if len(c.key) != 0 { r.key = c.key } else if c.Key != "" { - b, err := readFile(fp(c.Key, root)) + p, err := fp(c.Key, root) + if err != nil { + return false, err + } + b, err := readFile(p) if err != nil { return false, err } @@ -387,19 +414,39 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error) } r.skipVerify = c.SkipVerify for _, p := range c.ImportPaths { - r.importPaths = append(r.importPaths, fp(p, root)) + pp, err := fp(p, root) + if err != nil { + return false, err + } + r.importPaths = append(r.importPaths, pp) } for _, p := range c.Protos { - r.protos = append(r.protos, fp(p, root)) + pp, err := fp(p, root) + if err != nil { + return false, err + } + r.protos = append(r.protos, pp) } for _, p := range c.BufDirs { - r.bufDirs = append(r.bufDirs, fp(p, root)) + pp, err := fp(p, root) + if err != nil { + return false, err + } + r.bufDirs = append(r.bufDirs, pp) } for _, p := range c.BufLocks { - r.bufLocks = append(r.bufLocks, fp(p, root)) + pp, err := fp(p, root) + if err != nil { + return false, err + } + r.bufLocks = append(r.bufLocks, pp) } for _, p := range c.BufConfigs { - r.bufConfigs = append(r.bufConfigs, fp(p, root)) + pp, err := fp(p, root) + if err != nil { + return false, err + } + r.bufConfigs = append(r.bufConfigs, pp) } r.bufModules = c.BufModules r.trace = c.Trace.Enable @@ -447,9 +494,9 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error) } var opts []sshc.Option if c.SSHConfig != "" { - p := c.SSHConfig - if !filepath.IsAbs(c.SSHConfig) { - p = filepath.Join(root, c.SSHConfig) + p, err := fp(c.SSHConfig, root) + if err != nil { + return false, err } if _, err := os.Stat(p); err != nil { return false, err @@ -466,9 +513,9 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error) opts = append(opts, sshc.Port(c.Port)) } if c.IdentityFile != "" { - p := c.IdentityFile - if !filepath.IsAbs(c.IdentityFile) { - p = filepath.Join(root, c.IdentityFile) + p, err := fp(c.IdentityFile, root) + if err != nil { + return false, err } b, err := readFile(p) if err != nil { @@ -648,14 +695,6 @@ func detectSSHRunner(v any) bool { return false } -// fp returns the absolute path of root+p. -func fp(p, root string) string { - if filepath.IsAbs(p) { - return p - } - return filepath.Join(root, p) -} - func newBook() *book { return &book{ runners: map[string]any{}, diff --git a/cache.go b/cache.go index 6b67971a..349f6a57 100644 --- a/cache.go +++ b/cache.go @@ -33,7 +33,7 @@ func RemoveCacheDir() error { return os.RemoveAll(globalCacheDir) } -func cacheDir() (string, error) { +func cacheDirOrCreate() (string, error) { if globalCacheDir != "" { if _, err := os.Stat(globalCacheDir); err != nil { if err := os.MkdirAll(globalCacheDir, os.ModePerm); err != nil { @@ -49,3 +49,10 @@ func cacheDir() (string, error) { globalCacheDir = dir return dir, nil } + +func cacheDir() (string, error) { + if globalCacheDir != "" { + return globalCacheDir, nil + } + return "", fmt.Errorf("cache directory is not set") +} diff --git a/cdp.go b/cdp.go index 392045d5..1b3286c0 100644 --- a/cdp.go +++ b/cdp.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "os" - "path/filepath" "reflect" "sync" "time" @@ -222,8 +221,9 @@ func (rnr *cdpRunner) evalAction(ca CDPAction, s *step) ([]chromedp.Action, erro if !ok { return nil, fmt.Errorf("invalid action: %v", ca) } - if !filepath.IsAbs(pp) { - ca.Args["path"] = filepath.Join(o.root, pp) + ca.Args["path"], err = fp(pp, o.root) + if err != nil { + return nil, fmt.Errorf("invalid action: %v: %w", ca, err) } } diff --git a/coverage_test.go b/coverage_test.go index 5c584b27..632279d3 100644 --- a/coverage_test.go +++ b/coverage_test.go @@ -25,7 +25,7 @@ func TestCoverage(t *testing.T) { tt := tt t.Run(tt.book, func(t *testing.T) { t.Parallel() - o, err := New(Book(tt.book)) + o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent)) if err != nil { t.Fatal(err) } diff --git a/debugger_test.go b/debugger_test.go index 8e427bd7..20286f2e 100644 --- a/debugger_test.go +++ b/debugger_test.go @@ -44,7 +44,7 @@ func TestDebugger(t *testing.T) { DBRunner("db", db), Capture(NewDebugger(out)), Var("url", hs.URL), - Scopes(ScopeAllowRunExec), + Scopes(ScopeAllowRunExec, ScopeAllowReadParent), } o, err := New(opts...) if err != nil { diff --git a/hosts_test.go b/hosts_test.go index 1b99e8d0..fd7e51b8 100644 --- a/hosts_test.go +++ b/hosts_test.go @@ -23,7 +23,7 @@ func TestHostRules(t *testing.T) { for _, tt := range tests { t.Run(tt.book, func(t *testing.T) { tr.ClearRequests() - o, err := New(Book(tt.book)) + o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent)) if err != nil { t.Fatal(err) return diff --git a/http.go b/http.go index ae971543..70df6d76 100644 --- a/http.go +++ b/http.go @@ -160,7 +160,11 @@ func (r *httpRequest) encodeBody() (io.Reader, error) { if !ok { return nil, fmt.Errorf("invalid body: %v", r.body) } - b, err := readFile(filepath.Join(r.root, fileName)) + p, err := fp(fileName, r.root) + if err != nil { + return nil, err + } + b, err := readFile(p) if err != nil { return nil, err } @@ -253,7 +257,11 @@ func (r *httpRequest) encodeMultipart() (io.Reader, error) { default: return nil, fmt.Errorf("invalid body: %v", r.body) } - b, err := readFile(filepath.Join(r.root, fileName)) + p, err := fp(fileName, r.root) + if err != nil { + return nil, err + } + b, err := readFile(p) patherr := &fs.PathError{} if err != nil && !errors.Is(err, os.ErrNotExist) && !errors.As(err, &patherr) { return nil, err diff --git a/include.go b/include.go index b0fa99a6..3f326c01 100644 --- a/include.go +++ b/include.go @@ -3,7 +3,6 @@ package runn import ( "context" "errors" - "path/filepath" ) const includeRunnerKey = "include" @@ -65,13 +64,17 @@ func (rnr *includeRunner) Run(ctx context.Context, s *step) error { } rnr.runResults = nil + var err error ipath := rnr.path if ipath == "" { ipath = c.path } // ipath must not be variable expanded. Because it will be impossible to identify the step of the included runbook in case of run failure. if !hasRemotePrefix(ipath) { - ipath = filepath.Join(o.root, ipath) + ipath, err = fp(ipath, o.root) + if err != nil { + return err + } } // Store before record diff --git a/operator.go b/operator.go index cacb0d18..65949b86 100644 --- a/operator.go +++ b/operator.go @@ -10,7 +10,6 @@ import ( "math/rand" "net/http" "os" - "path/filepath" "sort" "strings" "sync" @@ -480,11 +479,19 @@ func New(opts ...Option) (*operator, error) { } op.root = root + var loErr error op.needs = lo.MapEntries(bk.needs, func(key string, path string) (string, *need) { + p, err := fp(path, op.root) + if err != nil { + loErr = errors.Join(loErr, err) + } return key, &need{ - path: filepath.Join(op.root, path), + path: p, } }) + if loErr != nil { + return nil, loErr + } // The host rules specified by the option take precedence. hostRules := append(bk.hostRulesFromOpts, bk.hostRules...) @@ -1766,7 +1773,11 @@ func (opn *operatorN) traverseOperators(op *operator) error { if opn.skipIncluded { for _, s := range op.steps { if s.includeRunner != nil && s.includeConfig != nil { - opn.included = append(opn.included, filepath.Join(op.root, s.includeConfig.path)) + p, err := fp(s.includeConfig.path, op.root) + if err != nil { + return err + } + opn.included = append(opn.included, p) } } } diff --git a/operator_test.go b/operator_test.go index 07f3ef8c..7f8d29d6 100644 --- a/operator_test.go +++ b/operator_test.go @@ -1042,7 +1042,7 @@ func TestHttp(t *testing.T) { t.Run(tt.book, func(t *testing.T) { ts := testutil.HTTPServer(t) t.Setenv("TEST_HTTP_ENDPOINT", ts.URL) - o, err := New(Book(tt.book)) + o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent)) if err != nil { t.Fatal(err) } @@ -1091,7 +1091,7 @@ func TestGrpcWithoutReflection(t *testing.T) { ctx, cancel := donegroup.WithCancel(context.Background()) t.Cleanup(cancel) t.Parallel() - o, err := New(Book(tt.book)) + o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent)) if err != nil { t.Fatal(err) } @@ -1303,6 +1303,7 @@ func TestTrace(t *testing.T) { id := "1234567890" opts := []Option{ Book(tt.book), + Scopes(ScopeAllowReadParent), GrpcRunner("greq", tg.Conn()), Capture(NewDebugger(buf)), Trace(true), diff --git a/option.go b/option.go index b6df3f16..2481c79d 100644 --- a/option.go +++ b/option.go @@ -292,24 +292,39 @@ func HTTPRunner(name, endpoint string, client *http.Client, opts ...httpRunnerOp } r.multipartBoundary = c.MultipartBoundary if c.OpenAPI3DocLocation != "" && !strings.HasPrefix(c.OpenAPI3DocLocation, "https://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "http://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "/") { - c.OpenAPI3DocLocation = fp(c.OpenAPI3DocLocation, root) + c.OpenAPI3DocLocation, err = fp(c.OpenAPI3DocLocation, root) + if err != nil { + return err + } } if c.CACert != "" { - b, err := readFile(fp(c.CACert, root)) + p, err := fp(c.CACert, root) + if err != nil { + return err + } + b, err := readFile(p) if err != nil { return err } r.cacert = b } if c.Cert != "" { - b, err := readFile(fp(c.Cert, root)) + p, err := fp(c.Cert, root) + if err != nil { + return err + } + b, err := readFile(p) if err != nil { return err } r.cert = b } if c.Key != "" { - b, err := readFile(fp(c.Key, root)) + p, err := fp(c.Key, root) + if err != nil { + return err + } + b, err := readFile(p) if err != nil { return err } diff --git a/path.go b/path.go index 5af28c10..c6fb5ca5 100644 --- a/path.go +++ b/path.go @@ -31,6 +31,45 @@ const ( prefixGist = schemeGist + "://" ) +// fp returns the absolute path of root+p. +func fp(p, root string) (string, error) { + if filepath.IsAbs(p) { + cd, err := cacheDir() + if err == nil { + if strings.HasPrefix(p, cd) { + globalScopes.mu.RLock() + if !globalScopes.readRemote { + globalScopes.mu.RUnlock() + return "", fmt.Errorf("scope error: remote file not allowed. 'read:remote' scope is required : %s", p) + } + globalScopes.mu.RUnlock() + return p, nil + } + } + rel, err := filepath.Rel(root, p) + if err != nil || strings.Contains(rel, "..") { + globalScopes.mu.RLock() + if !globalScopes.readParent { + globalScopes.mu.RUnlock() + return "", fmt.Errorf("scope error: parent directory not allowed. 'read:parent' scope is required : %s", p) + } + globalScopes.mu.RUnlock() + } + return p, nil + } + rel, err := filepath.Rel(root, filepath.Join(root, p)) + if err != nil || strings.Contains(rel, "..") { + globalScopes.mu.RLock() + if !globalScopes.readParent { + globalScopes.mu.RUnlock() + return "", fmt.Errorf("scope error: parent directory not allowed. 'read:parent' scope is required : %s", p) + } + globalScopes.mu.RUnlock() + } + + return filepath.Join(root, p), nil +} + // hasRemotePrefix returns true if the path has remote file prefix. func hasRemotePrefix(u string) bool { return strings.HasPrefix(u, prefixHttps) || strings.HasPrefix(u, prefixGitHub) || strings.HasPrefix(u, prefixGist) @@ -60,7 +99,7 @@ func ShortenPath(p string) string { // If the file paths are remote files, it fetches them and returns their local cache paths. func fetchPaths(pathp string) ([]string, error) { var paths []string - listp := splitList(pathp) + listp := splitPathList(pathp) for _, pp := range listp { base, pattern := doublestar.SplitPattern(filepath.ToSlash(pp)) switch { @@ -181,7 +220,8 @@ func fetchPath(path string) (string, error) { func readFile(p string) ([]byte, error) { fi, err := os.Stat(p) if err == nil { - if globalCacheDir != "" && strings.HasPrefix(p, globalCacheDir) { + cd, err := cacheDir() + if err == nil && strings.HasPrefix(p, cd) { // Read cache file globalScopes.mu.RLock() if !globalScopes.readRemote { @@ -219,7 +259,8 @@ func readFile(p string) ([]byte, error) { // Read local file return os.ReadFile(p) } - if globalCacheDir == "" || !strings.HasPrefix(p, globalCacheDir) { + cd, errr := cacheDir() + if errr != nil || !strings.HasPrefix(p, cd) { // Not cache file return nil, err } @@ -232,7 +273,7 @@ func readFile(p string) ([]byte, error) { globalScopes.mu.RUnlock() // Re-fetch remote file and create cache - cachePath, err := filepath.Rel(globalCacheDir, p) + cachePath, err := filepath.Rel(cd, p) if err != nil { return nil, err } @@ -287,11 +328,11 @@ func fetchPathViaHTTPS(urlstr string) (string, error) { return "", err } defer res.Body.Close() - cd, err := cacheDir() + ep, err := urlfilepath.Encode(u) if err != nil { return "", err } - ep, err := urlfilepath.Encode(u) + cd, err := cacheDirOrCreate() if err != nil { return "", err } @@ -312,15 +353,15 @@ func fetchPathViaHTTPS(urlstr string) (string, error) { func fetchPathsViaGitHub(fsys fs.FS, base, pattern string) ([]string, error) { var paths []string - cd, err := cacheDir() + u, err := url.Parse(base) if err != nil { return nil, err } - u, err := url.Parse(base) + ep, err := urlfilepath.Encode(u) if err != nil { return nil, err } - ep, err := urlfilepath.Encode(u) + cd, err := cacheDirOrCreate() if err != nil { return nil, err } @@ -399,7 +440,7 @@ func fetchPathViaGist(urlstr string) (string, error) { return "", fmt.Errorf("invalid filename: %s", filename) } } - cd, err := cacheDir() + cd, err := cacheDirOrCreate() if err != nil { return "", err } @@ -467,8 +508,8 @@ func readFileViaGitHub(urlstr string) ([]byte, error) { return io.ReadAll(f) } -// splitList splits the path list by os.PathListSeparator while keeping schemes. -func splitList(pathp string) []string { +// splitPathList splits the path list by os.PathListSeparator while keeping schemes. +func splitPathList(pathp string) []string { rep := strings.NewReplacer(prefixHttps, repKey(prefixHttps), prefixGitHub, repKey(prefixGitHub), prefixGist, repKey(prefixGist)) per := strings.NewReplacer(repKey(prefixHttps), prefixHttps, repKey(prefixGitHub), prefixGitHub, repKey(prefixGist), prefixGist) var listp []string diff --git a/path_test.go b/path_test.go index 168b3db0..a98ccdf1 100644 --- a/path_test.go +++ b/path_test.go @@ -2,9 +2,126 @@ package runn import ( "os" + "path/filepath" "testing" ) +func TestFp(t *testing.T) { + currentGlobalCacheDir := globalCacheDir + globalCacheDir = t.TempDir() + t.Cleanup(func() { + globalCacheDir = currentGlobalCacheDir + }) + root, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + tests := []struct { + name string + p string + readRemote bool + readParent bool + want string + wantErr bool + }{ + { + "Join root and path", + "path/to/book.yml", + false, + false, + filepath.Join(root, "path/to/book.yml"), + false, + }, + { + "scope `read:parent` error", + "/path/to/book.yml", + false, + false, + "", + true, + }, + { + "allow scope `read:parent`", + "/path/to/book.yml", + false, + true, + "/path/to/book.yml", + false, + }, + { + "Join root and path with relative path", + "path/../book.yml", + false, + false, + filepath.Join(root, "book.yml"), + false, + }, + { + "scope `read:parent` error with relative path", + "../book.yml", + false, + false, + "", + true, + }, + { + "allow scope `read:parent` with relative path", + "../book.yml", + false, + true, + filepath.Join(filepath.Dir(root), "book.yml"), + false, + }, + { + "scope `read:remote` error", + filepath.Join(globalCacheDir, "path/to/book.yml"), + false, + true, + "", + true, + }, + { + "allow scope `read:remote`", + filepath.Join(globalCacheDir, "path/to/book.yml"), + true, + false, + filepath.Join(globalCacheDir, "path/to/book.yml"), + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + globalScopes.mu.Lock() + globalScopes.readParent = tt.readParent + globalScopes.readRemote = tt.readRemote + globalScopes.mu.Unlock() + t.Cleanup(func() { + globalScopes.mu.Lock() + globalScopes.readParent = false + globalScopes.readRemote = false + globalScopes.mu.Unlock() + }) + + got, err := fp(tt.p, root) + if err != nil { + if !tt.wantErr { + t.Errorf("got %v", err) + } + return + } + if tt.wantErr { + t.Errorf("want error") + return + } + + if got != tt.want { + t.Errorf("got %v\nwant %v", got, tt.want) + } + }) + } + +} + func TestFetchPaths(t *testing.T) { tests := []struct { pathp string @@ -38,11 +155,17 @@ func TestFetchPaths(t *testing.T) { {"gist://def6fa739fba3fcf211b018f41630adc/book.yml", 1, false}, }...) } + globalScopes.mu.RLock() + globalScopes.readRemote = true + globalScopes.mu.RUnlock() t.Cleanup(func() { if err := RemoveCacheDir(); err != nil { t.Fatal(err) } + globalScopes.mu.RLock() + globalScopes.readRemote = false + globalScopes.mu.RUnlock() }) for _, tt := range tests { t.Run(tt.pathp, func(t *testing.T) { diff --git a/runbook.go b/runbook.go index a9c43ee5..7ddc89df 100644 --- a/runbook.go +++ b/runbook.go @@ -576,6 +576,9 @@ func detectRunbookAreas(in string) *areas { if err != nil { return a } + if len(parsed.Docs) == 0 { + return a + } m, ok := parsed.Docs[0].Body.(*ast.MappingNode) if !ok { return a