diff --git a/core/src/ledger/storage/masp_conversions.rs b/core/src/ledger/storage/masp_conversions.rs index 7a4b6b7701a..50f37256fdb 100644 --- a/core/src/ledger/storage/masp_conversions.rs +++ b/core/src/ledger/storage/masp_conversions.rs @@ -16,6 +16,7 @@ use crate::types::address::Address; use crate::types::dec::Dec; use crate::types::storage::Epoch; use crate::types::token::MaspDenom; +use crate::types::uint::Uint; use crate::types::{address, token}; /// A representation of the conversion state @@ -126,11 +127,22 @@ where let noterized_inflation = if total_token_in_masp.is_zero() { 0u128 } else { - crate::types::uint::Uint::try_into( - (inflation * crate::types::uint::Uint::from(precision)) - / total_token_in_masp.raw_amount(), - ) - .unwrap() + inflation + .checked_mul_div( + Uint::from(precision), + total_token_in_masp.raw_amount(), + ) + .0 + .and_then(|x| x.try_into().ok()) + .unwrap_or_else(|| { + tracing::warn!( + "MASP inflation for {} assumed to be 0 because the \ + computed value is too large. Please check the inflation \ + parameters.", + *addr + ); + 0u128 + }) }; tracing::debug!( @@ -159,21 +171,17 @@ where // but we should make sure the return value's ratio matches // this new inflation rate in 'update_allowed_conversions', // otherwise we will have an inaccurate view of inflation - wl_storage - .write( - &token::masp_last_inflation_key(addr), - token::Amount::from_uint( - (total_token_in_masp.raw_amount() / precision) - * crate::types::uint::Uint::from(noterized_inflation), - 0, - ) - .unwrap(), + wl_storage.write( + &token::masp_last_inflation_key(addr), + token::Amount::from_uint( + (total_token_in_masp.raw_amount() / precision) + * Uint::from(noterized_inflation), + 0, ) - .expect("unable to encode new inflation rate (Decimal)"); + .unwrap(), + )?; - wl_storage - .write(&token::masp_last_locked_ratio_key(addr), locked_ratio) - .expect("unable to encode new locked ratio (Decimal)"); + wl_storage.write(&token::masp_last_locked_ratio_key(addr), locked_ratio)?; Ok((noterized_inflation, precision)) } @@ -238,8 +246,7 @@ where let mut ref_inflation = 0; // Reward all tokens according to above reward rates for addr in &masp_reward_keys { - let reward = calculate_masp_rewards(wl_storage, addr) - .expect("Calculating the masp rewards should not fail"); + let reward = calculate_masp_rewards(wl_storage, addr)?; if *addr == native_token { // The reference inflation is the denominator of the native token // inflation, which is always a constant @@ -273,8 +280,21 @@ where // The amount that will be given of the new native token for // every amount of the native token given in the // previous epoch - let new_normed_inflation = *normed_inflation - + (*normed_inflation * reward.0) / reward.1; + let new_normed_inflation = Uint::from(*normed_inflation) + .checked_add( + (Uint::from(*normed_inflation) * Uint::from(reward.0)) + / reward.1, + ) + .and_then(|x| x.try_into().ok()) + .unwrap_or_else(|| { + tracing::warn!( + "MASP reward for {} assumed to be 0 because the \ + computed value is too large. Please check the \ + inflation parameters.", + *addr + ); + *normed_inflation + }); // The conversion is computed such that if consecutive // conversions are added together, the // intermediate native tokens cancel/ @@ -308,8 +328,19 @@ where // Express the inflation reward in real terms, that is, with // respect to the native asset in the zeroth // epoch - let real_reward = - (reward.0 * ref_inflation) / *normed_inflation; + let real_reward = ((Uint::from(reward.0) + * Uint::from(ref_inflation)) + / *normed_inflation) + .try_into() + .unwrap_or_else(|_| { + tracing::warn!( + "MASP reward for {} assumed to be 0 because the \ + computed value is too large. Please check the \ + inflation parameters.", + *addr + ); + 0u128 + }); // The conversion is computed such that if consecutive // conversions are added together, the // intermediate tokens cancel/ telescope out diff --git a/core/src/types/uint.rs b/core/src/types/uint.rs index ee14e67ad11..8cb9a7db791 100644 --- a/core/src/types/uint.rs +++ b/core/src/types/uint.rs @@ -21,10 +21,272 @@ pub const ZERO: Uint = Uint::from_u64(0); pub const ONE: Uint = Uint::from_u64(1); impl Uint { + const N_WORDS: usize = 4; + /// Convert a [`u64`] to a [`Uint`]. pub const fn from_u64(x: u64) -> Uint { Uint([x.to_le(), 0, 0, 0]) } + + /// Return the least number of bits needed to represent the number + #[inline] + pub fn bits_512(arr: &[u64; 2 * Self::N_WORDS]) -> usize { + for i in 1..arr.len() { + if arr[arr.len() - i] > 0 { + return (0x40 * (arr.len() - i + 1)) + - arr[arr.len() - i].leading_zeros() as usize; + } + } + 0x40 - arr[0].leading_zeros() as usize + } + + fn div_mod_small_512( + mut slf: [u64; 2 * Self::N_WORDS], + other: u64, + ) -> ([u64; 2 * Self::N_WORDS], Self) { + let mut rem = 0u64; + slf.iter_mut().rev().for_each(|d| { + let (q, r) = Self::div_mod_word(rem, *d, other); + *d = q; + rem = r; + }); + (slf, rem.into()) + } + + fn shr_512( + original: [u64; 2 * Self::N_WORDS], + shift: u32, + ) -> [u64; 2 * Self::N_WORDS] { + let shift = shift as usize; + let mut ret = [0u64; 2 * Self::N_WORDS]; + let word_shift = shift / 64; + let bit_shift = shift % 64; + + // shift + for i in word_shift..original.len() { + ret[i - word_shift] = original[i] >> bit_shift; + } + + // Carry + if bit_shift > 0 { + for i in word_shift + 1..original.len() { + ret[i - word_shift - 1] += original[i] << (64 - bit_shift); + } + } + + ret + } + + fn full_shl_512( + slf: [u64; 2 * Self::N_WORDS], + shift: u32, + ) -> [u64; 2 * Self::N_WORDS + 1] { + debug_assert!(shift < Self::WORD_BITS as u32); + let mut u = [0u64; 2 * Self::N_WORDS + 1]; + let u_lo = slf[0] << shift; + let u_hi = Self::shr_512(slf, Self::WORD_BITS as u32 - shift); + u[0] = u_lo; + u[1..].copy_from_slice(&u_hi[..]); + u + } + + fn full_shr_512( + u: [u64; 2 * Self::N_WORDS + 1], + shift: u32, + ) -> [u64; 2 * Self::N_WORDS] { + debug_assert!(shift < Self::WORD_BITS as u32); + let mut res = [0; 2 * Self::N_WORDS]; + for i in 0..res.len() { + res[i] = u[i] >> shift; + } + // carry + if shift > 0 { + for i in 1..=res.len() { + res[i - 1] |= u[i] << (Self::WORD_BITS as u32 - shift); + } + } + res + } + + // See Knuth, TAOCP, Volume 2, section 4.3.1, Algorithm D. + fn div_mod_knuth_512( + slf: [u64; 2 * Self::N_WORDS], + mut v: Self, + n: usize, + m: usize, + ) -> ([u64; 2 * Self::N_WORDS], Self) { + debug_assert!(Self::bits_512(&slf) >= v.bits() && !v.fits_word()); + debug_assert!(n + m <= slf.len()); + // D1. + // Make sure 64th bit in v's highest word is set. + // If we shift both self and v, it won't affect the quotient + // and the remainder will only need to be shifted back. + let shift = v.0[n - 1].leading_zeros(); + v <<= shift; + // u will store the remainder (shifted) + let mut u = Self::full_shl_512(slf, shift); + + // quotient + let mut q = [0; 2 * Self::N_WORDS]; + let v_n_1 = v.0[n - 1]; + let v_n_2 = v.0[n - 2]; + + // D2. D7. + // iterate from m downto 0 + for j in (0..=m).rev() { + let u_jn = u[j + n]; + + // D3. + // q_hat is our guess for the j-th quotient digit + // q_hat = min(b - 1, (u_{j+n} * b + u_{j+n-1}) / v_{n-1}) + // b = 1 << WORD_BITS + // Theorem B: q_hat >= q_j >= q_hat - 2 + let mut q_hat = if u_jn < v_n_1 { + let (mut q_hat, mut r_hat) = + Self::div_mod_word(u_jn, u[j + n - 1], v_n_1); + // this loop takes at most 2 iterations + loop { + // check if q_hat * v_{n-2} > b * r_hat + u_{j+n-2} + let (hi, lo) = + Self::split_u128(u128::from(q_hat) * u128::from(v_n_2)); + if (hi, lo) <= (r_hat, u[j + n - 2]) { + break; + } + // then iterate till it doesn't hold + q_hat -= 1; + let (new_r_hat, overflow) = r_hat.overflowing_add(v_n_1); + r_hat = new_r_hat; + // if r_hat overflowed, we're done + if overflow { + break; + } + } + q_hat + } else { + // here q_hat >= q_j >= q_hat - 1 + u64::max_value() + }; + + // ex. 20: + // since q_hat * v_{n-2} <= b * r_hat + u_{j+n-2}, + // either q_hat == q_j, or q_hat == q_j + 1 + + // D4. + // let's assume optimistically q_hat == q_j + // subtract (q_hat * v) from u[j..] + let q_hat_v = v.full_mul_u64(q_hat); + // u[j..] -= q_hat_v; + let c = Self::sub_slice(&mut u[j..], &q_hat_v[..n + 1]); + + // D6. + // actually, q_hat == q_j + 1 and u[j..] has overflowed + // highly unlikely ~ (1 / 2^63) + if c { + q_hat -= 1; + // add v to u[j..] + let c = Self::add_slice(&mut u[j..], &v.0[..n]); + u[j + n] = u[j + n].wrapping_add(u64::from(c)); + } + + // D5. + q[j] = q_hat; + } + + // D8. + let remainder = Self::full_shr_512(u, shift); + // The remainder should never exceed the capacity of Self + debug_assert!( + Self::bits_512(&remainder) <= Self::N_WORDS * Self::WORD_BITS + ); + (q, Self(remainder[..Self::N_WORDS].try_into().unwrap())) + } + + /// Returns a pair `(self / other, self % other)`. + /// + /// # Panics + /// + /// Panics if `other` is zero. + pub fn div_mod_512( + slf: [u64; 2 * Self::N_WORDS], + other: Self, + ) -> ([u64; 2 * Self::N_WORDS], Self) { + let my_bits = Self::bits_512(&slf); + let your_bits = other.bits(); + + assert!(your_bits != 0, "division by zero"); + + // Early return in case we are dividing by a larger number than us + if my_bits < your_bits { + return ( + [0; 2 * Self::N_WORDS], + Self(slf[..Self::N_WORDS].try_into().unwrap()), + ); + } + + if your_bits <= Self::WORD_BITS { + return Self::div_mod_small_512(slf, other.low_u64()); + } + + let (n, m) = { + let my_words = Self::words(my_bits); + let your_words = Self::words(your_bits); + (your_words, my_words - your_words) + }; + + Self::div_mod_knuth_512(slf, other, n, m) + } + + /// Returns a pair `(Some((self * num) / denom), (self * num) % denom)` if + /// the quotient fits into Self. Otherwise `(None, (self * num) % denom)` is + /// returned. + /// + /// # Panics + /// + /// Panics if `denom` is zero. + pub fn checked_mul_div( + &self, + num: Self, + denom: Self, + ) -> (Option, Self) { + let prod = uint::uint_full_mul_reg!(Uint, 4, self, num); + let (quotient, remainder) = Self::div_mod_512(prod, denom); + // The compiler WILL NOT inline this if you remove this annotation. + #[inline(always)] + fn any_nonzero(arr: &[u64]) -> bool { + use uint::unroll; + unroll! { + for i in 0..4 { + if arr[i] != 0 { + return true; + } + } + } + + false + } + ( + if any_nonzero("ient[Self::N_WORDS..]) { + None + } else { + Some(Self(quotient[0..Self::N_WORDS].try_into().unwrap())) + }, + remainder, + ) + } + + /// Returns a pair `((self * num) / denom, (self * num) % denom)`. + /// + /// # Panics + /// + /// Panics if `denom` is zero. + pub fn mul_div(&self, num: Self, denom: Self) -> (Self, Self) { + let prod = uint::uint_full_mul_reg!(Uint, 4, self, num); + let (quotient, remainder) = Self::div_mod_512(prod, denom); + ( + Self(quotient[0..Self::N_WORDS].try_into().unwrap()), + remainder, + ) + } } construct_uint! { @@ -171,10 +433,9 @@ impl Uint { /// * `self` * 10^(`denom`) overflows 256 bits /// * `other` is zero (`checked_div` will return `None`). pub fn fixed_precision_div(&self, rhs: &Self, denom: u8) -> Option { - let lhs = Uint::from(10) + Uint::from(10) .checked_pow(Uint::from(denom)) - .and_then(|res| res.checked_mul(*self))?; - lhs.checked_div(*rhs) + .and_then(|res| res.checked_mul_div(*self, *rhs).0) } /// Compute the two's complement of a number. @@ -710,4 +971,39 @@ mod test_uint { let amount: Result = serde_json::from_str(r#""1000000000.2""#); assert!(amount.is_err()); } + + #[test] + fn test_mul_div() { + use std::str::FromStr; + let a: Uint = Uint::from_str( + "0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + ).unwrap(); + let b: Uint = Uint::from_str( + "0x8000000000000000000000000000000000000000000000000000000000000000", + ).unwrap(); + let c: Uint = Uint::from_str( + "0x4000000000000000000000000000000000000000000000000000000000000000", + ).unwrap(); + let d: Uint = Uint::from_str( + "0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + ).unwrap(); + let e: Uint = Uint::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000001", + ).unwrap(); + let f: Uint = Uint::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000000", + ).unwrap(); + assert_eq!(a.mul_div(a, a), (a, Uint::zero())); + assert_eq!(b.mul_div(c, b), (c, Uint::zero())); + assert_eq!(a.mul_div(c, b), (d, c)); + assert_eq!(a.mul_div(e, e), (a, Uint::zero())); + assert_eq!(e.mul_div(c, b), (Uint::zero(), c)); + assert_eq!(f.mul_div(a, e), (Uint::zero(), Uint::zero())); + assert_eq!(a.checked_mul_div(a, a), (Some(a), Uint::zero())); + assert_eq!(b.checked_mul_div(c, b), (Some(c), Uint::zero())); + assert_eq!(a.checked_mul_div(c, b), (Some(d), c)); + assert_eq!(a.checked_mul_div(e, e), (Some(a), Uint::zero())); + assert_eq!(e.checked_mul_div(c, b), (Some(Uint::zero()), c)); + assert_eq!(d.checked_mul_div(a, e), (None, Uint::zero())); + } }