From a0efdc86de983e5455ce9d5b81cf746a11ed1b53 Mon Sep 17 00:00:00 2001 From: Brendan Chou <3680392+BrendanChou@users.noreply.github.com> Date: Wed, 13 Sep 2023 16:34:09 -0400 Subject: [PATCH] genericize median (#235) --- .../pricefeed/market_to_exchange_prices.go | 2 +- protocol/lib/math.go | 36 +++---------- protocol/lib/math_test.go | 51 +++---------------- protocol/lib/medianizer.go | 4 +- .../x/clob/keeper/process_single_match.go | 2 +- protocol/x/perpetuals/keeper/perpetual.go | 7 ++- 6 files changed, 21 insertions(+), 81 deletions(-) diff --git a/protocol/daemons/server/types/pricefeed/market_to_exchange_prices.go b/protocol/daemons/server/types/pricefeed/market_to_exchange_prices.go index ce65880e50..459aac3a75 100644 --- a/protocol/daemons/server/types/pricefeed/market_to_exchange_prices.go +++ b/protocol/daemons/server/types/pricefeed/market_to_exchange_prices.go @@ -104,7 +104,7 @@ func (mte *MarketToExchangePrices) GetValidMedianPrices( // The number of valid prices must be >= min number of exchanges. if len(validPrices) >= int(marketParam.MinExchanges) { // Calculate the median. Returns an error if the input is empty. - median, err := lib.MedianUint64(validPrices) + median, err := lib.Median(validPrices) if err != nil { telemetry.IncrCounterWithLabels( []string{ diff --git a/protocol/lib/math.go b/protocol/lib/math.go index 5808f32311..60fd5ce596 100644 --- a/protocol/lib/math.go +++ b/protocol/lib/math.go @@ -63,14 +63,6 @@ func Min[T constraints.Ordered](x, y T) T { return x } -func MaxUint32(x, y uint32) uint32 { - if x < y { - return y - } - - return x -} - func Int64MulPpm(x int64, ppm uint32) int64 { xMulPpm := BigIntMulPpm(big.NewInt(x), ppm) @@ -163,25 +155,9 @@ func ChangeRateUint64(originalV uint64, newV uint64) (float32, error) { return result, nil } -// MedianUint64 returns the median value of the input slice. Note that if the -// input has an even number of elements, then the returned median is rounded up -// the nearest uint64. For example, 6.5 is rounded up to 7. -func MedianUint64(input []uint64) (uint64, error) { - return medianIntGeneric(input) -} - -// MedianInt32 returns the median value of the input slice. Note that if the -// input has an even number of elements, then the returned median is rounded -// towards positive/negative infinity, to the nearest int32. For example, -// 6.5 is rounded to 7 and -4.5 is rounded to -5. -func MedianInt32(input []int32) (int32, error) { - return medianIntGeneric(input) -} - -// MustGetMedianInt32 is a wrapper around `MedianInt32` that panics if -// input length is zero. -func MustGetMedianInt32(input []int32) int32 { - ret, err := MedianInt32(input) +// MustGetMedian is a wrapper around `Median` that panics if input length is zero. +func MustGetMedian[V uint64 | uint32 | int64 | int32](input []V) V { + ret, err := Median(input) if err != nil { panic(err) @@ -190,9 +166,9 @@ func MustGetMedianInt32(input []int32) int32 { return ret } -// medianIntGeneric is a generic median calculator. -// It currently supports `uint64`, `int32` and more types can be added. -func medianIntGeneric[V uint64 | int32](input []V) (V, error) { +// Median is a generic median calculator. +// If the input has an even number of elements, then the average of the two middle numbers is rounded away from zero. +func Median[V uint64 | uint32 | int64 | int32](input []V) (V, error) { l := len(input) if l == 0 { return 0, errors.New("input cannot be empty") diff --git a/protocol/lib/math_test.go b/protocol/lib/math_test.go index d77560d799..41822aa4b3 100644 --- a/protocol/lib/math_test.go +++ b/protocol/lib/math_test.go @@ -261,41 +261,6 @@ func TestGenericMaxFloat64(t *testing.T) { } } -func TestMaxUInt32(t *testing.T) { - tests := map[string]struct { - x uint32 - y uint32 - expectedResult uint32 - }{ - "Equal": { - x: 5, - y: 5, - expectedResult: 5, - }, - "X is Less": { - x: 4, - y: 5, - expectedResult: 5, - }, - "Y is Less": { - x: 5, - y: 4, - expectedResult: 5, - }, - "Zero": { - x: 0, - y: 0, - expectedResult: 0, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - result := lib.MaxUint32(tc.x, tc.y) - require.Equal(t, tc.expectedResult, result) - }) - } -} - func TestInt64MulPpm(t *testing.T) { tests := map[string]struct { x int64 @@ -600,23 +565,23 @@ func TestChangeRateUint64(t *testing.T) { } } -func TestMustGetMedianInt32_Failure(t *testing.T) { +func TestMustGetMedian_Failure(t *testing.T) { require.PanicsWithError(t, "input cannot be empty", func() { - lib.MustGetMedianInt32([]int32{}) + lib.MustGetMedian([]int32{}) }, ) } -func TestMustGetMedianInt32_Success(t *testing.T) { +func TestMustGetMedian_Success(t *testing.T) { require.Equal(t, int32(5), - lib.MustGetMedianInt32([]int32{8, 1, -5, 100, -50, 59}), + lib.MustGetMedian([]int32{8, 1, -5, 100, -50, 59}), ) } -func TestMedianInt32(t *testing.T) { +func TestMedian_Int32(t *testing.T) { tests := map[string]struct { input []int32 expectedResult int32 @@ -720,7 +685,7 @@ func TestMedianInt32(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - result, err := lib.MedianInt32(tc.input) + result, err := lib.Median(tc.input) require.Equal(t, tc.expectedResult, result) if tc.expectedError { require.EqualError(t, err, "input cannot be empty") @@ -731,7 +696,7 @@ func TestMedianInt32(t *testing.T) { } } -func TestMedianUint64(t *testing.T) { +func TestMedian_Uint64(t *testing.T) { tests := map[string]struct { input []uint64 expectedResult uint64 @@ -770,7 +735,7 @@ func TestMedianUint64(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - result, err := lib.MedianUint64(tc.input) + result, err := lib.Median(tc.input) require.Equal(t, tc.expectedResult, result) if tc.expectedError { require.EqualError(t, err, "input cannot be empty") diff --git a/protocol/lib/medianizer.go b/protocol/lib/medianizer.go index e7ef8f9b6f..05c4875620 100644 --- a/protocol/lib/medianizer.go +++ b/protocol/lib/medianizer.go @@ -11,7 +11,7 @@ type Medianizer interface { MedianUint64(input []uint64) (uint64, error) } -// MedianUint64 wraps `lib.MedianUint64` which gets the median of a uint64 slice. +// MedianUint64 wraps `lib.Median` which gets the median of a uint64 slice. func (r *MedianizerImpl) MedianUint64(input []uint64) (uint64, error) { - return MedianUint64(input) + return Median(input) } diff --git a/protocol/x/clob/keeper/process_single_match.go b/protocol/x/clob/keeper/process_single_match.go index 9d6f7f8955..a1d3beb6ff 100644 --- a/protocol/x/clob/keeper/process_single_match.go +++ b/protocol/x/clob/keeper/process_single_match.go @@ -472,7 +472,7 @@ func (k Keeper) setOrderFillAmountsAndPruning( if !order.IsStatefulOrder() { // Compute the block at which this state fill amount can be pruned. This is the greater of // `GoodTilBlock + ShortBlockWindow` and the existing `pruneableBlockHeight`. - pruneableBlockHeight = lib.MaxUint32( + pruneableBlockHeight = lib.Max( order.GetGoodTilBlock()+types.ShortBlockWindow, curPruneableBlockHeight, ) diff --git a/protocol/x/perpetuals/keeper/perpetual.go b/protocol/x/perpetuals/keeper/perpetual.go index c2220d3021..7c344201af 100644 --- a/protocol/x/perpetuals/keeper/perpetual.go +++ b/protocol/x/perpetuals/keeper/perpetual.go @@ -235,9 +235,8 @@ func (k Keeper) processStoredPremiums( // if block times are longer than expected and hence there were not enough blocks to // collect votes. // Note `NumPremiums >= len(marketPremiums.Premiums)`, so `lenPadding >= 0`. - lenPadding := int64( - lib.MaxUint32(premiumStore.NumPremiums, - minNumPremiumsRequired)) - int64(len(marketPremiums.Premiums)) + lenPadding := int64(lib.Max(premiumStore.NumPremiums, minNumPremiumsRequired)) - + int64(len(marketPremiums.Premiums)) padding := make([]int32, lenPadding) paddedPremiums := append(marketPremiums.Premiums, padding...) @@ -265,7 +264,7 @@ func (k Keeper) processPremiumVotesIntoSamples( newFundingSampleEpoch, types.PremiumVotesKey, k.GetMinNumVotesPerSample(ctx), - lib.MustGetMedianInt32, // combineFunc + lib.MustGetMedian[int32], // combineFunc func(input []int32) []int32 { return input }, // filterFunc )