From 805a21cfa374db2de8d24ff30270b33f81397469 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 22 May 2023 13:45:15 -0500 Subject: [PATCH] feat: use strict ec ops more often (#45) * `msm` implementations now always use `ec_{add,sub}_unequal` in strict mode for safety * Add docs to `scalar_multiply` and a flag to specify when it's safe to turn off some strict assumptions --- halo2-ecc/src/ecc/ecdsa.rs | 2 +- halo2-ecc/src/ecc/mod.rs | 34 +++++++++++++++++++++++----------- halo2-ecc/src/ecc/pippenger.rs | 6 ++---- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index 9bdf8454..0f37a71a 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -57,6 +57,7 @@ where u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, + true, // we can call it with scalar_is_safe = true because of the u2_small check below ); // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey @@ -77,7 +78,6 @@ where let x1 = scalar_chip.enforce_less_than(ctx, sum.x); let equal_check = big_is_equal::assign(base_chip.gate(), ctx, x1.0, r); - // TODO: maybe the big_less_than is optional? let u1_small = big_less_than::assign( base_chip.range(), ctx, diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 0886db1c..8b3895f1 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -287,7 +287,7 @@ where // y_3 = lambda (x - x_3) - y (mod p) /// # Assumptions /// * `P.y != 0` -/// * `P` is not the point at infinity +/// * `P` is not the point at infinity (undefined behavior otherwise) pub fn ec_double>( chip: &FC, ctx: &mut Context, @@ -463,14 +463,15 @@ where StrictEcPoint::new(x, y) } -// computes [scalar] * P on short Weierstrass curve `y^2 = x^3 + b` -// - `scalar` is represented as a reference array of `AssignedValue`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: +/// Computes `[scalar] * P` on short Weierstrass curve `y^2 = x^3 + b` +/// - `scalar` is represented as a reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// /// # Assumptions /// * `P` is not the point at infinity -/// * `scalar` is less than the order of `P` +/// * `scalar > 0` +/// * If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) /// * `scalar_i < 2^{max_bits} for all i` /// * `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` pub fn scalar_multiply( @@ -480,6 +481,7 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint where FC: FieldChip + Selectable, @@ -530,7 +532,7 @@ where let double = ec_double(chip, ctx, &P); cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe); cached_points.push(new_point); } } @@ -555,7 +557,7 @@ where &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, false); + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe); let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = @@ -688,7 +690,7 @@ where ctx, &rand_start_vec[idx], &rand_start_vec[idx + window_bits], - false, + true, // not necessary if we assume (2^w - 1) * A != +- A, but put in for safety ); let point = into_strict_point(chip, ctx, point.clone()); let neg_mult_rand_start = into_strict_point(chip, ctx, neg_mult_rand_start); @@ -1002,6 +1004,7 @@ where ec_select(self.field_chip, ctx, P, Q, condition) } + /// See [`scalar_multiply`] for more details. pub fn scalar_mult( &self, ctx: &mut Context, @@ -1009,8 +1012,17 @@ where scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint { - scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) + scalar_multiply::( + self.field_chip, + ctx, + P, + scalar, + max_bits, + window_bits, + scalar_is_safe, + ) } // default for most purposes diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 88d22868..58e7c739 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -213,7 +213,6 @@ where /// * `scalars[i].len() == scalars[j].len()` for all `i, j` /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) -/// * `2^max_scalar_bits != +-1 mod modulus::()` where `max_scalar_bits = max_scalar_bits_per_cell * scalars[0].len()` /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point pub fn multi_exp_par( chip: &FC, @@ -337,7 +336,7 @@ where // let any_point = (2^num_rounds - 1) * any_base // TODO: can we remove all these random point operations somehow? let mut any_point = ec_double(chip, ctx, any_points.last().unwrap()); - any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], false); + any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true); // compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point // (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1) @@ -351,8 +350,7 @@ where } any_sum = ec_double(chip, ctx, any_sum); - // assume 2^scalar_bits != +-1 mod modulus::() - any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false); + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true); ec_sub_strict(chip, ctx, sum, any_sum) }