Skip to content

Commit

Permalink
add restriction to keeper
Browse files Browse the repository at this point in the history
  • Loading branch information
hieuvubk committed Sep 26, 2024
1 parent 06cef51 commit 42f4dfd
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 7 deletions.
42 changes: 42 additions & 0 deletions x/bank/v2/depinject.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package bankv2

import (
"fmt"
"maps"
"slices"
"sort"

"cosmossdk.io/core/address"
"cosmossdk.io/core/appmodule"
"cosmossdk.io/depinject"
Expand All @@ -22,6 +27,7 @@ func init() {
appconfig.RegisterModule(
&moduletypes.Module{},
appconfig.Provide(ProvideModule),
appconfig.Invoke(InvokeSetSendRestrictions),
)
}

Expand Down Expand Up @@ -61,3 +67,39 @@ func ProvideModule(in ModuleInputs) ModuleOutputs {
Module: m,
}
}

func InvokeSetSendRestrictions(
config *moduletypes.Module,
keeper keeper.Keeper,
restrictions map[string]types.SendRestrictionFn,
) error {
if config == nil {
return nil
}

modules := slices.Collect(maps.Keys(restrictions))
order := config.RestrictionsOrder
if len(order) == 0 {
order = modules
sort.Strings(order)
}

if len(order) != len(modules) {
return fmt.Errorf("len(restrictions order: %v) != len(restriction modules: %v)", order, modules)
}

if len(modules) == 0 {
return nil
}

for _, module := range order {
restriction, ok := restrictions[module]
if !ok {
return fmt.Errorf("can't find send restriction for module %s", module)
}

keeper.AppendSendRestriction(restriction)
}

return nil
}
35 changes: 28 additions & 7 deletions x/bank/v2/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ type Keeper struct {
params collections.Item[types.Params]
balances *collections.IndexedMap[collections.Pair[[]byte, string], math.Int, BalancesIndexes]
supply collections.Map[string, math.Int]

sendRestriction *sendRestriction
}

func NewKeeper(authority []byte, addressCodec address.Codec, env appmodulev2.Environment, cdc codec.BinaryCodec) *Keeper {
sb := collections.NewSchemaBuilder(env.KVStoreService)

k := &Keeper{
Environment: env,
authority: authority,
addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment?
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)),
supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue),
Environment: env,
authority: authority,
addressCodec: addressCodec, // TODO(@julienrbrt): Should we add address codec to the environment?
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
balances: collections.NewIndexedMap(sb, types.BalancesPrefix, "balances", collections.PairKeyCodec(collections.BytesKey, collections.StringKey), sdk.IntValue, newBalancesIndexes(sb)),
supply: collections.NewMap(sb, types.SupplyKey, "supply", collections.StringKey, sdk.IntValue),
sendRestriction: newSendRestriction(),
}

schema, err := sb.Build()
Expand Down Expand Up @@ -94,7 +97,10 @@ func (k Keeper) SendCoins(ctx context.Context, from, to []byte, amt sdk.Coins) e
}

var err error
// TODO: Send restriction
to, err = k.sendRestriction.apply(ctx, from, to, amt)
if err != nil {
return err
}

err = k.subUnlockedCoins(ctx, from, amt)
if err != nil {
Expand Down Expand Up @@ -252,3 +258,18 @@ func newBalancesIndexes(sb *collections.SchemaBuilder) BalancesIndexes {
type BalancesIndexes struct {
Denom *indexes.ReversePair[[]byte, string, math.Int]
}

// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions.
func (k Keeper) AppendSendRestriction(restriction types.SendRestrictionFn) {
k.sendRestriction.append(restriction)
}

// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions.
func (k Keeper) PrependSendRestriction(restriction types.SendRestrictionFn) {
k.sendRestriction.prepend(restriction)
}

// ClearSendRestriction removes the send restriction (if there is one).
func (k Keeper) ClearSendRestriction() {
k.sendRestriction.clear()
}
47 changes: 47 additions & 0 deletions x/bank/v2/keeper/restriction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package keeper

import (
"context"

"cosmossdk.io/x/bank/v2/types"

sdk "github.com/cosmos/cosmos-sdk/types"
)

// sendRestriction is a struct that houses a SendRestrictionFn.
// It exists so that the SendRestrictionFn can be updated in the SendKeeper without needing to have a pointer receiver.
type sendRestriction struct {
fn types.SendRestrictionFn
}

// newSendRestriction creates a new sendRestriction with nil send restriction.
func newSendRestriction() *sendRestriction {
return &sendRestriction{
fn: nil,
}
}

// append adds the provided restriction to this, to be run after the existing function.
func (r *sendRestriction) append(restriction types.SendRestrictionFn) {
r.fn = r.fn.Then(restriction)
}

// prepend adds the provided restriction to this, to be run before the existing function.
func (r *sendRestriction) prepend(restriction types.SendRestrictionFn) {
r.fn = restriction.Then(r.fn)
}

// clear removes the send restriction (sets it to nil).
func (r *sendRestriction) clear() {
r.fn = nil
}

var _ types.SendRestrictionFn = (*sendRestriction)(nil).apply

// apply applies the send restriction if there is one. If not, it's a no-op.
func (r *sendRestriction) apply(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) {
if r == nil || r.fn == nil {
return toAddr, nil
}
return r.fn(ctx, fromAddr, toAddr, amt)
}
57 changes: 57 additions & 0 deletions x/bank/v2/types/restrictions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package types

import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
)

// A SendRestrictionFn can restrict sends and/or provide a new receiver address.
type SendRestrictionFn func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) (newToAddr []byte, err error)

// IsOnePerModuleType implements the depinject.OnePerModuleType interface.
func (SendRestrictionFn) IsOnePerModuleType() {}

var _ SendRestrictionFn = NoOpSendRestrictionFn

// NoOpSendRestrictionFn is a no-op SendRestrictionFn.
func NoOpSendRestrictionFn(_ context.Context, _, toAddr []byte, _ sdk.Coins) ([]byte, error) {
return toAddr, nil
}

// Then creates a composite restriction that runs this one then the provided second one.
func (r SendRestrictionFn) Then(second SendRestrictionFn) SendRestrictionFn {
return ComposeSendRestrictions(r, second)
}

// ComposeSendRestrictions combines multiple SendRestrictionFn into one.
// nil entries are ignored.
// If all entries are nil, nil is returned.
// If exactly one entry is not nil, it is returned.
// Otherwise, a new SendRestrictionFn is returned that runs the non-nil restrictions in the order they are given.
// The composition runs each send restriction until an error is encountered and returns that error,
// otherwise it returns the toAddr of the last send restriction.
func ComposeSendRestrictions(restrictions ...SendRestrictionFn) SendRestrictionFn {
toRun := make([]SendRestrictionFn, 0, len(restrictions))
for _, r := range restrictions {
if r != nil {
toRun = append(toRun, r)
}
}
switch len(toRun) {
case 0:
return nil
case 1:
return toRun[0]
}
return func(ctx context.Context, fromAddr, toAddr []byte, amt sdk.Coins) ([]byte, error) {
var err error
for _, r := range toRun {
toAddr, err = r(ctx, fromAddr, toAddr, amt)
if err != nil {
return toAddr, err
}
}
return toAddr, err
}
}

0 comments on commit 42f4dfd

Please sign in to comment.