Skip to content

Commit

Permalink
feat: optimize to_radix (#8073)
Browse files Browse the repository at this point in the history
- Change the ToRadix gadget/blackbox to emit u8 limbs instead of
fields
- Modify the toradix blackbox in brillig with an output_bits flag, to
emit u1 limbs
- No casting is needed in either case (u8 or u1) saving some emitted
brillig opcodes
- The AVM transpiler, then ignores the output_bits flag, since it'll
output u8s which is what the AVM expects for bits
  • Loading branch information
sirasistant authored Aug 26, 2024
1 parent 717cf3d commit 8baeffd
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 39 deletions.
3 changes: 2 additions & 1 deletion avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,8 @@ fn handle_black_box_function(avm_instrs: &mut Vec<AvmInstruction>, operation: &B
..Default::default()
});
}
BlackBoxOp::ToRadix { input, radix, output } => {
// We ignore the output bits flag since we represent bits as bytes
BlackBoxOp::ToRadix { input, radix, output, output_bits: _ } => {
let num_limbs = output.size as u32;
let input_offset = input.0 as u32;
let output_offset = output.pointer.0 as u32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ struct BlackBoxOp {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;
bool output_bits;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -5400,6 +5401,9 @@ inline bool operator==(const BlackBoxOp::ToRadix& lhs, const BlackBoxOp::ToRadix
if (!(lhs.output == rhs.output)) {
return false;
}
if (!(lhs.output_bits == rhs.output_bits)) {
return false;
}
return true;
}

Expand Down Expand Up @@ -5430,6 +5434,7 @@ void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program:
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
serde::Serializable<decltype(obj.output_bits)>::serialize(obj.output_bits, serializer);
}

template <>
Expand All @@ -5441,6 +5446,7 @@ Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
obj.output_bits = serde::Deserializable<decltype(obj.output_bits)>::deserialize(deserializer);
return obj;
}

Expand Down
4 changes: 4 additions & 0 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ namespace Program {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;
bool output_bits;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4529,6 +4530,7 @@ namespace Program {
if (!(lhs.input == rhs.input)) { return false; }
if (!(lhs.radix == rhs.radix)) { return false; }
if (!(lhs.output == rhs.output)) { return false; }
if (!(lhs.output_bits == rhs.output_bits)) { return false; }
return true;
}

Expand All @@ -4555,6 +4557,7 @@ void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program:
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
serde::Serializable<decltype(obj.output_bits)>::serialize(obj.output_bits, serializer);
}

template <>
Expand All @@ -4564,6 +4567,7 @@ Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
obj.output_bits = serde::Deserializable<decltype(obj.output_bits)>::deserialize(deserializer);
return obj;
}

Expand Down
1 change: 1 addition & 0 deletions noir/noir-repo/acvm-repo/brillig/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,6 @@ pub enum BlackBoxOp {
input: MemoryAddress,
radix: u32,
output: HeapArray,
output_bits: bool,
},
}
15 changes: 12 additions & 3 deletions noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use acir::brillig::{BlackBoxOp, HeapArray, HeapVector};
use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, IntegerBitSize};
use acir::{AcirField, BlackBoxFunc};
use acvm_blackbox_solver::BigIntSolver;
use acvm_blackbox_solver::{
aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256,
keccakf1600, sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError,
};
use num_bigint::BigUint;
use num_traits::Zero;

use crate::memory::MemoryValue;
use crate::Memory;
Expand Down Expand Up @@ -366,7 +367,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
memory.write_slice(memory.read_ref(output.pointer), &state);
Ok(())
}
BlackBoxOp::ToRadix { input, radix, output } => {
BlackBoxOp::ToRadix { input, radix, output, output_bits } => {
let input: F = *memory.read(*input).extract_field().expect("ToRadix input not a field");

let mut input = BigUint::from_bytes_be(&input.to_be_bytes());
Expand All @@ -376,7 +377,15 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>

for _ in 0..output.size {
let limb = &input % &radix;
limbs.push(MemoryValue::new_field(F::from_be_bytes_reduce(&limb.to_bytes_be())));
if *output_bits {
limbs.push(MemoryValue::new_integer(
if limb.is_zero() { 0 } else { 1 },
IntegerBitSize::U1,
));
} else {
let limb: u8 = limb.try_into().unwrap();
limbs.push(MemoryValue::new_integer(limb as u128, IntegerBitSize::U8));
};
input /= &radix;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ impl<'block> BrilligBlock<'block> {
radix,
limb_count,
matches!(endianness, Endian::Big),
8,
false,
);
}
Value::Intrinsic(Intrinsic::ToBits(endianness)) => {
Expand Down Expand Up @@ -595,7 +595,7 @@ impl<'block> BrilligBlock<'block> {
2,
limb_count,
matches!(endianness, Endian::Big),
1,
true,
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
radix: u32,
limb_count: usize,
big_endian: bool,
limb_bit_size: u32,
output_bits: bool, // If true will generate bit limbs, if false will generate byte limbs
) {
assert!(source_field.bit_size == F::max_num_bits());

Expand All @@ -83,39 +83,9 @@ impl<F: AcirField + DebugToString, Registers: RegisterAllocator> BrilligContext<
input: source_field.address,
radix,
output: HeapArray { pointer: target_vector.pointer, size: limb_count },
output_bits,
});

if limb_bit_size != F::max_num_bits() {
let end_pointer = self.allocate_register();
let temporary_register = self.allocate_register();

self.memory_op_instruction(
target_vector.pointer,
target_vector.size,
end_pointer,
BrilligBinaryOp::Add,
);

self.codegen_for_loop(
Some(target_vector.pointer),
end_pointer,
None,
|ctx, item_pointer| {
ctx.load_instruction(temporary_register, item_pointer.address);

ctx.cast(
SingleAddrVariable::new(temporary_register, limb_bit_size),
SingleAddrVariable::new(temporary_register, F::max_num_bits()),
);

ctx.store_instruction(item_pointer.address, temporary_register);
},
);

self.deallocate_register(end_pointer);
self.deallocate_register(temporary_register);
}

if big_endian {
self.codegen_array_reverse(target_vector.pointer, target_vector.size);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ impl DebugShow {
output
);
}
BlackBoxOp::ToRadix { input, radix, output } => {
BlackBoxOp::ToRadix { input, radix, output, output_bits: _ } => {
debug_println!(
self.enable_debug_trace,
" TO_RADIX {} {} -> {}",
Expand Down

0 comments on commit 8baeffd

Please sign in to comment.