Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to plug in custom mounts provider #198

Merged
merged 3 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ type Client struct {

// absolutePath if the absolute path is desired instead of the relative path.
absolutePath bool

// mountProvider provides a list of all mounts.
mountProvider mountProvider
}

// ClientInterface exports the interface for the full Vaku client.
Expand Down Expand Up @@ -153,13 +156,32 @@ func (o withAbsolutePath) apply(c *Client) error {
return nil
}

// WithMountProvider makes it possible to inject a custom method for listing mounts.
// The default method uses the sys/mounts endpoint. This requires a level of privilege that
// not all users may have.
func WithMountProvider(p mountProvider) Option {
return withMountProvider{provider: p}
}

type withMountProvider struct {
provider mountProvider
}

func (o withMountProvider) apply(c *Client) error {
c.mountProvider = o.provider
return nil
}

// NewClient returns a new Vaku Client based on the Vault API config.
func NewClient(opts ...Option) (*Client, error) {
// set defaults
client := &Client{
workers: defaultWorkers,
}
client.dc = client
client.mountProvider = defaultMountProvider{
client: client,
}

// apply options
for _, opt := range opts {
Expand Down
7 changes: 7 additions & 0 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ func TestNewClient(t *testing.T) {
WithVaultSrcClient(newDefaultVaultClient(t)),
WithVaultDstClient(newDefaultVaultClient(t)),
WithAbsolutePath(true),
WithMountProvider(&defaultMountProvider{
client: &Client{},
}),
},
want: &Client{
vc: newDefaultVaultClient(t),
Expand Down Expand Up @@ -325,12 +328,16 @@ func assertClientsEqual(t *testing.T, expected *Client, actual *Client) {
// zero out clients and assert equal
expected.vc = nil
expected.vl = nil
expected.mountProvider = nil
expected.dc.vc = nil
expected.dc.vl = nil
expected.dc.mountProvider = nil
actual.vc = nil
actual.vl = nil
actual.mountProvider = nil
actual.dc.vc = nil
actual.dc.vl = nil
actual.dc.mountProvider = nil

if expected.dc.dc != expected {
expected.dc.dc = expected
Expand Down
38 changes: 38 additions & 0 deletions api/mount_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package vaku

// Mount is a high level representation of selected fields of a
// vault mount that are relevant to vaku.
type Mount struct {
Path string
Type string
Version string
}

// mountProvider is used to get a list of all mounts that the user has access to.
type mountProvider interface {
ListMounts() ([]Mount, error)
}

// defaultMountProvider is used if no other mountProvider is supplied.
type defaultMountProvider struct {
client *Client
}

// ListMounts lists mounts using the sys/mounts endpoint.
func (p defaultMountProvider) ListMounts() ([]Mount, error) {
mounts, err := p.client.vc.Sys().ListMounts()
if err != nil {
return nil, newWrapErr("", ErrMountInfo, newWrapErr(err.Error(), ErrListMounts, nil))
}

result := make([]Mount, 0)
for mountPath, data := range mounts {
mount := Mount{
Path: mountPath,
Type: data.Type,
Version: data.Options["version"],
}
result = append(result, mount)
}
return result, nil
}
15 changes: 7 additions & 8 deletions api/mounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,20 @@ const (

// mountInfo takes a path and returns the mount path and version.
func (c *Client) mountInfo(p string) (string, mountVersion, error) {
mounts, err := c.vc.Sys().ListMounts()
mounts, err := c.mountProvider.ListMounts()
if err != nil {
return "", mv0, newWrapErr(p, ErrMountInfo, newWrapErr(err.Error(), ErrListMounts, nil))
}

for mount, data := range mounts {
for _, mount := range mounts {
// Ensure '/' so that no match on foo/bar/ when actual path is foo/barbar/
mount = EnsureFolder(mount)
if strings.HasPrefix(p, mount) {
version, ok := data.Options["version"]
if !ok {
return mount, mv0, nil
mount.Path = EnsureFolder(mount.Path)
if strings.HasPrefix(p, mount.Path) {
if mount.Version == "" {
return mount.Path, mv0, nil
}

return mount, mountStringToVersion(version), nil
return mount.Path, mountStringToVersion(mount.Version), nil
}
}

Expand Down