Skip to content

Commit

Permalink
Use new types to validate input assumptions (#43)
Browse files Browse the repository at this point in the history
* feat: add new types `ProperUint` and `ProperCrtUint`

To guard around assumptions about big integer representations

* fix: remove unused `FixedAssignedCRTInteger`

* feat: use new types for bigint and field chips

New types now guard for different assumptions on non-native bigint
arithmetic. Distinguish between:
- Overflow CRT integers
- Proper BigUint with native part derived from limbs
- Field elements where inequality < modulus is checked

Also add type to help guard for inequality check in
ec_add_unequal_strict

Rust traits did not play so nicely with references, so I had to switch
many functions to move inputs instead of borrow by reference. However to
avoid writing `clone` everywhere, we allow conversion `From` reference
to the new type via cloning.

* feat: use `ProperUint` for `big_less_than`

* feat(ecc): add fns for assign private witness points

that constrain point to lie on curve

* fix: unnecessary lifetimes

* chore: remove clones
  • Loading branch information
jonathanpwang committed May 23, 2023
1 parent 01a8ac9 commit 8e9032c
Show file tree
Hide file tree
Showing 39 changed files with 1,958 additions and 1,645 deletions.
13 changes: 12 additions & 1 deletion halo2-base/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ use num_traits::{One, Zero};
pub trait BigPrimeField: ScalarField {
/// Converts a slice of [u64] to [BigPrimeField]
/// * `val`: the slice of u64
/// Assumes val.len() <= 4
///
/// # Assumptions
/// * `val` has the correct length for the implementation
/// * The integer value of `val` is already less than the modulus of `Self`
fn from_u64_digits(val: &[u64]) -> Self;
}
#[cfg(feature = "halo2-axiom")]
Expand Down Expand Up @@ -139,6 +142,9 @@ pub fn power_of_two<F: BigPrimeField>(n: usize) -> F {

/// Converts an immutable reference to [BigUint] to a [BigPrimeField].
/// * `e`: immutable reference to [BigUint]
///
/// # Assumptions:
/// * `e` is less than the modulus of `F`
pub fn biguint_to_fe<F: BigPrimeField>(e: &BigUint) -> F {
#[cfg(feature = "halo2-axiom")]
{
Expand All @@ -154,6 +160,9 @@ pub fn biguint_to_fe<F: BigPrimeField>(e: &BigUint) -> F {

/// Converts an immutable reference to [BigInt] to a [BigPrimeField].
/// * `e`: immutable reference to [BigInt]
///
/// # Assumptions:
/// * The absolute value of `e` is less than the modulus of `F`
pub fn bigint_to_fe<F: BigPrimeField>(e: &BigInt) -> F {
#[cfg(feature = "halo2-axiom")]
{
Expand Down Expand Up @@ -240,6 +249,8 @@ pub fn decompose_fe_to_u64_limbs<F: ScalarField>(
/// * `e`: immutable reference to [BigInt] to decompose
/// * `num_limbs`: number of limbs to decompose `e` into
/// * `bit_len`: number of bits in each limb
///
/// Truncates to `num_limbs` limbs if `e` is too large.
pub fn decompose_biguint<F: BigPrimeField>(
e: &BigUint,
num_limbs: usize,
Expand Down
2 changes: 1 addition & 1 deletion halo2-ecc/benches/fp_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fn fp_mul_bench<F: PrimeField>(
let range = RangeChip::<F>::default(lookup_bits);
let chip = FpChip::<F, Fq>::new(&range, limb_bits, num_limbs);

let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, FpChip::<F, Fq>::fe_to_witness(&x)));
let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x));
for _ in 0..2857 {
chip.mul(ctx, &a, &b);
}
Expand Down
6 changes: 4 additions & 2 deletions halo2-ecc/benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ fn msm_bench(
let ctx = builder.main(0);
let scalars_assigned =
scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::<Vec<_>>();
let bases_assigned =
bases.iter().map(|base| ecc_chip.load_private(ctx, (base.x, base.y))).collect::<Vec<_>>();
let bases_assigned = bases
.iter()
.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y)))
.collect::<Vec<_>>();

ecc_chip.variable_base_msm_in::<G1Affine>(
builder,
Expand Down
22 changes: 11 additions & 11 deletions halo2-ecc/src/bigint/add_no_carry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ use std::cmp::max;
pub fn assign<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &OverflowInteger<F>,
b: &OverflowInteger<F>,
a: OverflowInteger<F>,
b: OverflowInteger<F>,
) -> OverflowInteger<F> {
let out_limbs = a
.limbs
.iter()
.zip_eq(b.limbs.iter())
.map(|(&a_limb, &b_limb)| gate.add(ctx, a_limb, b_limb))
.into_iter()
.zip_eq(b.limbs)
.map(|(a_limb, b_limb)| gate.add(ctx, a_limb, b_limb))
.collect();

OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1)
OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1)
}

/// # Assumptions
Expand All @@ -27,11 +27,11 @@ pub fn assign<F: ScalarField>(
pub fn crt<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &CRTInteger<F>,
b: &CRTInteger<F>,
a: CRTInteger<F>,
b: CRTInteger<F>,
) -> CRTInteger<F> {
let out_trunc = assign::<F>(gate, ctx, &a.truncation, &b.truncation);
let out_trunc = assign(gate, ctx, a.truncation, b.truncation);
let out_native = gate.add(ctx, a.native, b.native);
let out_val = &a.value + &b.value;
CRTInteger::construct(out_trunc, out_native, out_val)
let out_val = a.value + b.value;
CRTInteger::new(out_trunc, out_native, out_val)
}
41 changes: 11 additions & 30 deletions halo2-ecc/src/bigint/big_is_equal.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::{CRTInteger, OverflowInteger};
use super::ProperUint;
use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context};
use itertools::Itertools;

/// Given OverflowInteger<F>'s `a` and `b` of the same shape,
/// Given [`ProperUint`]s `a` and `b` with the same number of limbs,
/// returns whether `a == b`.
///
/// # Assumptions:
Expand All @@ -11,38 +11,19 @@ use itertools::Itertools;
pub fn assign<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &OverflowInteger<F>,
b: &OverflowInteger<F>,
a: impl Into<ProperUint<F>>,
b: impl Into<ProperUint<F>>,
) -> AssignedValue<F> {
debug_assert!(!a.limbs.is_empty());
let a = a.into();
let b = b.into();
debug_assert!(!a.0.is_empty());

let mut a_limbs = a.limbs.iter();
let mut b_limbs = b.limbs.iter();
let mut partial = gate.is_equal(ctx, *a_limbs.next().unwrap(), *b_limbs.next().unwrap());
for (&a_limb, &b_limb) in a_limbs.zip_eq(b_limbs) {
let mut a_limbs = a.0.into_iter();
let mut b_limbs = b.0.into_iter();
let mut partial = gate.is_equal(ctx, a_limbs.next().unwrap(), b_limbs.next().unwrap());
for (a_limb, b_limb) in a_limbs.zip_eq(b_limbs) {
let eq_limb = gate.is_equal(ctx, a_limb, b_limb);
partial = gate.and(ctx, eq_limb, partial);
}
partial
}

pub fn wrapper<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &CRTInteger<F>,
b: &CRTInteger<F>,
) -> AssignedValue<F> {
assign(gate, ctx, &a.truncation, &b.truncation)
}

pub fn crt<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &CRTInteger<F>,
b: &CRTInteger<F>,
) -> AssignedValue<F> {
debug_assert_eq!(a.value, b.value);
let out_trunc = assign::<F>(gate, ctx, &a.truncation, &b.truncation);
let out_native = gate.is_equal(ctx, a.native, b.native);
gate.and(ctx, out_trunc, out_native)
}
34 changes: 19 additions & 15 deletions halo2-ecc/src/bigint/big_is_zero.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{CRTInteger, OverflowInteger};
use super::{OverflowInteger, ProperCrtUint, ProperUint};
use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context};

/// # Assumptions
Expand All @@ -8,42 +8,46 @@ use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Con
pub fn positive<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &OverflowInteger<F>,
a: OverflowInteger<F>,
) -> AssignedValue<F> {
let k = a.limbs.len();
debug_assert_ne!(k, 0);
debug_assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY);
assert_ne!(k, 0);
assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY);

let sum = gate.sum(ctx, a.limbs.iter().copied());
let sum = gate.sum(ctx, a.limbs);
gate.is_zero(ctx, sum)
}

/// Given OverflowInteger<F> `a`, returns whether `a == 0`
/// Given ProperUint<F> `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise.
///
/// It is almost always more efficient to use [`positive`] instead.
///
/// # Assumptions
/// * `a` has nonzero number of limbs
pub fn assign<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &OverflowInteger<F>,
a: ProperUint<F>,
) -> AssignedValue<F> {
debug_assert!(!a.limbs.is_empty());
assert!(!a.0.is_empty());

let mut a_limbs = a.limbs.iter();
let mut partial = gate.is_zero(ctx, *a_limbs.next().unwrap());
for &a_limb in a_limbs {
let mut a_limbs = a.0.into_iter();
let mut partial = gate.is_zero(ctx, a_limbs.next().unwrap());
for a_limb in a_limbs {
let limb_is_zero = gate.is_zero(ctx, a_limb);
partial = gate.and(ctx, limb_is_zero, partial);
}
partial
}

/// Returns 0 or 1. Returns 1 iff the limbs of `a` are identically zero.
/// This just calls [`assign`] on the limbs.
///
/// It is almost always more efficient to use [`positive`] instead.
pub fn crt<F: ScalarField>(
gate: &impl GateInstructions<F>,
ctx: &mut Context<F>,
a: &CRTInteger<F>,
a: ProperCrtUint<F>,
) -> AssignedValue<F> {
let out_trunc = assign::<F>(gate, ctx, &a.truncation);
let out_native = gate.is_zero(ctx, a.native);
gate.and(ctx, out_trunc, out_native)
assign(gate, ctx, ProperUint(a.0.truncation.limbs))
}
8 changes: 4 additions & 4 deletions halo2-ecc/src/bigint/big_less_than.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use super::OverflowInteger;
use super::ProperUint;
use halo2_base::{gates::RangeInstructions, utils::ScalarField, AssignedValue, Context};

// given OverflowInteger<F>'s `a` and `b` of the same shape,
// returns whether `a < b`
pub fn assign<F: ScalarField>(
range: &impl RangeInstructions<F>,
ctx: &mut Context<F>,
a: &OverflowInteger<F>,
b: &OverflowInteger<F>,
a: impl Into<ProperUint<F>>,
b: impl Into<ProperUint<F>>,
limb_bits: usize,
limb_base: F,
) -> AssignedValue<F> {
// a < b iff a - b has underflow
let (_, underflow) = super::sub::assign::<F>(range, ctx, a, b, limb_bits, limb_base);
let (_, underflow) = super::sub::assign(range, ctx, a, b, limb_bits, limb_base);
underflow
}
41 changes: 19 additions & 22 deletions halo2-ecc/src/bigint/carry_mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{check_carry_to_zero, CRTInteger, OverflowInteger};
use std::{cmp::max, iter};

use halo2_base::{
gates::{range::RangeStrategy, GateInstructions, RangeInstructions},
utils::{decompose_bigint, BigPrimeField},
Expand All @@ -8,7 +9,8 @@ use halo2_base::{
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::{One, Signed};
use std::{cmp::max, iter};

use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint};

// Input `a` is `CRTInteger` with `a.truncation` of length `k` with "signed" limbs
// Output is `out = a (mod modulus)` as CRTInteger with
Expand All @@ -29,15 +31,15 @@ pub fn crt<F: BigPrimeField>(
range: &impl RangeInstructions<F>,
// chip: &BigIntConfig<F>,
ctx: &mut Context<F>,
a: &CRTInteger<F>,
a: CRTInteger<F>,
k_bits: usize, // = a.len().bits()
modulus: &BigInt,
mod_vec: &[F],
mod_native: F,
limb_bits: usize,
limb_bases: &[F],
limb_base_big: &BigInt,
) -> CRTInteger<F> {
) -> ProperCrtUint<F> {
let n = limb_bits;
let k = a.truncation.limbs.len();
let trunc_len = n * k;
Expand Down Expand Up @@ -96,8 +98,8 @@ pub fn crt<F: BigPrimeField>(

// strategies where we carry out school-book multiplication in some form:
// BigIntStrategy::Simple => {
for (i, (a_limb, (quot_v, out_v))) in
a.truncation.limbs.iter().zip(quot_vec.into_iter().zip(out_vec.into_iter())).enumerate()
for (i, ((a_limb, quot_v), out_v)) in
a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate()
{
let (prod, new_quot_cell) = range.gate().inner_product_left_last(
ctx,
Expand All @@ -120,7 +122,7 @@ pub fn crt<F: BigPrimeField>(
ctx.assign_region(
[
Constant(-F::one()),
Existing(*a_limb),
Existing(a_limb),
Witness(temp1),
Constant(F::one()),
Witness(out_v),
Expand Down Expand Up @@ -156,7 +158,7 @@ pub fn crt<F: BigPrimeField>(
range.range_check(ctx, quot_shift, limb_bits + 1);
}

let check_overflow_int = OverflowInteger::construct(
let check_overflow_int = OverflowInteger::new(
check_assigned,
max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits),
);
Expand All @@ -172,21 +174,12 @@ pub fn crt<F: BigPrimeField>(
);

// Constrain `quot_native = sum_i quot_assigned[i] * 2^{n*i}` in `F`
let quot_native = OverflowInteger::<F>::evaluate(
range.gate(),
ctx,
quot_assigned,
limb_bases.iter().copied(),
);
let quot_native =
OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases);

// Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F`
let out_native = OverflowInteger::<F>::evaluate(
range.gate(),
ctx,
out_assigned.iter().copied(),
limb_bases.iter().copied(),
);

let out_native =
OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases);
// We save 1 cell by connecting `out_native` computation with the following:

// Check `out + modulus * quotient - a = 0` in native field
Expand All @@ -196,5 +189,9 @@ pub fn crt<F: BigPrimeField>(
[-1], // negative index because -1 relative offset is `out_native` assigned value
);

CRTInteger::construct(OverflowInteger::construct(out_assigned, limb_bits), out_native, out_val)
ProperCrtUint(CRTInteger::new(
ProperUint(out_assigned).into_overflow(limb_bits),
out_native,
out_val,
))
}
Loading

0 comments on commit 8e9032c

Please sign in to comment.