diff --git a/protocol/x/affiliates/keeper/keeper.go b/protocol/x/affiliates/keeper/keeper.go index ee068a7265..92554b8f1c 100644 --- a/protocol/x/affiliates/keeper/keeper.go +++ b/protocol/x/affiliates/keeper/keeper.go @@ -343,34 +343,42 @@ func (k Keeper) GetAffiliateWhitelist(ctx sdk.Context) (types.AffiliateWhitelist } return affiliateWhitelist, nil } + func (k Keeper) AggregateAffiliateReferredVolumeForFills( ctx sdk.Context, ) error { blockStats := k.statsKeeper.GetBlockStats(ctx) referredByCache := make(map[string]string) + for _, fill := range blockStats.Fills { - // Add taker's referred volume to the cache - if _, ok := referredByCache[fill.Taker]; !ok { - referredByAddrTaker, found := k.GetReferredBy(ctx, fill.Taker) - if !found { - continue + // Process taker's referred volume + referredByAddrTaker, cached := referredByCache[fill.Taker] + if !cached { + var found bool + referredByAddrTaker, found = k.GetReferredBy(ctx, fill.Taker) + if found { + referredByCache[fill.Taker] = referredByAddrTaker } - referredByCache[fill.Taker] = referredByAddrTaker } - if err := k.AddReferredVolume(ctx, referredByCache[fill.Taker], lib.BigU(fill.Notional)); err != nil { - return err + if referredByAddrTaker != "" { + if err := k.AddReferredVolume(ctx, referredByAddrTaker, lib.BigU(fill.Notional)); err != nil { + return err + } } - // Add maker's referred volume to the cache - if _, ok := referredByCache[fill.Maker]; !ok { - referredByAddrMaker, found := k.GetReferredBy(ctx, fill.Maker) - if !found { - continue + // Process maker's referred volume + referredByAddrMaker, cached := referredByCache[fill.Maker] + if !cached { + var found bool + referredByAddrMaker, found = k.GetReferredBy(ctx, fill.Maker) + if found { + referredByCache[fill.Maker] = referredByAddrMaker } - referredByCache[fill.Maker] = referredByAddrMaker } - if err := k.AddReferredVolume(ctx, referredByCache[fill.Maker], lib.BigU(fill.Notional)); err != nil { - return err + if referredByAddrMaker != "" { + if err := k.AddReferredVolume(ctx, referredByAddrMaker, lib.BigU(fill.Notional)); err != nil { + return err + } } } return nil diff --git a/protocol/x/affiliates/keeper/keeper_test.go b/protocol/x/affiliates/keeper/keeper_test.go index aeb9efd139..7a417cbdd6 100644 --- a/protocol/x/affiliates/keeper/keeper_test.go +++ b/protocol/x/affiliates/keeper/keeper_test.go @@ -781,6 +781,29 @@ func TestAggregateAffiliateReferredVolumeForFills(t *testing.T) { }) }, }, + { + name: "2 referrals, takers not referred, maker referred", + referrals: 2, + expectedVolume: big.NewInt(300_000_000_000), + setup: func(t *testing.T, ctx sdk.Context, k *keeper.Keeper, statsKeeper *statskeeper.Keeper) { + err := k.RegisterAffiliate(ctx, maker, affiliate) + require.NoError(t, err) + statsKeeper.SetBlockStats(ctx, &statstypes.BlockStats{ + Fills: []*statstypes.BlockStats_Fill{ + { + Taker: referee1, + Maker: maker, + Notional: 100_000_000_000, + }, + { + Taker: referee2, + Maker: maker, + Notional: 200_000_000_000, + }, + }, + }) + }, + }, } for _, tc := range testCases {