diff --git a/math/uint.go b/math/uint.go index be588b0cd342..97177ea7f0fa 100644 --- a/math/uint.go +++ b/math/uint.go @@ -21,6 +21,14 @@ func (u Uint) BigInt() *big.Int { return new(big.Int).Set(u.i) } +// BigInt converts Uint to big.Int, mutative the input +func (u Uint) BigIntMut() *big.Int { + if u.IsNil() { + return nil + } + return u.i +} + // IsNil returns true if Uint is uninitialized func (u Uint) IsNil() bool { return u.i == nil diff --git a/math/uint_test.go b/math/uint_test.go index b75b7514a584..93abdf886dae 100644 --- a/math/uint_test.go +++ b/math/uint_test.go @@ -99,6 +99,26 @@ func (s *uintTestSuite) TestIsNil() { s.Require().True(sdkmath.Uint{}.IsNil()) } +func (s *uintTestSuite) TestConvertToBigIntMutativeForUint() { + r := big.NewInt(30) + i := sdkmath.NewUintFromBigInt(r) + + // Compare value of BigInt & BigIntMut + s.Require().Equal(i.BigInt(), i.BigIntMut()) + + // Modify BigIntMut() pointer and ensure i.BigIntMut() & i.BigInt() change + p1 := i.BigIntMut() + p1.SetInt64(40) + s.Require().Equal(big.NewInt(40), i.BigIntMut()) + s.Require().Equal(big.NewInt(40), i.BigInt()) + + // Modify big.Int() pointer and ensure i.BigIntMut() & i.BigInt() don't change + p2 := i.BigInt() + p2.SetInt64(50) + s.Require().NotEqual(big.NewInt(50), i.BigIntMut()) + s.Require().NotEqual(big.NewInt(50), i.BigInt()) +} + func (s *uintTestSuite) TestArithUint() { for d := 0; d < 1000; d++ { n1 := uint64(rand.Uint32())