From 8b8c01bd1f713a6143e5122cdc8571be177c5573 Mon Sep 17 00:00:00 2001 From: Daisuke Iuchi <42408108+da1suk8@users.noreply.github.com> Date: Wed, 12 Oct 2022 12:05:25 +0900 Subject: [PATCH] fix: fix to prevent accepting file name (#690) * fix: fix to prevent accepting file name * Update a download file name and add test cases (#1) * Update short/long usage and change filename * Add query_test * Add test cases * fix: go.mod * docs: update CHANGELOG.md Co-authored-by: Toshimasa Nasu --- CHANGELOG.md | 3 +- go.mod | 1 + x/wasm/client/cli/query.go | 13 +- x/wasm/client/cli/query_test.go | 449 ++++++++++++++++++++++++++++++++ 4 files changed, 459 insertions(+), 7 deletions(-) create mode 100644 x/wasm/client/cli/query_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d5576f5fa..be30f71e56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,7 +75,8 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (x/foundation) [\#693](https://github.com/line/lbm-sdk/pull/693) add pool to the state of x/foundation * (x/wasm,distribution) [\#696](https://github.com/line/lbm-sdk/pull/696) x/wasm,distribution - add checking a file size before reading it * (x/foundation) [\#698](https://github.com/line/lbm-sdk/pull/698) update x/group relevant logic in x/foundation -* (x) [\#691](https://github.com/line/lbm-sdk/pull/691) change AccAddressFromBech32 to MustAccAddressFromBech32 +* (x/auth,bank,foundation,wasm) [\#691](https://github.com/line/lbm-sdk/pull/691) change AccAddressFromBech32 to MustAccAddressFromBech32 +* (x/wasm) [\#690](https://github.com/line/lbm-sdk/pull/690) fix to prevent accepting file name ### Bug Fixes * (x/wasm) [\#453](https://github.com/line/lbm-sdk/pull/453) modify wasm grpc query api path diff --git a/go.mod b/go.mod index e622e14e2f..2ac68410e7 100644 --- a/go.mod +++ b/go.mod @@ -120,6 +120,7 @@ require ( github.com/sasha-s/go-deadlock v0.2.1-0.20190427202633-1595213edefa // indirect github.com/spf13/afero v1.8.2 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect github.com/subosito/gotenv v1.4.1 // indirect github.com/tendermint/tendermint v0.34.19 // indirect github.com/zondax/hid v0.9.0 // indirect diff --git a/x/wasm/client/cli/query.go b/x/wasm/client/cli/query.go index 7ffa2ba185..8f976e95f3 100644 --- a/x/wasm/client/cli/query.go +++ b/x/wasm/client/cli/query.go @@ -146,11 +146,11 @@ func GetCmdListContractByCode() *cobra.Command { // GetCmdQueryCode returns the bytecode for a given contract func GetCmdQueryCode() *cobra.Command { cmd := &cobra.Command{ - Use: "code [code_id] [output filename]", - Short: "Downloads wasm bytecode for given code id", - Long: "Downloads wasm bytecode for given code id", + Use: "code [code_id]", + Short: "Downloads wasm bytecode for given code id to the current directory", + Long: "Downloads wasm bytecode for given code id to the current directory as `contract-[code_id].wasm`", Aliases: []string{"source-code", "source"}, - Args: cobra.ExactArgs(2), + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { clientCtx, err := client.GetClientQueryContext(cmd) if err != nil { @@ -176,8 +176,9 @@ func GetCmdQueryCode() *cobra.Command { return fmt.Errorf("contract not found") } - fmt.Printf("Downloading wasm code to %s\n", args[1]) - return os.WriteFile(args[1], res.Data, 0600) + fileName := "contract-" + strconv.FormatUint(codeID, 10) + ".wasm" + fmt.Printf("Downloading wasm code to %s\n", fileName) + return os.WriteFile(fileName, res.Data, 0600) }, } flags.AddQueryFlagsToCmd(cmd) diff --git a/x/wasm/client/cli/query_test.go b/x/wasm/client/cli/query_test.go new file mode 100644 index 0000000000..49a16017e3 --- /dev/null +++ b/x/wasm/client/cli/query_test.go @@ -0,0 +1,449 @@ +package cli + +import ( + "context" + "encoding/hex" + "errors" + sdkerrors "github.com/line/lbm-sdk/types/errors" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "net/url" + "os" + "strconv" + "testing" + + "github.com/line/lbm-sdk/client" + "github.com/line/lbm-sdk/codec" + "github.com/line/lbm-sdk/x/wasm/lbmtypes" + "github.com/line/lbm-sdk/x/wasm/types" + ocabcitypes "github.com/line/ostracon/abci/types" + ocrpcmocks "github.com/line/ostracon/rpc/client/mocks" + ocrpctypes "github.com/line/ostracon/rpc/core/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + codeID = "1" + accAddress = "link1yxfu3fldlgux939t0gwaqs82l4x77v2kasa7jf" + queryJson = `{"a":"b"}` + queryJsonHex = hex.EncodeToString([]byte(queryJson)) + argsWithCodeID = []string{codeID} + argsWithAddr = []string{accAddress} + badStatusError = status.Error(codes.Unknown, "") + invalidRequestFlags = []string{"--page=2", "--offset=1"} + invalidRequestError = sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, + "page and offset cannot be used together") + invalidNodeFlags = []string{"--node=" + string(rune(0))} + invalidControlChar = &url.Error{Op: "parse", URL: string(rune(0)), + Err: errors.New("net/url: invalid control character in URL")} + invalidSyntaxError = &strconv.NumError{Func: "ParseUint", Num: "", Err: strconv.ErrSyntax} + invalidAddrError = errors.New("empty address string is not allowed") + invalidQueryError = errors.New("query data must be json") +) + +type testcase []struct { + name string + want error + ctx context.Context + flags []string + args []string +} + +func TestGetQueryCmd(t *testing.T) { + tests := []struct { + name string + }{ + {"execute success"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetQueryCmd() + assert.NotNilf(t, cmd, "GetQueryCmd()") + }) + } +} + +func TestGetCmdLibVersion(t *testing.T) { + tests := []struct { + name string + want error + }{ + {"execute success", nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdLibVersion() + assert.Equalf(t, tt.want, cmd.RunE(cmd, nil), "GetCmdLibVersion()") + }) + } +} + +func TestGetCmdListCode(t *testing.T) { + res := types.QueryCodesResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, nil}, + {"bad status", badStatusError, ctx, nil, nil}, + {"invalid request", invalidRequestError, ctx, invalidRequestFlags, nil}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdListCode() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdListCode()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdListCode()") + } + }) + } +} + +func TestGetCmdListContractByCode(t *testing.T) { + res := types.QueryContractsByCodeResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithCodeID}, + {"bad status", badStatusError, ctx, nil, argsWithCodeID}, + {"invalid request", invalidRequestError, ctx, invalidRequestFlags, argsWithCodeID}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithCodeID}, + {"invalid codeID", invalidSyntaxError, ctx, nil, []string{""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdListContractByCode() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdListContractByCode()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdListContractByCode()") + } + }) + } +} + +func TestGetCmdQueryCode(t *testing.T) { + res := types.QueryCodeResponse{Data: []byte{0}} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithCodeID}, + {"bad status", badStatusError, ctx, nil, argsWithCodeID}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithCodeID}, + {"invalid codeID", invalidSyntaxError, ctx, nil, []string{""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdQueryCode() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdQueryCode()") + downloaded := "contract-" + codeID + ".wasm" + assert.FileExists(t, downloaded) + assert.NoError(t, os.Remove(downloaded)) + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdQueryCode()") + } + }) + } +} + +func TestGetCmdQueryCodeInfo(t *testing.T) { + res := types.QueryCodeResponse{CodeInfoResponse: &types.CodeInfoResponse{}} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithCodeID}, + {"bad status", badStatusError, ctx, nil, argsWithCodeID}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithCodeID}, + {"invalid codeID", invalidSyntaxError, ctx, nil, []string{""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdQueryCodeInfo() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdQueryCodeInfo()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdQueryCodeInfo()") + } + }) + } +} + +func TestGetCmdGetContractInfo(t *testing.T) { + res := types.QueryContractInfoResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithAddr}, + {"bad status", badStatusError, ctx, nil, argsWithAddr}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithAddr}, + {"invalid address", invalidAddrError, ctx, nil, []string{""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdGetContractInfo() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdGetContractInfo()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdGetContractInfo()") + } + }) + } +} + +func TestGetCmdGetContractState(t *testing.T) { + tests := []struct { + name string + want error + }{ + {"execute success", nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdGetContractState() + assert.Equalf(t, tt.want, cmd.RunE(cmd, nil), "GetCmdGetContractState()") + }) + } +} + +func TestGetCmdGetContractStateAll(t *testing.T) { + res := types.QueryAllContractStateResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithAddr}, + {"bad status", badStatusError, ctx, nil, argsWithAddr}, + {"invalid request", invalidRequestError, ctx, invalidRequestFlags, argsWithAddr}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithAddr}, + {"invalid address", invalidAddrError, ctx, nil, []string{""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdGetContractStateAll() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdGetContractStateAll()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdGetContractStateAll()") + } + }) + } +} + +func TestGetCmdGetContractStateRaw(t *testing.T) { + res := types.QueryRawContractStateResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + args := []string{accAddress, queryJsonHex} + tests := testcase{ + {"execute success", nil, ctx, nil, args}, + {"bad status", badStatusError, ctx, nil, args}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, args}, + {"invalid address", invalidAddrError, ctx, nil, []string{"", "a"}}, + {"invalid key", hex.ErrLength, ctx, nil, []string{accAddress, "a"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdGetContractStateRaw() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdGetContractStateRaw()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdGetContractStateRaw()") + } + }) + } +} + +func TestGetCmdGetContractStateSmart(t *testing.T) { + res := types.QueryRawContractStateResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + args := []string{accAddress, queryJson} + tests := testcase{ + {"execute success", nil, ctx, nil, args}, + {"bad status", badStatusError, ctx, nil, args}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, args}, + {"invalid address", invalidAddrError, ctx, nil, []string{"", "a"}}, + {"invalid query", invalidQueryError, ctx, nil, []string{accAddress, "a"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdGetContractStateSmart() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdGetContractStateSmart()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdGetContractStateSmart()") + } + }) + } +} + +func TestGetCmdGetContractHistory(t *testing.T) { + res := types.QueryContractHistoryResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithAddr}, + {"bad status", badStatusError, ctx, nil, argsWithAddr}, + {"invalid request", invalidRequestError, ctx, invalidRequestFlags, argsWithAddr}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithAddr}, + {"invalid address", invalidAddrError, ctx, nil, []string{""}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdGetContractHistory() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdGetContractHistory()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdGetContractHistory()") + } + }) + } +} + +func TestGetCmdListPinnedCode(t *testing.T) { + res := types.QueryPinnedCodesResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, nil}, + {"bad status", badStatusError, ctx, nil, nil}, + {"invalid request", invalidRequestError, ctx, invalidRequestFlags, nil}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdListPinnedCode() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdListPinnedCode()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdListPinnedCode()") + } + }) + } +} + +func TestGetCmdListInactiveContracts(t *testing.T) { + res := lbmtypes.QueryInactiveContractsResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, nil}, + {"bad status", badStatusError, ctx, nil, nil}, + {"invalid request", invalidRequestError, ctx, invalidRequestFlags, nil}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdListInactiveContracts() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdListInactiveContracts()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdListInactiveContracts()") + } + }) + } +} + +func TestGetCmdIsInactiveContract(t *testing.T) { + res := lbmtypes.QueryInactiveContractResponse{} + bz, err := res.Marshal() + require.NoError(t, err) + ctx := makeContext(bz) + tests := testcase{ + {"execute success", nil, ctx, nil, argsWithAddr}, + {"bad status", badStatusError, ctx, nil, argsWithAddr}, + {"invalid url", invalidControlChar, context.Background(), invalidNodeFlags, argsWithAddr}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := GetCmdIsInactiveContract() + err := cmd.ParseFlags(tt.flags) + require.NoError(t, err) + cmd.SetContext(tt.ctx) + actual := cmd.RunE(cmd, tt.args) + if tt.want == nil { + assert.Nilf(t, actual, "GetCmdIsInactiveContract()") + } else { + assert.Equalf(t, tt.want.Error(), actual.Error(), "GetCmdIsInactiveContract()") + } + }) + } +} +func makeContext(bz []byte) context.Context { + result := ocrpctypes.ResultABCIQuery{Response: ocabcitypes.ResponseQuery{Value: bz}} + mockClient := ocrpcmocks.RemoteClient{} + { + // #1 + mockClient.On("ABCIQueryWithOptions", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Once().Return(&result, nil) + } + { + // #2 + failure := result + failure.Response.Code = 1 + mockClient.On("ABCIQueryWithOptions", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, + ).Once().Return(&failure, nil) + } + cli := client.Context{}.WithClient(&mockClient).WithCodec(codec.NewProtoCodec(nil)) + return context.WithValue(context.Background(), client.ClientContextKey, &cli) +}