diff --git a/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.rpc.msg.ts b/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.rpc.msg.ts index 3a684ef21c..da9e7b4ae3 100644 --- a/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.rpc.msg.ts +++ b/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.rpc.msg.ts @@ -1,6 +1,6 @@ import { Rpc } from "../../helpers"; import * as _m0 from "protobufjs/minimal"; -import { MsgRegisterAffiliate, MsgRegisterAffiliateResponse, MsgUpdateAffiliateTiers, MsgUpdateAffiliateTiersResponse } from "./tx"; +import { MsgRegisterAffiliate, MsgRegisterAffiliateResponse, MsgUpdateAffiliateTiers, MsgUpdateAffiliateTiersResponse, MsgUpdateAffiliateWhitelist, MsgUpdateAffiliateWhitelistResponse } from "./tx"; /** Msg defines the Msg service. */ export interface Msg { @@ -9,6 +9,9 @@ export interface Msg { /** UpdateAffiliateTiers updates affiliate tiers */ updateAffiliateTiers(request: MsgUpdateAffiliateTiers): Promise; + /** UpdateAffiliateWhitelist updates affiliate whitelist */ + + updateAffiliateWhitelist(request: MsgUpdateAffiliateWhitelist): Promise; } export class MsgClientImpl implements Msg { private readonly rpc: Rpc; @@ -17,6 +20,7 @@ export class MsgClientImpl implements Msg { this.rpc = rpc; this.registerAffiliate = this.registerAffiliate.bind(this); this.updateAffiliateTiers = this.updateAffiliateTiers.bind(this); + this.updateAffiliateWhitelist = this.updateAffiliateWhitelist.bind(this); } registerAffiliate(request: MsgRegisterAffiliate): Promise { @@ -31,4 +35,10 @@ export class MsgClientImpl implements Msg { return promise.then(data => MsgUpdateAffiliateTiersResponse.decode(new _m0.Reader(data))); } + updateAffiliateWhitelist(request: MsgUpdateAffiliateWhitelist): Promise { + const data = MsgUpdateAffiliateWhitelist.encode(request).finish(); + const promise = this.rpc.request("dydxprotocol.affiliates.Msg", "UpdateAffiliateWhitelist", data); + return promise.then(data => MsgUpdateAffiliateWhitelistResponse.decode(new _m0.Reader(data))); + } + } \ No newline at end of file diff --git a/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.ts b/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.ts index 48c2c3ea70..1e6721656a 100644 --- a/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.ts +++ b/indexer/packages/v4-protos/src/codegen/dydxprotocol/affiliates/tx.ts @@ -1,4 +1,4 @@ -import { AffiliateTiers, AffiliateTiersSDKType } from "./affiliates"; +import { AffiliateTiers, AffiliateTiersSDKType, AffiliateWhitelist, AffiliateWhitelistSDKType } from "./affiliates"; import * as _m0 from "protobufjs/minimal"; import { DeepPartial } from "../../helpers"; /** Message to register a referee-affiliate relationship */ @@ -49,6 +49,30 @@ export interface MsgUpdateAffiliateTiersResponse {} /** Response to MsgUpdateAffiliateTiers */ export interface MsgUpdateAffiliateTiersResponseSDKType {} +/** Message to update affiliate whitelist */ + +export interface MsgUpdateAffiliateWhitelist { + /** Authority sending this message. Will be sent by gov */ + authority: string; + /** Updated affiliate whitelist information */ + + whitelist?: AffiliateWhitelist; +} +/** Message to update affiliate whitelist */ + +export interface MsgUpdateAffiliateWhitelistSDKType { + /** Authority sending this message. Will be sent by gov */ + authority: string; + /** Updated affiliate whitelist information */ + + whitelist?: AffiliateWhitelistSDKType; +} +/** Response to MsgUpdateAffiliateWhitelist */ + +export interface MsgUpdateAffiliateWhitelistResponse {} +/** Response to MsgUpdateAffiliateWhitelist */ + +export interface MsgUpdateAffiliateWhitelistResponseSDKType {} function createBaseMsgRegisterAffiliate(): MsgRegisterAffiliate { return { @@ -226,4 +250,93 @@ export const MsgUpdateAffiliateTiersResponse = { return message; } +}; + +function createBaseMsgUpdateAffiliateWhitelist(): MsgUpdateAffiliateWhitelist { + return { + authority: "", + whitelist: undefined + }; +} + +export const MsgUpdateAffiliateWhitelist = { + encode(message: MsgUpdateAffiliateWhitelist, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.authority !== "") { + writer.uint32(10).string(message.authority); + } + + if (message.whitelist !== undefined) { + AffiliateWhitelist.encode(message.whitelist, writer.uint32(18).fork()).ldelim(); + } + + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): MsgUpdateAffiliateWhitelist { + const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseMsgUpdateAffiliateWhitelist(); + + while (reader.pos < end) { + const tag = reader.uint32(); + + switch (tag >>> 3) { + case 1: + message.authority = reader.string(); + break; + + case 2: + message.whitelist = AffiliateWhitelist.decode(reader, reader.uint32()); + break; + + default: + reader.skipType(tag & 7); + break; + } + } + + return message; + }, + + fromPartial(object: DeepPartial): MsgUpdateAffiliateWhitelist { + const message = createBaseMsgUpdateAffiliateWhitelist(); + message.authority = object.authority ?? ""; + message.whitelist = object.whitelist !== undefined && object.whitelist !== null ? AffiliateWhitelist.fromPartial(object.whitelist) : undefined; + return message; + } + +}; + +function createBaseMsgUpdateAffiliateWhitelistResponse(): MsgUpdateAffiliateWhitelistResponse { + return {}; +} + +export const MsgUpdateAffiliateWhitelistResponse = { + encode(_: MsgUpdateAffiliateWhitelistResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): MsgUpdateAffiliateWhitelistResponse { + const reader = input instanceof _m0.Reader ? input : new _m0.Reader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseMsgUpdateAffiliateWhitelistResponse(); + + while (reader.pos < end) { + const tag = reader.uint32(); + + switch (tag >>> 3) { + default: + reader.skipType(tag & 7); + break; + } + } + + return message; + }, + + fromPartial(_: DeepPartial): MsgUpdateAffiliateWhitelistResponse { + const message = createBaseMsgUpdateAffiliateWhitelistResponse(); + return message; + } + }; \ No newline at end of file diff --git a/proto/dydxprotocol/affiliates/tx.proto b/proto/dydxprotocol/affiliates/tx.proto index 420f0f7863..d0ed36f7f4 100644 --- a/proto/dydxprotocol/affiliates/tx.proto +++ b/proto/dydxprotocol/affiliates/tx.proto @@ -16,6 +16,9 @@ service Msg { // UpdateAffiliateTiers updates affiliate tiers rpc UpdateAffiliateTiers(MsgUpdateAffiliateTiers) returns (MsgUpdateAffiliateTiersResponse); + // UpdateAffiliateWhitelist updates affiliate whitelist + rpc UpdateAffiliateWhitelist(MsgUpdateAffiliateWhitelist) + returns (MsgUpdateAffiliateWhitelistResponse); } // Message to register a referee-affiliate relationship @@ -43,4 +46,16 @@ message MsgUpdateAffiliateTiers { } // Response to MsgUpdateAffiliateTiers -message MsgUpdateAffiliateTiersResponse {} \ No newline at end of file +message MsgUpdateAffiliateTiersResponse {} + +// Message to update affiliate whitelist +message MsgUpdateAffiliateWhitelist { + option (cosmos.msg.v1.signer) = "authority"; + // Authority sending this message. Will be sent by gov + string authority = 1 [ (cosmos_proto.scalar) = "cosmos.AddressString" ]; + // Updated affiliate whitelist information + AffiliateWhitelist whitelist = 2 [ (gogoproto.nullable) = false ]; +} + +// Response to MsgUpdateAffiliateWhitelist +message MsgUpdateAffiliateWhitelistResponse {} \ No newline at end of file diff --git a/protocol/app/app.go b/protocol/app/app.go index ca7f12c16d..953fc8e99a 100644 --- a/protocol/app/app.go +++ b/protocol/app/app.go @@ -1124,6 +1124,7 @@ func New( app.PricesKeeper, app.StatsKeeper, app.RewardsKeeper, + app.AffiliatesKeeper, app.IndexerEventManager, app.FullNodeStreamingManager, txConfig.TxDecoder(), diff --git a/protocol/app/msgs/all_msgs.go b/protocol/app/msgs/all_msgs.go index b864918dcd..6825f7e259 100644 --- a/protocol/app/msgs/all_msgs.go +++ b/protocol/app/msgs/all_msgs.go @@ -149,11 +149,12 @@ var ( "/cosmos.upgrade.v1beta1.SoftwareUpgradeProposal": {}, // affiliates - "/dydxprotocol.affiliates.MsgRegisterAffiliate": {}, - "/dydxprotocol.affiliates.MsgRegisterAffiliateResponse": {}, - "/dydxprotocol.affiliates.MsgUpdateAffiliateTiers": {}, - "/dydxprotocol.affiliates.MsgUpdateAffiliateTiersResponse": {}, - + "/dydxprotocol.affiliates.MsgRegisterAffiliate": {}, + "/dydxprotocol.affiliates.MsgRegisterAffiliateResponse": {}, + "/dydxprotocol.affiliates.MsgUpdateAffiliateTiers": {}, + "/dydxprotocol.affiliates.MsgUpdateAffiliateTiersResponse": {}, + "/dydxprotocol.affiliates.MsgUpdateAffiliateWhitelist": {}, + "/dydxprotocol.affiliates.MsgUpdateAffiliateWhitelistResponse": {}, // blocktime "/dydxprotocol.blocktime.MsgUpdateDowntimeParams": {}, "/dydxprotocol.blocktime.MsgUpdateDowntimeParamsResponse": {}, diff --git a/protocol/app/msgs/internal_msgs.go b/protocol/app/msgs/internal_msgs.go index 1403318f96..8e030a9437 100644 --- a/protocol/app/msgs/internal_msgs.go +++ b/protocol/app/msgs/internal_msgs.go @@ -108,8 +108,10 @@ var ( InternalMsgSamplesDydxCustom = map[string]sdk.Msg{ // affiliates - "/dydxprotocol.affiliates.MsgUpdateAffiliateTiers": &affiliates.MsgUpdateAffiliateTiers{}, - "/dydxprotocol.affiliates.MsgUpdateAffiliateTiersResponse": nil, + "/dydxprotocol.affiliates.MsgUpdateAffiliateTiers": &affiliates.MsgUpdateAffiliateTiers{}, + "/dydxprotocol.affiliates.MsgUpdateAffiliateTiersResponse": nil, + "/dydxprotocol.affiliates.MsgUpdateAffiliateWhitelist": &affiliates.MsgUpdateAffiliateWhitelist{}, + "/dydxprotocol.affiliates.MsgUpdateAffiliateWhitelistResponse": nil, // blocktime "/dydxprotocol.blocktime.MsgUpdateDowntimeParams": &blocktime.MsgUpdateDowntimeParams{}, diff --git a/protocol/app/msgs/internal_msgs_test.go b/protocol/app/msgs/internal_msgs_test.go index 0b183c962e..316256964c 100644 --- a/protocol/app/msgs/internal_msgs_test.go +++ b/protocol/app/msgs/internal_msgs_test.go @@ -66,6 +66,8 @@ func TestInternalMsgSamples_Gov_Key(t *testing.T) { // affiliates "/dydxprotocol.affiliates.MsgUpdateAffiliateTiers", "/dydxprotocol.affiliates.MsgUpdateAffiliateTiersResponse", + "/dydxprotocol.affiliates.MsgUpdateAffiliateWhitelist", + "/dydxprotocol.affiliates.MsgUpdateAffiliateWhitelistResponse", // blocktime "/dydxprotocol.blocktime.MsgUpdateDowntimeParams", diff --git a/protocol/lib/ante/internal_msg.go b/protocol/lib/ante/internal_msg.go index e58e43d43b..d26333d116 100644 --- a/protocol/lib/ante/internal_msg.go +++ b/protocol/lib/ante/internal_msg.go @@ -145,7 +145,8 @@ func IsInternalMsg(msg sdk.Msg) bool { *ibcconn.MsgUpdateParams, // affiliates - *affiliates.MsgUpdateAffiliateTiers: + *affiliates.MsgUpdateAffiliateTiers, + *affiliates.MsgUpdateAffiliateWhitelist: return true diff --git a/protocol/mocks/ClobKeeper.go b/protocol/mocks/ClobKeeper.go index 5a65322993..cd70f7d85d 100644 --- a/protocol/mocks/ClobKeeper.go +++ b/protocol/mocks/ClobKeeper.go @@ -999,9 +999,9 @@ func (_m *ClobKeeper) ProcessProposerOperations(ctx types.Context, operations [] return r0 } -// ProcessSingleMatch provides a mock function with given fields: ctx, matchWithOrders -func (_m *ClobKeeper) ProcessSingleMatch(ctx types.Context, matchWithOrders *clobtypes.MatchWithOrders) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error) { - ret := _m.Called(ctx, matchWithOrders) +// ProcessSingleMatch provides a mock function with given fields: ctx, matchWithOrders, affiliatesWhitelistMap +func (_m *ClobKeeper) ProcessSingleMatch(ctx types.Context, matchWithOrders *clobtypes.MatchWithOrders, affiliatesWhitelistMap map[string]uint32) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error) { + ret := _m.Called(ctx, matchWithOrders, affiliatesWhitelistMap) if len(ret) == 0 { panic("no return value specified for ProcessSingleMatch") @@ -1012,37 +1012,37 @@ func (_m *ClobKeeper) ProcessSingleMatch(ctx types.Context, matchWithOrders *clo var r2 subaccountstypes.UpdateResult var r3 *big.Int var r4 error - if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error)); ok { - return rf(ctx, matchWithOrders) + if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error)); ok { + return rf(ctx, matchWithOrders, affiliatesWhitelistMap) } - if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders) bool); ok { - r0 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) bool); ok { + r0 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(types.Context, *clobtypes.MatchWithOrders) subaccountstypes.UpdateResult); ok { - r1 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(1).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) subaccountstypes.UpdateResult); ok { + r1 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r1 = ret.Get(1).(subaccountstypes.UpdateResult) } - if rf, ok := ret.Get(2).(func(types.Context, *clobtypes.MatchWithOrders) subaccountstypes.UpdateResult); ok { - r2 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(2).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) subaccountstypes.UpdateResult); ok { + r2 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r2 = ret.Get(2).(subaccountstypes.UpdateResult) } - if rf, ok := ret.Get(3).(func(types.Context, *clobtypes.MatchWithOrders) *big.Int); ok { - r3 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(3).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) *big.Int); ok { + r3 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { if ret.Get(3) != nil { r3 = ret.Get(3).(*big.Int) } } - if rf, ok := ret.Get(4).(func(types.Context, *clobtypes.MatchWithOrders) error); ok { - r4 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(4).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) error); ok { + r4 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r4 = ret.Error(4) } diff --git a/protocol/mocks/MemClobKeeper.go b/protocol/mocks/MemClobKeeper.go index adfdd745f4..12a2f8cff3 100644 --- a/protocol/mocks/MemClobKeeper.go +++ b/protocol/mocks/MemClobKeeper.go @@ -320,9 +320,9 @@ func (_m *MemClobKeeper) OffsetSubaccountPerpetualPosition(ctx types.Context, li return r0, r1 } -// ProcessSingleMatch provides a mock function with given fields: ctx, matchWithOrders -func (_m *MemClobKeeper) ProcessSingleMatch(ctx types.Context, matchWithOrders *clobtypes.MatchWithOrders) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error) { - ret := _m.Called(ctx, matchWithOrders) +// ProcessSingleMatch provides a mock function with given fields: ctx, matchWithOrders, affiliatesWhitelistMap +func (_m *MemClobKeeper) ProcessSingleMatch(ctx types.Context, matchWithOrders *clobtypes.MatchWithOrders, affiliatesWhitelistMap map[string]uint32) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error) { + ret := _m.Called(ctx, matchWithOrders, affiliatesWhitelistMap) if len(ret) == 0 { panic("no return value specified for ProcessSingleMatch") @@ -333,37 +333,37 @@ func (_m *MemClobKeeper) ProcessSingleMatch(ctx types.Context, matchWithOrders * var r2 subaccountstypes.UpdateResult var r3 *big.Int var r4 error - if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error)); ok { - return rf(ctx, matchWithOrders) + if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) (bool, subaccountstypes.UpdateResult, subaccountstypes.UpdateResult, *big.Int, error)); ok { + return rf(ctx, matchWithOrders, affiliatesWhitelistMap) } - if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders) bool); ok { - r0 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(0).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) bool); ok { + r0 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(types.Context, *clobtypes.MatchWithOrders) subaccountstypes.UpdateResult); ok { - r1 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(1).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) subaccountstypes.UpdateResult); ok { + r1 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r1 = ret.Get(1).(subaccountstypes.UpdateResult) } - if rf, ok := ret.Get(2).(func(types.Context, *clobtypes.MatchWithOrders) subaccountstypes.UpdateResult); ok { - r2 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(2).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) subaccountstypes.UpdateResult); ok { + r2 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r2 = ret.Get(2).(subaccountstypes.UpdateResult) } - if rf, ok := ret.Get(3).(func(types.Context, *clobtypes.MatchWithOrders) *big.Int); ok { - r3 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(3).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) *big.Int); ok { + r3 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { if ret.Get(3) != nil { r3 = ret.Get(3).(*big.Int) } } - if rf, ok := ret.Get(4).(func(types.Context, *clobtypes.MatchWithOrders) error); ok { - r4 = rf(ctx, matchWithOrders) + if rf, ok := ret.Get(4).(func(types.Context, *clobtypes.MatchWithOrders, map[string]uint32) error); ok { + r4 = rf(ctx, matchWithOrders, affiliatesWhitelistMap) } else { r4 = ret.Error(4) } diff --git a/protocol/testutil/keeper/clob.go b/protocol/testutil/keeper/clob.go index f5bee3fad1..0b75109217 100644 --- a/protocol/testutil/keeper/clob.go +++ b/protocol/testutil/keeper/clob.go @@ -17,6 +17,7 @@ import ( streaming "github.com/dydxprotocol/v4-chain/protocol/streaming" clobtest "github.com/dydxprotocol/v4-chain/protocol/testutil/clob" "github.com/dydxprotocol/v4-chain/protocol/testutil/constants" + affiliateskeeper "github.com/dydxprotocol/v4-chain/protocol/x/affiliates/keeper" asskeeper "github.com/dydxprotocol/v4-chain/protocol/x/assets/keeper" blocktimekeeper "github.com/dydxprotocol/v4-chain/protocol/x/blocktime/keeper" "github.com/dydxprotocol/v4-chain/protocol/x/clob/flags" @@ -48,6 +49,7 @@ type ClobKeepersTestContext struct { StatsKeeper *statskeeper.Keeper RewardsKeeper *rewardskeeper.Keeper SubaccountsKeeper *subkeeper.Keeper + AffiliatesKeeper *affiliateskeeper.Keeper VaultKeeper *vaultkeeper.Keeper StoreKey storetypes.StoreKey MemKey storetypes.StoreKey @@ -105,9 +107,9 @@ func NewClobKeepersTestContextWithUninitializedMemStore( cdc, stakingKeeper, ) - affiliatesKeeper, _ := createAffiliatesKeeper(stateStore, db, cdc, ks.StatsKeeper, + ks.AffiliatesKeeper, _ = createAffiliatesKeeper(stateStore, db, cdc, ks.StatsKeeper, indexerEventsTransientStoreKey, true) - revShareKeeper, _, _ := createRevShareKeeper(stateStore, db, cdc, affiliatesKeeper) + revShareKeeper, _, _ := createRevShareKeeper(stateStore, db, cdc, ks.AffiliatesKeeper) ks.MarketMapKeeper, _ = createMarketMapKeeper(stateStore, db, cdc) ks.PricesKeeper, _, _, mockTimeProvider = createPricesKeeper( stateStore, @@ -146,7 +148,7 @@ func NewClobKeepersTestContextWithUninitializedMemStore( stateStore, ks.StatsKeeper, ks.VaultKeeper, - affiliatesKeeper, + ks.AffiliatesKeeper, db, cdc, ) @@ -185,6 +187,7 @@ func NewClobKeepersTestContextWithUninitializedMemStore( ks.PricesKeeper, ks.StatsKeeper, ks.RewardsKeeper, + ks.AffiliatesKeeper, ks.SubaccountsKeeper, revShareKeeper, indexerEventManager, @@ -224,6 +227,7 @@ func createClobKeeper( pricesKeeper *priceskeeper.Keeper, statsKeeper *statskeeper.Keeper, rewardsKeeper types.RewardsKeeper, + affiliatesKeeper types.AffiliatesKeeper, saKeeper *subkeeper.Keeper, revShareKeeper types.RevShareKeeper, indexerEventManager indexer_manager.IndexerEventManager, @@ -256,6 +260,7 @@ func createClobKeeper( pricesKeeper, statsKeeper, rewardsKeeper, + affiliatesKeeper, indexerEventManager, streaming.NewNoopGrpcStreamingManager(), constants.TestEncodingCfg.TxConfig.TxDecoder(), diff --git a/protocol/testutil/keeper/listing.go b/protocol/testutil/keeper/listing.go index dbf11239f8..0c02b3c05a 100644 --- a/protocol/testutil/keeper/listing.go +++ b/protocol/testutil/keeper/listing.go @@ -150,6 +150,7 @@ func ListingKeepers( pricesKeeper, statsKeeper, rewardsKeeper, + affiliatesKeeper, subaccountsKeeper, revShareKeeper, indexerEventManager, diff --git a/protocol/testutil/memclob/keeper.go b/protocol/testutil/memclob/keeper.go index 137df08835..22ba76acdd 100644 --- a/protocol/testutil/memclob/keeper.go +++ b/protocol/testutil/memclob/keeper.go @@ -327,6 +327,7 @@ func (f *FakeMemClobKeeper) addFakeFillAmount( func (f *FakeMemClobKeeper) ProcessSingleMatch( ctx sdk.Context, matchWithOrders *types.MatchWithOrders, + affiliatesWhitelistMap map[string]uint32, ) ( success bool, takerUpdateResult satypes.UpdateResult, diff --git a/protocol/x/affiliates/client/cli/query.go b/protocol/x/affiliates/client/cli/query.go index 872ba5443b..4b79b11d2f 100644 --- a/protocol/x/affiliates/client/cli/query.go +++ b/protocol/x/affiliates/client/cli/query.go @@ -25,6 +25,7 @@ func GetQueryCmd(queryRoute string) *cobra.Command { GetCmdQueryAffiliateTiers(), GetCmdQueryAffiliateInfo(), GetCmdQueryReferredBy(), + GetCmdQueryAffiliateWhitelist(), ) return cmd } @@ -94,3 +95,23 @@ func GetCmdQueryReferredBy() *cobra.Command { } return cmd } + +func GetCmdQueryAffiliateWhitelist() *cobra.Command { + cmd := &cobra.Command{ + Use: "affiliate-whitelist", + Short: "Query affiliate whitelist", + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx, err := client.GetClientQueryContext(cmd) + if err != nil { + return err + } + queryClient := types.NewQueryClient(clientCtx) + res, err := queryClient.AffiliateWhitelist(context.Background(), &types.AffiliateWhitelistRequest{}) + if err != nil { + return err + } + return clientCtx.PrintProto(res) + }, + } + return cmd +} diff --git a/protocol/x/affiliates/client/cli/query_test.go b/protocol/x/affiliates/client/cli/query_test.go index 036e64a404..d0ab812240 100644 --- a/protocol/x/affiliates/client/cli/query_test.go +++ b/protocol/x/affiliates/client/cli/query_test.go @@ -72,3 +72,13 @@ func TestQueryReferredBy(t *testing.T) { var resp types.ReferredByResponse require.NoError(t, net.Config.Codec.UnmarshalJSON(out.Bytes(), &resp)) } + +func TestQueryAffiliateWhitelist(t *testing.T) { + net, ctx := setupNetwork(t) + + out, err := clitestutil.ExecTestCLICmd(ctx, cli.GetCmdQueryAffiliateWhitelist(), []string{}) + require.NoError(t, err) + + var resp types.AffiliateWhitelistResponse + require.NoError(t, net.Config.Codec.UnmarshalJSON(out.Bytes(), &resp)) +} diff --git a/protocol/x/affiliates/keeper/grpc_query.go b/protocol/x/affiliates/keeper/grpc_query.go index a4a992b271..63b10dca48 100644 --- a/protocol/x/affiliates/keeper/grpc_query.go +++ b/protocol/x/affiliates/keeper/grpc_query.go @@ -2,7 +2,6 @@ package keeper import ( "context" - "errors" errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" @@ -23,10 +22,22 @@ func (k Keeper) AffiliateInfo(c context.Context, req.GetAddress(), err.Error()) } - tierLevel, feeSharePpm, err := k.GetTierForAffiliate(ctx, addr.String()) + affiliateWhitelistMap, err := k.GetAffiliateWhitelistMap(ctx) if err != nil { return nil, err } + tierLevel := uint32(0) + feeSharePpm := uint32(0) + isWhitelisted := false + if _, exists := affiliateWhitelistMap[addr.String()]; exists { + feeSharePpm = affiliateWhitelistMap[addr.String()] + isWhitelisted = true + } else { + tierLevel, feeSharePpm, err = k.GetTierForAffiliate(ctx, addr.String()) + if err != nil { + return nil, err + } + } referredVolume, err := k.GetReferredVolume(ctx, req.GetAddress()) if err != nil { @@ -36,6 +47,7 @@ func (k Keeper) AffiliateInfo(c context.Context, stakedAmount := k.statsKeeper.GetStakedAmount(ctx, req.GetAddress()) return &types.AffiliateInfoResponse{ + IsWhitelisted: isWhitelisted, Tier: tierLevel, FeeSharePpm: feeSharePpm, ReferredVolume: dtypes.NewIntFromBigInt(referredVolume), @@ -73,6 +85,14 @@ func (k Keeper) AllAffiliateTiers(c context.Context, func (k Keeper) AffiliateWhitelist(c context.Context, req *types.AffiliateWhitelistRequest) (*types.AffiliateWhitelistResponse, error) { - // TODO(OTE-791): Implement `AffiliateWhitelist` RPC method. - return nil, errors.New("not implemented") + ctx := sdk.UnwrapSDKContext(c) + + affiliateWhitelist, err := k.GetAffiliateWhitelist(ctx) + if err != nil { + return nil, err + } + + return &types.AffiliateWhitelistResponse{ + Whitelist: affiliateWhitelist, + }, nil } diff --git a/protocol/x/affiliates/keeper/grpc_query_test.go b/protocol/x/affiliates/keeper/grpc_query_test.go index dce8e7c35e..adeedd9269 100644 --- a/protocol/x/affiliates/keeper/grpc_query_test.go +++ b/protocol/x/affiliates/keeper/grpc_query_test.go @@ -29,6 +29,7 @@ func TestAffiliateInfo(t *testing.T) { Address: constants.AliceAccAddress.String(), }, res: &types.AffiliateInfoResponse{ + IsWhitelisted: false, Tier: 0, FeeSharePpm: types.DefaultAffiliateTiers.Tiers[0].TakerFeeSharePpm, ReferredVolume: dtypes.NewIntFromUint64(types.DefaultAffiliateTiers.Tiers[0].ReqReferredVolumeQuoteQuantums), @@ -58,6 +59,7 @@ func TestAffiliateInfo(t *testing.T) { Address: constants.AliceAccAddress.String(), }, res: &types.AffiliateInfoResponse{ + IsWhitelisted: false, Tier: 0, FeeSharePpm: types.DefaultAffiliateTiers.Tiers[0].TakerFeeSharePpm, ReferredVolume: dtypes.NewIntFromUint64(types.DefaultAffiliateTiers.Tiers[0].ReqReferredVolumeQuoteQuantums), @@ -89,6 +91,44 @@ func TestAffiliateInfo(t *testing.T) { setup: func(ctx sdk.Context, k keeper.Keeper, tApp *testapp.TestApp) {}, expectError: types.ErrInvalidAddress, }, + "Whitelisted": { + req: &types.AffiliateInfoRequest{ + Address: constants.AliceAccAddress.String(), + }, + res: &types.AffiliateInfoResponse{ + IsWhitelisted: true, + Tier: 0, + FeeSharePpm: 120_000, + ReferredVolume: dtypes.NewIntFromUint64(0), + StakedAmount: dtypes.NewIntFromUint64(0), + }, + setup: func(ctx sdk.Context, k keeper.Keeper, tApp *testapp.TestApp) { + err := k.RegisterAffiliate(ctx, constants.BobAccAddress.String(), constants.AliceAccAddress.String()) + require.NoError(t, err) + + stakingKeeper := tApp.App.StakingKeeper + + err = stakingKeeper.SetDelegation(ctx, + stakingtypes.NewDelegation(constants.AliceAccAddress.String(), + constants.AliceValAddress.String(), math.LegacyNewDecFromBigInt( + big.NewInt(0), + ), + ), + ) + require.NoError(t, err) + + affiliatesWhitelist := types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 120_000, // 12% + }, + }, + } + err = k.SetAffiliateWhitelist(ctx, affiliatesWhitelist) + require.NoError(t, err) + }, + }, } for name, tc := range testCases { @@ -184,3 +224,26 @@ func TestAllAffiliateTiers(t *testing.T) { require.NotNil(t, res) require.Equal(t, &types.AllAffiliateTiersResponse{Tiers: tiers}, res) } + +func TestAffiliateWhitelist(t *testing.T) { + tApp := testapp.NewTestAppBuilder(t).Build() + ctx := tApp.InitChain() + k := tApp.App.AffiliatesKeeper + + req := &types.AffiliateWhitelistRequest{} + whitelist := types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 100_000, + }, + }, + } + err := k.SetAffiliateWhitelist(ctx, whitelist) + require.NoError(t, err) + + res, err := k.AffiliateWhitelist(ctx, req) + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, &types.AffiliateWhitelistResponse{Whitelist: whitelist}, res) +} diff --git a/protocol/x/affiliates/keeper/keeper.go b/protocol/x/affiliates/keeper/keeper.go index e76d6d938e..53cc82f772 100644 --- a/protocol/x/affiliates/keeper/keeper.go +++ b/protocol/x/affiliates/keeper/keeper.go @@ -161,10 +161,12 @@ func (k Keeper) GetAllAffiliateTiers(ctx sdk.Context) (types.AffiliateTiers, err return affiliateTiers, nil } -// GetTakerFeeShare returns the taker fee share for an address. +// GetTakerFeeShare returns the taker fee share for an address based on the affiliate tiers. +// If the address is in the whitelist, the fee share ppm is overridden. func (k Keeper) GetTakerFeeShare( ctx sdk.Context, address string, + affiliatesWhitelistMap map[string]uint32, ) ( affiliateAddress string, feeSharePpm uint32, @@ -175,6 +177,12 @@ func (k Keeper) GetTakerFeeShare( if !exists { return "", 0, false, nil } + // Override fee share ppm if the address is in the whitelist. + if _, exists := affiliatesWhitelistMap[affiliateAddress]; exists { + feeSharePpm = affiliatesWhitelistMap[affiliateAddress] + return affiliateAddress, feeSharePpm, true, nil + } + _, feeSharePpm, err = k.GetTierForAffiliate(ctx, affiliateAddress) if err != nil { return "", 0, false, err @@ -241,6 +249,13 @@ func (k Keeper) UpdateAffiliateTiers(ctx sdk.Context, affiliateTiers types.Affil tiers := affiliateTiers.GetTiers() // start at 1, since 0 is the default tier. for i := 1; i < len(tiers); i++ { + // Check if the taker fee share ppm is greater than the cap. + if tiers[i].TakerFeeSharePpm > types.AffiliatesRevSharePpmCap { + return errorsmod.Wrapf(types.ErrRevShareSafetyViolation, + "taker fee share ppm %d is greater than the cap %d", + tiers[i].TakerFeeSharePpm, types.AffiliatesRevSharePpmCap) + } + // Check if the tiers are strictly increasing. if tiers[i].ReqReferredVolumeQuoteQuantums <= tiers[i-1].ReqReferredVolumeQuoteQuantums || tiers[i].ReqStakedWholeCoins <= tiers[i-1].ReqStakedWholeCoins { return errorsmod.Wrapf(types.ErrInvalidAffiliateTiers, @@ -259,6 +274,59 @@ func (k Keeper) GetIndexerEventManager() indexer_manager.IndexerEventManager { return k.indexerEventManager } +func (k Keeper) GetAffiliateWhitelistMap(ctx sdk.Context) (map[string]uint32, error) { + affiliateWhitelist, err := k.GetAffiliateWhitelist(ctx) + if err != nil { + return nil, err + } + affiliateWhitelistMap := make(map[string]uint32) + for _, tier := range affiliateWhitelist.GetTiers() { + for _, address := range tier.GetAddresses() { + affiliateWhitelistMap[address] = tier.GetTakerFeeSharePpm() + } + } + return affiliateWhitelistMap, nil +} + +func (k Keeper) SetAffiliateWhitelist(ctx sdk.Context, whitelist types.AffiliateWhitelist) error { + store := ctx.KVStore(k.storeKey) + addressSet := make(map[string]bool) + for _, tier := range whitelist.Tiers { + // Check if the taker fee share ppm is greater than the cap. + if tier.TakerFeeSharePpm > types.AffiliatesRevSharePpmCap { + return errorsmod.Wrapf(types.ErrRevShareSafetyViolation, + "taker fee share ppm %d is greater than the cap %d", + tier.TakerFeeSharePpm, types.AffiliatesRevSharePpmCap) + } + // Check for duplicate addresses. + for _, address := range tier.Addresses { + if addressSet[address] { + return errorsmod.Wrapf(types.ErrDuplicateAffiliateAddressForWhitelist, + "address %s is duplicated in affiliate whitelist", address) + } + addressSet[address] = true + } + } + affiliateWhitelistBytes := k.cdc.MustMarshal(&whitelist) + store.Set([]byte(types.AffiliateWhitelistKey), affiliateWhitelistBytes) + return nil +} + +func (k Keeper) GetAffiliateWhitelist(ctx sdk.Context) (types.AffiliateWhitelist, error) { + store := ctx.KVStore(k.storeKey) + affiliateWhitelistBytes := store.Get([]byte(types.AffiliateWhitelistKey)) + if affiliateWhitelistBytes == nil { + return types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{}, + }, nil + } + affiliateWhitelist := types.AffiliateWhitelist{} + err := k.cdc.Unmarshal(affiliateWhitelistBytes, &affiliateWhitelist) + if err != nil { + return types.AffiliateWhitelist{}, err + } + return affiliateWhitelist, nil +} func (k Keeper) AggregateAffiliateReferredVolumeForFills( ctx sdk.Context, ) error { diff --git a/protocol/x/affiliates/keeper/keeper_test.go b/protocol/x/affiliates/keeper/keeper_test.go index 0488c538a6..c2d5805ec6 100644 --- a/protocol/x/affiliates/keeper/keeper_test.go +++ b/protocol/x/affiliates/keeper/keeper_test.go @@ -173,7 +173,7 @@ func TestGetTakerFeeShareViaReferredVolume(t *testing.T) { require.NoError(t, err) // Get taker fee share for referee - affiliateAddr, feeSharePpm, exists, err := k.GetTakerFeeShare(ctx, referee) + affiliateAddr, feeSharePpm, exists, err := k.GetTakerFeeShare(ctx, referee, map[string]uint32{}) require.NoError(t, err) require.True(t, exists) require.Equal(t, affiliate, affiliateAddr) @@ -186,7 +186,7 @@ func TestGetTakerFeeShareViaReferredVolume(t *testing.T) { require.NoError(t, err) // Get updated taker fee share for referee - affiliateAddr, feeSharePpm, exists, err = k.GetTakerFeeShare(ctx, referee) + affiliateAddr, feeSharePpm, exists, err = k.GetTakerFeeShare(ctx, referee, map[string]uint32{}) require.NoError(t, err) require.True(t, exists) require.Equal(t, affiliate, affiliateAddr) @@ -222,7 +222,7 @@ func TestGetTakerFeeShareViaStakedAmount(t *testing.T) { require.NoError(t, err) // Get taker fee share for referee - affiliateAddr, feeSharePpm, exists, err := k.GetTakerFeeShare(ctx, referee) + affiliateAddr, feeSharePpm, exists, err := k.GetTakerFeeShare(ctx, referee, map[string]uint32{}) require.NoError(t, err) require.True(t, exists) require.Equal(t, affiliate, affiliateAddr) @@ -240,7 +240,7 @@ func TestGetTakerFeeShareViaStakedAmount(t *testing.T) { )))) require.NoError(t, err) // Get updated taker fee share for referee - affiliateAddr, feeSharePpm, exists, err = k.GetTakerFeeShare(ctx, referee) + affiliateAddr, feeSharePpm, exists, err = k.GetTakerFeeShare(ctx, referee, map[string]uint32{}) require.NoError(t, err) require.True(t, exists) require.Equal(t, affiliate, affiliateAddr) @@ -334,6 +334,16 @@ func TestUpdateAffiliateTiers(t *testing.T) { }, expectedError: types.ErrInvalidAffiliateTiers, }, + { + name: "Taker fee share ppm greater than cap", + affiliateTiers: types.AffiliateTiers{ + Tiers: []types.AffiliateTiers_Tier{ + {ReqReferredVolumeQuoteQuantums: 1000, ReqStakedWholeCoins: 100, TakerFeeSharePpm: 100}, + {ReqReferredVolumeQuoteQuantums: 2000, ReqStakedWholeCoins: 200, TakerFeeSharePpm: 550_000}, // 55% + }, + }, + expectedError: types.ErrRevShareSafetyViolation, + }, } for _, tc := range tests { @@ -394,6 +404,241 @@ func TestRegisterAffiliateEmitEvent(t *testing.T) { require.Equal(t, expectedEvent, events[0]) } +func TestSetAffiliateWhitelist(t *testing.T) { + ctx, k, _, _ := keepertest.AffiliatesKeepers(t, true) + + testCases := []struct { + name string + whitelist types.AffiliateWhitelist + expectedError error + }{ + { + name: "Single tier with single address", + whitelist: types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 100_000, // 10% + }, + }, + }, + expectedError: nil, + }, + { + name: "Multiple tiers with multiple addresses", + whitelist: types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String(), constants.BobAccAddress.String()}, + TakerFeeSharePpm: 200_000, // 20% + }, + { + Addresses: []string{constants.CarlAccAddress.String()}, + TakerFeeSharePpm: 300_000, // 30% + }, + }, + }, + expectedError: nil, + }, + { + name: "Duplicate address across tiers", + whitelist: types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 250_000, // 25% + }, + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 350_000, // 35% + }, + }, + }, + expectedError: types.ErrDuplicateAffiliateAddressForWhitelist, + }, + { + name: "Taker fee share ppm greater than cap", + whitelist: types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 550_000, // 55% + }, + }, + }, + expectedError: types.ErrRevShareSafetyViolation, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := k.SetAffiliateWhitelist(ctx, tc.whitelist) + if tc.expectedError != nil { + require.ErrorIs(t, err, tc.expectedError) + } else { + require.NoError(t, err) + + storedWhitelist, err := k.GetAffiliateWhitelist(ctx) + require.NoError(t, err) + require.Equal(t, tc.whitelist, storedWhitelist) + } + }) + } +} + +func TestGetAffiliateWhiteListMap(t *testing.T) { + testCases := []struct { + name string + whitelist *types.AffiliateWhitelist + expectedLength int + expectedMap map[string]uint32 + }{ + { + name: "Multiple tiers with multiple addresses", + whitelist: &types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String(), constants.CarlAccAddress.String()}, + TakerFeeSharePpm: 100_000, // 10% + }, + { + Addresses: []string{constants.BobAccAddress.String()}, + TakerFeeSharePpm: 200_000, // 20% + }, + }, + }, + expectedLength: 3, + expectedMap: map[string]uint32{ + constants.AliceAccAddress.String(): 100_000, // 10% + constants.CarlAccAddress.String(): 100_000, // 10% + constants.BobAccAddress.String(): 200_000, // 20% + }, + }, + { + name: "Single tier with single address", + whitelist: &types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 150_000, // 15% + }, + }, + }, + expectedLength: 1, + expectedMap: map[string]uint32{ + constants.AliceAccAddress.String(): 150_000, // 15% + }, + }, + { + name: "Empty tiers", + whitelist: &types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{}, + }, + expectedLength: 0, + expectedMap: map[string]uint32{}, + }, + { + name: "tiers not set", + whitelist: nil, + expectedLength: 0, + expectedMap: map[string]uint32{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, k, _, _ := keepertest.AffiliatesKeepers(t, true) + + if tc.whitelist != nil { + err := k.SetAffiliateWhitelist(ctx, *tc.whitelist) + require.NoError(t, err) + } + + whitelistMap, err := k.GetAffiliateWhitelistMap(ctx) + require.NoError(t, err) + require.Equal(t, tc.expectedLength, len(whitelistMap)) + require.Equal(t, tc.expectedMap, whitelistMap) + }) + } +} + +func TestGetTakerFeeShareViaWhitelist(t *testing.T) { + tiers := types.DefaultAffiliateTiers + + testCases := []struct { + name string + affiliateAddr string + refereeAddr string + whitelist *types.AffiliateWhitelist + expectedFeeSharePpm uint32 + expectedExists bool + }{ + { + name: "Affiliate in whitelist", + affiliateAddr: constants.AliceAccAddress.String(), + refereeAddr: constants.BobAccAddress.String(), + whitelist: &types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 400_000, // 40% + }, + }, + }, + expectedFeeSharePpm: 400_000, // 40% + expectedExists: true, + }, + { + name: "Affiliate not in whitelist", + affiliateAddr: constants.AliceAccAddress.String(), + refereeAddr: constants.BobAccAddress.String(), + whitelist: &types.AffiliateWhitelist{}, + expectedFeeSharePpm: tiers.Tiers[0].TakerFeeSharePpm, + expectedExists: true, + }, + { + name: "Referee not registered", + affiliateAddr: "", + refereeAddr: constants.BobAccAddress.String(), + whitelist: &types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 400_000, // 40% + }, + }, + }, + expectedFeeSharePpm: 0, + expectedExists: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, k, _, _ := keepertest.AffiliatesKeepers(t, true) + err := k.UpdateAffiliateTiers(ctx, tiers) + require.NoError(t, err) + + if tc.whitelist != nil { + err := k.SetAffiliateWhitelist(ctx, *tc.whitelist) + require.NoError(t, err) + } + if tc.affiliateAddr != "" { + err := k.RegisterAffiliate(ctx, tc.refereeAddr, tc.affiliateAddr) + require.NoError(t, err) + } + affiliateWhitelistMap, err := k.GetAffiliateWhitelistMap(ctx) + require.NoError(t, err) + + affiliateAddr, feeSharePpm, exists, err := k.GetTakerFeeShare(ctx, tc.refereeAddr, affiliateWhitelistMap) + require.NoError(t, err) + require.Equal(t, tc.affiliateAddr, affiliateAddr) + require.Equal(t, tc.expectedFeeSharePpm, feeSharePpm) + require.Equal(t, tc.expectedExists, exists) + }) + } +} + func TestAggregateAffiliateReferredVolumeForFills(t *testing.T) { affiliate := constants.AliceAccAddress.String() referee1 := constants.BobAccAddress.String() diff --git a/protocol/x/affiliates/keeper/msg_server.go b/protocol/x/affiliates/keeper/msg_server.go index 6a308ca37e..49be6ed542 100644 --- a/protocol/x/affiliates/keeper/msg_server.go +++ b/protocol/x/affiliates/keeper/msg_server.go @@ -43,8 +43,13 @@ func (k msgServer) UpdateAffiliateTiers(ctx context.Context, return nil, err } marketMapperRevShareParams := k.revShareKeeper.GetMarketMapperRevenueShareParams(sdkCtx) + affiliateWhitelist, err := k.GetAffiliateWhitelist(sdkCtx) + if err != nil { + return nil, err + } - if !k.revShareKeeper.ValidateRevShareSafety(msg.Tiers, unconditionalRevShareConfig, marketMapperRevShareParams) { + if !k.revShareKeeper.ValidateRevShareSafety(msg.Tiers, + unconditionalRevShareConfig, marketMapperRevShareParams, affiliateWhitelist) { return nil, errorsmod.Wrapf( types.ErrRevShareSafetyViolation, "rev share safety violation", @@ -59,6 +64,40 @@ func (k msgServer) UpdateAffiliateTiers(ctx context.Context, return &types.MsgUpdateAffiliateTiersResponse{}, nil } +func (k msgServer) UpdateAffiliateWhitelist(ctx context.Context, + msg *types.MsgUpdateAffiliateWhitelist) (*types.MsgUpdateAffiliateWhitelistResponse, error) { + if !k.Keeper.HasAuthority(msg.Authority) { + return nil, errors.New("invalid authority") + } + + sdkCtx := sdk.UnwrapSDKContext(ctx) + unconditionalRevShareConfig, err := k.revShareKeeper.GetUnconditionalRevShareConfigParams(sdkCtx) + if err != nil { + return nil, err + } + marketMapperRevShareParams := k.revShareKeeper.GetMarketMapperRevenueShareParams(sdkCtx) + + affiliateTiers, err := k.Keeper.GetAllAffiliateTiers(sdkCtx) + if err != nil { + return nil, err + } + + if !k.revShareKeeper.ValidateRevShareSafety(affiliateTiers, + unconditionalRevShareConfig, marketMapperRevShareParams, msg.Whitelist) { + return nil, errorsmod.Wrapf( + types.ErrRevShareSafetyViolation, + "rev share safety violation", + ) + } + + err = k.Keeper.SetAffiliateWhitelist(sdk.UnwrapSDKContext(ctx), msg.Whitelist) + if err != nil { + return nil, err + } + + return &types.MsgUpdateAffiliateWhitelistResponse{}, nil +} + // NewMsgServerImpl returns an implementation of the MsgServer interface // for the provided Keeper. func NewMsgServerImpl(keeper Keeper) types.MsgServer { diff --git a/protocol/x/affiliates/keeper/msg_server_test.go b/protocol/x/affiliates/keeper/msg_server_test.go index 404b554391..3d7db09e0c 100644 --- a/protocol/x/affiliates/keeper/msg_server_test.go +++ b/protocol/x/affiliates/keeper/msg_server_test.go @@ -132,3 +132,51 @@ func TestMsgServer_UpdateAffiliateTiers(t *testing.T) { }) } } + +func TestMsgServer_UpdateAffiliateWhitelist(t *testing.T) { + whitelist := types.AffiliateWhitelist{ + Tiers: []types.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 100_000, // 10% + }, + }, + } + testCases := []struct { + name string + msg *types.MsgUpdateAffiliateWhitelist + expectErr bool + }{ + { + name: "Gov module updates whitelist", + msg: &types.MsgUpdateAffiliateWhitelist{ + Authority: lib.GovModuleAddress.String(), + Whitelist: whitelist, + }, + }, + { + name: "non-gov module updates whitelist", + msg: &types.MsgUpdateAffiliateWhitelist{ + Authority: constants.BobAccAddress.String(), + Whitelist: whitelist, + }, + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + k, ms, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + _, err := ms.UpdateAffiliateWhitelist(ctx, tc.msg) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + whitelist, err := k.GetAffiliateWhitelist(sdkCtx) + require.NoError(t, err) + require.Equal(t, tc.msg.Whitelist, whitelist) + } + }) + } +} diff --git a/protocol/x/affiliates/types/constants.go b/protocol/x/affiliates/types/constants.go index 8bc90aa892..788eb5c9e0 100644 --- a/protocol/x/affiliates/types/constants.go +++ b/protocol/x/affiliates/types/constants.go @@ -25,4 +25,6 @@ var ( }, }, } + + AffiliatesRevSharePpmCap = uint32(500_000) // 50% ) diff --git a/protocol/x/affiliates/types/errors.go b/protocol/x/affiliates/types/errors.go index 0888d35da7..0beb21aa6a 100644 --- a/protocol/x/affiliates/types/errors.go +++ b/protocol/x/affiliates/types/errors.go @@ -6,8 +6,12 @@ var ( ErrAffiliateAlreadyExistsForReferee = errorsmod.Register(ModuleName, 1, "Affiliate already exists for referee") ErrAffiliateTiersNotInitialized = errorsmod.Register(ModuleName, 2, "Affiliate tier data not found") ErrInvalidAffiliateTiers = errorsmod.Register(ModuleName, 3, "Invalid affiliate tier data") - ErrUpdatingAffiliateReferredVolume = errorsmod.Register(ModuleName, 4, "Error updating affiliate referred volume") - ErrInvalidAddress = errorsmod.Register(ModuleName, 5, "Invalid address") - ErrAffiliateNotFound = errorsmod.Register(ModuleName, 6, "Affiliate not found") - ErrRevShareSafetyViolation = errorsmod.Register(ModuleName, 7, "Rev share safety violation") + ErrUpdatingAffiliateReferredVolume = errorsmod.Register( + ModuleName, 4, "Error updating affiliate referred volume") + ErrInvalidAddress = errorsmod.Register(ModuleName, 5, "Invalid address") + ErrAffiliateNotFound = errorsmod.Register(ModuleName, 6, "Affiliate not found") + ErrRevShareSafetyViolation = errorsmod.Register( + ModuleName, 7, "Rev share safety violation") + ErrDuplicateAffiliateAddressForWhitelist = errorsmod.Register( + ModuleName, 8, "Duplicate affiliate address for whitelist") ) diff --git a/protocol/x/affiliates/types/expected_keepers.go b/protocol/x/affiliates/types/expected_keepers.go index 1a60b73dfc..eca412e1de 100644 --- a/protocol/x/affiliates/types/expected_keepers.go +++ b/protocol/x/affiliates/types/expected_keepers.go @@ -22,5 +22,6 @@ type RevShareKeeper interface { affiliateTiers AffiliateTiers, unconditionalRevShareConfig revsharetypes.UnconditionalRevShareConfig, marketMapperRevShareParams revsharetypes.MarketMapperRevenueShareParams, + affiliateWhitelist AffiliateWhitelist, ) bool } diff --git a/protocol/x/affiliates/types/keys.go b/protocol/x/affiliates/types/keys.go index 386ec9b717..0965b44e1d 100644 --- a/protocol/x/affiliates/types/keys.go +++ b/protocol/x/affiliates/types/keys.go @@ -11,9 +11,11 @@ const ( // State const ( - ReferredByKeyPrefix = "ReferredBy:" + ReferredByKeyPrefix = "RB:" - ReferredVolumeKeyPrefix = "ReferredVolume:" + ReferredVolumeKeyPrefix = "RV:" - AffiliateTiersKey = "AffiliateTiers" + AffiliateTiersKey = "AT" + + AffiliateWhitelistKey = "AW" ) diff --git a/protocol/x/affiliates/types/tx.pb.go b/protocol/x/affiliates/types/tx.pb.go index f0a9f62a78..c77dc192c2 100644 --- a/protocol/x/affiliates/types/tx.pb.go +++ b/protocol/x/affiliates/types/tx.pb.go @@ -200,17 +200,111 @@ func (m *MsgUpdateAffiliateTiersResponse) XXX_DiscardUnknown() { var xxx_messageInfo_MsgUpdateAffiliateTiersResponse proto.InternalMessageInfo +// Message to update affiliate whitelist +type MsgUpdateAffiliateWhitelist struct { + // Authority sending this message. Will be sent by gov + Authority string `protobuf:"bytes,1,opt,name=authority,proto3" json:"authority,omitempty"` + // Updated affiliate whitelist information + Whitelist AffiliateWhitelist `protobuf:"bytes,2,opt,name=whitelist,proto3" json:"whitelist"` +} + +func (m *MsgUpdateAffiliateWhitelist) Reset() { *m = MsgUpdateAffiliateWhitelist{} } +func (m *MsgUpdateAffiliateWhitelist) String() string { return proto.CompactTextString(m) } +func (*MsgUpdateAffiliateWhitelist) ProtoMessage() {} +func (*MsgUpdateAffiliateWhitelist) Descriptor() ([]byte, []int) { + return fileDescriptor_41c2f092a0ec6d7f, []int{4} +} +func (m *MsgUpdateAffiliateWhitelist) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *MsgUpdateAffiliateWhitelist) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_MsgUpdateAffiliateWhitelist.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *MsgUpdateAffiliateWhitelist) XXX_Merge(src proto.Message) { + xxx_messageInfo_MsgUpdateAffiliateWhitelist.Merge(m, src) +} +func (m *MsgUpdateAffiliateWhitelist) XXX_Size() int { + return m.Size() +} +func (m *MsgUpdateAffiliateWhitelist) XXX_DiscardUnknown() { + xxx_messageInfo_MsgUpdateAffiliateWhitelist.DiscardUnknown(m) +} + +var xxx_messageInfo_MsgUpdateAffiliateWhitelist proto.InternalMessageInfo + +func (m *MsgUpdateAffiliateWhitelist) GetAuthority() string { + if m != nil { + return m.Authority + } + return "" +} + +func (m *MsgUpdateAffiliateWhitelist) GetWhitelist() AffiliateWhitelist { + if m != nil { + return m.Whitelist + } + return AffiliateWhitelist{} +} + +// Response to MsgUpdateAffiliateWhitelist +type MsgUpdateAffiliateWhitelistResponse struct { +} + +func (m *MsgUpdateAffiliateWhitelistResponse) Reset() { *m = MsgUpdateAffiliateWhitelistResponse{} } +func (m *MsgUpdateAffiliateWhitelistResponse) String() string { return proto.CompactTextString(m) } +func (*MsgUpdateAffiliateWhitelistResponse) ProtoMessage() {} +func (*MsgUpdateAffiliateWhitelistResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_41c2f092a0ec6d7f, []int{5} +} +func (m *MsgUpdateAffiliateWhitelistResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *MsgUpdateAffiliateWhitelistResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_MsgUpdateAffiliateWhitelistResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *MsgUpdateAffiliateWhitelistResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_MsgUpdateAffiliateWhitelistResponse.Merge(m, src) +} +func (m *MsgUpdateAffiliateWhitelistResponse) XXX_Size() int { + return m.Size() +} +func (m *MsgUpdateAffiliateWhitelistResponse) XXX_DiscardUnknown() { + xxx_messageInfo_MsgUpdateAffiliateWhitelistResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_MsgUpdateAffiliateWhitelistResponse proto.InternalMessageInfo + func init() { proto.RegisterType((*MsgRegisterAffiliate)(nil), "dydxprotocol.affiliates.MsgRegisterAffiliate") proto.RegisterType((*MsgRegisterAffiliateResponse)(nil), "dydxprotocol.affiliates.MsgRegisterAffiliateResponse") proto.RegisterType((*MsgUpdateAffiliateTiers)(nil), "dydxprotocol.affiliates.MsgUpdateAffiliateTiers") proto.RegisterType((*MsgUpdateAffiliateTiersResponse)(nil), "dydxprotocol.affiliates.MsgUpdateAffiliateTiersResponse") + proto.RegisterType((*MsgUpdateAffiliateWhitelist)(nil), "dydxprotocol.affiliates.MsgUpdateAffiliateWhitelist") + proto.RegisterType((*MsgUpdateAffiliateWhitelistResponse)(nil), "dydxprotocol.affiliates.MsgUpdateAffiliateWhitelistResponse") } func init() { proto.RegisterFile("dydxprotocol/affiliates/tx.proto", fileDescriptor_41c2f092a0ec6d7f) } var fileDescriptor_41c2f092a0ec6d7f = []byte{ - // 407 bytes of a gzipped FileDescriptorProto + // 473 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x52, 0x48, 0xa9, 0x4c, 0xa9, 0x28, 0x28, 0xca, 0x2f, 0xc9, 0x4f, 0xce, 0xcf, 0xd1, 0x4f, 0x4c, 0x4b, 0xcb, 0xcc, 0xc9, 0x4c, 0x2c, 0x49, 0x2d, 0xd6, 0x2f, 0xa9, 0xd0, 0x03, 0x0b, 0x0b, 0x89, 0x23, 0xab, 0xd0, 0x43, 0xa8, @@ -228,15 +322,19 @@ var fileDescriptor_41c2f092a0ec6d7f = []byte{ 0x69, 0x49, 0x46, 0x7e, 0x51, 0x66, 0x49, 0x25, 0x41, 0xf7, 0x23, 0x94, 0x0a, 0x39, 0x73, 0xb1, 0x96, 0x80, 0x0c, 0x00, 0xbb, 0x9e, 0xdb, 0x48, 0x5d, 0x0f, 0x47, 0x5c, 0xe9, 0xa1, 0xda, 0xe7, 0xc4, 0x72, 0xe2, 0x9e, 0x3c, 0x43, 0x10, 0x44, 0xaf, 0x15, 0x1f, 0xc8, 0x1b, 0x08, 0x43, 0x95, - 0x14, 0xb9, 0xe4, 0x71, 0xb8, 0x13, 0xe6, 0x17, 0xa3, 0x56, 0x26, 0x2e, 0x66, 0xdf, 0xe2, 0x74, - 0xa1, 0x4a, 0x2e, 0x41, 0xcc, 0xa8, 0xd0, 0xc5, 0xe9, 0x0a, 0x6c, 0xe1, 0x23, 0x65, 0x4a, 0x92, - 0x72, 0x98, 0x13, 0x84, 0x9a, 0x18, 0xb9, 0x44, 0xb0, 0x86, 0xa5, 0x01, 0x3e, 0xf3, 0xb0, 0xe9, - 0x90, 0xb2, 0x20, 0x55, 0x07, 0xcc, 0x11, 0x4e, 0x61, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, - 0xc7, 0xf8, 0xe0, 0x91, 0x1c, 0xe3, 0x84, 0xc7, 0x72, 0x0c, 0x17, 0x1e, 0xcb, 0x31, 0xdc, 0x78, - 0x2c, 0xc7, 0x10, 0x65, 0x93, 0x9e, 0x59, 0x92, 0x51, 0x9a, 0xa4, 0x97, 0x9c, 0x9f, 0xab, 0x8f, - 0x92, 0xec, 0xcb, 0x4c, 0x74, 0x93, 0x33, 0x12, 0x33, 0xf3, 0xf4, 0xe1, 0x22, 0x15, 0x28, 0xd9, - 0xae, 0xb2, 0x20, 0xb5, 0x38, 0x89, 0x0d, 0x2c, 0x69, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x02, - 0x66, 0x1f, 0x2f, 0x9e, 0x03, 0x00, 0x00, + 0x14, 0xb9, 0xe4, 0x71, 0xb8, 0x13, 0xee, 0x97, 0x6d, 0x8c, 0x5c, 0xd2, 0x98, 0x6a, 0xc2, 0x33, + 0x32, 0x4b, 0x52, 0x73, 0x32, 0x8b, 0x4b, 0xc8, 0xf6, 0x8f, 0x3f, 0x17, 0x67, 0x39, 0xcc, 0x10, + 0xa8, 0x9f, 0xb4, 0x09, 0xfb, 0x09, 0x6e, 0x2f, 0xd4, 0x5f, 0x08, 0x33, 0x30, 0xfc, 0xa6, 0xca, + 0xa5, 0x8c, 0xc7, 0xdd, 0x30, 0xff, 0x19, 0x4d, 0x63, 0xe6, 0x62, 0xf6, 0x2d, 0x4e, 0x17, 0xaa, + 0xe4, 0x12, 0xc4, 0x4c, 0x6a, 0xba, 0x38, 0x5d, 0x84, 0x2d, 0xfe, 0xa5, 0x4c, 0x49, 0x52, 0x0e, + 0x73, 0x82, 0x50, 0x13, 0x23, 0x97, 0x08, 0xd6, 0xb4, 0x62, 0x80, 0xcf, 0x3c, 0x6c, 0x3a, 0xa4, + 0x2c, 0x48, 0xd5, 0x01, 0x77, 0x44, 0x1f, 0x23, 0x97, 0x04, 0xce, 0x48, 0x36, 0x21, 0xc1, 0x58, + 0xb8, 0x2e, 0x29, 0x1b, 0x72, 0x74, 0xc1, 0x1c, 0xe4, 0x14, 0x76, 0xe2, 0x91, 0x1c, 0xe3, 0x85, + 0x47, 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x4e, 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, + 0x8d, 0xc7, 0x72, 0x0c, 0x51, 0x36, 0xe9, 0x99, 0x25, 0x19, 0xa5, 0x49, 0x7a, 0xc9, 0xf9, 0xb9, + 0xfa, 0x28, 0xe5, 0x4c, 0x99, 0x89, 0x6e, 0x72, 0x46, 0x62, 0x66, 0x9e, 0x3e, 0x5c, 0xa4, 0x02, + 0xa5, 0x9c, 0xab, 0x2c, 0x48, 0x2d, 0x4e, 0x62, 0x03, 0x4b, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, + 0xff, 0x18, 0x7b, 0xfd, 0xff, 0x0f, 0x05, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. @@ -255,6 +353,8 @@ type MsgClient interface { RegisterAffiliate(ctx context.Context, in *MsgRegisterAffiliate, opts ...grpc.CallOption) (*MsgRegisterAffiliateResponse, error) // UpdateAffiliateTiers updates affiliate tiers UpdateAffiliateTiers(ctx context.Context, in *MsgUpdateAffiliateTiers, opts ...grpc.CallOption) (*MsgUpdateAffiliateTiersResponse, error) + // UpdateAffiliateWhitelist updates affiliate whitelist + UpdateAffiliateWhitelist(ctx context.Context, in *MsgUpdateAffiliateWhitelist, opts ...grpc.CallOption) (*MsgUpdateAffiliateWhitelistResponse, error) } type msgClient struct { @@ -283,12 +383,23 @@ func (c *msgClient) UpdateAffiliateTiers(ctx context.Context, in *MsgUpdateAffil return out, nil } +func (c *msgClient) UpdateAffiliateWhitelist(ctx context.Context, in *MsgUpdateAffiliateWhitelist, opts ...grpc.CallOption) (*MsgUpdateAffiliateWhitelistResponse, error) { + out := new(MsgUpdateAffiliateWhitelistResponse) + err := c.cc.Invoke(ctx, "/dydxprotocol.affiliates.Msg/UpdateAffiliateWhitelist", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // MsgServer is the server API for Msg service. type MsgServer interface { // RegisterAffiliate registers a referee-affiliate relationship RegisterAffiliate(context.Context, *MsgRegisterAffiliate) (*MsgRegisterAffiliateResponse, error) // UpdateAffiliateTiers updates affiliate tiers UpdateAffiliateTiers(context.Context, *MsgUpdateAffiliateTiers) (*MsgUpdateAffiliateTiersResponse, error) + // UpdateAffiliateWhitelist updates affiliate whitelist + UpdateAffiliateWhitelist(context.Context, *MsgUpdateAffiliateWhitelist) (*MsgUpdateAffiliateWhitelistResponse, error) } // UnimplementedMsgServer can be embedded to have forward compatible implementations. @@ -301,6 +412,9 @@ func (*UnimplementedMsgServer) RegisterAffiliate(ctx context.Context, req *MsgRe func (*UnimplementedMsgServer) UpdateAffiliateTiers(ctx context.Context, req *MsgUpdateAffiliateTiers) (*MsgUpdateAffiliateTiersResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method UpdateAffiliateTiers not implemented") } +func (*UnimplementedMsgServer) UpdateAffiliateWhitelist(ctx context.Context, req *MsgUpdateAffiliateWhitelist) (*MsgUpdateAffiliateWhitelistResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method UpdateAffiliateWhitelist not implemented") +} func RegisterMsgServer(s grpc1.Server, srv MsgServer) { s.RegisterService(&_Msg_serviceDesc, srv) @@ -342,6 +456,24 @@ func _Msg_UpdateAffiliateTiers_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _Msg_UpdateAffiliateWhitelist_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(MsgUpdateAffiliateWhitelist) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(MsgServer).UpdateAffiliateWhitelist(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dydxprotocol.affiliates.Msg/UpdateAffiliateWhitelist", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(MsgServer).UpdateAffiliateWhitelist(ctx, req.(*MsgUpdateAffiliateWhitelist)) + } + return interceptor(ctx, in, info, handler) +} + var _Msg_serviceDesc = grpc.ServiceDesc{ ServiceName: "dydxprotocol.affiliates.Msg", HandlerType: (*MsgServer)(nil), @@ -354,6 +486,10 @@ var _Msg_serviceDesc = grpc.ServiceDesc{ MethodName: "UpdateAffiliateTiers", Handler: _Msg_UpdateAffiliateTiers_Handler, }, + { + MethodName: "UpdateAffiliateWhitelist", + Handler: _Msg_UpdateAffiliateWhitelist_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "dydxprotocol/affiliates/tx.proto", @@ -482,6 +618,69 @@ func (m *MsgUpdateAffiliateTiersResponse) MarshalToSizedBuffer(dAtA []byte) (int return len(dAtA) - i, nil } +func (m *MsgUpdateAffiliateWhitelist) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *MsgUpdateAffiliateWhitelist) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *MsgUpdateAffiliateWhitelist) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + { + size, err := m.Whitelist.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintTx(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + if len(m.Authority) > 0 { + i -= len(m.Authority) + copy(dAtA[i:], m.Authority) + i = encodeVarintTx(dAtA, i, uint64(len(m.Authority))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *MsgUpdateAffiliateWhitelistResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *MsgUpdateAffiliateWhitelistResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *MsgUpdateAffiliateWhitelistResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + return len(dAtA) - i, nil +} + func encodeVarintTx(dAtA []byte, offset int, v uint64) int { offset -= sovTx(v) base := offset @@ -543,6 +742,30 @@ func (m *MsgUpdateAffiliateTiersResponse) Size() (n int) { return n } +func (m *MsgUpdateAffiliateWhitelist) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Authority) + if l > 0 { + n += 1 + l + sovTx(uint64(l)) + } + l = m.Whitelist.Size() + n += 1 + l + sovTx(uint64(l)) + return n +} + +func (m *MsgUpdateAffiliateWhitelistResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + return n +} + func sovTx(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -878,6 +1101,171 @@ func (m *MsgUpdateAffiliateTiersResponse) Unmarshal(dAtA []byte) error { } return nil } +func (m *MsgUpdateAffiliateWhitelist) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTx + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: MsgUpdateAffiliateWhitelist: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: MsgUpdateAffiliateWhitelist: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Authority", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTx + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTx + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthTx + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Authority = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Whitelist", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTx + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthTx + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthTx + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if err := m.Whitelist.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipTx(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthTx + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *MsgUpdateAffiliateWhitelistResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTx + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: MsgUpdateAffiliateWhitelistResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: MsgUpdateAffiliateWhitelistResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + default: + iNdEx = preIndex + skippy, err := skipTx(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthTx + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipTx(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/protocol/x/clob/keeper/keeper.go b/protocol/x/clob/keeper/keeper.go index 7b530258ba..8c39342ca6 100644 --- a/protocol/x/clob/keeper/keeper.go +++ b/protocol/x/clob/keeper/keeper.go @@ -42,6 +42,7 @@ type ( pricesKeeper types.PricesKeeper statsKeeper types.StatsKeeper rewardsKeeper types.RewardsKeeper + affiliatesKeeper types.AffiliatesKeeper revshareKeeper types.RevShareKeeper indexerEventManager indexer_manager.IndexerEventManager @@ -86,6 +87,7 @@ func NewKeeper( pricesKeeper types.PricesKeeper, statsKeeper types.StatsKeeper, rewardsKeeper types.RewardsKeeper, + affiliatesKeeper types.AffiliatesKeeper, indexerEventManager indexer_manager.IndexerEventManager, streamingManager streamingtypes.FullNodeStreamingManager, txDecoder sdk.TxDecoder, @@ -111,6 +113,7 @@ func NewKeeper( pricesKeeper: pricesKeeper, statsKeeper: statsKeeper, rewardsKeeper: rewardsKeeper, + affiliatesKeeper: affiliatesKeeper, indexerEventManager: indexerEventManager, streamingManager: streamingManager, memStoreInitialized: &atomic.Bool{}, // False by default. diff --git a/protocol/x/clob/keeper/process_operations.go b/protocol/x/clob/keeper/process_operations.go index 5dc706e9d0..026d9d316c 100644 --- a/protocol/x/clob/keeper/process_operations.go +++ b/protocol/x/clob/keeper/process_operations.go @@ -136,6 +136,7 @@ func (k Keeper) ProcessInternalOperations( // All short term orders in this map have passed validation. placedShortTermOrders := make(map[types.OrderId]types.Order, 0) + var affiliatesWhitelistMap map[string]uint32 = nil // Write the matches to state if all stateful validation passes. for _, operation := range operations { if err := k.validateInternalOperationAgainstClobPairStatus(ctx, operation); err != nil { @@ -144,8 +145,21 @@ func (k Keeper) ProcessInternalOperations( switch castedOperation := operation.Operation.(type) { case *types.InternalOperation_Match: + // check if affiliate whitelist map is nil and initialize it if it is. + // This is done to avoid getting whitelist map on list of operations + // where there are no matches. + if affiliatesWhitelistMap == nil { + var err error + affiliatesWhitelistMap, err = k.affiliatesKeeper.GetAffiliateWhitelistMap(ctx) + if err != nil { + return errorsmod.Wrapf( + err, + "ProcessInternalOperations: Failed to get affiliates whitelist map", + ) + } + } clobMatch := castedOperation.Match - if err := k.PersistMatchToState(ctx, clobMatch, placedShortTermOrders); err != nil { + if err := k.PersistMatchToState(ctx, clobMatch, placedShortTermOrders, affiliatesWhitelistMap); err != nil { return errorsmod.Wrapf( err, "ProcessInternalOperations: Failed to process clobMatch: %+v", @@ -201,10 +215,11 @@ func (k Keeper) PersistMatchToState( ctx sdk.Context, clobMatch *types.ClobMatch, ordersMap map[types.OrderId]types.Order, + affiliatesWhitelistMap map[string]uint32, ) error { switch castedMatch := clobMatch.Match.(type) { case *types.ClobMatch_MatchOrders: - if err := k.PersistMatchOrdersToState(ctx, castedMatch.MatchOrders, ordersMap); err != nil { + if err := k.PersistMatchOrdersToState(ctx, castedMatch.MatchOrders, ordersMap, affiliatesWhitelistMap); err != nil { return err } case *types.ClobMatch_MatchPerpetualLiquidation: @@ -212,6 +227,7 @@ func (k Keeper) PersistMatchToState( ctx, castedMatch.MatchPerpetualLiquidation, ordersMap, + affiliatesWhitelistMap, ); err != nil { return err } @@ -442,6 +458,7 @@ func (k Keeper) PersistMatchOrdersToState( ctx sdk.Context, matchOrders *types.MatchOrders, ordersMap map[types.OrderId]types.Order, + affiliatesWhitelistMap map[string]uint32, ) error { takerOrderId := matchOrders.GetTakerOrderId() // Fetch the taker order from either short term orders or state @@ -486,7 +503,7 @@ func (k Keeper) PersistMatchOrdersToState( } makerOrders = append(makerOrders, makerOrder) - _, _, _, affiliateRevSharesQuoteQuantums, err := k.ProcessSingleMatch(ctx, &matchWithOrders) + _, _, _, affiliateRevSharesQuoteQuantums, err := k.ProcessSingleMatch(ctx, &matchWithOrders, affiliatesWhitelistMap) if err != nil { return err } @@ -558,6 +575,7 @@ func (k Keeper) PersistMatchLiquidationToState( ctx sdk.Context, matchLiquidation *types.MatchPerpetualLiquidation, ordersMap map[types.OrderId]types.Order, + affiliatesWhitelistMap map[string]uint32, ) error { // If the subaccount is not liquidatable, do nothing. if err := k.EnsureIsLiquidatable(ctx, matchLiquidation.Liquidated); err != nil { @@ -598,6 +616,7 @@ func (k Keeper) PersistMatchLiquidationToState( _, _, _, affiliateRevSharesQuoteQuantums, err := k.ProcessSingleMatch( ctx, &matchWithOrders, + affiliatesWhitelistMap, ) if err != nil { return err diff --git a/protocol/x/clob/keeper/process_single_match.go b/protocol/x/clob/keeper/process_single_match.go index e35edf5f61..509a5ccfc0 100644 --- a/protocol/x/clob/keeper/process_single_match.go +++ b/protocol/x/clob/keeper/process_single_match.go @@ -41,6 +41,7 @@ import ( func (k Keeper) ProcessSingleMatch( ctx sdk.Context, matchWithOrders *types.MatchWithOrders, + affiliatesWhitelistMap map[string]uint32, ) ( success bool, takerUpdateResult satypes.UpdateResult, @@ -224,6 +225,7 @@ func (k Keeper) ProcessSingleMatch( makerFeePpm, bigFillQuoteQuantums, takerInsuranceFundDelta, + affiliatesWhitelistMap, ) if err != nil { @@ -307,6 +309,7 @@ func (k Keeper) persistMatchedOrders( makerFeePpm int32, bigFillQuoteQuantums *big.Int, insuranceFundDelta *big.Int, + affiliatesWhitelistMap map[string]uint32, ) ( takerUpdateResult satypes.UpdateResult, makerUpdateResult satypes.UpdateResult, @@ -460,7 +463,7 @@ func (k Keeper) persistMatchedOrders( // Distribute the fee amount from subacounts module to fee collector and rev share accounts bigTotalFeeQuoteQuantums := new(big.Int).Add(bigTakerFeeQuoteQuantums, bigMakerFeeQuoteQuantums) - revSharesForFill, err := k.revshareKeeper.GetAllRevShares(ctx, fillForProcess) + revSharesForFill, err := k.revshareKeeper.GetAllRevShares(ctx, fillForProcess, affiliatesWhitelistMap) if revSharesForFill.AffiliateRevShare != nil { affiliateRevSharesQuoteQuantums = revSharesForFill.AffiliateRevShare.QuoteQuantums diff --git a/protocol/x/clob/memclob/memclob.go b/protocol/x/clob/memclob/memclob.go index 9aaefaba05..d6d1e08774 100644 --- a/protocol/x/clob/memclob/memclob.go +++ b/protocol/x/clob/memclob/memclob.go @@ -1696,7 +1696,12 @@ func (m *MemClobPriceTimePriority) mustPerformTakerOrderMatching( FillAmount: matchedAmount, } - success, takerUpdateResult, makerUpdateResult, _, err := m.clobKeeper.ProcessSingleMatch(ctx, &matchWithOrders) + // Pass in empty map to avoid reading `AffiliateWhitelist` from state in every `CheckTx`. This deviates + // from `DeliverTx` which accounts for affiliate whitelist correctly. This deviation is ok because rev + // shares/fees are distributed to the recipient’s bank balance and not settled at the subaccount level, + // and won’t affect the collateralization of future operations in the operations queue. + success, takerUpdateResult, makerUpdateResult, _, err := m.clobKeeper.ProcessSingleMatch( + ctx, &matchWithOrders, map[string]uint32{}) if err != nil && !errors.Is(err, satypes.ErrFailedToUpdateSubaccounts) { if errors.Is(err, types.ErrLiquidationExceedsSubaccountMaxInsuranceLost) { // Subaccount has reached max insurance lost block limit. Stop matching. diff --git a/protocol/x/clob/types/clob_keeper.go b/protocol/x/clob/types/clob_keeper.go index 2d3067b7d7..f5bbe110e0 100644 --- a/protocol/x/clob/types/clob_keeper.go +++ b/protocol/x/clob/types/clob_keeper.go @@ -81,6 +81,7 @@ type ClobKeeper interface { ProcessSingleMatch( ctx sdk.Context, matchWithOrders *MatchWithOrders, + affiliatesWhitelistMap map[string]uint32, ) ( success bool, takerUpdateResult satypes.UpdateResult, diff --git a/protocol/x/clob/types/expected_keepers.go b/protocol/x/clob/types/expected_keepers.go index c0f7d7bb22..872daf2905 100644 --- a/protocol/x/clob/types/expected_keepers.go +++ b/protocol/x/clob/types/expected_keepers.go @@ -175,7 +175,12 @@ type RevShareKeeper interface { GetAllRevShares( ctx sdk.Context, fill FillForProcess, + affiliateWhitelistMap map[string]uint32, ) ( revsharetypes.RevSharesForFill, error, ) } + +type AffiliatesKeeper interface { + GetAffiliateWhitelistMap(ctx sdk.Context) (map[string]uint32, error) +} diff --git a/protocol/x/clob/types/mem_clob_keeper.go b/protocol/x/clob/types/mem_clob_keeper.go index a90418ec04..d555367718 100644 --- a/protocol/x/clob/types/mem_clob_keeper.go +++ b/protocol/x/clob/types/mem_clob_keeper.go @@ -21,6 +21,7 @@ type MemClobKeeper interface { ProcessSingleMatch( ctx sdk.Context, matchWithOrders *MatchWithOrders, + affiliatesWhitelistMap map[string]uint32, ) ( success bool, takerUpdateResult satypes.UpdateResult, diff --git a/protocol/x/revshare/keeper/msg_set_marketmapper_revenue_share.go b/protocol/x/revshare/keeper/msg_set_marketmapper_revenue_share.go index d53e516d3a..72d139b62a 100644 --- a/protocol/x/revshare/keeper/msg_set_marketmapper_revenue_share.go +++ b/protocol/x/revshare/keeper/msg_set_marketmapper_revenue_share.go @@ -33,7 +33,12 @@ func (k msgServer) SetMarketMapperRevenueShare( if err != nil { return nil, err } - if !k.ValidateRevShareSafety(affiliateTiers, unconditionalRevShareConfig, msg.Params) { + affiliateWhitelist, err := k.affiliatesKeeper.GetAffiliateWhitelist(ctx) + if err != nil { + return nil, err + } + + if !k.ValidateRevShareSafety(affiliateTiers, unconditionalRevShareConfig, msg.Params, affiliateWhitelist) { return nil, errorsmod.Wrapf( types.ErrRevShareSafetyViolation, "rev share safety violation", diff --git a/protocol/x/revshare/keeper/msg_update_unconditional_revshare_config.go b/protocol/x/revshare/keeper/msg_update_unconditional_revshare_config.go index 0740becf46..538126352c 100644 --- a/protocol/x/revshare/keeper/msg_update_unconditional_revshare_config.go +++ b/protocol/x/revshare/keeper/msg_update_unconditional_revshare_config.go @@ -33,9 +33,12 @@ func (k msgServer) UpdateUnconditionalRevShareConfig( if err != nil { return nil, err } - + affiliateWhitelist, err := k.affiliatesKeeper.GetAffiliateWhitelist(ctx) + if err != nil { + return nil, err + } marketMapperRevShareParams := k.GetMarketMapperRevenueShareParams(ctx) - if !k.ValidateRevShareSafety(affiliateTiers, msg.Config, marketMapperRevShareParams) { + if !k.ValidateRevShareSafety(affiliateTiers, msg.Config, marketMapperRevShareParams, affiliateWhitelist) { return nil, errorsmod.Wrapf( types.ErrRevShareSafetyViolation, "rev share safety violation", diff --git a/protocol/x/revshare/keeper/revshare.go b/protocol/x/revshare/keeper/revshare.go index aa5a1df2c5..5a50f4efcd 100644 --- a/protocol/x/revshare/keeper/revshare.go +++ b/protocol/x/revshare/keeper/revshare.go @@ -132,10 +132,16 @@ func (k Keeper) ValidateRevShareSafety( affiliateTiers affiliatetypes.AffiliateTiers, unconditionalRevShareConfig types.UnconditionalRevShareConfig, marketMapperRevShareParams types.MarketMapperRevenueShareParams, + affiliateWhitelist affiliatetypes.AffiliateWhitelist, ) bool { - highestTierRevSharePpm := uint32(0) + highestAffilliateTierRevSharePpm := uint32(0) if len(affiliateTiers.Tiers) > 0 { - highestTierRevSharePpm = affiliateTiers.Tiers[len(affiliateTiers.Tiers)-1].TakerFeeSharePpm + highestAffilliateTierRevSharePpm = affiliateTiers.Tiers[len(affiliateTiers.Tiers)-1].TakerFeeSharePpm + } + for _, tier := range affiliateWhitelist.Tiers { + if tier.TakerFeeSharePpm > highestAffilliateTierRevSharePpm { + highestAffilliateTierRevSharePpm = tier.TakerFeeSharePpm + } } totalUnconditionalRevSharePpm := uint32(0) for _, recipientConfig := range unconditionalRevShareConfig.Configs { @@ -143,13 +149,14 @@ func (k Keeper) ValidateRevShareSafety( } totalMarketMapperRevSharePpm := marketMapperRevShareParams.RevenueSharePpm - totalRevSharePpm := totalUnconditionalRevSharePpm + totalMarketMapperRevSharePpm + highestTierRevSharePpm + totalRevSharePpm := totalUnconditionalRevSharePpm + totalMarketMapperRevSharePpm + highestAffilliateTierRevSharePpm return totalRevSharePpm < lib.OneMillion } func (k Keeper) GetAllRevShares( ctx sdk.Context, fill clobtypes.FillForProcess, + affiliatesWhitelistMap map[string]uint32, ) (types.RevSharesForFill, error) { revShares := []types.RevShare{} feeSourceToQuoteQuantums := make(map[types.RevShareFeeSource]*big.Int) @@ -164,7 +171,7 @@ func (k Keeper) GetAllRevShares( makerFees := fill.MakerFeeQuoteQuantums netFees := big.NewInt(0).Add(takerFees, makerFees) - affiliateRevShares, err := k.getAffiliateRevShares(ctx, fill) + affiliateRevShares, err := k.getAffiliateRevShares(ctx, fill, affiliatesWhitelistMap) if err != nil { return types.RevSharesForFill{}, err } @@ -216,6 +223,7 @@ func (k Keeper) GetAllRevShares( func (k Keeper) getAffiliateRevShares( ctx sdk.Context, fill clobtypes.FillForProcess, + affiliatesWhitelistMap map[string]uint32, ) ([]types.RevShare, error) { takerAddr := fill.TakerAddr takerFee := fill.TakerFeeQuoteQuantums @@ -223,7 +231,8 @@ func (k Keeper) getAffiliateRevShares( return nil, nil } - takerAffiliateAddr, feeSharePpm, exists, err := k.affiliatesKeeper.GetTakerFeeShare(ctx, takerAddr) + takerAffiliateAddr, feeSharePpm, exists, err := k.affiliatesKeeper.GetTakerFeeShare( + ctx, takerAddr, affiliatesWhitelistMap) if err != nil { return nil, err } diff --git a/protocol/x/revshare/keeper/revshare_test.go b/protocol/x/revshare/keeper/revshare_test.go index 0a8a0cf291..6182ead2a0 100644 --- a/protocol/x/revshare/keeper/revshare_test.go +++ b/protocol/x/revshare/keeper/revshare_test.go @@ -178,6 +178,7 @@ func TestValidateRevShareSafety(t *testing.T) { affiliateTiers affiliatetypes.AffiliateTiers revShareConfig types.UnconditionalRevShareConfig marketMapperRevShareParams types.MarketMapperRevenueShareParams + affiliateWhitelist affiliatetypes.AffiliateWhitelist expectedValid bool }{ "valid rev share config": { @@ -195,6 +196,14 @@ func TestValidateRevShareSafety(t *testing.T) { RevenueSharePpm: 100_000, // 10% ValidDays: 0, }, + affiliateWhitelist: affiliatetypes.AffiliateWhitelist{ + Tiers: []affiliatetypes.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 100_000, // 10% + }, + }, + }, expectedValid: true, }, "invalid rev share config - sum of shares > 100%": { @@ -216,6 +225,9 @@ func TestValidateRevShareSafety(t *testing.T) { RevenueSharePpm: 100_000, // 10% ValidDays: 0, }, + affiliateWhitelist: affiliatetypes.AffiliateWhitelist{ + Tiers: []affiliatetypes.AffiliateWhitelist_Tier{}, + }, expectedValid: false, }, "invalid rev share config - sum of shares + highest tier share > 100%": { @@ -250,6 +262,32 @@ func TestValidateRevShareSafety(t *testing.T) { RevenueSharePpm: 100_000, // 10% ValidDays: 0, }, + affiliateWhitelist: affiliatetypes.AffiliateWhitelist{}, + expectedValid: false, + }, + "invalid rev share config - very high whitelist tier share exceeding 100%": { + affiliateTiers: affiliatetypes.DefaultAffiliateTiers, + revShareConfig: types.UnconditionalRevShareConfig{ + Configs: []types.UnconditionalRevShareConfig_RecipientConfig{ + { + Address: constants.AliceAccAddress.String(), + SharePpm: 200_000, // 20% + }, + }, + }, + marketMapperRevShareParams: types.MarketMapperRevenueShareParams{ + Address: constants.AliceAccAddress.String(), + RevenueSharePpm: 200_000, // 20% + ValidDays: 0, + }, + affiliateWhitelist: affiliatetypes.AffiliateWhitelist{ + Tiers: []affiliatetypes.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.AliceAccAddress.String()}, + TakerFeeSharePpm: 700_000, // 70% + }, + }, + }, expectedValid: false, }, } @@ -260,7 +298,12 @@ func TestValidateRevShareSafety(t *testing.T) { _ = tApp.InitChain() k := tApp.App.RevShareKeeper - valid := k.ValidateRevShareSafety(tc.affiliateTiers, tc.revShareConfig, tc.marketMapperRevShareParams) + valid := k.ValidateRevShareSafety( + tc.affiliateTiers, + tc.revShareConfig, + tc.marketMapperRevShareParams, + tc.affiliateWhitelist, + ) require.Equal(t, tc.expectedValid, valid) }) } @@ -649,6 +692,75 @@ func TestKeeper_GetAllRevShares_Valid(t *testing.T) { require.NoError(t, err) }, }, + { + name: "Valid revenue share with whitelisted affiliate and no unconditional rev shares", + expectedRevSharesForFill: types.RevSharesForFill{ + AllRevShares: []types.RevShare{ + { + Recipient: constants.BobAccAddress.String(), + RevShareFeeSource: types.REV_SHARE_FEE_SOURCE_TAKER_FEE, + RevShareType: types.REV_SHARE_TYPE_AFFILIATE, + QuoteQuantums: big.NewInt(2_500_000), + RevSharePpm: 250_000, // 25% + }, + { + Recipient: constants.AliceAccAddress.String(), + RevShareFeeSource: types.REV_SHARE_FEE_SOURCE_NET_FEE, + RevShareType: types.REV_SHARE_TYPE_MARKET_MAPPER, + QuoteQuantums: big.NewInt(1_200_000), + RevSharePpm: 100_000, // 10% + }, + }, + AffiliateRevShare: &types.RevShare{ + Recipient: constants.BobAccAddress.String(), + RevShareFeeSource: types.REV_SHARE_FEE_SOURCE_TAKER_FEE, + RevShareType: types.REV_SHARE_TYPE_AFFILIATE, + QuoteQuantums: big.NewInt(2_500_000), + RevSharePpm: 250_000, // 25% + }, + FeeSourceToQuoteQuantums: map[types.RevShareFeeSource]*big.Int{ + types.REV_SHARE_FEE_SOURCE_NET_FEE: big.NewInt(1_200_000), + types.REV_SHARE_FEE_SOURCE_TAKER_FEE: big.NewInt(2_500_000), + }, + FeeSourceToRevSharePpm: map[types.RevShareFeeSource]uint32{ + types.REV_SHARE_FEE_SOURCE_NET_FEE: 100_000, // 10% + types.REV_SHARE_FEE_SOURCE_TAKER_FEE: 250_000, // 25% + }, + }, + fill: clobtypes.FillForProcess{ + TakerAddr: constants.AliceAccAddress.String(), + TakerFeeQuoteQuantums: big.NewInt(10_000_000), + MakerAddr: constants.BobAccAddress.String(), + MakerFeeQuoteQuantums: big.NewInt(2_000_000), + FillQuoteQuantums: big.NewInt(100_000_000_000), + ProductId: marketId, + MonthlyRollingTakerVolumeQuantums: 1_000_000_000_000, + }, + setup: func(tApp *testapp.TestApp, ctx sdk.Context, keeper *keeper.Keeper, + affiliatesKeeper *affiliateskeeper.Keeper) { + err := keeper.SetMarketMapperRevenueShareParams(ctx, types.MarketMapperRevenueShareParams{ + Address: constants.AliceAccAddress.String(), + RevenueSharePpm: 100_000, // 10% + ValidDays: 1, + }) + require.NoError(t, err) + + err = affiliatesKeeper.UpdateAffiliateTiers(ctx, affiliatetypes.DefaultAffiliateTiers) + require.NoError(t, err) + err = affiliatesKeeper.RegisterAffiliate(ctx, constants.AliceAccAddress.String(), + constants.BobAccAddress.String()) + require.NoError(t, err) + err = affiliatesKeeper.SetAffiliateWhitelist(ctx, affiliatetypes.AffiliateWhitelist{ + Tiers: []affiliatetypes.AffiliateWhitelist_Tier{ + { + Addresses: []string{constants.BobAccAddress.String()}, + TakerFeeSharePpm: 250_000, // 25% + }, + }, + }) + require.NoError(t, err) + }, + }, { name: "No rev shares", expectedRevSharesForFill: types.RevSharesForFill{ @@ -691,8 +803,10 @@ func TestKeeper_GetAllRevShares_Valid(t *testing.T) { } keeper.CreateNewMarketRevShare(ctx, marketId) + affiliateWhitelistMap, err := affiliatesKeeper.GetAffiliateWhitelistMap(ctx) + require.NoError(t, err) - revSharesForFill, err := keeper.GetAllRevShares(ctx, tc.fill) + revSharesForFill, err := keeper.GetAllRevShares(ctx, tc.fill, affiliateWhitelistMap) require.NoError(t, err) require.Equal(t, tc.expectedRevSharesForFill, revSharesForFill) @@ -838,7 +952,7 @@ func TestKeeper_GetAllRevShares_Invalid(t *testing.T) { keeper.CreateNewMarketRevShare(ctx, marketId) - _, err := keeper.GetAllRevShares(ctx, fill) + _, err := keeper.GetAllRevShares(ctx, fill, map[string]uint32{}) require.ErrorIs(t, err, tc.expectedError) }) diff --git a/protocol/x/subaccounts/keeper/transfer_test.go b/protocol/x/subaccounts/keeper/transfer_test.go index 79ea1c5afd..5bb53a217e 100644 --- a/protocol/x/subaccounts/keeper/transfer_test.go +++ b/protocol/x/subaccounts/keeper/transfer_test.go @@ -1642,7 +1642,9 @@ func TestDistributeFees(t *testing.T) { }, }) } - revSharesForFill, err := revShareKeeper.GetAllRevShares(ctx, tc.fill) + affiliateWhitelistMap, err := affiliatesKeeper.GetAffiliateWhitelistMap(ctx) + require.NoError(t, err) + revSharesForFill, err := revShareKeeper.GetAllRevShares(ctx, tc.fill, affiliateWhitelistMap) require.NoError(t, err) err = keeper.DistributeFees(ctx, tc.asset.Id, revSharesForFill, tc.fill)