Skip to content

Commit

Permalink
Implemented a mul_div operation for Uints and reduced overflow risks …
Browse files Browse the repository at this point in the history
…in inflation computations.
  • Loading branch information
murisi committed Oct 20, 2023
1 parent ab20766 commit 8064a89
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 23 deletions.
37 changes: 17 additions & 20 deletions core/src/ledger/storage/masp_conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ where
let noterized_inflation = if total_token_in_masp.is_zero() {
0u128
} else {
crate::types::uint::Uint::try_into(
(inflation * crate::types::uint::Uint::from(precision))
/ total_token_in_masp.raw_amount(),
)
.unwrap()
inflation
.checked_mul_div(
crate::types::uint::Uint::from(precision),
total_token_in_masp.raw_amount(),
)
.0
.map_or(u128::MAX, |x| x.try_into().unwrap_or(u128::MAX))
};

tracing::debug!(
Expand Down Expand Up @@ -159,21 +161,17 @@ where
// but we should make sure the return value's ratio matches
// this new inflation rate in 'update_allowed_conversions',
// otherwise we will have an inaccurate view of inflation
wl_storage
.write(
&token::masp_last_inflation_key(addr),
token::Amount::from_uint(
(total_token_in_masp.raw_amount() / precision)
* crate::types::uint::Uint::from(noterized_inflation),
0,
)
.unwrap(),
wl_storage.write(
&token::masp_last_inflation_key(addr),
token::Amount::from_uint(
(total_token_in_masp.raw_amount() / precision)
* crate::types::uint::Uint::from(noterized_inflation),
0,
)
.expect("unable to encode new inflation rate (Decimal)");
.unwrap(),
)?;

wl_storage
.write(&token::masp_last_locked_ratio_key(addr), locked_ratio)
.expect("unable to encode new locked ratio (Decimal)");
wl_storage.write(&token::masp_last_locked_ratio_key(addr), locked_ratio)?;

Ok((noterized_inflation, precision))
}
Expand Down Expand Up @@ -238,8 +236,7 @@ where
let mut ref_inflation = 0;
// Reward all tokens according to above reward rates
for addr in &masp_reward_keys {
let reward = calculate_masp_rewards(wl_storage, addr)
.expect("Calculating the masp rewards should not fail");
let reward = calculate_masp_rewards(wl_storage, addr)?;
if *addr == native_token {
// The reference inflation is the denominator of the native token
// inflation, which is always a constant
Expand Down
302 changes: 299 additions & 3 deletions core/src/types/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,272 @@ pub const ZERO: Uint = Uint::from_u64(0);
pub const ONE: Uint = Uint::from_u64(1);

impl Uint {
const N_WORDS: usize = 4;

/// Convert a [`u64`] to a [`Uint`].
pub const fn from_u64(x: u64) -> Uint {
Uint([x.to_le(), 0, 0, 0])
}

/// Return the least number of bits needed to represent the number
#[inline]
pub fn bits_512(arr: &[u64; 2 * Self::N_WORDS]) -> usize {
for i in 1..arr.len() {
if arr[arr.len() - i] > 0 {
return (0x40 * (arr.len() - i + 1))
- arr[arr.len() - i].leading_zeros() as usize;
}
}
0x40 - arr[0].leading_zeros() as usize
}

fn div_mod_small_512(
mut slf: [u64; 2 * Self::N_WORDS],
other: u64,
) -> ([u64; 2 * Self::N_WORDS], Self) {
let mut rem = 0u64;
slf.iter_mut().rev().for_each(|d| {
let (q, r) = Self::div_mod_word(rem, *d, other);
*d = q;
rem = r;
});
(slf, rem.into())
}

fn shr_512(
original: [u64; 2 * Self::N_WORDS],
shift: u32,
) -> [u64; 2 * Self::N_WORDS] {
let shift = shift as usize;
let mut ret = [0u64; 2 * Self::N_WORDS];
let word_shift = shift / 64;
let bit_shift = shift % 64;

// shift
for i in word_shift..original.len() {
ret[i - word_shift] = original[i] >> bit_shift;
}

// Carry
if bit_shift > 0 {
for i in word_shift + 1..original.len() {
ret[i - word_shift - 1] += original[i] << (64 - bit_shift);
}
}

ret
}

fn full_shl_512(
slf: [u64; 2 * Self::N_WORDS],
shift: u32,
) -> [u64; 2 * Self::N_WORDS + 1] {
debug_assert!(shift < Self::WORD_BITS as u32);
let mut u = [0u64; 2 * Self::N_WORDS + 1];
let u_lo = slf[0] << shift;
let u_hi = Self::shr_512(slf, Self::WORD_BITS as u32 - shift);
u[0] = u_lo;
u[1..].copy_from_slice(&u_hi[..]);
u
}

fn full_shr_512(
u: [u64; 2 * Self::N_WORDS + 1],
shift: u32,
) -> [u64; 2 * Self::N_WORDS] {
debug_assert!(shift < Self::WORD_BITS as u32);
let mut res = [0; 2 * Self::N_WORDS];
for i in 0..res.len() {
res[i] = u[i] >> shift;
}
// carry
if shift > 0 {
for i in 1..=res.len() {
res[i - 1] |= u[i] << (Self::WORD_BITS as u32 - shift);
}
}
res
}

// See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D.
fn div_mod_knuth_512(
slf: [u64; 2 * Self::N_WORDS],
mut v: Self,
n: usize,
m: usize,
) -> ([u64; 2 * Self::N_WORDS], Self) {
debug_assert!(Self::bits_512(&slf) >= v.bits() && !v.fits_word());
debug_assert!(n + m <= slf.len());
// D1.
// Make sure 64th bit in v's highest word is set.
// If we shift both self and v, it won't affect the quotient
// and the remainder will only need to be shifted back.
let shift = v.0[n - 1].leading_zeros();
v <<= shift;
// u will store the remainder (shifted)
let mut u = Self::full_shl_512(slf, shift);

// quotient
let mut q = [0; 2 * Self::N_WORDS];
let v_n_1 = v.0[n - 1];
let v_n_2 = v.0[n - 2];

// D2. D7.
// iterate from m downto 0
for j in (0..=m).rev() {
let u_jn = u[j + n];

// D3.
// q_hat is our guess for the j-th quotient digit
// q_hat = min(b - 1, (u_{j+n} * b + u_{j+n-1}) / v_{n-1})
// b = 1 << WORD_BITS
// Theorem B: q_hat >= q_j >= q_hat - 2
let mut q_hat = if u_jn < v_n_1 {
let (mut q_hat, mut r_hat) =
Self::div_mod_word(u_jn, u[j + n - 1], v_n_1);
// this loop takes at most 2 iterations
loop {
// check if q_hat * v_{n-2} > b * r_hat + u_{j+n-2}
let (hi, lo) =
Self::split_u128(u128::from(q_hat) * u128::from(v_n_2));
if (hi, lo) <= (r_hat, u[j + n - 2]) {
break;
}
// then iterate till it doesn't hold
q_hat -= 1;
let (new_r_hat, overflow) = r_hat.overflowing_add(v_n_1);
r_hat = new_r_hat;
// if r_hat overflowed, we're done
if overflow {
break;
}
}
q_hat
} else {
// here q_hat >= q_j >= q_hat - 1
u64::max_value()
};

// ex. 20:
// since q_hat * v_{n-2} <= b * r_hat + u_{j+n-2},
// either q_hat == q_j, or q_hat == q_j + 1

// D4.
// let's assume optimistically q_hat == q_j
// subtract (q_hat * v) from u[j..]
let q_hat_v = v.full_mul_u64(q_hat);
// u[j..] -= q_hat_v;
let c = Self::sub_slice(&mut u[j..], &q_hat_v[..n + 1]);

// D6.
// actually, q_hat == q_j + 1 and u[j..] has overflowed
// highly unlikely ~ (1 / 2^63)
if c {
q_hat -= 1;
// add v to u[j..]
let c = Self::add_slice(&mut u[j..], &v.0[..n]);
u[j + n] = u[j + n].wrapping_add(u64::from(c));
}

// D5.
q[j] = q_hat;
}

// D8.
let remainder = Self::full_shr_512(u, shift);
// The remainder should never exceed the capacity of Self
debug_assert!(
Self::bits_512(&remainder) <= Self::N_WORDS * Self::WORD_BITS
);
(q, Self(remainder[..Self::N_WORDS].try_into().unwrap()))
}

/// Returns a pair `(self / other, self % other)`.
///
/// # Panics
///
/// Panics if `other` is zero.
pub fn div_mod_512(
slf: [u64; 2 * Self::N_WORDS],
other: Self,
) -> ([u64; 2 * Self::N_WORDS], Self) {
let my_bits = Self::bits_512(&slf);
let your_bits = other.bits();

assert!(your_bits != 0, "division by zero");

// Early return in case we are dividing by a larger number than us
if my_bits < your_bits {
return (
[0; 2 * Self::N_WORDS],
Self(slf[..Self::N_WORDS].try_into().unwrap()),
);
}

if your_bits <= Self::WORD_BITS {
return Self::div_mod_small_512(slf, other.low_u64());
}

let (n, m) = {
let my_words = Self::words(my_bits);
let your_words = Self::words(your_bits);
(your_words, my_words - your_words)
};

Self::div_mod_knuth_512(slf, other, n, m)
}

/// Returns a pair `(Some((self * num) / denom), (self * num) % denom)` if
/// the quotient fits into Self. Otherwise `(None, (self * num) % denom)` is
/// returned.
///
/// # Panics
///
/// Panics if `denom` is zero.
pub fn checked_mul_div(
&self,
num: Self,
denom: Self,
) -> (Option<Self>, Self) {
let prod = uint::uint_full_mul_reg!(Uint, 4, self, num);
let (quotient, remainder) = Self::div_mod_512(prod, denom);
// The compiler WILL NOT inline this if you remove this annotation.
#[inline(always)]
fn any_nonzero(arr: &[u64]) -> bool {
use uint::unroll;
unroll! {
for i in 0..4 {
if arr[i] != 0 {
return true;
}
}
}

false
}
(
if any_nonzero(&quotient[Self::N_WORDS..]) {
None
} else {
Some(Self(quotient[0..Self::N_WORDS].try_into().unwrap()))
},
remainder,
)
}

/// Returns a pair `((self * num) / denom, (self * num) % denom)`.
///
/// # Panics
///
/// Panics if `denom` is zero.
pub fn mul_div(&self, num: Self, denom: Self) -> (Self, Self) {
let prod = uint::uint_full_mul_reg!(Uint, 4, self, num);
let (quotient, remainder) = Self::div_mod_512(prod, denom);
(
Self(quotient[0..Self::N_WORDS].try_into().unwrap()),
remainder,
)
}
}

construct_uint! {
Expand Down Expand Up @@ -171,10 +433,9 @@ impl Uint {
/// * `self` * 10^(`denom`) overflows 256 bits
/// * `other` is zero (`checked_div` will return `None`).
pub fn fixed_precision_div(&self, rhs: &Self, denom: u8) -> Option<Self> {
let lhs = Uint::from(10)
Uint::from(10)
.checked_pow(Uint::from(denom))
.and_then(|res| res.checked_mul(*self))?;
lhs.checked_div(*rhs)
.and_then(|res| res.checked_mul_div(*self, *rhs).0)
}

/// Compute the two's complement of a number.
Expand Down Expand Up @@ -710,4 +971,39 @@ mod test_uint {
let amount: Result<Uint, _> = serde_json::from_str(r#""1000000000.2""#);
assert!(amount.is_err());
}

#[test]
fn test_mul_div() {
use std::str::FromStr;
let a: Uint = Uint::from_str(
"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
).unwrap();
let b: Uint = Uint::from_str(
"0x8000000000000000000000000000000000000000000000000000000000000000",
).unwrap();
let c: Uint = Uint::from_str(
"0x4000000000000000000000000000000000000000000000000000000000000000",
).unwrap();
let d: Uint = Uint::from_str(
"0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
).unwrap();
let e: Uint = Uint::from_str(
"0x0000000000000000000000000000000000000000000000000000000000000001",
).unwrap();
let f: Uint = Uint::from_str(
"0x0000000000000000000000000000000000000000000000000000000000000000",
).unwrap();
assert_eq!(a.mul_div(a, a), (a, Uint::zero()));
assert_eq!(b.mul_div(c, b), (c, Uint::zero()));
assert_eq!(a.mul_div(c, b), (d, c));
assert_eq!(a.mul_div(e, e), (a, Uint::zero()));
assert_eq!(e.mul_div(c, b), (Uint::zero(), c));
assert_eq!(f.mul_div(a, e), (Uint::zero(), Uint::zero()));
assert_eq!(a.checked_mul_div(a, a), (Some(a), Uint::zero()));
assert_eq!(b.checked_mul_div(c, b), (Some(c), Uint::zero()));
assert_eq!(a.checked_mul_div(c, b), (Some(d), c));
assert_eq!(a.checked_mul_div(e, e), (Some(a), Uint::zero()));
assert_eq!(e.checked_mul_div(c, b), (Some(Uint::zero()), c));
assert_eq!(d.checked_mul_div(a, e), (None, Uint::zero()));
}
}

0 comments on commit 8064a89

Please sign in to comment.