From 91dffedc3955a6407ac6644a0d797215b4aaabc0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 1 Feb 2018 14:30:17 -0800 Subject: [PATCH] plugins/gRPC: fix issues with reserved keywords in response data (#3881) * plugins/gRPC: fix issues with reserved keywords in response data * Add the path raw file for mock plugin * Fix panic when special paths is nil * Add tests for Listing and raw requests from plugins * Add json.Number case when decoding the status * Bump the version required for gRPC defaults * Fix test for gRPC version check --- helper/pluginutil/version.go | 4 +- helper/pluginutil/version_test.go | 12 +- http/logical.go | 33 +++++- http/plugin_test.go | 153 ++++++++++++++++++++++++++ logical/plugin/grpc_backend_server.go | 5 + logical/plugin/mock/backend.go | 1 + logical/plugin/mock/path_raw.go | 29 +++++ logical/response_util.go | 19 +++- 8 files changed, 245 insertions(+), 11 deletions(-) create mode 100644 http/plugin_test.go create mode 100644 logical/plugin/mock/path_raw.go diff --git a/helper/pluginutil/version.go b/helper/pluginutil/version.go index ec2336761b0d..e1537a697683 100644 --- a/helper/pluginutil/version.go +++ b/helper/pluginutil/version.go @@ -28,7 +28,9 @@ func GRPCSupport() bool { return true } - constraint, err := version.NewConstraint(">= 0.9.2") + // Due to some regressions on 0.9.2 & 0.9.3 we now require version 0.9.4 + // to allow the plugin framework to default to gRPC. + constraint, err := version.NewConstraint(">= 0.9.4") if err != nil { return true } diff --git a/helper/pluginutil/version_test.go b/helper/pluginutil/version_test.go index 8921f84b922d..1d04b327524e 100644 --- a/helper/pluginutil/version_test.go +++ b/helper/pluginutil/version_test.go @@ -16,18 +16,22 @@ func TestGRPCSupport(t *testing.T) { }, { "0.9.2", - true, + false, }, { - "0.9.2+ent", + "0.9.3", + false, + }, + { + "0.9.4+ent", true, }, { - "0.9.2-beta", + "0.9.4-beta", false, }, { - "0.9.3", + "0.9.4", true, }, { diff --git a/http/logical.go b/http/logical.go index bcc5ba902b84..2b09414e54a3 100644 --- a/http/logical.go +++ b/http/logical.go @@ -1,6 +1,8 @@ package http import ( + "encoding/base64" + "encoding/json" "io" "net" "net/http" @@ -200,8 +202,21 @@ func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response) retErr(w, "no status code given") return } - status, ok := statusRaw.(int) - if !ok { + + var status int + switch statusRaw.(type) { + case int: + status = statusRaw.(int) + case float64: + status = int(statusRaw.(float64)) + case json.Number: + s64, err := statusRaw.(json.Number).Float64() + if err != nil { + retErr(w, "cannot decode status code") + return + } + status = int(s64) + default: retErr(w, "cannot decode status code") return } @@ -232,8 +247,18 @@ func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response) retErr(w, "no body given") return } - body, ok = bodyRaw.([]byte) - if !ok { + + switch bodyRaw.(type) { + case string: + var err error + body, err = base64.StdEncoding.DecodeString(bodyRaw.(string)) + if err != nil { + retErr(w, "cannot decode body") + return + } + case []byte: + body = bodyRaw.([]byte) + default: retErr(w, "cannot decode body") return } diff --git a/http/plugin_test.go b/http/plugin_test.go new file mode 100644 index 000000000000..b96e8d484020 --- /dev/null +++ b/http/plugin_test.go @@ -0,0 +1,153 @@ +package http + +import ( + "io/ioutil" + "os" + "sync" + "testing" + + hclog "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + bplugin "github.com/hashicorp/vault/builtin/plugin" + "github.com/hashicorp/vault/helper/logbridge" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/plugin" + "github.com/hashicorp/vault/logical/plugin/mock" + "github.com/hashicorp/vault/physical/inmem" + "github.com/hashicorp/vault/vault" +) + +func getPluginClusterAndCore(t testing.TB, logger *logbridge.Logger) (*vault.TestCluster, *vault.TestClusterCore) { + inmha, err := inmem.NewInmemHA(nil, logger.LogxiLogger()) + if err != nil { + t.Fatal(err) + } + + coreConfig := &vault.CoreConfig{ + Physical: inmha, + LogicalBackends: map[string]logical.Factory{ + "plugin": bplugin.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: Handler, + RawLogger: logger, + }) + cluster.Start() + + cores := cluster.Cores + core := cores[0] + + os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) + + vault.TestWaitActive(t, core.Core) + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestPlugin_PluginMain") + + // Mount the mock plugin + err = core.Client.Sys().Mount("mock", &api.MountInput{ + Type: "plugin", + PluginName: "mock-plugin", + }) + if err != nil { + t.Fatal(err) + } + + return cluster, core +} + +func TestPlugin_PluginMain(t *testing.T) { + if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + + args := []string{"--ca-cert=" + caPEM} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + tlsConfig := apiClientMeta.GetTLSConfig() + tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) + + factoryFunc := mock.FactoryType(logical.TypeLogical) + + err := plugin.Serve(&plugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + TLSProviderFunc: tlsProviderFunc, + }) + if err != nil { + t.Fatal(err) + } + t.Fatal("Why are we here") +} + +func TestPlugin_MockList(t *testing.T) { + logger := logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{ + Mutex: &sync.Mutex{}, + })) + cluster, core := getPluginClusterAndCore(t, logger) + defer cluster.Cleanup() + + _, err := core.Client.Logical().Write("mock/kv/foo", map[string]interface{}{ + "bar": "baz", + }) + if err != nil { + t.Fatal(err) + } + + keys, err := core.Client.Logical().List("mock/kv/") + if err != nil { + t.Fatal(err) + } + if keys.Data["keys"].([]interface{})[0].(string) != "foo" { + t.Fatal(keys) + } + + _, err = core.Client.Logical().Write("mock/kv/zoo", map[string]interface{}{ + "bar": "baz", + }) + if err != nil { + t.Fatal(err) + } + + keys, err = core.Client.Logical().List("mock/kv/") + if err != nil { + t.Fatal(err) + } + if keys.Data["keys"].([]interface{})[0].(string) != "foo" || keys.Data["keys"].([]interface{})[1].(string) != "zoo" { + t.Fatal(keys) + } +} + +func TestPlugin_MockRawResponse(t *testing.T) { + logger := logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{ + Mutex: &sync.Mutex{}, + })) + cluster, core := getPluginClusterAndCore(t, logger) + defer cluster.Cleanup() + + resp, err := core.Client.RawRequest(core.Client.NewRequest("GET", "/v1/mock/raw")) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body[:]) != "Response" { + t.Fatal("bad body") + } + + if resp.StatusCode != 200 { + t.Fatal("bad status") + } + +} diff --git a/logical/plugin/grpc_backend_server.go b/logical/plugin/grpc_backend_server.go index e0940225a0e9..3e902e03301b 100644 --- a/logical/plugin/grpc_backend_server.go +++ b/logical/plugin/grpc_backend_server.go @@ -81,6 +81,11 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Empty) (*pb.SpecialPathsReply, error) { paths := b.backend.SpecialPaths() + if paths == nil { + return &pb.SpecialPathsReply{ + Paths: nil, + }, nil + } return &pb.SpecialPathsReply{ Paths: &pb.Paths{ diff --git a/logical/plugin/mock/backend.go b/logical/plugin/mock/backend.go index 82e578101633..45c72a148c16 100644 --- a/logical/plugin/mock/backend.go +++ b/logical/plugin/mock/backend.go @@ -46,6 +46,7 @@ func Backend() *backend { []*framework.Path{ pathInternal(&b), pathSpecial(&b), + pathRaw(&b), }, ), PathsSpecial: &logical.Paths{ diff --git a/logical/plugin/mock/path_raw.go b/logical/plugin/mock/path_raw.go new file mode 100644 index 000000000000..132155629505 --- /dev/null +++ b/logical/plugin/mock/path_raw.go @@ -0,0 +1,29 @@ +package mock + +import ( + "context" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +// pathRaw is used to test raw responses. +func pathRaw(b *backend) *framework.Path { + return &framework.Path{ + Pattern: "raw", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRawRead, + }, + } +} + +func (b *backend) pathRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + return &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPContentType: "text/plain", + logical.HTTPRawBody: []byte("Response"), + logical.HTTPStatusCode: 200, + }, + }, nil + +} diff --git a/logical/response_util.go b/logical/response_util.go index a3fd2bfd193d..41e617a8b47f 100644 --- a/logical/response_util.go +++ b/logical/response_util.go @@ -31,10 +31,25 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) { if !ok || keysRaw == nil { return http.StatusNotFound, nil } - keys, ok := keysRaw.([]string) - if !ok { + + var keys []string + switch keysRaw.(type) { + case []interface{}: + keys = make([]string, len(keysRaw.([]interface{}))) + for i, el := range keysRaw.([]interface{}) { + s, ok := el.(string) + if !ok { + return http.StatusInternalServerError, nil + } + keys[i] = s + } + + case []string: + keys = keysRaw.([]string) + default: return http.StatusInternalServerError, nil } + if len(keys) == 0 { return http.StatusNotFound, nil }