Skip to content

Commit

Permalink
genericize median (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
BrendanChou authored Sep 13, 2023
1 parent cbfd63a commit a0efdc8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (mte *MarketToExchangePrices) GetValidMedianPrices(
// The number of valid prices must be >= min number of exchanges.
if len(validPrices) >= int(marketParam.MinExchanges) {
// Calculate the median. Returns an error if the input is empty.
median, err := lib.MedianUint64(validPrices)
median, err := lib.Median(validPrices)
if err != nil {
telemetry.IncrCounterWithLabels(
[]string{
Expand Down
36 changes: 6 additions & 30 deletions protocol/lib/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,6 @@ func Min[T constraints.Ordered](x, y T) T {
return x
}

func MaxUint32(x, y uint32) uint32 {
if x < y {
return y
}

return x
}

func Int64MulPpm(x int64, ppm uint32) int64 {
xMulPpm := BigIntMulPpm(big.NewInt(x), ppm)

Expand Down Expand Up @@ -163,25 +155,9 @@ func ChangeRateUint64(originalV uint64, newV uint64) (float32, error) {
return result, nil
}

// MedianUint64 returns the median value of the input slice. Note that if the
// input has an even number of elements, then the returned median is rounded up
// the nearest uint64. For example, 6.5 is rounded up to 7.
func MedianUint64(input []uint64) (uint64, error) {
return medianIntGeneric(input)
}

// MedianInt32 returns the median value of the input slice. Note that if the
// input has an even number of elements, then the returned median is rounded
// towards positive/negative infinity, to the nearest int32. For example,
// 6.5 is rounded to 7 and -4.5 is rounded to -5.
func MedianInt32(input []int32) (int32, error) {
return medianIntGeneric(input)
}

// MustGetMedianInt32 is a wrapper around `MedianInt32` that panics if
// input length is zero.
func MustGetMedianInt32(input []int32) int32 {
ret, err := MedianInt32(input)
// MustGetMedian is a wrapper around `Median` that panics if input length is zero.
func MustGetMedian[V uint64 | uint32 | int64 | int32](input []V) V {
ret, err := Median(input)

if err != nil {
panic(err)
Expand All @@ -190,9 +166,9 @@ func MustGetMedianInt32(input []int32) int32 {
return ret
}

// medianIntGeneric is a generic median calculator.
// It currently supports `uint64`, `int32` and more types can be added.
func medianIntGeneric[V uint64 | int32](input []V) (V, error) {
// Median is a generic median calculator.
// If the input has an even number of elements, then the average of the two middle numbers is rounded away from zero.
func Median[V uint64 | uint32 | int64 | int32](input []V) (V, error) {
l := len(input)
if l == 0 {
return 0, errors.New("input cannot be empty")
Expand Down
51 changes: 8 additions & 43 deletions protocol/lib/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,41 +261,6 @@ func TestGenericMaxFloat64(t *testing.T) {
}
}

func TestMaxUInt32(t *testing.T) {
tests := map[string]struct {
x uint32
y uint32
expectedResult uint32
}{
"Equal": {
x: 5,
y: 5,
expectedResult: 5,
},
"X is Less": {
x: 4,
y: 5,
expectedResult: 5,
},
"Y is Less": {
x: 5,
y: 4,
expectedResult: 5,
},
"Zero": {
x: 0,
y: 0,
expectedResult: 0,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := lib.MaxUint32(tc.x, tc.y)
require.Equal(t, tc.expectedResult, result)
})
}
}

func TestInt64MulPpm(t *testing.T) {
tests := map[string]struct {
x int64
Expand Down Expand Up @@ -600,23 +565,23 @@ func TestChangeRateUint64(t *testing.T) {
}
}

func TestMustGetMedianInt32_Failure(t *testing.T) {
func TestMustGetMedian_Failure(t *testing.T) {
require.PanicsWithError(t,
"input cannot be empty",
func() {
lib.MustGetMedianInt32([]int32{})
lib.MustGetMedian([]int32{})
},
)
}

func TestMustGetMedianInt32_Success(t *testing.T) {
func TestMustGetMedian_Success(t *testing.T) {
require.Equal(t,
int32(5),
lib.MustGetMedianInt32([]int32{8, 1, -5, 100, -50, 59}),
lib.MustGetMedian([]int32{8, 1, -5, 100, -50, 59}),
)
}

func TestMedianInt32(t *testing.T) {
func TestMedian_Int32(t *testing.T) {
tests := map[string]struct {
input []int32
expectedResult int32
Expand Down Expand Up @@ -720,7 +685,7 @@ func TestMedianInt32(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result, err := lib.MedianInt32(tc.input)
result, err := lib.Median(tc.input)
require.Equal(t, tc.expectedResult, result)
if tc.expectedError {
require.EqualError(t, err, "input cannot be empty")
Expand All @@ -731,7 +696,7 @@ func TestMedianInt32(t *testing.T) {
}
}

func TestMedianUint64(t *testing.T) {
func TestMedian_Uint64(t *testing.T) {
tests := map[string]struct {
input []uint64
expectedResult uint64
Expand Down Expand Up @@ -770,7 +735,7 @@ func TestMedianUint64(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result, err := lib.MedianUint64(tc.input)
result, err := lib.Median(tc.input)
require.Equal(t, tc.expectedResult, result)
if tc.expectedError {
require.EqualError(t, err, "input cannot be empty")
Expand Down
4 changes: 2 additions & 2 deletions protocol/lib/medianizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type Medianizer interface {
MedianUint64(input []uint64) (uint64, error)
}

// MedianUint64 wraps `lib.MedianUint64` which gets the median of a uint64 slice.
// MedianUint64 wraps `lib.Median` which gets the median of a uint64 slice.
func (r *MedianizerImpl) MedianUint64(input []uint64) (uint64, error) {
return MedianUint64(input)
return Median(input)
}
2 changes: 1 addition & 1 deletion protocol/x/clob/keeper/process_single_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ func (k Keeper) setOrderFillAmountsAndPruning(
if !order.IsStatefulOrder() {
// Compute the block at which this state fill amount can be pruned. This is the greater of
// `GoodTilBlock + ShortBlockWindow` and the existing `pruneableBlockHeight`.
pruneableBlockHeight = lib.MaxUint32(
pruneableBlockHeight = lib.Max(
order.GetGoodTilBlock()+types.ShortBlockWindow,
curPruneableBlockHeight,
)
Expand Down
7 changes: 3 additions & 4 deletions protocol/x/perpetuals/keeper/perpetual.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,8 @@ func (k Keeper) processStoredPremiums(
// if block times are longer than expected and hence there were not enough blocks to
// collect votes.
// Note `NumPremiums >= len(marketPremiums.Premiums)`, so `lenPadding >= 0`.
lenPadding := int64(
lib.MaxUint32(premiumStore.NumPremiums,
minNumPremiumsRequired)) - int64(len(marketPremiums.Premiums))
lenPadding := int64(lib.Max(premiumStore.NumPremiums, minNumPremiumsRequired)) -
int64(len(marketPremiums.Premiums))

padding := make([]int32, lenPadding)
paddedPremiums := append(marketPremiums.Premiums, padding...)
Expand Down Expand Up @@ -265,7 +264,7 @@ func (k Keeper) processPremiumVotesIntoSamples(
newFundingSampleEpoch,
types.PremiumVotesKey,
k.GetMinNumVotesPerSample(ctx),
lib.MustGetMedianInt32, // combineFunc
lib.MustGetMedian[int32], // combineFunc
func(input []int32) []int32 { return input }, // filterFunc
)

Expand Down

0 comments on commit a0efdc8

Please sign in to comment.