Skip to content

Commit

Permalink
Merge reimplementation of cbrt into trunk
Browse files Browse the repository at this point in the history
  • Loading branch information
akubera committed Oct 28, 2024
2 parents 3c58a3e + ea53472 commit 1e96272
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 69 deletions.
176 changes: 112 additions & 64 deletions src/arithmetic/cbrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,107 @@
use crate::*;
use num_bigint::BigUint;
use rounding::NonDigitRoundingData;
use stdlib::num::NonZeroU64;


/// implementation of cuberoot - always positive
pub(crate) fn impl_cbrt_uint_scale(n: &BigUint, scale: i64, ctx: &Context) -> BigDecimal {
// make guess based on number of bits in the number
let guess = make_cbrt_guess(n.bits(), scale);
pub(crate) fn impl_cbrt_int_scale(n: &BigInt, scale: i64, ctx: &Context) -> BigDecimal {
let rounding_data = NonDigitRoundingData {
sign: n.sign(),
mode: ctx.rounding_mode(),
};

let three = BigInt::from(3);
impl_cbrt_uint_scale((n.magnitude(), scale).into(), ctx.precision(), rounding_data)
}

let n = BigInt::from_biguint(Sign::Plus, n.clone());
/// implementation of cuberoot - always positive
pub(crate) fn impl_cbrt_uint_scale(
n: WithScale<&BigUint>,
precision: NonZeroU64,
// contains sign and rounding mode
rounding_data: NonDigitRoundingData,
) -> BigDecimal {
if n.is_zero() {
let biguint = BigInt::from_biguint(Sign::Plus, n.value.clone());
return BigDecimal::new(biguint, n.scale / 3);
}

let max_precision = ctx.precision().get();
// count number of digits in the decimal
let integer_digit_count = count_decimal_digits_uint(n.value);

let next_iteration = move |r: BigDecimal| {
let sqrd = r.square();
let tmp = impl_division(
n.clone(),
&sqrd.int_val,
scale - sqrd.scale,
max_precision + 1,
);
let tmp = tmp + r.double();
impl_division(tmp.int_val, &three, tmp.scale, max_precision + 3)
};
// extra digits to use for rounding
let extra_rounding_digit_count = 4;

// result initial
let mut running_result = next_iteration(guess);
// required number of digits for precision and rounding
let required_precision = precision.get() + extra_rounding_digit_count;
let required_precision = 3 * required_precision;

let mut prev_result = BigDecimal::one();
let mut result = BigDecimal::zero();
// number of extra zeros to add to end of integer_digits
let mut exp_shift = required_precision.saturating_sub(integer_digit_count as u64);

// TODO: Prove that we don't need to arbitrarily limit iterations
// and that convergence can be calculated
while prev_result != result {
// store current result to test for convergence
prev_result = result;
// effective scale after multiplying by 10^exp_shift
// (we've added that many trailing zeros after)
let shifted_scale = n.scale + exp_shift as i64;

running_result = next_iteration(running_result);
let (mut new_scale, remainder) = shifted_scale.div_rem(&3);

// result has clipped precision, running_result has full precision
result = if running_result.digits() > max_precision {
running_result.with_precision_round(ctx.precision(), ctx.rounding_mode())
} else {
running_result.clone()
};
if remainder > 0 {
new_scale += 1;
exp_shift += (3 - remainder) as u64;
} else if remainder < 0 {
exp_shift += remainder.neg() as u64;
}

return result;
}
// clone-on-write copy of digits
let mut integer_digits = stdlib::borrow::Cow::Borrowed(n.value);

/// Find initial cbrt guess based on number of bits in integer and the scale
///
/// ```math
/// 2^bit_count * 10^-scale <= *n* < 2^(bit_count+1) * 10^-scale
///
/// cbrt(n2^bit_count * 10^-scale)
/// cbrt(2^bit_count * 10^-scale)
/// => Exp2[1/3 * Log2[2^bit_count * 10^-scale]]
/// => Exp2[1/3 * (bit_count - scale * Log2[10]]
/// ```
///
fn make_cbrt_guess(bit_count: u64, scale: i64) -> BigDecimal {
// weight of cube root average above minimum within range: 3/4*2^(4/3)
let magic_guess_scale = 1.1398815748423097_f64;

let bit_count = bit_count as f64;
let scale = scale as f64;

let initial_guess = (bit_count - scale * LOG2_10) / 3.0;
let res = magic_guess_scale * exp2(initial_guess);

match BigDecimal::try_from(res) {
Ok(res) => res,
Err(_) => {
// can't guess with float - just guess magnitude
let scale = (scale - bit_count / LOG2_10).round() as i64;
BigDecimal::new(BigInt::from(1), scale / 3)
}
// add required trailing zeros to integer_digits
if exp_shift > 0 {
arithmetic::multiply_by_ten_to_the_uint(
integer_digits.to_mut(), exp_shift
);
}

let result_digits = integer_digits.nth_root(3);
let result_digits_count = count_decimal_digits_uint(&result_digits);
debug_assert!(result_digits_count >= precision.get() + 1);

let digits_to_trim = result_digits_count - precision.get();
debug_assert_ne!(digits_to_trim, 0);
debug_assert!((result_digits_count as i64 - count_decimal_digits_uint(&integer_digits) as i64 / 3).abs() < 2);

new_scale -= digits_to_trim as i64;

let divisor = ten_to_the_uint(digits_to_trim);
let (mut result_digits, remainder) = result_digits.div_rem(&divisor);

let remainder_digits = remainder.to_radix_le(10);
let insig_digit0;
let trailing_digits;
if remainder_digits.len() < digits_to_trim as usize {
// leading zeros
insig_digit0 = 0;
trailing_digits = remainder_digits.as_slice();
} else {
let (&d, rest) = remainder_digits.split_last().unwrap();
insig_digit0 = d;
trailing_digits = rest;
}

let insig_data = rounding::InsigData::from_digit_and_lazy_trailing_zeros(
rounding_data, insig_digit0, || { trailing_digits.iter().all(Zero::is_zero) }
);

// lowest digit to round
let sig_digit = (&result_digits % 10u8).to_u8().unwrap();
let rounded_digit = insig_data.round_digit(sig_digit);

let rounding_term = rounded_digit - sig_digit;
result_digits += rounding_term;

let result = BigInt::from_biguint(rounding_data.sign, result_digits);

BigDecimal::new(result, new_scale)
}


Expand Down Expand Up @@ -128,6 +150,32 @@ mod test {
impl_test!(case_prec15_down_10; prec=15; round=Down; "10" => "2.15443469003188");
impl_test!(case_prec6_up_0d979970546636727; prec=6; round=Up; "0.979970546636727" => "0.993279");

impl_test!(case_1037d495615705321421375_full; "1037.495615705321421375" => "10.123455");
impl_test!(case_1037d495615705321421375_prec7_halfdown; prec=7; round=HalfDown; "1037.495615705321421375" => "10.12345");
impl_test!(case_1037d495615705321421375_prec7_halfeven; prec=7; round=HalfEven; "1037.495615705321421375" => "10.12346");
impl_test!(case_1037d495615705321421375_prec7_halfup; prec=7; round=HalfUp; "1037.495615705321421375" => "10.12346");

impl_test!(case_0d014313506928855520728400001_full; "0.014313506928855520728400001" => "0.242800001");
impl_test!(case_0d014313506928855520728400001_prec6_down; prec=6; round=Down; "0.014313506928855520728400001" => "0.242800");
impl_test!(case_0d014313506928855520728400001_prec6_up; prec=6; round=Up; "0.014313506928855520728400001" => "0.242801");

impl_test!(case_4151902e20_prec16_halfup; prec=16; round=HalfUp; "4151902e20" => "746017527.6855992");
impl_test!(case_4151902e20_prec16_up; prec=16; round=Up; "4151902e20" => "746017527.6855993");
impl_test!(case_4151902e20_prec17_up; prec=17; round=Up; "4151902e20" => "746017527.68559921");
impl_test!(case_4151902e20_prec18_up; prec=18; round=Up; "4151902e20" => "746017527.685599209");
// impl_test!(case_4151902e20_prec18_up; prec=18; round=Up; "4151902e20" => "746017527.685599209");

impl_test!(case_1850846e201_prec14_up; prec=16; round=Up; "1850846e201" => "1.227788123885769e69");

impl_test!(case_6d3797558642427987505823530913e85_prec16_up; prec=160; round=Up; "6.3797558642427987505823530913E+85" => "3995778017e19");

impl_test!(case_88573536600476899341824_prec20_up; prec=20; round=Up; "88573536600476899341824" => "44576024");
impl_test!(case_88573536600476899341824_prec7_up; prec=7; round=Up; "88573536600476899341824" => "4457603e1");

impl_test!(case_833636d150970875_prec5_up; prec=5; round=Up; "833636.150970875" => "94.115000");
impl_test!(case_833636d150970875_prec5_halfup; prec=5; round=HalfUp; "833636.150970875" => "94.115");
impl_test!(case_833636d150970875_prec4_halfup; prec=4; round=HalfUp; "833636.150970875" => "94.12");
impl_test!(case_833636d150970875_prec20_up; prec=20; round=Up; "833636.150970875" => "94.115000");

#[cfg(property_tests)]
mod prop {
Expand Down
12 changes: 12 additions & 0 deletions src/arithmetic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ pub(crate) fn ten_to_the_uint(pow: u64) -> BigUint {
}
}

pub(crate) fn multiply_by_ten_to_the_uint<T, P>(n: &mut T, pow: P)
where T: MulAssign<u64> + MulAssign<BigUint>,
P: ToPrimitive
{
let pow = pow.to_u64().expect("exponent overflow error");
if pow < 20 {
*n *= 10u64.pow(pow as u32);
} else {
*n *= ten_to_the_uint(pow);
}

}

/// Return number of decimal digits in integer
pub(crate) fn count_decimal_digits(int: &BigInt) -> u64 {
Expand Down
66 changes: 61 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,11 +761,7 @@ impl BigDecimal {
return self.clone();
}

let uint = self.int_val.magnitude();
let result = arithmetic::cbrt::impl_cbrt_uint_scale(uint, self.scale, ctx);

// always copy sign
result.take_with_sign(self.sign())
arithmetic::cbrt::impl_cbrt_int_scale(&self.int_val, self.scale, ctx)
}

/// Compute the reciprical of the number: x<sup>-1</sup>
Expand Down Expand Up @@ -1221,6 +1217,19 @@ impl BigDecimalRef<'_> {
count_decimal_digits_uint(self.digits)
}

/// Return the number of trailing zeros in the referenced integer
#[allow(dead_code)]
fn count_trailing_zeroes(&self) -> usize {
if self.digits.is_zero() || self.digits.is_odd() {
return 0;
}

let digit_pairs = self.digits.to_radix_le(100);
let loc = digit_pairs.iter().position(|&d| d != 0).unwrap_or(0);

2 * loc + usize::from(digit_pairs[loc] % 10 == 0)
}

/// Split into components
pub(crate) fn as_parts(&self) -> (Sign, i64, &BigUint) {
(self.sign, self.scale, self.digits)
Expand Down Expand Up @@ -1310,6 +1319,53 @@ impl<'a> From<&'a BigInt> for BigDecimalRef<'a> {
}
}


/// pair i64 'scale' with some other value
#[derive(Clone, Copy)]
struct WithScale<T> {
pub value: T,
pub scale: i64,
}

impl<T: fmt::Debug> fmt::Debug for WithScale<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "(scale={} {:?})", self.scale, self.value)
}
}

impl<T> From<(T, i64)> for WithScale<T> {
fn from(pair: (T, i64)) -> Self {
Self { value: pair.0, scale: pair.1 }
}
}

impl<'a> From<WithScale<&'a BigInt>> for BigDecimalRef<'a> {
fn from(obj: WithScale<&'a BigInt>) -> Self {
Self {
scale: obj.scale,
sign: obj.value.sign(),
digits: obj.value.magnitude(),
}
}
}

impl<'a> From<WithScale<&'a BigUint>> for BigDecimalRef<'a> {
fn from(obj: WithScale<&'a BigUint>) -> Self {
Self {
scale: obj.scale,
sign: Sign::Plus,
digits: obj.value,
}
}
}

impl<T: Zero> WithScale<&T> {
fn is_zero(&self) -> bool {
self.value.is_zero()
}
}


#[rustfmt::skip]
#[cfg(test)]
#[allow(non_snake_case)]
Expand Down

0 comments on commit 1e96272

Please sign in to comment.