diff --git a/x/bank/proto/cosmos/bank/module/v2/module.proto b/x/bank/proto/cosmos/bank/module/v2/module.proto index a4c0ed40b52f..4d6db9db5fc6 100644 --- a/x/bank/proto/cosmos/bank/module/v2/module.proto +++ b/x/bank/proto/cosmos/bank/module/v2/module.proto @@ -14,4 +14,10 @@ message Module { // authority defines the custom module authority. If not set, defaults to the governance module. string authority = 1; + + // restrictions_order specifies the order of send restrictions and should be + // a list of module names which provide a send restriction instance. If no + // order is provided, then restrictions will be applied in alphabetical order + // of module names. + repeated string restrictions_order = 2; } diff --git a/x/bank/v2/depinject.go b/x/bank/v2/depinject.go index b5465d7e02df..6bb04908db2a 100644 --- a/x/bank/v2/depinject.go +++ b/x/bank/v2/depinject.go @@ -1,6 +1,11 @@ package bankv2 import ( + "fmt" + "maps" + "slices" + "sort" + "cosmossdk.io/core/address" "cosmossdk.io/core/appmodule" "cosmossdk.io/depinject" @@ -22,6 +27,7 @@ func init() { appconfig.RegisterModule( &moduletypes.Module{}, appconfig.Provide(ProvideModule), + appconfig.Invoke(InvokeSetSendRestrictions), ) } @@ -61,3 +67,39 @@ func ProvideModule(in ModuleInputs) ModuleOutputs { Module: m, } } + +func InvokeSetSendRestrictions( + config *moduletypes.Module, + keeper keeper.Keeper, + restrictions map[string]types.SendRestrictionFn, +) error { + if config == nil { + return nil + } + + modules := slices.Collect(maps.Keys(restrictions)) + order := config.RestrictionsOrder + if len(order) == 0 { + order = modules + sort.Strings(order) + } + + if len(order) != len(modules) { + return fmt.Errorf("len(restrictions order: %v) != len(restriction modules: %v)", order, modules) + } + + if len(modules) == 0 { + return nil + } + + for _, module := range order { + restriction, ok := restrictions[module] + if !ok { + return fmt.Errorf("can't find send restriction for module %s", module) + } + + keeper.AppendGlobalSendRestriction(restriction) + } + + return nil +} diff --git a/x/bank/v2/keeper/keeper.go b/x/bank/v2/keeper/keeper.go index f16982f168c4..833fb298ef6f 100644 --- a/x/bank/v2/keeper/keeper.go +++ b/x/bank/v2/keeper/keeper.go @@ -28,18 +28,21 @@ type Keeper struct { params collections.Item[types.Params] balances *collections.IndexedMap[collections.Pair[[]byte, string], math.Int, BalancesIndexes] supply collections.Map[string, math.Int] + + sendRestriction *sendRestriction } func NewKeeper(authority []byte, addressCodec address.Codec, env appmodulev2.Environment, cdc codec.BinaryCodec) *Keeper { sb := collections.NewSchemaBuilder(env.KVStoreService) k := &Keeper{ - Environment: env, - authority: authority, - addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment? - params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)), - balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)), - supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue), + Environment: env, + authority: authority, + addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment? + params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)), + balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)), + supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue), + sendRestriction: newSendRestriction(), } schema, err := sb.Build() @@ -94,7 +97,10 @@ func (k Keeper) SendCoins(ctx context.Context, from, to []byte, amt sdk.Coins) e } var err error - // TODO: Send restriction + to, err = k.sendRestriction.apply(ctx, from, to, amt) + if err != nil { + return err + } err = k.subUnlockedCoins(ctx, from, amt) if err != nil { diff --git a/x/bank/v2/keeper/keeper_test.go b/x/bank/v2/keeper/keeper_test.go index 32ccbac4d2ef..2920d0aea040 100644 --- a/x/bank/v2/keeper/keeper_test.go +++ b/x/bank/v2/keeper/keeper_test.go @@ -1,7 +1,9 @@ package keeper_test import ( + "bytes" "context" + "fmt" "testing" "time" @@ -184,3 +186,53 @@ func (suite *KeeperTestSuite) TestSendCoins_Module_To_Module() { mintBarBalance := suite.bankKeeper.GetBalance(ctx, mintAcc.GetAddress(), barDenom) require.Equal(mintBarBalance.Amount, math.NewInt(0)) } + +func (suite *KeeperTestSuite) TestSendCoins_WithRestriction() { + ctx := suite.ctx + require := suite.Require() + balances := sdk.NewCoins(newFooCoin(100), newBarCoin(50)) + sendAmt := sdk.NewCoins(newFooCoin(10), newBarCoin(10)) + + require.NoError(banktestutil.FundAccount(ctx, suite.bankKeeper, accAddrs[0], balances)) + + // Add first restriction + addrRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) { + if bytes.Equal(from, to) { + return nil, fmt.Errorf("Can not send to same address") + } + return to, nil + } + suite.bankKeeper.AppendGlobalSendRestriction(addrRestrictFunc) + + err := suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[0], sendAmt) + require.Error(err) + require.Contains(err.Error(), "Can not send to same address") + + // Add second restriction + amtRestrictFunc := func(ctx context.Context, from, to []byte, amount sdk.Coins) ([]byte, error) { + if len(amount) > 1 { + return nil, fmt.Errorf("Allow only one denom per one send") + } + return to, nil + } + suite.bankKeeper.AppendGlobalSendRestriction(amtRestrictFunc) + + // Pass the 1st but failt at the 2nd + err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sendAmt) + require.Error(err) + require.Contains(err.Error(), "Allow only one denom per one send") + + // Pass both 2 restrictions + err = suite.bankKeeper.SendCoins(ctx, accAddrs[0], accAddrs[1], sdk.NewCoins(newFooCoin(10))) + require.NoError(err) + + // Check balances + acc0FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], fooDenom) + require.Equal(acc0FooBalance.Amount, math.NewInt(90)) + acc0BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[0], barDenom) + require.Equal(acc0BarBalance.Amount, math.NewInt(50)) + acc1FooBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], fooDenom) + require.Equal(acc1FooBalance.Amount, math.NewInt(10)) + acc1BarBalance := suite.bankKeeper.GetBalance(ctx, accAddrs[1], barDenom) + require.Equal(acc1BarBalance.Amount, math.ZeroInt()) +} diff --git a/x/bank/v2/keeper/restriction.go b/x/bank/v2/keeper/restriction.go new file mode 100644 index 000000000000..048d090ec7c2 --- /dev/null +++ b/x/bank/v2/keeper/restriction.go @@ -0,0 +1,62 @@ +package keeper + +import ( + "context" + + "cosmossdk.io/x/bank/v2/types" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// sendRestriction is a struct that houses a SendRestrictionFn. +// It exists so that the SendRestrictionFn can be updated in the SendKeeper without needing to have a pointer receiver. +type sendRestriction struct { + fn types.SendRestrictionFn +} + +// newSendRestriction creates a new sendRestriction with nil send restriction. +func newSendRestriction() *sendRestriction { + return &sendRestriction{ + fn: nil, + } +} + +// append adds the provided restriction to this, to be run after the existing function. +func (r *sendRestriction) append(restriction types.SendRestrictionFn) { + r.fn = r.fn.Then(restriction) +} + +// prepend adds the provided restriction to this, to be run before the existing function. +func (r *sendRestriction) prepend(restriction types.SendRestrictionFn) { + r.fn = restriction.Then(r.fn) +} + +// clear removes the send restriction (sets it to nil). +func (r *sendRestriction) clear() { + r.fn = nil +} + +var _ types.SendRestrictionFn = (*sendRestriction)(nil).apply + +// apply applies the send restriction if there is one. If not, it's a no-op. +func (r *sendRestriction) apply(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) { + if r == nil || r.fn == nil { + return toAddr, nil + } + return r.fn(ctx, fromAddr, toAddr, amt) +} + +// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions. +func (k Keeper) AppendGlobalSendRestriction(restriction types.SendRestrictionFn) { + k.sendRestriction.append(restriction) +} + +// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions. +func (k Keeper) PrependGlobalSendRestriction(restriction types.SendRestrictionFn) { + k.sendRestriction.prepend(restriction) +} + +// ClearSendRestriction removes the send restriction (if there is one). +func (k Keeper) ClearGlobalSendRestriction() { + k.sendRestriction.clear() +} diff --git a/x/bank/v2/types/module/module.pb.go b/x/bank/v2/types/module/module.pb.go index 247cbd747e3b..d752f9ecf933 100644 --- a/x/bank/v2/types/module/module.pb.go +++ b/x/bank/v2/types/module/module.pb.go @@ -27,6 +27,11 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type Module struct { // authority defines the custom module authority. If not set, defaults to the governance module. Authority string `protobuf:"bytes,1,opt,name=authority,proto3" json:"authority,omitempty"` + // restrictions_order specifies the order of send restrictions and should be + // a list of module names which provide a send restriction instance. If no + // order is provided, then restrictions will be applied in alphabetical order + // of module names. + RestrictionsOrder []string `protobuf:"bytes,2,rep,name=restrictions_order,json=restrictionsOrder,proto3" json:"restrictions_order,omitempty"` } func (m *Module) Reset() { *m = Module{} } @@ -69,6 +74,13 @@ func (m *Module) GetAuthority() string { return "" } +func (m *Module) GetRestrictionsOrder() []string { + if m != nil { + return m.RestrictionsOrder + } + return nil +} + func init() { proto.RegisterType((*Module)(nil), "cosmos.bank.module.v2.Module") } @@ -78,19 +90,21 @@ func init() { } var fileDescriptor_34a109a905e2a25b = []byte{ - // 184 bytes of a gzipped FileDescriptorProto + // 219 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x52, 0x4a, 0xce, 0x2f, 0xce, 0xcd, 0x2f, 0xd6, 0x4f, 0x4a, 0xcc, 0xcb, 0xd6, 0xcf, 0xcd, 0x4f, 0x29, 0xcd, 0x49, 0xd5, 0x2f, 0x33, 0x82, 0xb2, 0xf4, 0x0a, 0x8a, 0xf2, 0x4b, 0xf2, 0x85, 0x44, 0x21, 0x6a, 0xf4, 0x40, 0x6a, 0xf4, 0xa0, 0x32, 0x65, 0x46, 0x52, 0x0a, 0x50, 0xad, 0x89, 0x05, 0x05, 0xfa, 0x65, 0x86, 0x89, - 0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0xdc, 0xb8, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19, + 0x39, 0x05, 0x19, 0x89, 0x86, 0x28, 0x1a, 0x95, 0x4a, 0xb9, 0xd8, 0x7c, 0xc1, 0x7c, 0x21, 0x19, 0x2e, 0xce, 0xc4, 0xd2, 0x92, 0x8c, 0xfc, 0xa2, 0xcc, 0x92, 0x4a, 0x09, 0x46, 0x05, 0x46, 0x0d, - 0xce, 0x20, 0x84, 0x80, 0x95, 0xdc, 0xae, 0x03, 0xd3, 0x6e, 0x31, 0x4a, 0x70, 0x89, 0x41, 0x4c, - 0x2c, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0x38, 0xaa, 0xcc, 0xc8, 0xc9, 0xf6, 0xc4, - 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1, - 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0, 0x2f, 0xa9, 0x2c, - 0x48, 0x2d, 0x86, 0x3a, 0x26, 0x89, 0x0d, 0xec, 0x1a, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, - 0x5e, 0xfc, 0x10, 0x2c, 0xec, 0x00, 0x00, 0x00, + 0xce, 0x20, 0x84, 0x80, 0x90, 0x2e, 0x97, 0x50, 0x51, 0x6a, 0x71, 0x49, 0x51, 0x66, 0x72, 0x49, + 0x66, 0x7e, 0x5e, 0x71, 0x7c, 0x7e, 0x51, 0x4a, 0x6a, 0x91, 0x04, 0x93, 0x02, 0xb3, 0x06, 0x67, + 0x90, 0x20, 0xb2, 0x8c, 0x3f, 0x48, 0xc2, 0x4a, 0x6e, 0xd7, 0x81, 0x69, 0xb7, 0x18, 0x25, 0xb8, + 0xc4, 0x20, 0x0e, 0x28, 0x4e, 0xc9, 0xd6, 0xcb, 0xcc, 0xd7, 0xaf, 0x80, 0xf8, 0xa1, 0xcc, 0xc8, + 0xc9, 0xf6, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, + 0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x94, 0xb1, 0xeb, 0xd0, + 0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0x86, 0xba, 0x3d, 0x89, 0x0d, 0xec, 0x78, 0x63, 0x40, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x69, 0x6d, 0xb0, 0x10, 0x1b, 0x01, 0x00, 0x00, } func (m *Module) Marshal() (dAtA []byte, err error) { @@ -113,6 +127,15 @@ func (m *Module) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.RestrictionsOrder) > 0 { + for iNdEx := len(m.RestrictionsOrder) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.RestrictionsOrder[iNdEx]) + copy(dAtA[i:], m.RestrictionsOrder[iNdEx]) + i = encodeVarintModule(dAtA, i, uint64(len(m.RestrictionsOrder[iNdEx]))) + i-- + dAtA[i] = 0x12 + } + } if len(m.Authority) > 0 { i -= len(m.Authority) copy(dAtA[i:], m.Authority) @@ -144,6 +167,12 @@ func (m *Module) Size() (n int) { if l > 0 { n += 1 + l + sovModule(uint64(l)) } + if len(m.RestrictionsOrder) > 0 { + for _, s := range m.RestrictionsOrder { + l = len(s) + n += 1 + l + sovModule(uint64(l)) + } + } return n } @@ -214,6 +243,38 @@ func (m *Module) Unmarshal(dAtA []byte) error { } m.Authority = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RestrictionsOrder", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowModule + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthModule + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthModule + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RestrictionsOrder = append(m.RestrictionsOrder, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipModule(dAtA[iNdEx:]) diff --git a/x/bank/v2/types/restrictions.go b/x/bank/v2/types/restrictions.go new file mode 100644 index 000000000000..89c1f7e6cfea --- /dev/null +++ b/x/bank/v2/types/restrictions.go @@ -0,0 +1,57 @@ +package types + +import ( + "context" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// A SendRestrictionFn can restrict sends and/or provide a new receiver address. +type SendRestrictionFn func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) (newToAddr []byte, err error) + +// IsOnePerModuleType implements the depinject.OnePerModuleType interface. +func (SendRestrictionFn) IsOnePerModuleType() {} + +var _ SendRestrictionFn = NoOpSendRestrictionFn + +// NoOpSendRestrictionFn is a no-op SendRestrictionFn. +func NoOpSendRestrictionFn(_ context.Context, _, toAddr []byte, _ sdk.Coins) ([]byte, error) { + return toAddr, nil +} + +// Then creates a composite restriction that runs this one then the provided second one. +func (r SendRestrictionFn) Then(second SendRestrictionFn) SendRestrictionFn { + return ComposeSendRestrictions(r, second) +} + +// ComposeSendRestrictions combines multiple SendRestrictionFn into one. +// nil entries are ignored. +// If all entries are nil, nil is returned. +// If exactly one entry is not nil, it is returned. +// Otherwise, a new SendRestrictionFn is returned that runs the non-nil restrictions in the order they are given. +// The composition runs each send restriction until an error is encountered and returns that error, +// otherwise it returns the toAddr of the last send restriction. +func ComposeSendRestrictions(restrictions ...SendRestrictionFn) SendRestrictionFn { + toRun := make([]SendRestrictionFn, 0, len(restrictions)) + for _, r := range restrictions { + if r != nil { + toRun = append(toRun, r) + } + } + switch len(toRun) { + case 0: + return nil + case 1: + return toRun[0] + } + return func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) { + var err error + for _, r := range toRun { + toAddr, err = r(ctx, fromAddr, toAddr, amt) + if err != nil { + return toAddr, err + } + } + return toAddr, err + } +}