Skip to content

Commit

Permalink
plugins/gRPC: fix issues with reserved keywords in response data (#3881)
Browse files Browse the repository at this point in the history
* plugins/gRPC: fix issues with reserved keywords in response data

* Add the path raw file for mock plugin

* Fix panic when special paths is nil

* Add tests for Listing and raw requests from plugins

* Add json.Number case when decoding the status

* Bump the version required for gRPC defaults

* Fix test for gRPC version check
  • Loading branch information
briankassouf authored Feb 1, 2018
1 parent 3d62f76 commit 91dffed
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 11 deletions.
4 changes: 3 additions & 1 deletion helper/pluginutil/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ func GRPCSupport() bool {
return true
}

constraint, err := version.NewConstraint(">= 0.9.2")
// Due to some regressions on 0.9.2 & 0.9.3 we now require version 0.9.4
// to allow the plugin framework to default to gRPC.
constraint, err := version.NewConstraint(">= 0.9.4")
if err != nil {
return true
}
Expand Down
12 changes: 8 additions & 4 deletions helper/pluginutil/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ func TestGRPCSupport(t *testing.T) {
},
{
"0.9.2",
true,
false,
},
{
"0.9.2+ent",
"0.9.3",
false,
},
{
"0.9.4+ent",
true,
},
{
"0.9.2-beta",
"0.9.4-beta",
false,
},
{
"0.9.3",
"0.9.4",
true,
},
{
Expand Down
33 changes: 29 additions & 4 deletions http/logical.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package http

import (
"encoding/base64"
"encoding/json"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -200,8 +202,21 @@ func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response)
retErr(w, "no status code given")
return
}
status, ok := statusRaw.(int)
if !ok {

var status int
switch statusRaw.(type) {
case int:
status = statusRaw.(int)
case float64:
status = int(statusRaw.(float64))
case json.Number:
s64, err := statusRaw.(json.Number).Float64()
if err != nil {
retErr(w, "cannot decode status code")
return
}
status = int(s64)
default:
retErr(w, "cannot decode status code")
return
}
Expand Down Expand Up @@ -232,8 +247,18 @@ func respondRaw(w http.ResponseWriter, r *http.Request, resp *logical.Response)
retErr(w, "no body given")
return
}
body, ok = bodyRaw.([]byte)
if !ok {

switch bodyRaw.(type) {
case string:
var err error
body, err = base64.StdEncoding.DecodeString(bodyRaw.(string))
if err != nil {
retErr(w, "cannot decode body")
return
}
case []byte:
body = bodyRaw.([]byte)
default:
retErr(w, "cannot decode body")
return
}
Expand Down
153 changes: 153 additions & 0 deletions http/plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package http

import (
"io/ioutil"
"os"
"sync"
"testing"

hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
bplugin "github.com/hashicorp/vault/builtin/plugin"
"github.com/hashicorp/vault/helper/logbridge"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/plugin"
"github.com/hashicorp/vault/logical/plugin/mock"
"github.com/hashicorp/vault/physical/inmem"
"github.com/hashicorp/vault/vault"
)

func getPluginClusterAndCore(t testing.TB, logger *logbridge.Logger) (*vault.TestCluster, *vault.TestClusterCore) {
inmha, err := inmem.NewInmemHA(nil, logger.LogxiLogger())
if err != nil {
t.Fatal(err)
}

coreConfig := &vault.CoreConfig{
Physical: inmha,
LogicalBackends: map[string]logical.Factory{
"plugin": bplugin.Factory,
},
}

cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: Handler,
RawLogger: logger,
})
cluster.Start()

cores := cluster.Cores
core := cores[0]

os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)

vault.TestWaitActive(t, core.Core)
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestPlugin_PluginMain")

// Mount the mock plugin
err = core.Client.Sys().Mount("mock", &api.MountInput{
Type: "plugin",
PluginName: "mock-plugin",
})
if err != nil {
t.Fatal(err)
}

return cluster, core
}

func TestPlugin_PluginMain(t *testing.T) {
if os.Getenv(pluginutil.PluginVaultVersionEnv) == "" {
return
}

caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
if caPEM == "" {
t.Fatal("CA cert not passed in")
}

args := []string{"--ca-cert=" + caPEM}

apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)

tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)

factoryFunc := mock.FactoryType(logical.TypeLogical)

err := plugin.Serve(&plugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc,
})
if err != nil {
t.Fatal(err)
}
t.Fatal("Why are we here")
}

func TestPlugin_MockList(t *testing.T) {
logger := logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{
Mutex: &sync.Mutex{},
}))
cluster, core := getPluginClusterAndCore(t, logger)
defer cluster.Cleanup()

_, err := core.Client.Logical().Write("mock/kv/foo", map[string]interface{}{
"bar": "baz",
})
if err != nil {
t.Fatal(err)
}

keys, err := core.Client.Logical().List("mock/kv/")
if err != nil {
t.Fatal(err)
}
if keys.Data["keys"].([]interface{})[0].(string) != "foo" {
t.Fatal(keys)
}

_, err = core.Client.Logical().Write("mock/kv/zoo", map[string]interface{}{
"bar": "baz",
})
if err != nil {
t.Fatal(err)
}

keys, err = core.Client.Logical().List("mock/kv/")
if err != nil {
t.Fatal(err)
}
if keys.Data["keys"].([]interface{})[0].(string) != "foo" || keys.Data["keys"].([]interface{})[1].(string) != "zoo" {
t.Fatal(keys)
}
}

func TestPlugin_MockRawResponse(t *testing.T) {
logger := logbridge.NewLogger(hclog.New(&hclog.LoggerOptions{
Mutex: &sync.Mutex{},
}))
cluster, core := getPluginClusterAndCore(t, logger)
defer cluster.Cleanup()

resp, err := core.Client.RawRequest(core.Client.NewRequest("GET", "/v1/mock/raw"))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body[:]) != "Response" {
t.Fatal("bad body")
}

if resp.StatusCode != 200 {
t.Fatal("bad status")
}

}
5 changes: 5 additions & 0 deletions logical/plugin/grpc_backend_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func (b *backendGRPCPluginServer) HandleRequest(ctx context.Context, args *pb.Ha

func (b *backendGRPCPluginServer) SpecialPaths(ctx context.Context, args *pb.Empty) (*pb.SpecialPathsReply, error) {
paths := b.backend.SpecialPaths()
if paths == nil {
return &pb.SpecialPathsReply{
Paths: nil,
}, nil
}

return &pb.SpecialPathsReply{
Paths: &pb.Paths{
Expand Down
1 change: 1 addition & 0 deletions logical/plugin/mock/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func Backend() *backend {
[]*framework.Path{
pathInternal(&b),
pathSpecial(&b),
pathRaw(&b),
},
),
PathsSpecial: &logical.Paths{
Expand Down
29 changes: 29 additions & 0 deletions logical/plugin/mock/path_raw.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package mock

import (
"context"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

// pathRaw is used to test raw responses.
func pathRaw(b *backend) *framework.Path {
return &framework.Path{
Pattern: "raw",
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.ReadOperation: b.pathRawRead,
},
}
}

func (b *backend) pathRawRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return &logical.Response{
Data: map[string]interface{}{
logical.HTTPContentType: "text/plain",
logical.HTTPRawBody: []byte("Response"),
logical.HTTPStatusCode: 200,
},
}, nil

}
19 changes: 17 additions & 2 deletions logical/response_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,25 @@ func RespondErrorCommon(req *Request, resp *Response, err error) (int, error) {
if !ok || keysRaw == nil {
return http.StatusNotFound, nil
}
keys, ok := keysRaw.([]string)
if !ok {

var keys []string
switch keysRaw.(type) {
case []interface{}:
keys = make([]string, len(keysRaw.([]interface{})))
for i, el := range keysRaw.([]interface{}) {
s, ok := el.(string)
if !ok {
return http.StatusInternalServerError, nil
}
keys[i] = s
}

case []string:
keys = keysRaw.([]string)
default:
return http.StatusInternalServerError, nil
}

if len(keys) == 0 {
return http.StatusNotFound, nil
}
Expand Down

0 comments on commit 91dffed

Please sign in to comment.