Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow absolute paths and enhance scopes detection #1080

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 65 additions & 26 deletions book.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,24 +300,39 @@ func (bk *book) parseHTTPRunnerWithDetailed(name string, b []byte) (bool, error)
}
r.multipartBoundary = c.MultipartBoundary
if c.OpenAPI3DocLocation != "" && !strings.HasPrefix(c.OpenAPI3DocLocation, "https://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "http://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "/") {
c.OpenAPI3DocLocation = fp(c.OpenAPI3DocLocation, root)
c.OpenAPI3DocLocation, err = fp(c.OpenAPI3DocLocation, root)
if err != nil {
return false, err
}
}
if c.CACert != "" {
b, err := readFile(fp(c.CACert, root))
p, err := fp(c.CACert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
r.cacert = b
}
if c.Cert != "" {
b, err := readFile(fp(c.Cert, root))
p, err := fp(c.Cert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
r.cert = b
}
if c.Key != "" {
b, err := readFile(fp(c.Key, root))
p, err := fp(c.Key, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -361,7 +376,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error)
if len(c.cacert) != 0 {
r.cacert = c.cacert
} else if c.CACert != "" {
b, err := readFile(fp(c.CACert, root))
p, err := fp(c.CACert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
Expand All @@ -370,7 +389,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error)
if len(c.cert) != 0 {
r.cert = c.cert
} else if c.Cert != "" {
b, err := readFile(fp(c.Cert, root))
p, err := fp(c.Cert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
Expand All @@ -379,27 +402,51 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error)
if len(c.key) != 0 {
r.key = c.key
} else if c.Key != "" {
b, err := readFile(fp(c.Key, root))
p, err := fp(c.Key, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
r.key = b
}
r.skipVerify = c.SkipVerify
for _, p := range c.ImportPaths {
r.importPaths = append(r.importPaths, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.importPaths = append(r.importPaths, pp)
}
for _, p := range c.Protos {
r.protos = append(r.protos, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.protos = append(r.protos, pp)
}
for _, p := range c.BufDirs {
r.bufDirs = append(r.bufDirs, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.bufDirs = append(r.bufDirs, pp)
}
for _, p := range c.BufLocks {
r.bufLocks = append(r.bufLocks, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.bufLocks = append(r.bufLocks, pp)
}
for _, p := range c.BufConfigs {
r.bufConfigs = append(r.bufConfigs, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.bufConfigs = append(r.bufConfigs, pp)
}
r.bufModules = c.BufModules
r.trace = c.Trace.Enable
Expand Down Expand Up @@ -447,9 +494,9 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error)
}
var opts []sshc.Option
if c.SSHConfig != "" {
p := c.SSHConfig
if !filepath.IsAbs(c.SSHConfig) {
p = filepath.Join(root, c.SSHConfig)
p, err := fp(c.SSHConfig, root)
if err != nil {
return false, err
}
if _, err := os.Stat(p); err != nil {
return false, err
Expand All @@ -466,9 +513,9 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error)
opts = append(opts, sshc.Port(c.Port))
}
if c.IdentityFile != "" {
p := c.IdentityFile
if !filepath.IsAbs(c.IdentityFile) {
p = filepath.Join(root, c.IdentityFile)
p, err := fp(c.IdentityFile, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
Expand Down Expand Up @@ -648,14 +695,6 @@ func detectSSHRunner(v any) bool {
return false
}

// fp returns the absolute path of root+p.
func fp(p, root string) string {
if filepath.IsAbs(p) {
return p
}
return filepath.Join(root, p)
}

func newBook() *book {
return &book{
runners: map[string]any{},
Expand Down
9 changes: 8 additions & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func RemoveCacheDir() error {
return os.RemoveAll(globalCacheDir)
}

func cacheDir() (string, error) {
func cacheDirOrCreate() (string, error) {
if globalCacheDir != "" {
if _, err := os.Stat(globalCacheDir); err != nil {
if err := os.MkdirAll(globalCacheDir, os.ModePerm); err != nil {
Expand All @@ -49,3 +49,10 @@ func cacheDir() (string, error) {
globalCacheDir = dir
return dir, nil
}

func cacheDir() (string, error) {
if globalCacheDir != "" {
return globalCacheDir, nil
}
return "", fmt.Errorf("cache directory is not set")
}
6 changes: 3 additions & 3 deletions cdp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"sync"
"time"
Expand Down Expand Up @@ -222,8 +221,9 @@ func (rnr *cdpRunner) evalAction(ca CDPAction, s *step) ([]chromedp.Action, erro
if !ok {
return nil, fmt.Errorf("invalid action: %v", ca)
}
if !filepath.IsAbs(pp) {
ca.Args["path"] = filepath.Join(o.root, pp)
ca.Args["path"], err = fp(pp, o.root)
if err != nil {
return nil, fmt.Errorf("invalid action: %v: %w", ca, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestCoverage(t *testing.T) {
tt := tt
t.Run(tt.book, func(t *testing.T) {
t.Parallel()
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion debugger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestDebugger(t *testing.T) {
DBRunner("db", db),
Capture(NewDebugger(out)),
Var("url", hs.URL),
Scopes(ScopeAllowRunExec),
Scopes(ScopeAllowRunExec, ScopeAllowReadParent),
}
o, err := New(opts...)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestHostRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.book, func(t *testing.T) {
tr.ClearRequests()
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
return
Expand Down
12 changes: 10 additions & 2 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ func (r *httpRequest) encodeBody() (io.Reader, error) {
if !ok {
return nil, fmt.Errorf("invalid body: %v", r.body)
}
b, err := readFile(filepath.Join(r.root, fileName))
p, err := fp(fileName, r.root)
if err != nil {
return nil, err
}
b, err := readFile(p)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -253,7 +257,11 @@ func (r *httpRequest) encodeMultipart() (io.Reader, error) {
default:
return nil, fmt.Errorf("invalid body: %v", r.body)
}
b, err := readFile(filepath.Join(r.root, fileName))
p, err := fp(fileName, r.root)
if err != nil {
return nil, err
}
b, err := readFile(p)
patherr := &fs.PathError{}
if err != nil && !errors.Is(err, os.ErrNotExist) && !errors.As(err, &patherr) {
return nil, err
Expand Down
7 changes: 5 additions & 2 deletions include.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package runn
import (
"context"
"errors"
"path/filepath"
)

const includeRunnerKey = "include"
Expand Down Expand Up @@ -65,13 +64,17 @@ func (rnr *includeRunner) Run(ctx context.Context, s *step) error {
}
rnr.runResults = nil

var err error
ipath := rnr.path
if ipath == "" {
ipath = c.path
}
// ipath must not be variable expanded. Because it will be impossible to identify the step of the included runbook in case of run failure.
if !hasRemotePrefix(ipath) {
ipath = filepath.Join(o.root, ipath)
ipath, err = fp(ipath, o.root)
if err != nil {
return err
}
}

// Store before record
Expand Down
17 changes: 14 additions & 3 deletions operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"math/rand"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -480,11 +479,19 @@ func New(opts ...Option) (*operator, error) {
}
op.root = root

var loErr error
op.needs = lo.MapEntries(bk.needs, func(key string, path string) (string, *need) {
p, err := fp(path, op.root)
if err != nil {
loErr = errors.Join(loErr, err)
}
return key, &need{
path: filepath.Join(op.root, path),
path: p,
}
})
if loErr != nil {
return nil, loErr
}

// The host rules specified by the option take precedence.
hostRules := append(bk.hostRulesFromOpts, bk.hostRules...)
Expand Down Expand Up @@ -1766,7 +1773,11 @@ func (opn *operatorN) traverseOperators(op *operator) error {
if opn.skipIncluded {
for _, s := range op.steps {
if s.includeRunner != nil && s.includeConfig != nil {
opn.included = append(opn.included, filepath.Join(op.root, s.includeConfig.path))
p, err := fp(s.includeConfig.path, op.root)
if err != nil {
return err
}
opn.included = append(opn.included, p)
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ func TestHttp(t *testing.T) {
t.Run(tt.book, func(t *testing.T) {
ts := testutil.HTTPServer(t)
t.Setenv("TEST_HTTP_ENDPOINT", ts.URL)
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1091,7 +1091,7 @@ func TestGrpcWithoutReflection(t *testing.T) {
ctx, cancel := donegroup.WithCancel(context.Background())
t.Cleanup(cancel)
t.Parallel()
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1303,6 +1303,7 @@ func TestTrace(t *testing.T) {
id := "1234567890"
opts := []Option{
Book(tt.book),
Scopes(ScopeAllowReadParent),
GrpcRunner("greq", tg.Conn()),
Capture(NewDebugger(buf)),
Trace(true),
Expand Down
Loading
Loading