From 06cef518fcc1cfda78a1b9663d3712f742bdb83d Mon Sep 17 00:00:00 2001 From: Hieu Vu <72878483+hieuvubk@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:12:42 +0700 Subject: [PATCH 1/4] add module restriction order --- .../proto/cosmos/bank/module/v2/module.proto | 6 ++ x/bank/v2/types/module/module.pb.go | 77 +++++++++++++++++-- 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/x/bank/proto/cosmos/bank/module/v2/module.proto b/x/bank/proto/cosmos/bank/module/v2/module.proto index a4c0ed40b52f..4d6db9db5fc6 100644 --- a/x/bank/proto/cosmos/bank/module/v2/module.proto +++ b/x/bank/proto/cosmos/bank/module/v2/module.proto @@ -14,4 +14,10 @@ message Module { // authority defines the custom module authority. If not set, defaults to the governance module. string authority = 1; + + // restrictions_order specifies the order of send restrictions and should be + // a list of module names which provide a send restriction instance. If no + // order is provided, then restrictions will be applied in alphabetical order + // of module names. + repeated string restrictions_order = 2; } diff --git a/x/bank/v2/types/module/module.pb.go b/x/bank/v2/types/module/module.pb.go index 247cbd747e3b..d752f9ecf933 100644 --- a/x/bank/v2/types/module/module.pb.go +++ b/x/bank/v2/types/module/module.pb.go @@ -27,6 +27,11 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type Module struct { // authority defines the custom module authority. If not set, defaults to the governance module. Authority string `protobuf:"bytes,1,opt,name=authority,proto3" json:"authority,omitempty"` + // restrictions_order specifies the order of send restrictions and should be + // a list of module names which provide a send restriction instance. If no + // order is provided, then restrictions will be applied in alphabetical order + // of module names. + RestrictionsOrder []string `protobuf:"bytes,2,rep,name=restrictions_order,json=restrictionsOrder,proto3" json:"restrictions_order,omitempty"` } func (m *Module) Reset() { *m = Module{} } @@ -69,6 +74,13 @@ func (m *Module) GetAuthority() string { return "" } +func (m *Module) GetRestrictionsOrder() []string { + if m != nil { + return m.RestrictionsOrder + } + return nil +} + func init() { proto.RegisterType((*Module)(nil), "cosmos.bank.module.v2.Module") } @@ -78,19 +90,21 @@ func init() { } var fileDescriptor_34a109a905e2a25b = []byte{ - // 184 bytes of a gzipped FileDescriptorProto + // 219 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x52, 0x4a, 0xce, 0x2f, 0xce, 0xcd, 0x2f, 0xd6, 0x4f, 0x4a, 0xcc, 0xcb, 0xd6, 0xcf, 0xcd, 0x4f, 0x29, 0xcd, 0x49, 0xd5, 0x2f, 0x33, 0x82, 0xb2, 0xf4, 0x0a, 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0x44, 0x21, 0x6a, 0xf4, 0x40, 0x6a, 0xf4, 0xa0, 0x32, 0x65, 0x46, 0x52, 0x0a, 0x50, 0xad, 0x89, 0x05, 0x05, 0xfa, 0x65, 0x86, 0x89, - 0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0xdc, 0xb8, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19, + 0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0x4a, 0xb9, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19, 0x2e, 0xce, 0xc4, 0xd2, 0x92, 0x8c, 0xfc, 0xa2, 0xcc, 0x92, 0x4a, 0x09, 0x46, 0x05, 0x46, 0x0d, - 0xce, 0x20, 0x84, 0x80, 0x95, 0xdc, 0xae, 0x03, 0xd3, 0x6e, 0x31, 0x4a, 0x70, 0x89, 0x41, 0x4c, - 0x2c, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0x38, 0xaa, 0xcc, 0xc8, 0xc9, 0xf6, 0xc4, - 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1, - 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0, 0x2f, 0xa9, 0x2c, - 0x48, 0x2d, 0x86, 0x3a, 0x26, 0x89, 0x0d, 0xec, 0x1a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, - 0x5e, 0xfc, 0x10, 0x2c, 0xec, 0x00, 0x00, 0x00, + 0xce, 0x20, 0x84, 0x80, 0x90, 0x2e, 0x97, 0x50, 0x51, 0x6a, 0x71, 0x49, 0x51, 0x66, 0x72, 0x49, + 0x66, 0x7e, 0x5e, 0x71, 0x7c, 0x7e, 0x51, 0x4a, 0x6a, 0x91, 0x04, 0x93, 0x02, 0xb3, 0x06, 0x67, + 0x90, 0x20, 0xb2, 0x8c, 0x3f, 0x48, 0xc2, 0x4a, 0x6e, 0xd7, 0x81, 0x69, 0xb7, 0x18, 0x25, 0xb8, + 0xc4, 0x20, 0x0e, 0x28, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0xf8, 0xa1, 0xcc, 0xc8, + 0xc9, 0xf6, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, + 0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0, + 0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0x86, 0xba, 0x3d, 0x89, 0x0d, 0xec, 0x78, 0x63, 0x40, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x69, 0x6d, 0xb0, 0x10, 0x1b, 0x01, 0x00, 0x00, } func (m *Module) Marshal() (dAtA []byte, err error) { @@ -113,6 +127,15 @@ func (m *Module) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.RestrictionsOrder) > 0 { + for iNdEx := len(m.RestrictionsOrder) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.RestrictionsOrder[iNdEx]) + copy(dAtA[i:], m.RestrictionsOrder[iNdEx]) + i = encodeVarintModule(dAtA, i, uint64(len(m.RestrictionsOrder[iNdEx]))) + i-- + dAtA[i] = 0x12 + } + } if len(m.Authority) > 0 { i -= len(m.Authority) copy(dAtA[i:], m.Authority) @@ -144,6 +167,12 @@ func (m *Module) Size() (n int) { if l > 0 { n += 1 + l + sovModule(uint64(l)) } + if len(m.RestrictionsOrder) > 0 { + for _, s := range m.RestrictionsOrder { + l = len(s) + n += 1 + l + sovModule(uint64(l)) + } + } return n } @@ -214,6 +243,38 @@ func (m *Module) Unmarshal(dAtA []byte) error { } m.Authority = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RestrictionsOrder", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowModule + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthModule + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthModule + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RestrictionsOrder = append(m.RestrictionsOrder, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipModule(dAtA[iNdEx:]) From 42f4dfde70a1e9951ec59de3601967b05bef5f1d Mon Sep 17 00:00:00 2001 From: Hieu Vu <72878483+hieuvubk@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:13:38 +0700 Subject: [PATCH 2/4] add restriction to keeper --- x/bank/v2/depinject.go | 42 ++++++++++++++++++++++++ x/bank/v2/keeper/keeper.go | 35 ++++++++++++++++---- x/bank/v2/keeper/restriction.go | 47 +++++++++++++++++++++++++++ x/bank/v2/types/restrictions.go | 57 +++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 7 deletions(-) create mode 100644 x/bank/v2/keeper/restriction.go create mode 100644 x/bank/v2/types/restrictions.go diff --git a/x/bank/v2/depinject.go b/x/bank/v2/depinject.go index b5465d7e02df..c626d2f10f2d 100644 --- a/x/bank/v2/depinject.go +++ b/x/bank/v2/depinject.go @@ -1,6 +1,11 @@ package bankv2 import ( + "fmt" + "maps" + "slices" + "sort" + "cosmossdk.io/core/address" "cosmossdk.io/core/appmodule" "cosmossdk.io/depinject" @@ -22,6 +27,7 @@ func init() { appconfig.RegisterModule( &moduletypes.Module{}, appconfig.Provide(ProvideModule), + appconfig.Invoke(InvokeSetSendRestrictions), ) } @@ -61,3 +67,39 @@ func ProvideModule(in ModuleInputs) ModuleOutputs { Module: m, } } + +func InvokeSetSendRestrictions( + config *moduletypes.Module, + keeper keeper.Keeper, + restrictions map[string]types.SendRestrictionFn, +) error { + if config == nil { + return nil + } + + modules := slices.Collect(maps.Keys(restrictions)) + order := config.RestrictionsOrder + if len(order) == 0 { + order = modules + sort.Strings(order) + } + + if len(order) != len(modules) { + return fmt.Errorf("len(restrictions order: %v) != len(restriction modules: %v)", order, modules) + } + + if len(modules) == 0 { + return nil + } + + for _, module := range order { + restriction, ok := restrictions[module] + if !ok { + return fmt.Errorf("can't find send restriction for module %s", module) + } + + keeper.AppendSendRestriction(restriction) + } + + return nil +} diff --git a/x/bank/v2/keeper/keeper.go b/x/bank/v2/keeper/keeper.go index f16982f168c4..ff4827238040 100644 --- a/x/bank/v2/keeper/keeper.go +++ b/x/bank/v2/keeper/keeper.go @@ -28,18 +28,21 @@ type Keeper struct { params collections.Item[types.Params] balances *collections.IndexedMap[collections.Pair[[]byte, string], math.Int, BalancesIndexes] supply collections.Map[string, math.Int] + + sendRestriction *sendRestriction } func NewKeeper(authority []byte, addressCodec address.Codec, env appmodulev2.Environment, cdc codec.BinaryCodec) *Keeper { sb := collections.NewSchemaBuilder(env.KVStoreService) k := &Keeper{ - Environment: env, - authority: authority, - addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment? - params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)), - balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)), - supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue), + Environment: env, + authority: authority, + addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment? + params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)), + balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)), + supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue), + sendRestriction: newSendRestriction(), } schema, err := sb.Build() @@ -94,7 +97,10 @@ func (k Keeper) SendCoins(ctx context.Context, from, to []byte, amt sdk.Coins) e } var err error - // TODO: Send restriction + to, err = k.sendRestriction.apply(ctx, from, to, amt) + if err != nil { + return err + } err = k.subUnlockedCoins(ctx, from, amt) if err != nil { @@ -252,3 +258,18 @@ func newBalancesIndexes(sb *collections.SchemaBuilder) BalancesIndexes { type BalancesIndexes struct { Denom *indexes.ReversePair[[]byte, string, math.Int] } + +// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions. +func (k Keeper) AppendSendRestriction(restriction types.SendRestrictionFn) { + k.sendRestriction.append(restriction) +} + +// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions. +func (k Keeper) PrependSendRestriction(restriction types.SendRestrictionFn) { + k.sendRestriction.prepend(restriction) +} + +// ClearSendRestriction removes the send restriction (if there is one). +func (k Keeper) ClearSendRestriction() { + k.sendRestriction.clear() +} diff --git a/x/bank/v2/keeper/restriction.go b/x/bank/v2/keeper/restriction.go new file mode 100644 index 000000000000..b2b1910f6649 --- /dev/null +++ b/x/bank/v2/keeper/restriction.go @@ -0,0 +1,47 @@ +package keeper + +import ( + "context" + + "cosmossdk.io/x/bank/v2/types" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// sendRestriction is a struct that houses a SendRestrictionFn. +// It exists so that the SendRestrictionFn can be updated in the SendKeeper without needing to have a pointer receiver. +type sendRestriction struct { + fn types.SendRestrictionFn +} + +// newSendRestriction creates a new sendRestriction with nil send restriction. +func newSendRestriction() *sendRestriction { + return &sendRestriction{ + fn: nil, + } +} + +// append adds the provided restriction to this, to be run after the existing function. +func (r *sendRestriction) append(restriction types.SendRestrictionFn) { + r.fn = r.fn.Then(restriction) +} + +// prepend adds the provided restriction to this, to be run before the existing function. +func (r *sendRestriction) prepend(restriction types.SendRestrictionFn) { + r.fn = restriction.Then(r.fn) +} + +// clear removes the send restriction (sets it to nil). +func (r *sendRestriction) clear() { + r.fn = nil +} + +var _ types.SendRestrictionFn = (*sendRestriction)(nil).apply + +// apply applies the send restriction if there is one. If not, it's a no-op. +func (r *sendRestriction) apply(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) { + if r == nil || r.fn == nil { + return toAddr, nil + } + return r.fn(ctx, fromAddr, toAddr, amt) +} diff --git a/x/bank/v2/types/restrictions.go b/x/bank/v2/types/restrictions.go new file mode 100644 index 000000000000..89c1f7e6cfea --- /dev/null +++ b/x/bank/v2/types/restrictions.go @@ -0,0 +1,57 @@ +package types + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// A SendRestrictionFn can restrict sends and/or provide a new receiver address. +type SendRestrictionFn func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) (newToAddr []byte, err error) + +// IsOnePerModuleType implements the depinject.OnePerModuleType interface. +func (SendRestrictionFn) IsOnePerModuleType() {} + +var _ SendRestrictionFn = NoOpSendRestrictionFn + +// NoOpSendRestrictionFn is a no-op SendRestrictionFn. +func NoOpSendRestrictionFn(_ context.Context, _, toAddr []byte, _ sdk.Coins) ([]byte, error) { + return toAddr, nil +} + +// Then creates a composite restriction that runs this one then the provided second one. +func (r SendRestrictionFn) Then(second SendRestrictionFn) SendRestrictionFn { + return ComposeSendRestrictions(r, second) +} + +// ComposeSendRestrictions combines multiple SendRestrictionFn into one. +// nil entries are ignored. +// If all entries are nil, nil is returned. +// If exactly one entry is not nil, it is returned. +// Otherwise, a new SendRestrictionFn is returned that runs the non-nil restrictions in the order they are given. +// The composition runs each send restriction until an error is encountered and returns that error, +// otherwise it returns the toAddr of the last send restriction. +func ComposeSendRestrictions(restrictions ...SendRestrictionFn) SendRestrictionFn { + toRun := make([]SendRestrictionFn, 0, len(restrictions)) + for _, r := range restrictions { + if r != nil { + toRun = append(toRun, r) + } + } + switch len(toRun) { + case 0: + return nil + case 1: + return toRun[0] + } + return func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) { + var err error + for _, r := range toRun { + toAddr, err = r(ctx, fromAddr, toAddr, amt) + if err != nil { + return toAddr, err + } + } + return toAddr, err + } +} From f19996ece56e7d50b6086c209ac7d567f6d00afc Mon Sep 17 00:00:00 2001 From: Hieu Vu <72878483+hieuvubk@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:13:48 +0700 Subject: [PATCH 3/4] test --- x/bank/v2/keeper/keeper_test.go | 52 +++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/x/bank/v2/keeper/keeper_test.go b/x/bank/v2/keeper/keeper_test.go index 32ccbac4d2ef..4da9faed3b5c 100644 --- a/x/bank/v2/keeper/keeper_test.go +++ b/x/bank/v2/keeper/keeper_test.go @@ -1,7 +1,9 @@ package keeper_test import ( + "bytes" "context" + "fmt" "testing" "time" @@ -184,3 +186,53 @@ func (suite *KeeperTestSuite) TestSendCoins_Module_To_Module() { mintBarBalance := suite.bankKeeper.GetBalance(ctx, mintAcc.GetAddress(), barDenom) require.Equal(mintBarBalance.Amount, math.NewInt(0)) } + +func (suite *KeeperTestSuite) TestSendCoins_WithRestriction() { + ctx := suite.ctx + require := suite.Require() + balances := sdk.NewCoins(newFooCoin(100), newBarCoin(50)) + sendAmt := sdk.NewCoins(newFooCoin(10), newBarCoin(10)) + + require.NoError(banktestutil.FundAccount(ctx, suite.bankKeeper, accAddrs[0], balances)) + + // Add first restriction + addrRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) { + if bytes.Equal(from, to) { + return nil, fmt.Errorf("Can not send to same address") + } + return to, nil + } + suite.bankKeeper.AppendSendRestriction(addrRestrictFunc) + + err := suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[0], sendAmt) + require.Error(err) + require.Contains(err.Error(), "Can not send to same address") + + // Add second restriction + amtRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) { + if len(amount) > 1 { + return nil, fmt.Errorf("Allow only one denom per one send") + } + return to, nil + } + suite.bankKeeper.AppendSendRestriction(amtRestrictFunc) + + // Pass the 1st but failt at the 2nd + err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sendAmt) + require.Error(err) + require.Contains(err.Error(), "Allow only one denom per one send") + + // Pass both 2 restrictions + err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sdk.NewCoins(newFooCoin(10))) + require.NoError(err) + + // Check balances + acc0FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], fooDenom) + require.Equal(acc0FooBalance.Amount, math.NewInt(90)) + acc0BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], barDenom) + require.Equal(acc0BarBalance.Amount, math.NewInt(50)) + acc1FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], fooDenom) + require.Equal(acc1FooBalance.Amount, math.NewInt(10)) + acc1BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], barDenom) + require.Equal(acc1BarBalance.Amount, math.ZeroInt()) +} From 3c685f104b0ae5e5302fd1c1940add823cc1134a Mon Sep 17 00:00:00 2001 From: Hieu Vu <72878483+hieuvubk@users.noreply.github.com> Date: Tue, 1 Oct 2024 20:36:50 +0700 Subject: [PATCH 4/4] feedback --- x/bank/v2/depinject.go | 2 +- x/bank/v2/keeper/keeper.go | 15 --------------- x/bank/v2/keeper/keeper_test.go | 4 ++-- x/bank/v2/keeper/restriction.go | 15 +++++++++++++++ 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/x/bank/v2/depinject.go b/x/bank/v2/depinject.go index c626d2f10f2d..6bb04908db2a 100644 --- a/x/bank/v2/depinject.go +++ b/x/bank/v2/depinject.go @@ -98,7 +98,7 @@ func InvokeSetSendRestrictions( return fmt.Errorf("can't find send restriction for module %s", module) } - keeper.AppendSendRestriction(restriction) + keeper.AppendGlobalSendRestriction(restriction) } return nil diff --git a/x/bank/v2/keeper/keeper.go b/x/bank/v2/keeper/keeper.go index ff4827238040..833fb298ef6f 100644 --- a/x/bank/v2/keeper/keeper.go +++ b/x/bank/v2/keeper/keeper.go @@ -258,18 +258,3 @@ func newBalancesIndexes(sb *collections.SchemaBuilder) BalancesIndexes { type BalancesIndexes struct { Denom *indexes.ReversePair[[]byte, string, math.Int] } - -// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions. -func (k Keeper) AppendSendRestriction(restriction types.SendRestrictionFn) { - k.sendRestriction.append(restriction) -} - -// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions. -func (k Keeper) PrependSendRestriction(restriction types.SendRestrictionFn) { - k.sendRestriction.prepend(restriction) -} - -// ClearSendRestriction removes the send restriction (if there is one). -func (k Keeper) ClearSendRestriction() { - k.sendRestriction.clear() -} diff --git a/x/bank/v2/keeper/keeper_test.go b/x/bank/v2/keeper/keeper_test.go index 4da9faed3b5c..2920d0aea040 100644 --- a/x/bank/v2/keeper/keeper_test.go +++ b/x/bank/v2/keeper/keeper_test.go @@ -202,7 +202,7 @@ func (suite *KeeperTestSuite) TestSendCoins_WithRestriction() { } return to, nil } - suite.bankKeeper.AppendSendRestriction(addrRestrictFunc) + suite.bankKeeper.AppendGlobalSendRestriction(addrRestrictFunc) err := suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[0], sendAmt) require.Error(err) @@ -215,7 +215,7 @@ func (suite *KeeperTestSuite) TestSendCoins_WithRestriction() { } return to, nil } - suite.bankKeeper.AppendSendRestriction(amtRestrictFunc) + suite.bankKeeper.AppendGlobalSendRestriction(amtRestrictFunc) // Pass the 1st but failt at the 2nd err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sendAmt) diff --git a/x/bank/v2/keeper/restriction.go b/x/bank/v2/keeper/restriction.go index b2b1910f6649..048d090ec7c2 100644 --- a/x/bank/v2/keeper/restriction.go +++ b/x/bank/v2/keeper/restriction.go @@ -45,3 +45,18 @@ func (r *sendRestriction) apply(ctx context.Context, fromAddr, toAddr []byte, am } return r.fn(ctx, fromAddr, toAddr, amt) } + +// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions. +func (k Keeper) AppendGlobalSendRestriction(restriction types.SendRestrictionFn) { + k.sendRestriction.append(restriction) +} + +// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions. +func (k Keeper) PrependGlobalSendRestriction(restriction types.SendRestrictionFn) { + k.sendRestriction.prepend(restriction) +} + +// ClearSendRestriction removes the send restriction (if there is one). +func (k Keeper) ClearGlobalSendRestriction() { + k.sendRestriction.clear() +}