Skip to content

Commit

Permalink
Merge pull request #1080 from k1LoW/use-fp
Browse files Browse the repository at this point in the history
Allow absolute paths and enhance scopes detection
  • Loading branch information
k1LoW authored Nov 26, 2024
2 parents 9990ad7 + 4f02d17 commit ecf7547
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 58 deletions.
91 changes: 65 additions & 26 deletions book.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,24 +300,39 @@ func (bk *book) parseHTTPRunnerWithDetailed(name string, b []byte) (bool, error)
}
r.multipartBoundary = c.MultipartBoundary
if c.OpenAPI3DocLocation != "" && !strings.HasPrefix(c.OpenAPI3DocLocation, "https://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "http://") && !strings.HasPrefix(c.OpenAPI3DocLocation, "/") {
c.OpenAPI3DocLocation = fp(c.OpenAPI3DocLocation, root)
c.OpenAPI3DocLocation, err = fp(c.OpenAPI3DocLocation, root)
if err != nil {
return false, err
}
}
if c.CACert != "" {
b, err := readFile(fp(c.CACert, root))
p, err := fp(c.CACert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
r.cacert = b
}
if c.Cert != "" {
b, err := readFile(fp(c.Cert, root))
p, err := fp(c.Cert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
r.cert = b
}
if c.Key != "" {
b, err := readFile(fp(c.Key, root))
p, err := fp(c.Key, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -361,7 +376,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error)
if len(c.cacert) != 0 {
r.cacert = c.cacert
} else if c.CACert != "" {
b, err := readFile(fp(c.CACert, root))
p, err := fp(c.CACert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
Expand All @@ -370,7 +389,11 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error)
if len(c.cert) != 0 {
r.cert = c.cert
} else if c.Cert != "" {
b, err := readFile(fp(c.Cert, root))
p, err := fp(c.Cert, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
Expand All @@ -379,27 +402,51 @@ func (bk *book) parseGRPCRunnerWithDetailed(name string, b []byte) (bool, error)
if len(c.key) != 0 {
r.key = c.key
} else if c.Key != "" {
b, err := readFile(fp(c.Key, root))
p, err := fp(c.Key, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
return false, err
}
r.key = b
}
r.skipVerify = c.SkipVerify
for _, p := range c.ImportPaths {
r.importPaths = append(r.importPaths, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.importPaths = append(r.importPaths, pp)
}
for _, p := range c.Protos {
r.protos = append(r.protos, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.protos = append(r.protos, pp)
}
for _, p := range c.BufDirs {
r.bufDirs = append(r.bufDirs, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.bufDirs = append(r.bufDirs, pp)
}
for _, p := range c.BufLocks {
r.bufLocks = append(r.bufLocks, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.bufLocks = append(r.bufLocks, pp)
}
for _, p := range c.BufConfigs {
r.bufConfigs = append(r.bufConfigs, fp(p, root))
pp, err := fp(p, root)
if err != nil {
return false, err
}
r.bufConfigs = append(r.bufConfigs, pp)
}
r.bufModules = c.BufModules
r.trace = c.Trace.Enable
Expand Down Expand Up @@ -447,9 +494,9 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error)
}
var opts []sshc.Option
if c.SSHConfig != "" {
p := c.SSHConfig
if !filepath.IsAbs(c.SSHConfig) {
p = filepath.Join(root, c.SSHConfig)
p, err := fp(c.SSHConfig, root)
if err != nil {
return false, err
}
if _, err := os.Stat(p); err != nil {
return false, err
Expand All @@ -466,9 +513,9 @@ func (bk *book) parseSSHRunnerWithDetailed(name string, b []byte) (bool, error)
opts = append(opts, sshc.Port(c.Port))
}
if c.IdentityFile != "" {
p := c.IdentityFile
if !filepath.IsAbs(c.IdentityFile) {
p = filepath.Join(root, c.IdentityFile)
p, err := fp(c.IdentityFile, root)
if err != nil {
return false, err
}
b, err := readFile(p)
if err != nil {
Expand Down Expand Up @@ -648,14 +695,6 @@ func detectSSHRunner(v any) bool {
return false
}

// fp returns the absolute path of root+p.
func fp(p, root string) string {
if filepath.IsAbs(p) {
return p
}
return filepath.Join(root, p)
}

func newBook() *book {
return &book{
runners: map[string]any{},
Expand Down
9 changes: 8 additions & 1 deletion cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func RemoveCacheDir() error {
return os.RemoveAll(globalCacheDir)
}

func cacheDir() (string, error) {
func cacheDirOrCreate() (string, error) {
if globalCacheDir != "" {
if _, err := os.Stat(globalCacheDir); err != nil {
if err := os.MkdirAll(globalCacheDir, os.ModePerm); err != nil {
Expand All @@ -49,3 +49,10 @@ func cacheDir() (string, error) {
globalCacheDir = dir
return dir, nil
}

func cacheDir() (string, error) {
if globalCacheDir != "" {
return globalCacheDir, nil
}
return "", fmt.Errorf("cache directory is not set")
}
6 changes: 3 additions & 3 deletions cdp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"reflect"
"sync"
"time"
Expand Down Expand Up @@ -222,8 +221,9 @@ func (rnr *cdpRunner) evalAction(ca CDPAction, s *step) ([]chromedp.Action, erro
if !ok {
return nil, fmt.Errorf("invalid action: %v", ca)
}
if !filepath.IsAbs(pp) {
ca.Args["path"] = filepath.Join(o.root, pp)
ca.Args["path"], err = fp(pp, o.root)
if err != nil {
return nil, fmt.Errorf("invalid action: %v: %w", ca, err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion coverage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestCoverage(t *testing.T) {
tt := tt
t.Run(tt.book, func(t *testing.T) {
t.Parallel()
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion debugger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestDebugger(t *testing.T) {
DBRunner("db", db),
Capture(NewDebugger(out)),
Var("url", hs.URL),
Scopes(ScopeAllowRunExec),
Scopes(ScopeAllowRunExec, ScopeAllowReadParent),
}
o, err := New(opts...)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestHostRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.book, func(t *testing.T) {
tr.ClearRequests()
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
return
Expand Down
12 changes: 10 additions & 2 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ func (r *httpRequest) encodeBody() (io.Reader, error) {
if !ok {
return nil, fmt.Errorf("invalid body: %v", r.body)
}
b, err := readFile(filepath.Join(r.root, fileName))
p, err := fp(fileName, r.root)
if err != nil {
return nil, err
}
b, err := readFile(p)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -253,7 +257,11 @@ func (r *httpRequest) encodeMultipart() (io.Reader, error) {
default:
return nil, fmt.Errorf("invalid body: %v", r.body)
}
b, err := readFile(filepath.Join(r.root, fileName))
p, err := fp(fileName, r.root)
if err != nil {
return nil, err
}
b, err := readFile(p)
patherr := &fs.PathError{}
if err != nil && !errors.Is(err, os.ErrNotExist) && !errors.As(err, &patherr) {
return nil, err
Expand Down
7 changes: 5 additions & 2 deletions include.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package runn
import (
"context"
"errors"
"path/filepath"
)

const includeRunnerKey = "include"
Expand Down Expand Up @@ -65,13 +64,17 @@ func (rnr *includeRunner) Run(ctx context.Context, s *step) error {
}
rnr.runResults = nil

var err error
ipath := rnr.path
if ipath == "" {
ipath = c.path
}
// ipath must not be variable expanded. Because it will be impossible to identify the step of the included runbook in case of run failure.
if !hasRemotePrefix(ipath) {
ipath = filepath.Join(o.root, ipath)
ipath, err = fp(ipath, o.root)
if err != nil {
return err
}
}

// Store before record
Expand Down
17 changes: 14 additions & 3 deletions operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"math/rand"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -480,11 +479,19 @@ func New(opts ...Option) (*operator, error) {
}
op.root = root

var loErr error
op.needs = lo.MapEntries(bk.needs, func(key string, path string) (string, *need) {
p, err := fp(path, op.root)
if err != nil {
loErr = errors.Join(loErr, err)
}
return key, &need{
path: filepath.Join(op.root, path),
path: p,
}
})
if loErr != nil {
return nil, loErr
}

// The host rules specified by the option take precedence.
hostRules := append(bk.hostRulesFromOpts, bk.hostRules...)
Expand Down Expand Up @@ -1766,7 +1773,11 @@ func (opn *operatorN) traverseOperators(op *operator) error {
if opn.skipIncluded {
for _, s := range op.steps {
if s.includeRunner != nil && s.includeConfig != nil {
opn.included = append(opn.included, filepath.Join(op.root, s.includeConfig.path))
p, err := fp(s.includeConfig.path, op.root)
if err != nil {
return err
}
opn.included = append(opn.included, p)
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ func TestHttp(t *testing.T) {
t.Run(tt.book, func(t *testing.T) {
ts := testutil.HTTPServer(t)
t.Setenv("TEST_HTTP_ENDPOINT", ts.URL)
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1091,7 +1091,7 @@ func TestGrpcWithoutReflection(t *testing.T) {
ctx, cancel := donegroup.WithCancel(context.Background())
t.Cleanup(cancel)
t.Parallel()
o, err := New(Book(tt.book))
o, err := New(Book(tt.book), Scopes(ScopeAllowReadParent))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1303,6 +1303,7 @@ func TestTrace(t *testing.T) {
id := "1234567890"
opts := []Option{
Book(tt.book),
Scopes(ScopeAllowReadParent),
GrpcRunner("greq", tg.Conn()),
Capture(NewDebugger(buf)),
Trace(true),
Expand Down
Loading

0 comments on commit ecf7547

Please sign in to comment.