diff --git a/osmomath/decimal.go b/osmomath/decimal.go index 955b6229f58..e7801b481b5 100644 --- a/osmomath/decimal.go +++ b/osmomath/decimal.go @@ -240,12 +240,20 @@ func (d BigDec) BigInt() *big.Int { // addition func (d BigDec) Add(d2 BigDec) BigDec { - res := new(big.Int).Add(d.i, d2.i) + copy := d.Clone() + copy.AddMut(d2) + return copy +} - if res.BitLen() > maxDecBitLen { +// mutative addition +func (d BigDec) AddMut(d2 BigDec) BigDec { + d.i.Add(d.i, d2.i) + + if d.i.BitLen() > maxDecBitLen { panic("Int overflow") } - return BigDec{res} + + return d } // subtraction @@ -319,19 +327,25 @@ func (d BigDec) MulInt64(i int64) BigDec { // quotient func (d BigDec) Quo(d2 BigDec) BigDec { + copy := d.Clone() + copy.QuoMut(d2) + return copy +} + +// mutative quotient +func (d BigDec) QuoMut(d2 BigDec) BigDec { // multiply precision twice - mul := new(big.Int).Mul(d.i, precisionReuse) - mul.Mul(mul, precisionReuse) + d.i.Mul(d.i, precisionReuse) + d.i.Mul(d.i, precisionReuse) - quo := new(big.Int).Quo(mul, d2.i) - chopped := chopPrecisionAndRound(quo) + d.i.Quo(d.i, d2.i) + chopPrecisionAndRound(d.i) - if chopped.BitLen() > maxDecBitLen { + if d.i.BitLen() > maxDecBitLen { panic("Int overflow") } - return BigDec{chopped} + return d } - func (d BigDec) QuoRaw(d2 int64) BigDec { // multiply precision, so we can chop it later mul := new(big.Int).Mul(d.i, precisionReuse) diff --git a/osmomath/decimal_test.go b/osmomath/decimal_test.go index c1ed64eae1d..d68ec66f291 100644 --- a/osmomath/decimal_test.go +++ b/osmomath/decimal_test.go @@ -37,6 +37,53 @@ func (s *decimalTestSuite) assertMutResult(expectedResult, startValue, mutativeR s.Require().Equal(nonMutativeStartValue, startValue) } +func (s *decimalTestSuite) TestAddMut() { + toAdd := osmomath.MustNewDecFromStr("10") + tests := map[string]struct { + startValue osmomath.BigDec + expectedMutResult osmomath.BigDec + }{ + "0": {osmomath.NewBigDec(0), osmomath.NewBigDec(10)}, + "1": {osmomath.NewBigDec(1), osmomath.NewBigDec(11)}, + "10": {osmomath.NewBigDec(10), osmomath.NewBigDec(20)}, + } + + for name, tc := range tests { + s.Run(name, func() { + startMut := tc.startValue.Clone() + startNonMut := tc.startValue.Clone() + + resultMut := startMut.AddMut(toAdd) + resultNonMut := startNonMut.Add(toAdd) + + s.assertMutResult(tc.expectedMutResult, tc.startValue, resultMut, resultNonMut, startMut, startNonMut) + }) + } +} + +func (s *decimalTestSuite) TestQuoMut() { + quoBy := osmomath.MustNewDecFromStr("2") + tests := map[string]struct { + startValue osmomath.BigDec + expectedMutResult osmomath.BigDec + }{ + "0": {osmomath.NewBigDec(0), osmomath.NewBigDec(0)}, + "1": {osmomath.NewBigDec(1), osmomath.MustNewDecFromStr("0.5")}, + "10": {osmomath.NewBigDec(10), osmomath.NewBigDec(5)}, + } + + for name, tc := range tests { + s.Run(name, func() { + startMut := tc.startValue.Clone() + startNonMut := tc.startValue.Clone() + + resultMut := startMut.QuoMut(quoBy) + resultNonMut := startNonMut.Quo(quoBy) + + s.assertMutResult(tc.expectedMutResult, tc.startValue, resultMut, resultNonMut, startMut, startNonMut) + }) + } +} func TestDecApproxEq(t *testing.T) { // d1 = 0.55, d2 = 0.6, tol = 0.1 d1 := osmomath.NewDecWithPrec(55, 2) diff --git a/osmomath/exp2.go b/osmomath/exp2.go new file mode 100644 index 00000000000..39b2333ebd3 --- /dev/null +++ b/osmomath/exp2.go @@ -0,0 +1,101 @@ +package osmomath + +import "fmt" + +var ( + // Truncated at precision end. + // See scripts/approximations/main.py exponent_approximation_choice function for details. + numeratorCoefficients13Param = []BigDec{ + MustNewDecFromStr("1.000000000000000000000044212244679434"), + MustNewDecFromStr("0.352032455817400196452603772766844426"), + MustNewDecFromStr("0.056507868883666405413116800969512484"), + MustNewDecFromStr("0.005343900728213034434757419480319916"), + MustNewDecFromStr("0.000317708814342353603087543715930732"), + MustNewDecFromStr("0.000011429747507407623028722262874632"), + MustNewDecFromStr("0.000000198381965651614980168744540366"), + } + + // Rounded up at precision end. + // See scripts/approximations/main.py exponent_approximation_choice function for details. + denominatorCoefficients13Param = []BigDec{ + OneDec(), + MustNewDecFromStr("0.341114724742545112949699755780593311").Neg(), + MustNewDecFromStr("0.052724071627342653404436933178482287"), + MustNewDecFromStr("0.004760950735524957576233524801866342").Neg(), + MustNewDecFromStr("0.000267168475410566529819971616894193"), + MustNewDecFromStr("0.000008923715368802211181557353097439").Neg(), + MustNewDecFromStr("0.000000140277233177373698516010555916"), + } + + // maxSupportedExponent = 2^10. The value is chosen by benchmarking + // when the underlying internal functions overflow. + // If needed in the future, Exp2 can be reimplemented to allow for greater exponents. + maxSupportedExponent = MustNewDecFromStr("2").PowerInteger(9) +) + +// Exp2 takes 2 to the power of a given non-negative decimal exponent +// and returns the result. +// The computation is performed by using th following property: +// 2^decimal_exp = 2^{integer_exp + fractional_exp} = 2^integer_exp * 2^fractional_exp +// The max supported exponent is defined by the global maxSupportedExponent. +// If a greater exponent is given, the function panics. +// Panics if the exponent is negative. +// The answer is correct up to a factor of 10^-18. +// Meaning, result = result * k for k in [1 - 10^(-18), 1 + 10^(-18)] +// Note: our Python script plots show accuracy up to a factor of 10^22. +// However, in Go tests we only test up to 10^18. Therefore, this is the guarantee. +func Exp2(exponent BigDec) BigDec { + if exponent.Abs().GT(maxSupportedExponent) { + panic(fmt.Sprintf("integer exponent %s is too large, max (%s)", exponent, maxSupportedExponent)) + } + if exponent.IsNegative() { + panic(fmt.Sprintf("negative exponent %s is not supported", exponent)) + } + + integerExponent := exponent.TruncateDec() + + fractionalExponent := exponent.Sub(integerExponent) + fractionalResult := exp2ChebyshevRationalApprox(fractionalExponent) + + // Left bit shift is equivalent to multiplying by 2^integerExponent. + fractionalResult.i = fractionalResult.i.Lsh(fractionalResult.i, uint(integerExponent.TruncateInt().Uint64())) + + return fractionalResult +} + +// exp2ChebyshevRationalApprox takes 2 to the power of a given decimal exponent. +// The result is approximated by a 13 parameter Chebyshev rational approximation. +// f(x) = h(x) / p(x) (7, 7) terms. We set the first term of p(x) to 1. +// As a result, this ends up being 7 + 6 = 13 parameters. +// The numerator coefficients are truncated at precision end. The denominator +// coefficients are rounded up at precision end. +// See scripts/approximations/README.md for details of the scripts used +// to compute the coefficients. +// CONTRACT: exponent must be in the range [0, 1], panics if not. +// The answer is correct up to a factor of 10^-18. +// Meaning, result = result * k for k in [1 - 10^(-18), 1 + 10^(-18)] +// Note: our Python script plots show accuracy up to a factor of 10^22. +// However, in Go tests we only test up to 10^18. Therefore, this is the guarantee. +func exp2ChebyshevRationalApprox(x BigDec) BigDec { + if x.LT(ZeroDec()) || x.GT(OneDec()) { + panic(fmt.Sprintf("exponent must be in the range [0, 1], got %s", x)) + } + if x.IsZero() { + return OneDec() + } + if x.Equal(OneDec()) { + return twoBigDec + } + + h_x := numeratorCoefficients13Param[0].Clone() + p_x := denominatorCoefficients13Param[0].Clone() + x_exp_i := OneDec() + for i := 1; i < len(numeratorCoefficients13Param); i++ { + x_exp_i.MulMut(x) + + h_x.AddMut(numeratorCoefficients13Param[i].Mul(x_exp_i)) + p_x.AddMut(denominatorCoefficients13Param[i].Mul(x_exp_i)) + } + + return h_x.QuoMut(p_x) +}