Skip to content

Commit

Permalink
feat: use strict ec ops more often (#45)
Browse files Browse the repository at this point in the history
* `msm` implementations now always use `ec_{add,sub}_unequal` in strict
mode for safety
* Add docs to `scalar_multiply` and a flag to specify when it's safe to
  turn off some strict assumptions
  • Loading branch information
jonathanpwang committed May 23, 2023
1 parent 2c276b4 commit 805a21c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
2 changes: 1 addition & 1 deletion halo2-ecc/src/ecc/ecdsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ where
u2.limbs().to_vec(),
base_chip.limb_bits,
var_window_bits,
true, // we can call it with scalar_is_safe = true because of the u2_small check below
);

// check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey
Expand All @@ -77,7 +78,6 @@ where
let x1 = scalar_chip.enforce_less_than(ctx, sum.x);
let equal_check = big_is_equal::assign(base_chip.gate(), ctx, x1.0, r);

// TODO: maybe the big_less_than is optional?
let u1_small = big_less_than::assign(
base_chip.range(),
ctx,
Expand Down
34 changes: 23 additions & 11 deletions halo2-ecc/src/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ where
// y_3 = lambda (x - x_3) - y (mod p)
/// # Assumptions
/// * `P.y != 0`
/// * `P` is not the point at infinity
/// * `P` is not the point at infinity (undefined behavior otherwise)
pub fn ec_double<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -463,14 +463,15 @@ where
StrictEcPoint::new(x, y)
}

// computes [scalar] * P on short Weierstrass curve `y^2 = x^3 + b`
// - `scalar` is represented as a reference array of `AssignedValue`s
// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
// assumes:
/// Computes `[scalar] * P` on short Weierstrass curve `y^2 = x^3 + b`
/// - `scalar` is represented as a reference array of `AssignedValue`s
/// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
///
/// # Assumptions
/// * `P` is not the point at infinity
/// * `scalar` is less than the order of `P`
/// * `scalar > 0`
/// * If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`)
/// * `scalar_i < 2^{max_bits} for all i`
/// * `max_bits <= modulus::<F>.bits()`, and equality only allowed when the order of `P` equals the modulus of `F`
pub fn scalar_multiply<F: PrimeField, FC>(
Expand All @@ -480,6 +481,7 @@ pub fn scalar_multiply<F: PrimeField, FC>(
scalar: Vec<AssignedValue<F>>,
max_bits: usize,
window_bits: usize,
scalar_is_safe: bool,
) -> EcPoint<F, FC::FieldPoint>
where
FC: FieldChip<F> + Selectable<F, FC::FieldPoint>,
Expand Down Expand Up @@ -530,7 +532,7 @@ where
let double = ec_double(chip, ctx, &P);
cached_points.push(double);
} else {
let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false);
let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe);
cached_points.push(new_point);
}
}
Expand All @@ -555,7 +557,7 @@ where
&rounded_bits
[rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx],
);
let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, false);
let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe);
let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]);

curr_point =
Expand Down Expand Up @@ -688,7 +690,7 @@ where
ctx,
&rand_start_vec[idx],
&rand_start_vec[idx + window_bits],
false,
true, // not necessary if we assume (2^w - 1) * A != +- A, but put in for safety
);
let point = into_strict_point(chip, ctx, point.clone());
let neg_mult_rand_start = into_strict_point(chip, ctx, neg_mult_rand_start);
Expand Down Expand Up @@ -1002,15 +1004,25 @@ where
ec_select(self.field_chip, ctx, P, Q, condition)
}

/// See [`scalar_multiply`] for more details.
pub fn scalar_mult(
&self,
ctx: &mut Context<F>,
P: EcPoint<F, FC::FieldPoint>,
scalar: Vec<AssignedValue<F>>,
max_bits: usize,
window_bits: usize,
scalar_is_safe: bool,
) -> EcPoint<F, FC::FieldPoint> {
scalar_multiply::<F, FC>(self.field_chip, ctx, P, scalar, max_bits, window_bits)
scalar_multiply::<F, FC>(
self.field_chip,
ctx,
P,
scalar,
max_bits,
window_bits,
scalar_is_safe,
)
}

// default for most purposes
Expand Down
6 changes: 2 additions & 4 deletions halo2-ecc/src/ecc/pippenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ where
/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
/// * `points` are all on the curve or the point at infinity
/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point)
/// * `2^max_scalar_bits != +-1 mod modulus::<F>()` where `max_scalar_bits = max_scalar_bits_per_cell * scalars[0].len()`
/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point
pub fn multi_exp_par<F: PrimeField, FC, C>(
chip: &FC,
Expand Down Expand Up @@ -337,7 +336,7 @@ where
// let any_point = (2^num_rounds - 1) * any_base
// TODO: can we remove all these random point operations somehow?
let mut any_point = ec_double(chip, ctx, any_points.last().unwrap());
any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], false);
any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true);

// compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point
// (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1)
Expand All @@ -351,8 +350,7 @@ where
}

any_sum = ec_double(chip, ctx, any_sum);
// assume 2^scalar_bits != +-1 mod modulus::<F>()
any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false);
any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true);

ec_sub_strict(chip, ctx, sum, any_sum)
}

0 comments on commit 805a21c

Please sign in to comment.