Skip to content

Commit

Permalink
Refactor plugin catalog set functions (#22666)
Browse files Browse the repository at this point in the history
Use a struct arg instead of a long list of args. Plugins running in containers
will require even more args and it's getting difficult to maintain.
  • Loading branch information
tomhjp authored Aug 31, 2023
1 parent 1acd0c6 commit 3e55447
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 62 deletions.
14 changes: 14 additions & 0 deletions sdk/helper/pluginutil/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ type PluginRunner struct {
BuiltinFactory func() (interface{}, error) `json:"-" structs:"-"`
}

// SetPluginInput is only used as input for the plugin catalog's set methods.
// We don't use the very similar PluginRunner struct to avoid confusion about
// what's settable, which does not include the builtin fields.
type SetPluginInput struct {
Name string
Type consts.PluginType
Version string
Command string
OCIImage string
Args []string
Env []string
Sha256 []byte
}

// Run takes a wrapper RunnerUtil instance along with the go-plugin parameters and
// returns a configured plugin.Client with TLS Configured and a wrapping token set
// on PluginUnwrapTokenEnv for plugin process consumption.
Expand Down
10 changes: 9 additions & 1 deletion vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,15 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica
return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err
}

err = b.Core.pluginCatalog.Set(ctx, pluginName, pluginType, pluginVersion, parts[0], args, env, sha256Bytes)
err = b.Core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{
Name: pluginName,
Type: pluginType,
Version: pluginVersion,
Command: parts[0],
Args: args,
Env: env,
Sha256: sha256Bytes,
})
if err != nil {
if errors.Is(err, ErrPluginNotFound) || strings.HasPrefix(err.Error(), "plugin version mismatch") {
return logical.ErrorResponse(err.Error()), nil
Expand Down
20 changes: 18 additions & 2 deletions vault/logical_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2189,7 +2189,15 @@ func TestSystemBackend_tuneAuth(t *testing.T) {
if err := file.Close(); err != nil {
t.Fatal(err)
}
err = c.pluginCatalog.Set(context.Background(), "token", consts.PluginTypeCredential, "v1.0.0", "foo", []string{}, []string{}, []byte{})
err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "token",
Type: consts.PluginTypeCredential,
Version: "v1.0.0",
Command: "foo",
Args: []string{},
Env: []string{},
Sha256: []byte{},
})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -5742,7 +5750,15 @@ func TestValidateVersion_HelpfulErrorWhenBuiltinOverridden(t *testing.T) {
defer file.Close()

command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), "kubernetes", consts.PluginTypeCredential, "", command, nil, nil, nil)
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "kubernetes",
Type: consts.PluginTypeCredential,
Version: "",
Command: command,
Args: nil,
Env: nil,
Sha256: nil,
})
if err != nil {
t.Fatal(err)
}
Expand Down
75 changes: 42 additions & 33 deletions vault/plugin_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,15 @@ func (c *PluginCatalog) UpgradePlugins(ctx context.Context, logger log.Logger) e
plugin.Command = filepath.Join(c.directory, plugin.Command)

// Upgrade the storage. At this point we don't know what type of plugin this is so pass in the unknown type.
runner, err := c.setInternal(ctx, pluginName, consts.PluginTypeUnknown, plugin.Version, cmdOld, plugin.Args, plugin.Env, plugin.Sha256)
runner, err := c.setInternal(ctx, pluginutil.SetPluginInput{
Name: pluginName,
Type: consts.PluginTypeUnknown,
Version: plugin.Version,
Command: cmdOld,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
})
if err != nil {
if errors.Is(err, ErrPluginBadType) {
retErr = multierror.Append(retErr, fmt.Errorf("could not upgrade plugin %s: plugin of unknown type", pluginName))
Expand Down Expand Up @@ -868,29 +876,29 @@ func (c *PluginCatalog) get(ctx context.Context, name string, pluginType consts.

// Set registers a new external plugin with the catalog, or updates an existing
// external plugin. It takes the name, command and SHA256 of the plugin.
func (c *PluginCatalog) Set(ctx context.Context, name string, pluginType consts.PluginType, version string, command string, args []string, env []string, sha256 []byte) error {
func (c *PluginCatalog) Set(ctx context.Context, plugin pluginutil.SetPluginInput) error {
if c.directory == "" {
return ErrDirectoryNotConfigured
}

switch {
case strings.Contains(name, ".."):
case strings.Contains(plugin.Name, ".."):
fallthrough
case strings.Contains(command, ".."):
case strings.Contains(plugin.Command, ".."):
return consts.ErrPathContainsParentReferences
}

c.lock.Lock()
defer c.lock.Unlock()

_, err := c.setInternal(ctx, name, pluginType, version, command, args, env, sha256)
_, err := c.setInternal(ctx, plugin)
return err
}

func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType consts.PluginType, version string, command string, args []string, env []string, sha256 []byte) (*pluginutil.PluginRunner, error) {
func (c *PluginCatalog) setInternal(ctx context.Context, plugin pluginutil.SetPluginInput) (*pluginutil.PluginRunner, error) {
// Best effort check to make sure the command isn't breaking out of the
// configured plugin directory.
commandFull := filepath.Join(c.directory, command)
commandFull := filepath.Join(c.directory, plugin.Command)
sym, err := filepath.EvalSymlinks(commandFull)
if err != nil {
return nil, fmt.Errorf("error while validating the command path: %w", err)
Expand All @@ -907,57 +915,58 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
// entryTmp should only be used for the below type and version checks, it uses the
// full command instead of the relative command.
entryTmp := &pluginutil.PluginRunner{
Name: name,
Name: plugin.Name,
Command: commandFull,
Args: args,
Env: env,
Sha256: sha256,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
Builtin: false,
}
// If the plugin type is unknown, we want to attempt to determine the type
if pluginType == consts.PluginTypeUnknown {
pluginType, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
if plugin.Type == consts.PluginTypeUnknown {
var err error
plugin.Type, err = c.getPluginTypeFromUnknown(ctx, entryTmp)
if err != nil {
return nil, err
}
if pluginType == consts.PluginTypeUnknown {
if plugin.Type == consts.PluginTypeUnknown {
return nil, ErrPluginBadType
}
}

// getting the plugin version is best-effort, so errors are not fatal
runningVersion := logical.EmptyPluginVersion
var versionErr error
switch pluginType {
switch plugin.Type {
case consts.PluginTypeSecrets, consts.PluginTypeCredential:
runningVersion, versionErr = c.getBackendRunningVersion(ctx, entryTmp)
case consts.PluginTypeDatabase:
runningVersion, versionErr = c.getDatabaseRunningVersion(ctx, entryTmp)
default:
return nil, fmt.Errorf("unknown plugin type: %v", pluginType)
return nil, fmt.Errorf("unknown plugin type: %v", plugin.Type)
}
if versionErr != nil {
c.logger.Warn("Error determining plugin version", "error", versionErr)
} else if version != "" && runningVersion.Version != "" && version != runningVersion.Version {
c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", name, "requestedVersion", version, "reportedVersion", runningVersion.Version)
return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", name, runningVersion.Version, version)
} else if version == "" && runningVersion.Version != "" {
version = runningVersion.Version
_, err := semver.NewVersion(version)
} else if plugin.Version != "" && runningVersion.Version != "" && plugin.Version != runningVersion.Version {
c.logger.Warn("Plugin self-reported version did not match requested version", "plugin", plugin.Name, "requestedVersion", plugin.Version, "reportedVersion", runningVersion.Version)
return nil, fmt.Errorf("plugin version mismatch: %s reported version (%s) did not match requested version (%s)", plugin.Name, runningVersion.Version, plugin.Version)
} else if plugin.Version == "" && runningVersion.Version != "" {
plugin.Version = runningVersion.Version
_, err := semver.NewVersion(plugin.Version)
if err != nil {
return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", version, err)
return nil, fmt.Errorf("plugin self-reported version %q is not a valid semantic version: %w", plugin.Version, err)
}

}

entry := &pluginutil.PluginRunner{
Name: name,
Type: pluginType,
Version: version,
Command: command,
Args: args,
Env: env,
Sha256: sha256,
Name: plugin.Name,
Type: plugin.Type,
Version: plugin.Version,
Command: plugin.Command,
Args: plugin.Args,
Env: plugin.Env,
Sha256: plugin.Sha256,
Builtin: false,
}

Expand All @@ -966,9 +975,9 @@ func (c *PluginCatalog) setInternal(ctx context.Context, name string, pluginType
return nil, fmt.Errorf("failed to encode plugin entry: %w", err)
}

storageKey := path.Join(pluginType.String(), name)
if version != "" {
storageKey = path.Join(storageKey, version)
storageKey := path.Join(plugin.Type.String(), plugin.Name)
if plugin.Version != "" {
storageKey = path.Join(storageKey, plugin.Version)
}
logicalEntry := logical.StorageEntry{
Key: storageKey,
Expand Down
88 changes: 63 additions & 25 deletions vault/plugin_catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,15 @@ func TestPluginCatalog_CRUD(t *testing.T) {
defer file.Close()

command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), pluginName, consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'})
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: pluginName,
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{"FOO=BAR"},
Sha256: []byte{'1'},
})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -163,7 +171,15 @@ func TestPluginCatalog_VersionedCRUD(t *testing.T) {
const name = "mysql-database-plugin"
const version = "1.0.0"
command := fmt.Sprintf("%s", filepath.Base(file.Name()))
err = core.pluginCatalog.Set(context.Background(), name, consts.PluginTypeDatabase, version, command, []string{"--test"}, []string{"FOO=BAR"}, []byte{'1'})
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: name,
Type: consts.PluginTypeDatabase,
Version: version,
Command: command,
Args: []string{"--test"},
Env: []string{"FOO=BAR"},
Sha256: []byte{'1'},
})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -270,13 +286,29 @@ func TestPluginCatalog_List(t *testing.T) {
defer file.Close()

command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(context.Background(), "mysql-database-plugin", consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{}, []byte{'1'})
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "mysql-database-plugin",
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{},
Sha256: []byte{'1'},
})
if err != nil {
t.Fatal(err)
}

// Set another plugin
err = core.pluginCatalog.Set(context.Background(), "aaaaaaa", consts.PluginTypeDatabase, "", command, []string{"--test"}, []string{}, []byte{'1'})
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "aaaaaaa",
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{},
Sha256: []byte{'1'},
})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -341,31 +373,29 @@ func TestPluginCatalog_ListVersionedPlugins(t *testing.T) {
defer file.Close()

command := filepath.Base(file.Name())
err = core.pluginCatalog.Set(
context.Background(),
"mysql-database-plugin",
consts.PluginTypeDatabase,
"",
command,
[]string{"--test"},
[]string{},
[]byte{'1'},
)
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "mysql-database-plugin",
Type: consts.PluginTypeDatabase,
Version: "",
Command: command,
Args: []string{"--test"},
Env: []string{},
Sha256: []byte{'1'},
})
if err != nil {
t.Fatal(err)
}

// Set another plugin, with version information
err = core.pluginCatalog.Set(
context.Background(),
"aaaaaaa",
consts.PluginTypeDatabase,
"1.1.0",
command,
[]string{"--test"},
[]string{},
[]byte{'1'},
)
err = core.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: "aaaaaaa",
Type: consts.PluginTypeDatabase,
Version: "1.1.0",
Command: command,
Args: []string{"--test"},
Env: []string{},
Sha256: []byte{'1'},
})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -458,7 +488,15 @@ func TestPluginCatalog_ListHandlesPluginNamesWithSlashes(t *testing.T) {
},
}
for _, entry := range pluginsToRegister {
err = core.pluginCatalog.Set(ctx, entry.Name, consts.PluginTypeCredential, entry.Version, command, nil, nil, nil)
err = core.pluginCatalog.Set(ctx, pluginutil.SetPluginInput{
Name: entry.Name,
Type: consts.PluginTypeCredential,
Version: entry.Version,
Command: command,
Args: nil,
Env: nil,
Sha256: nil,
})
if err != nil {
t.Fatal(err)
}
Expand Down
10 changes: 9 additions & 1 deletion vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,15 @@ func TestAddTestPlugin(t testing.T, c *Core, name string, pluginType consts.Plug
c.pluginCatalog.directory = fullPath

args := []string{fmt.Sprintf("--test.run=%s", testFunc)}
err = c.pluginCatalog.Set(context.Background(), name, pluginType, version, fileName, args, env, sum)
err = c.pluginCatalog.Set(context.Background(), pluginutil.SetPluginInput{
Name: name,
Type: pluginType,
Version: version,
Command: fileName,
Args: args,
Env: env,
Sha256: sum,
})
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 3e55447

Please sign in to comment.