Skip to content

Commit

Permalink
division that's constant-time in both arguments
Browse files Browse the repository at this point in the history
This makes `Uint::sqrt` constant-time as well.
  • Loading branch information
HastD committed Sep 20, 2023
1 parent bbce847 commit 5086d13
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
83 changes: 76 additions & 7 deletions src/uint/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,42 @@ impl<const LIMBS: usize> Uint<LIMBS> {
(quo, rem)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
/// and the truthy value for is_some or the falsy value for is_none.
///
/// NOTE: Use only if you need to access const fn. Otherwise use [`Self::div_rem`] because
/// the value for is_some needs to be checked before using `q` and `r`.
///
/// This function is constant-time with respect to both `self` and `rhs`.
pub(crate) const fn const_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
let mb = rhs.bits();
let mut rem = *self;
let mut quo = Self::ZERO;
let mut c = rhs.shl(Self::BITS - mb);

let mut i = Self::BITS;
let mut done = CtChoice::FALSE;
loop {
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
rem = Self::ct_select(&r, &rem, CtChoice::from_mask(borrow.0).or(done));
r = quo.bitor(&Self::ONE);
quo = Self::ct_select(&r, &quo, CtChoice::from_mask(borrow.0).or(done));
if i == 0 {
break;
}
i -= 1;
// when `i < mb`, the computation is actually done, so we ensure `quo` and `rem`
// aren't modified further (but do the remaining iterations anyway to be constant-time)
done = Limb::ct_lt(Limb(i as Word), Limb(mb as Word));
c = c.shr_vartime(1);
quo = Self::ct_select(&quo.shl_vartime(1), &quo, done);
}

let is_some = Limb(mb as Word).ct_is_nonzero();
quo = Self::ct_select(&Self::ZERO, &quo, is_some);
(quo, rem, is_some)
}

/// Computes `self` / `rhs`, returns the quotient (q), remainder (r)
/// and the truthy value for is_some or the falsy value for is_none.
///
Expand All @@ -51,7 +87,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
pub(crate) const fn ct_div_rem(&self, rhs: &Self) -> (Self, Self, CtChoice) {
pub(crate) const fn const_div_rem_vartime(&self, rhs: &Self) -> (Self, Self, CtChoice) {
let mb = rhs.bits_vartime();
let mut bd = Self::BITS - mb;
let mut rem = *self;
Expand Down Expand Up @@ -168,7 +204,14 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Computes self / rhs, returns the quotient, remainder.
pub fn div_rem(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
let (q, r, _c) = self.ct_div_rem(rhs);
let (q, r, _c) = self.const_div_rem(rhs);
(q, r)
}

/// Computes self / rhs, returns the quotient, remainder. Constant-time only for fixed `rhs`.
pub fn div_rem_vartime(&self, rhs: &NonZero<Self>) -> (Self, Self) {
// Since `rhs` is nonzero, this should always hold.
let (q, r, _c) = self.const_div_rem_vartime(rhs);
(q, r)
}

Expand All @@ -185,7 +228,18 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// Panics if `rhs == 0`.
pub const fn wrapping_div(&self, rhs: &Self) -> Self {
let (q, _, c) = self.ct_div_rem(rhs);
let (q, _, c) = self.const_div_rem(rhs);
assert!(c.is_true_vartime(), "divide by zero");
q
}

/// Wrapped division is just normal division i.e. `self` / `rhs`
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
///
/// Panics if `rhs == 0`. Constant-time only for fixed `rhs`.
pub const fn wrapping_div_vartime(&self, rhs: &Self) -> Self {
let (q, _, c) = self.const_div_rem_vartime(rhs);
assert!(c.is_true_vartime(), "divide by zero");
q
}
Expand Down Expand Up @@ -609,7 +663,11 @@ mod tests {
] {
let lhs = U256::from(*n);
let rhs = U256::from(*d);
let (q, r, is_some) = lhs.ct_div_rem(&rhs);
let (q, r, is_some) = lhs.const_div_rem(&rhs);
assert!(is_some.is_true_vartime());
assert_eq!(U256::from(*e), q);
assert_eq!(U256::from(*ee), r);
let (q, r, is_some) = lhs.const_div_rem_vartime(&rhs);
assert!(is_some.is_true_vartime());
assert_eq!(U256::from(*e), q);
assert_eq!(U256::from(*ee), r);
Expand All @@ -625,7 +683,10 @@ mod tests {
let den = U256::random(&mut rng).shr_vartime(128);
let n = num.checked_mul(&den);
if n.is_some().into() {
let (q, _, is_some) = n.unwrap().ct_div_rem(&den);
let (q, _, is_some) = n.unwrap().const_div_rem(&den);
assert!(is_some.is_true_vartime());
assert_eq!(q, num);
let (q, _, is_some) = n.unwrap().const_div_rem_vartime(&den);
assert!(is_some.is_true_vartime());
assert_eq!(q, num);
}
Expand All @@ -647,15 +708,23 @@ mod tests {

#[test]
fn div_zero() {
let (q, r, is_some) = U256::ONE.ct_div_rem(&U256::ZERO);
let (q, r, is_some) = U256::ONE.const_div_rem(&U256::ZERO);
assert!(!is_some.is_true_vartime());
assert_eq!(q, U256::ZERO);
assert_eq!(r, U256::ONE);
let (q, r, is_some) = U256::ONE.const_div_rem_vartime(&U256::ZERO);
assert!(!is_some.is_true_vartime());
assert_eq!(q, U256::ZERO);
assert_eq!(r, U256::ONE);
}

#[test]
fn div_one() {
let (q, r, is_some) = U256::from(10u8).ct_div_rem(&U256::ONE);
let (q, r, is_some) = U256::from(10u8).const_div_rem(&U256::ONE);
assert!(is_some.is_true_vartime());
assert_eq!(q, U256::from(10u8));
assert_eq!(r, U256::ZERO);
let (q, r, is_some) = U256::from(10u8).const_div_rem_vartime(&U256::ONE);
assert!(is_some.is_true_vartime());
assert_eq!(q, U256::from(10u8));
assert_eq!(r, U256::ZERO);
Expand Down
6 changes: 3 additions & 3 deletions src/uint/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
while i < usize::BITS - Self::BITS.leading_zeros() {
guess = xn;
xn = {
let (q, _, is_some) = self.ct_div_rem(&guess);
let (q, _, is_some) = self.const_div_rem(&guess);
let q = Self::ct_select(&Self::ZERO, &q, is_some);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
Expand All @@ -45,7 +45,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let cap = Self::ONE.shl_vartime(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
Expand All @@ -55,7 +55,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
while Uint::ct_gt(&guess, &xn).is_true_vartime() && xn.ct_is_nonzero().is_true_vartime() {
guess = xn;
xn = {
let q = self.wrapping_div(&guess);
let q = self.wrapping_div_vartime(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
Expand Down
3 changes: 2 additions & 1 deletion tests/proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ proptest! {
if !b_bi.is_zero() {
let expected = to_uint(a_bi / b_bi);
let actual = a.wrapping_div(&b);

assert_eq!(expected, actual);
let actual_vartime = a.wrapping_div_vartime(&b);
assert_eq!(expected, actual_vartime);
}
}

Expand Down

0 comments on commit 5086d13

Please sign in to comment.