diff --git a/.gitignore b/.gitignore index 1de56593..f99b23da 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -target \ No newline at end of file +target +.vscode/launch.json diff --git a/src/fns/constrained_ops.nr b/src/fns/constrained_ops.nr index d457cf85..758e23dc 100644 --- a/src/fns/constrained_ops.nr +++ b/src/fns/constrained_ops.nr @@ -64,7 +64,8 @@ pub(crate) fn derive_from_seed(lhs: [Field; N], rhs: [ // a - b = r // p + a - b - r = 0 let (result, carry_flags, borrow_flags) = unsafe { __validate_gt_remainder(lhs, rhs) }; - validate_in_range::<_, MOD_BITS>(result); let borrow_shift = 0x1000000000000000000000000000000; @@ -278,12 +279,14 @@ pub(crate) fn validate_gt(lhs: [Field; N], rhs: [ + (borrow_flags[0] as Field * borrow_shift) - (carry_flags[0] as Field * carry_shift); assert(result_limb == 0); + for i in 1..N - 1 { let result_limb = lhs[i] - rhs[i] + addend[i] - result[i] - borrow_flags[i - 1] as Field + carry_flags[i - 1] as Field + ((borrow_flags[i] as Field - carry_flags[i] as Field) * borrow_shift); assert(result_limb == 0); } + let result_limb = lhs[N - 1] - rhs[N - 1] + addend[N - 1] - result[N - 1] - borrow_flags[N - 2] as Field diff --git a/src/fns/unconstrained_helpers.nr b/src/fns/unconstrained_helpers.nr index d19c2046..0809e4bf 100644 --- a/src/fns/unconstrained_helpers.nr +++ b/src/fns/unconstrained_helpers.nr @@ -68,7 +68,6 @@ pub(crate) unconstrained fn __validate_gt_remainder( borrow_flags[i / 2] = borrow as bool; } } - let result = U60Repr::into(result_u60); (result, carry_flags, borrow_flags) } @@ -208,12 +207,14 @@ pub(crate) unconstrained fn __barrett_reduction( modulus: [Field; N], modulus_u60: U60Repr, ) -> ([Field; N], [Field; N]) { + // for each i in 0..(N + N), adds x[i] * redc_param[j] to mulout[i + j] for each j in 0..N let mut mulout: [Field; 3 * N] = [0; 3 * N]; for i in 0..(N + N) { for j in 0..N { mulout[i + j] += x[i] * redc_param[j]; } } + mulout = split_bits::__normalize_limbs(mulout, 3 * N - 1); let mulout_u60: U60Repr = U60Repr::new(mulout); @@ -259,13 +260,22 @@ pub(crate) unconstrained fn __barrett_reduction( } let quotient_mul_modulus_u60: U60Repr = U60Repr::new(quotient_mul_modulus_normalized); + // convert the input into U60Repr let x_u60: U60Repr = U60Repr::new(x); let mut remainder_u60 = x_u60 - quotient_mul_modulus_u60; - + // barrett reduction is quircky so might need to remove a few modulus_u60 from the remainder if (remainder_u60.gte(modulus_u60)) { remainder_u60 = remainder_u60 - modulus_u60; quotient_u60.increment(); } else {} + if (remainder_u60.gte(modulus_u60)) { + remainder_u60 = remainder_u60 - modulus_u60; + quotient_u60.increment(); + } + if (remainder_u60.gte(modulus_u60)) { + remainder_u60 = remainder_u60 - modulus_u60; + quotient_u60.increment(); + } let q: [Field; N] = U60Repr::into(quotient_u60); let r: [Field; N] = U60Repr::into(remainder_u60); diff --git a/src/fns/unconstrained_ops.nr b/src/fns/unconstrained_ops.nr index f3d6ebf1..68e12df0 100644 --- a/src/fns/unconstrained_ops.nr +++ b/src/fns/unconstrained_ops.nr @@ -1,11 +1,11 @@ -use crate::utils::split_bits; -use crate::utils::u60_representation::U60Repr; - +use crate::fns::constrained_ops::derive_from_seed; use crate::fns::unconstrained_helpers::{ __barrett_reduction, __multiplicative_generator, __primitive_root_log_size, __tonelli_shanks_sqrt_inner_loop_check, }; use crate::params::BigNumParams as P; +use crate::utils::split_bits; +use crate::utils::u60_representation::U60Repr; /** * In this file: @@ -33,86 +33,24 @@ pub(crate) unconstrained fn __one() -> [Field; N] { limbs } +/// Deterministically derives a big_num from a seed value. +/// +/// Takes a seed byte array and generates a big_num in the range [0, modulus-1]. +/// +/// ## Value Parameters +/// +/// - `params`: The BigNum parameters containing modulus and reduction info +/// - `seed`: Input seed bytes to derive from. +/// +/// ## Returns +/// +/// An array of field elements derived from the seed (the limbs of the big_num) pub(crate) unconstrained fn __derive_from_seed( params: P, seed: [u8; SeedBytes], ) -> [Field; N] { - let mut rolling_hash_fields: [Field; (SeedBytes / 31) + 1] = [0; (SeedBytes / 31) + 1]; - let mut seed_ptr = 0; - for i in 0..(SeedBytes / 31) + 1 { - let mut packed: Field = 0; - for _ in 0..31 { - if (seed_ptr < SeedBytes) { - packed *= 256; - packed += seed[seed_ptr] as Field; - seed_ptr += 1; - } - } - rolling_hash_fields[i] = packed; - } - let compressed = - std::hash::poseidon2::Poseidon2::hash(rolling_hash_fields, (SeedBytes / 31) + 1); - let mut rolling_hash: [Field; 2] = [compressed, 0]; - - let mut to_reduce: [Field; 2 * N] = [0; 2 * N]; - - let mut double_modulus_bits = MOD_BITS * 2; - let mut double_modulus_bytes = - (double_modulus_bits) / 8 + (double_modulus_bits % 8 != 0) as u32; - - let mut last_limb_bytes = double_modulus_bytes % 15; - if (last_limb_bytes == 0) { - last_limb_bytes = 15; - } - let mut last_limb_bits = double_modulus_bits % 8; - if (last_limb_bits == 0) { - last_limb_bits = 8; - } - - for i in 0..(N - 1) { - let hash = std::hash::poseidon2::Poseidon2::hash(rolling_hash, 2); - let hash: [u8; 30] = hash.to_le_bytes(); - let mut lo: Field = 0; - let mut hi: Field = 0; - for j in 0..15 { - hi *= 256; - lo *= 256; - - if (i < 2 * N - 2) { - lo += hash[j + 15] as Field; - hi += hash[j] as Field; - } - } - to_reduce[2 * i] = lo; - to_reduce[2 * i + 1] = hi; - rolling_hash[1] += 1; - } - - { - let hash = std::hash::poseidon2::Poseidon2::hash(rolling_hash, 2); - let hash: [u8; 30] = hash.to_le_bytes(); - let mut hi: Field = 0; - for j in 0..(last_limb_bytes - 1) { - hi *= 256; - hi += hash[j] as Field; - } - hi *= 256; - let last_byte = hash[last_limb_bytes - 1]; - let mask = (1 as u64 << (last_limb_bits) as u8) - 1; - let last_bits = last_byte as u64 & mask; - hi += last_bits as Field; - to_reduce[2 * N - 2] = hi; - } - - let (_, remainder) = __barrett_reduction( - to_reduce, - params.redc_param, - MOD_BITS, - params.modulus, - params.modulus_u60_x4, - ); - let result = remainder; - result + let out = derive_from_seed::(params, seed); + out } pub(crate) unconstrained fn __eq(lhs: [Field; N], rhs: [Field; N]) -> bool { @@ -124,6 +62,7 @@ pub(crate) unconstrained fn __is_zero(limbs: [Field; N]) -> bool { for i in 0..N { result = result & (limbs[i] == 0); } + result } diff --git a/src/tests/bignum_test.nr b/src/tests/bignum_test.nr index a728e7e9..46b936c6 100644 --- a/src/tests/bignum_test.nr +++ b/src/tests/bignum_test.nr @@ -624,7 +624,6 @@ type U256 = BN256; fn test_udiv_mod_U256() { let a: U256 = unsafe { BigNum::__derive_from_seed([1, 2, 3, 4]) }; let b: U256 = BigNum::from_array([12, 0, 0]); - let (q, r) = a.udiv_mod(b); // let qb = q.__mul(b); diff --git a/src/tests/runtime_bignum_test.nr b/src/tests/runtime_bignum_test.nr index 36f86758..7c3c8f5e 100644 --- a/src/tests/runtime_bignum_test.nr +++ b/src/tests/runtime_bignum_test.nr @@ -209,24 +209,45 @@ impl BigNumParamsGetter<18, 2048> for Test2048Params { **/ comptime fn make_test(f: StructDefinition, N: Quoted, MOD_BITS: Quoted, typ: Quoted) -> Quoted { let k = f.name(); + let test_add = f"{typ}_{N}_{MOD_BITS}_test_add".quoted_contents(); + let test_sub = f"{typ}_{N}_{MOD_BITS}_test_sub".quoted_contents(); + let test_sub_modulus_limit = f"{typ}_{N}_{MOD_BITS}_test_sub_modulus_limit".quoted_contents(); + let test_sub_modulus_underflow = + f"{typ}_{N}_{MOD_BITS}_test_sub_modulus_underflow".quoted_contents(); + let test_add_modulus_limit = f"{typ}_{N}_{MOD_BITS}_test_add_modulus_limit".quoted_contents(); + let test_add_modulus_overflow = + f"{typ}_{N}_{MOD_BITS}_test_add_modulus_overflow".quoted_contents(); + let test_mul = f"{typ}_{N}_{MOD_BITS}_test_mul".quoted_contents(); + let test_quadratic_expression = + f"{typ}_{N}_{MOD_BITS}_test_quadratic_expression".quoted_contents(); + let assert_is_not_equal = f"{typ}_{N}_{MOD_BITS}_assert_is_not_equal".quoted_contents(); + let assert_is_not_equal_fail = + f"{typ}_{N}_{MOD_BITS}_assert_is_not_equal_fail".quoted_contents(); + let assert_is_not_equal_overloaded_lhs_fail = + f"{typ}_{N}_{MOD_BITS}_assert_is_not_equal_overloaded_lhs_fail".quoted_contents(); + let assert_is_not_equal_overloaded_rhs_fail = + f"{typ}_{N}_{MOD_BITS}_assert_is_not_equal_overloaded_rhs_fail".quoted_contents(); + let assert_is_not_equal_overloaded_fail = + f"{typ}_{N}_{MOD_BITS}_assert_is_not_equal_overloaded_fail".quoted_contents(); + let test_derive = f"{typ}_{N}_{MOD_BITS}_test_derive".quoted_contents(); + let test_eq = f"{typ}_{N}_{MOD_BITS}_test_eq".quoted_contents(); quote { impl $k { #[test] -fn test_add() { +fn $test_add() { let params = $typ ::get_params(); - + let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe{ RuntimeBigNum::__derive_from_seed(params, [1, 2, 3, 4]) }; let b: RuntimeBigNum<$N, $MOD_BITS> = unsafe{ RuntimeBigNum::__derive_from_seed(params, [4, 5, 6, 7]) }; - let one: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum::one(params); - a.validate_in_range(); a.validate_in_field(); b.validate_in_range(); b.validate_in_field(); - let mut c = a + b; + c = c + c; + let d = (a + b) * (one + one); assert(c == d); @@ -238,7 +259,7 @@ fn test_add() { } #[test] -fn test_sub() { +fn $test_sub() { let params = $typ ::get_params(); // 0 - 1 should equal p - 1 @@ -254,7 +275,7 @@ fn test_sub() { #[test] -fn test_sub_modulus_limit() { +fn $test_sub_modulus_limit() { let params = $typ ::get_params(); // if we underflow, maximum result should be ... // 0 - 1 = o-1 @@ -269,7 +290,7 @@ fn test_sub_modulus_limit() { #[test(should_fail_with = "call to assert_max_bit_size")] -fn test_sub_modulus_underflow() { +fn $test_sub_modulus_underflow() { let params = $typ ::get_params(); // 0 - (p + 1) is smaller than p and should produce unsatisfiable constraints @@ -284,7 +305,7 @@ fn test_sub_modulus_underflow() { } #[test] -fn test_add_modulus_limit() { +fn $test_add_modulus_limit() { let params = $typ ::get_params(); // p + 2^{modulus_bits()} - 1 should be the maximum allowed value fed into an add operation @@ -298,30 +319,28 @@ fn test_add_modulus_limit() { let two_pow_modulus_bits_minus_one: U60Repr<$N, 2> = unsafe{ one.shl($MOD_BITS) - one }; let b: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum { limbs: U60Repr::into(two_pow_modulus_bits_minus_one), params }; - let result = a + b; assert(result == b); } #[test(should_fail_with = "call to assert_max_bit_size")] -fn test_add_modulus_overflow() { +fn $test_add_modulus_overflow() { let params = $typ ::get_params(); let p : U60Repr<$N, 2> = U60Repr::from(params.modulus); let one = unsafe{ U60Repr::one() }; let a: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum { limbs: U60Repr::into(p + one), params }; - + let mut two_pow_modulus_bits_minus_one: U60Repr<$N, 2> = unsafe{ one.shl($MOD_BITS) - one }; let b: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum { limbs: U60Repr::into(two_pow_modulus_bits_minus_one), params }; - let result = a + b; assert(result == b); } #[test] -fn test_mul() { +fn $test_mul() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe { @@ -337,7 +356,7 @@ fn test_mul() { } #[test] -fn test_quadratic_expression() { +fn $test_quadratic_expression() { let params = $typ ::get_params(); for i in 0..32 { @@ -385,7 +404,7 @@ fn test_quadratic_expression() { } #[test] -fn assert_is_not_equal() { +fn $assert_is_not_equal() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe { @@ -399,7 +418,7 @@ fn assert_is_not_equal() { } #[test(should_fail_with = "asssert_is_not_equal fail")] -fn assert_is_not_equal_fail() { +fn $assert_is_not_equal_fail() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe { @@ -413,7 +432,7 @@ fn assert_is_not_equal_fail() { } #[test(should_fail_with = "asssert_is_not_equal fail")] -fn assert_is_not_equal_overloaded_lhs_fail() { +fn $assert_is_not_equal_overloaded_lhs_fail() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe { @@ -433,7 +452,7 @@ fn assert_is_not_equal_overloaded_lhs_fail() { } #[test(should_fail_with = "asssert_is_not_equal fail")] -fn assert_is_not_equal_overloaded_rhs_fail() { +fn $assert_is_not_equal_overloaded_rhs_fail() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe { @@ -453,7 +472,7 @@ fn assert_is_not_equal_overloaded_rhs_fail() { } #[test(should_fail_with = "asssert_is_not_equal fail")] -fn assert_is_not_equal_overloaded_fail() { +fn $assert_is_not_equal_overloaded_fail() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe { @@ -476,19 +495,21 @@ fn assert_is_not_equal_overloaded_fail() { } #[test] -fn test_derive() +fn $test_derive() { let params = $typ ::get_params(); - let a: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum::derive_from_seed(params, "hello".as_bytes()); + // let a: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum::derive_from_seed(params, "hello".as_bytes()); + let a: RuntimeBigNum<$N, $MOD_BITS> = RuntimeBigNum::derive_from_seed(params, [1, 2, 3, 4]); let b: RuntimeBigNum<$N, $MOD_BITS> = unsafe { - RuntimeBigNum::__derive_from_seed(params, "hello".as_bytes()) + // RuntimeBigNum::__derive_from_seed(params, "hello".as_bytes()) + RuntimeBigNum::__derive_from_seed(params, [1, 2, 3, 4]) }; - assert(a == b); + assert_eq(a, b); } #[test] -fn test_eq() { +fn $test_eq() { let params = $typ ::get_params(); let a: RuntimeBigNum<$N, $MOD_BITS> = unsafe {