From 551504aa572d3f9d56b5576d25ce1211296ee488 Mon Sep 17 00:00:00 2001 From: kevaundray Date: Tue, 12 Sep 2023 14:29:50 +0100 Subject: [PATCH] fix: Implements handling of the high limb during fixed base scalar multiplication (#535) Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> Co-authored-by: Tom French --- Cargo.lock | 2 + acir/tests/test_program_serialization.rs | 18 +++---- acvm_js/test/shared/fixed_base_scalar_mul.ts | 13 +++-- blackbox_solver/Cargo.toml | 9 ++-- blackbox_solver/src/barretenberg/wasm/mod.rs | 4 ++ .../src/barretenberg/wasm/scalar_mul.rs | 49 ++++++++++++++++++- 6 files changed, 76 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8bf7a0e8..814ecabd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,8 +56,10 @@ dependencies = [ "blake2", "flate2", "getrandom", + "hex", "js-sys", "k256", + "num-bigint", "p256", "pkg-config", "reqwest", diff --git a/acir/tests/test_program_serialization.rs b/acir/tests/test_program_serialization.rs index e8bd066e..5aa51237 100644 --- a/acir/tests/test_program_serialization.rs +++ b/acir/tests/test_program_serialization.rs @@ -60,16 +60,16 @@ fn addition_circuit() { #[test] fn fixed_base_scalar_mul_circuit() { let fixed_base_scalar_mul = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::FixedBaseScalarMul { - low: FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }, - high: FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }, - outputs: (Witness(2), Witness(3)), + low: FunctionInput { witness: Witness(1), num_bits: 128 }, + high: FunctionInput { witness: Witness(2), num_bits: 128 }, + outputs: (Witness(3), Witness(4)), }); let circuit = Circuit { - current_witness_index: 4, + current_witness_index: 5, opcodes: vec![fixed_base_scalar_mul], - private_parameters: BTreeSet::from([Witness(1)]), - return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(3)])), + private_parameters: BTreeSet::from([Witness(1), Witness(2)]), + return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(3), Witness(4)])), ..Circuit::default() }; @@ -77,9 +77,9 @@ fn fixed_base_scalar_mul_circuit() { circuit.write(&mut bytes).unwrap(); let expected_serialization: Vec = vec![ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 202, 65, 10, 0, 64, 8, 2, 64, 183, 246, 212, 255, - 223, 27, 21, 21, 72, 130, 12, 136, 31, 192, 67, 167, 180, 209, 73, 201, 234, 249, 109, 132, - 84, 218, 3, 23, 46, 165, 61, 88, 0, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 77, 138, 91, 10, 0, 48, 12, 194, 178, 215, 207, 78, 189, + 163, 175, 165, 10, 21, 36, 10, 57, 192, 160, 146, 188, 226, 139, 78, 113, 69, 183, 190, 61, + 111, 218, 182, 231, 124, 68, 185, 243, 207, 92, 0, 0, 0, ]; assert_eq!(bytes, expected_serialization) diff --git a/acvm_js/test/shared/fixed_base_scalar_mul.ts b/acvm_js/test/shared/fixed_base_scalar_mul.ts index a1fd36d7..4240b424 100644 --- a/acvm_js/test/shared/fixed_base_scalar_mul.ts +++ b/acvm_js/test/shared/fixed_base_scalar_mul.ts @@ -1,15 +1,18 @@ // See `fixed_base_scalar_mul_circuit` integration test in `acir/tests/test_program_serialization.rs`. export const bytecode = Uint8Array.from([ - 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 202, 65, 10, 0, 64, 8, 2, 64, 183, 246, - 212, 255, 223, 27, 21, 21, 72, 130, 12, 136, 31, 192, 67, 167, 180, 209, 73, - 201, 234, 249, 109, 132, 84, 218, 3, 23, 46, 165, 61, 88, 0, 0, 0, + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 77, 138, 91, 10, 0, 48, 12, 194, 178, 215, + 207, 78, 189, 163, 175, 165, 10, 21, 36, 10, 57, 192, 160, 146, 188, 226, 139, + 78, 113, 69, 183, 190, 61, 111, 218, 182, 231, 124, 68, 185, 243, 207, 92, 0, + 0, 0, ]); export const initialWitnessMap = new Map([ [1, "0x0000000000000000000000000000000000000000000000000000000000000001"], + [2, "0x0000000000000000000000000000000000000000000000000000000000000000"], ]); export const expectedWitnessMap = new Map([ [1, "0x0000000000000000000000000000000000000000000000000000000000000001"], - [2, "0x0000000000000000000000000000000000000000000000000000000000000001"], - [3, "0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c"], + [2, "0x0000000000000000000000000000000000000000000000000000000000000000"], + [3, "0x0000000000000000000000000000000000000000000000000000000000000001"], + [4, "0x0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c"], ]); diff --git a/blackbox_solver/Cargo.toml b/blackbox_solver/Cargo.toml index 275c4dd9..631152d1 100644 --- a/blackbox_solver/Cargo.toml +++ b/blackbox_solver/Cargo.toml @@ -31,6 +31,8 @@ p256 = { version = "0.11.0", features = [ "digest", "arithmetic", ] } +hex = "*" +num-bigint.workspace = true # Barretenberg WASM dependencies rust-embed = { version = "6.6.0", features = [ @@ -38,10 +40,11 @@ rust-embed = { version = "6.6.0", features = [ "interpolate-folder-path", "include-exclude", ] } - [target.'cfg(target_arch = "wasm32")'.dependencies] -wasmer = { version = "3.3", default-features = false, features = [ "js-default" ] } -getrandom = { version = "0.2", features = [ "js" ]} +wasmer = { version = "3.3", default-features = false, features = [ + "js-default", +] } +getrandom = { version = "0.2", features = ["js"] } wasm-bindgen-futures = "0.4.36" js-sys = "0.3.62" diff --git a/blackbox_solver/src/barretenberg/wasm/mod.rs b/blackbox_solver/src/barretenberg/wasm/mod.rs index 995ef673..03d9712d 100644 --- a/blackbox_solver/src/barretenberg/wasm/mod.rs +++ b/blackbox_solver/src/barretenberg/wasm/mod.rs @@ -34,6 +34,10 @@ pub(crate) enum FeatureError { NoValue, #[error("Value expected to be i32")] InvalidI32, + #[error("Value {scalar_as_hex} is not a valid grumpkin scalar")] + InvalidGrumpkinScalar { scalar_as_hex: String }, + #[error("Limb {limb_as_hex} is not less than 2^128")] + InvalidGrumpkinScalarLimb { limb_as_hex: String }, #[error("Could not convert value {value} from i32 to u32")] InvalidU32 { value: i32, source: std::num::TryFromIntError }, #[error("Could not convert value {value} from i32 to usize")] diff --git a/blackbox_solver/src/barretenberg/wasm/scalar_mul.rs b/blackbox_solver/src/barretenberg/wasm/scalar_mul.rs index 008eaa1d..aa333c31 100644 --- a/blackbox_solver/src/barretenberg/wasm/scalar_mul.rs +++ b/blackbox_solver/src/barretenberg/wasm/scalar_mul.rs @@ -1,4 +1,7 @@ use acir::FieldElement; +use num_bigint::BigUint; + +use crate::barretenberg::wasm::FeatureError; use super::{Barretenberg, Error, FIELD_BYTES}; @@ -14,12 +17,40 @@ impl ScalarMul for Barretenberg { fn fixed_base( &self, low: &FieldElement, - _high: &FieldElement, + high: &FieldElement, ) -> Result<(FieldElement, FieldElement), Error> { let lhs_ptr: usize = 0; let result_ptr: usize = lhs_ptr + FIELD_BYTES; - self.transfer_to_heap(&low.to_be_bytes(), lhs_ptr); + let low: u128 = low.try_into_u128().ok_or_else(|| { + Error::FromFeature(FeatureError::InvalidGrumpkinScalarLimb { + limb_as_hex: low.to_hex(), + }) + })?; + + let high: u128 = high.try_into_u128().ok_or_else(|| { + Error::FromFeature(FeatureError::InvalidGrumpkinScalarLimb { + limb_as_hex: high.to_hex(), + }) + })?; + + let mut bytes = high.to_be_bytes().to_vec(); + bytes.extend_from_slice(&low.to_be_bytes()); + + // Check if this is smaller than the grumpkin modulus + let grumpkin_integer = BigUint::from_bytes_be(&bytes); + let grumpkin_modulus = BigUint::from_bytes_be(&[ + 48, 100, 78, 114, 225, 49, 160, 41, 184, 80, 69, 182, 129, 129, 88, 93, 151, 129, 106, + 145, 104, 113, 202, 141, 60, 32, 140, 22, 216, 124, 253, 71, + ]); + + if grumpkin_integer >= grumpkin_modulus { + return Err(Error::FromFeature(FeatureError::InvalidGrumpkinScalar { + scalar_as_hex: hex::encode(grumpkin_integer.to_bytes_be()), + })); + } + + self.transfer_to_heap(&bytes, lhs_ptr); self.call_multiple("compute_public_key", vec![&lhs_ptr.into(), &result_ptr.into()])?; let result_bytes: [u8; 2 * FIELD_BYTES] = self.read_memory(result_ptr); @@ -46,6 +77,20 @@ mod test { let x = "0000000000000000000000000000000000000000000000000000000000000001"; let y = "0000000000000002cf135e7506a45d632d270d45f1181294833fc48d823f272c"; + assert_eq!(x, res.0.to_hex()); + assert_eq!(y, res.1.to_hex()); + Ok(()) + } + #[test] + fn low_high_smoke_test() -> Result<(), Error> { + let barretenberg = Barretenberg::new(); + let low = FieldElement::one(); + let high = FieldElement::from(2u128); + + let res = barretenberg.fixed_base(&low, &high)?; + let x = "0702ab9c7038eeecc179b4f209991bcb68c7cb05bf4c532d804ccac36199c9a9"; + let y = "23f10e9e43a3ae8d75d24154e796aae12ae7af546716e8f81a2564f1b5814130"; + assert_eq!(x, res.0.to_hex()); assert_eq!(y, res.1.to_hex()); Ok(())