diff --git a/types/dec_coin.go b/types/dec_coin.go index b70a1f476592..cc29c921c61b 100644 --- a/types/dec_coin.go +++ b/types/dec_coin.go @@ -201,6 +201,19 @@ func (coins DecCoins) MulDec(d Dec) DecCoins { return res } +// multiply all the coins by a decimal, truncating +func (coins DecCoins) MulDecTruncate(d Dec) DecCoins { + res := make([]DecCoin, len(coins)) + for i, coin := range coins { + product := DecCoin{ + Denom: coin.Denom, + Amount: coin.Amount.MulTruncate(d), + } + res[i] = product + } + return res +} + // divide all the coins by a decimal func (coins DecCoins) QuoDec(d Dec) DecCoins { res := make([]DecCoin, len(coins)) @@ -214,6 +227,19 @@ func (coins DecCoins) QuoDec(d Dec) DecCoins { return res } +// divide all the coins by a decimal, truncating +func (coins DecCoins) QuoDecTruncate(d Dec) DecCoins { + res := make([]DecCoin, len(coins)) + for i, coin := range coins { + quotient := DecCoin{ + Denom: coin.Denom, + Amount: coin.Amount.QuoTruncate(d), + } + res[i] = quotient + } + return res +} + // returns the amount of a denom from deccoins func (coins DecCoins) AmountOf(denom string) Dec { switch len(coins) { diff --git a/types/decimal.go b/types/decimal.go index 49d6b2d4d494..976ddd560e79 100644 --- a/types/decimal.go +++ b/types/decimal.go @@ -228,6 +228,17 @@ func (d Dec) Mul(d2 Dec) Dec { return Dec{chopped} } +// multiplication truncate +func (d Dec) MulTruncate(d2 Dec) Dec { + mul := new(big.Int).Mul(d.Int, d2.Int) + chopped := chopPrecisionAndTruncate(mul) + + if chopped.BitLen() > 255+DecimalPrecisionBits { + panic("Int overflow") + } + return Dec{chopped} +} + // multiplication func (d Dec) MulInt(i Int) Dec { mul := new(big.Int).Mul(d.Int, i.i) @@ -254,6 +265,22 @@ func (d Dec) Quo(d2 Dec) Dec { return Dec{chopped} } +// quotient truncate +func (d Dec) QuoTruncate(d2 Dec) Dec { + + // multiply precision twice + mul := new(big.Int).Mul(d.Int, precisionReuse) + mul.Mul(mul, precisionReuse) + + quo := new(big.Int).Quo(mul, d2.Int) + chopped := chopPrecisionAndTruncate(quo) + + if chopped.BitLen() > 255+DecimalPrecisionBits { + panic("Int overflow") + } + return Dec{chopped} +} + // quotient func (d Dec) QuoInt(i Int) Dec { mul := new(big.Int).Quo(d.Int, i.i) @@ -351,9 +378,6 @@ func chopPrecisionAndRound(d *big.Int) *big.Int { quo, rem := d, big.NewInt(0) quo, rem = quo.QuoRem(d, precisionReuse, rem) - // TODO testing - return quo - if rem.Sign() == 0 { // remainder is zero return quo } diff --git a/x/distribution/keeper/delegation.go b/x/distribution/keeper/delegation.go index 3381ee206f22..a29f417e7605 100644 --- a/x/distribution/keeper/delegation.go +++ b/x/distribution/keeper/delegation.go @@ -16,13 +16,13 @@ func (k Keeper) initializeDelegation(ctx sdk.Context, val sdk.ValAddress, del sd // calculate delegation stake in tokens // we don't store directly, so multiply delegation shares * (tokens per share) - stake := delegation.GetShares().Mul(validator.GetDelegatorShareExRate()) + stake := delegation.GetShares().MulTruncate(validator.GetDelegatorShareExRate()) k.SetDelegatorStartingInfo(ctx, val, del, types.NewDelegatorStartingInfo(previousPeriod, stake, uint64(ctx.BlockHeight()))) } // calculate the rewards accrued by a delegation between two periods func (k Keeper) calculateDelegationRewardsBetween(ctx sdk.Context, val sdk.Validator, - startingPeriod, endingPeriod uint64, staking sdk.Dec) (rewards sdk.DecCoins) { + startingPeriod, endingPeriod uint64, stake sdk.Dec) (rewards sdk.DecCoins) { // sanity check if startingPeriod > endingPeriod { panic("startingPeriod cannot be greater than endingPeriod") @@ -32,7 +32,7 @@ func (k Keeper) calculateDelegationRewardsBetween(ctx sdk.Context, val sdk.Valid starting := k.GetValidatorHistoricalRewards(ctx, val.GetOperator(), startingPeriod) ending := k.GetValidatorHistoricalRewards(ctx, val.GetOperator(), endingPeriod) difference := ending.CumulativeRewardRatio.Minus(starting.CumulativeRewardRatio) - rewards = difference.MulDec(staking) + rewards = difference.MulDecTruncate(stake) return } @@ -54,7 +54,7 @@ func (k Keeper) calculateDelegationRewards(ctx sdk.Context, val sdk.Validator, d func(height uint64, event types.ValidatorSlashEvent) (stop bool) { endingPeriod := event.ValidatorPeriod rewards = rewards.Plus(k.calculateDelegationRewardsBetween(ctx, val, startingPeriod, endingPeriod, stake)) - stake = stake.Mul(sdk.OneDec().Sub(event.Fraction)) + stake = stake.MulTruncate(sdk.OneDec().Sub(event.Fraction)) startingPeriod = endingPeriod return false }, diff --git a/x/distribution/keeper/validator.go b/x/distribution/keeper/validator.go index d7bebb9c7ccb..aac0bbb72a72 100644 --- a/x/distribution/keeper/validator.go +++ b/x/distribution/keeper/validator.go @@ -38,7 +38,7 @@ func (k Keeper) incrementValidatorPeriod(ctx sdk.Context, val sdk.Validator) uin current = sdk.DecCoins{} } else { - current = rewards.Rewards.QuoDec(sdk.NewDecFromInt(val.GetTokens())) + current = rewards.Rewards.QuoDecTruncate(sdk.NewDecFromInt(val.GetTokens())) } // fetch historical rewards for last period