From 2247814f951f5d33257cd123a3bdcba857c9b167 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Tue, 5 Nov 2024 10:31:24 +0100 Subject: [PATCH] fix: right shift is not a regular division (#6400) --- .../src/brillig/brillig_gen/brillig_block.rs | 156 +++++++----------- .../noirc_evaluator/src/brillig/brillig_ir.rs | 82 +++++++++ .../src/ssa/ir/instruction/binary.rs | 10 +- .../src/ssa/opt/remove_bit_shifts.rs | 33 ++-- .../bit_shifts_comptime/src/main.nr | 4 + .../bit_shifts_runtime/Prover.toml | 1 + .../bit_shifts_runtime/src/main.nr | 4 +- 7 files changed, 175 insertions(+), 115 deletions(-) diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index 23c802e3a14..0e9ebd8900d 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -2,6 +2,7 @@ use crate::brillig::brillig_ir::artifact::Label; use crate::brillig::brillig_ir::brillig_variable::{ type_to_heap_value_type, BrilligArray, BrilligVariable, SingleAddrVariable, }; + use crate::brillig::brillig_ir::registers::Stack; use crate::brillig::brillig_ir::{ BrilligBinaryOp, BrilligContext, ReservedRegisters, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE, @@ -1202,7 +1203,7 @@ impl<'block> BrilligBlock<'block> { let brillig_binary_op = match binary.operator { BinaryOp::Div => { if is_signed { - self.convert_signed_division(left, right, result_variable); + self.brillig_context.convert_signed_division(left, right, result_variable); return; } else if is_field { BrilligBinaryOp::FieldDiv @@ -1234,7 +1235,14 @@ impl<'block> BrilligBlock<'block> { BinaryOp::Or => BrilligBinaryOp::Or, BinaryOp::Xor => BrilligBinaryOp::Xor, BinaryOp::Shl => BrilligBinaryOp::Shl, - BinaryOp::Shr => BrilligBinaryOp::Shr, + BinaryOp::Shr => { + if is_signed { + self.convert_signed_shr(left, right, result_variable); + return; + } else { + BrilligBinaryOp::Shr + } + } }; self.brillig_context.binary_instruction(left, right, result_variable, brillig_binary_op); @@ -1250,98 +1258,6 @@ impl<'block> BrilligBlock<'block> { ); } - /// Splits a two's complement signed integer in the sign bit and the absolute value. - /// For example, -6 i8 (11111010) is split to 00000110 (6, absolute value) and 1 (is_negative). - fn absolute_value( - &mut self, - num: SingleAddrVariable, - absolute_value: SingleAddrVariable, - result_is_negative: SingleAddrVariable, - ) { - let max_positive = self - .brillig_context - .make_constant_instruction(((1_u128 << (num.bit_size - 1)) - 1).into(), num.bit_size); - - // Compute if num is negative - self.brillig_context.binary_instruction( - max_positive, - num, - result_is_negative, - BrilligBinaryOp::LessThan, - ); - - // Two's complement of num - let zero = self.brillig_context.make_constant_instruction(0_usize.into(), num.bit_size); - let twos_complement = - SingleAddrVariable::new(self.brillig_context.allocate_register(), num.bit_size); - self.brillig_context.binary_instruction(zero, num, twos_complement, BrilligBinaryOp::Sub); - - // absolute_value = result_is_negative ? twos_complement : num - self.brillig_context.codegen_branch(result_is_negative.address, |ctx, is_negative| { - if is_negative { - ctx.mov_instruction(absolute_value.address, twos_complement.address); - } else { - ctx.mov_instruction(absolute_value.address, num.address); - } - }); - - self.brillig_context.deallocate_single_addr(zero); - self.brillig_context.deallocate_single_addr(max_positive); - self.brillig_context.deallocate_single_addr(twos_complement); - } - - fn convert_signed_division( - &mut self, - left: SingleAddrVariable, - right: SingleAddrVariable, - result: SingleAddrVariable, - ) { - let left_is_negative = SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); - let left_abs_value = - SingleAddrVariable::new(self.brillig_context.allocate_register(), left.bit_size); - - let right_is_negative = - SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); - let right_abs_value = - SingleAddrVariable::new(self.brillig_context.allocate_register(), right.bit_size); - - let result_is_negative = - SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); - - // Compute both absolute values - self.absolute_value(left, left_abs_value, left_is_negative); - self.absolute_value(right, right_abs_value, right_is_negative); - - // Perform the division on the absolute values - self.brillig_context.binary_instruction( - left_abs_value, - right_abs_value, - result, - BrilligBinaryOp::UnsignedDiv, - ); - - // Compute result sign - self.brillig_context.binary_instruction( - left_is_negative, - right_is_negative, - result_is_negative, - BrilligBinaryOp::Xor, - ); - - // If result has to be negative, perform two's complement - self.brillig_context.codegen_if(result_is_negative.address, |ctx| { - let zero = ctx.make_constant_instruction(0_usize.into(), result.bit_size); - ctx.binary_instruction(zero, result, result, BrilligBinaryOp::Sub); - ctx.deallocate_single_addr(zero); - }); - - self.brillig_context.deallocate_single_addr(left_is_negative); - self.brillig_context.deallocate_single_addr(left_abs_value); - self.brillig_context.deallocate_single_addr(right_is_negative); - self.brillig_context.deallocate_single_addr(right_abs_value); - self.brillig_context.deallocate_single_addr(result_is_negative); - } - fn convert_signed_modulo( &mut self, left: SingleAddrVariable, @@ -1354,7 +1270,7 @@ impl<'block> BrilligBlock<'block> { SingleAddrVariable::new(self.brillig_context.allocate_register(), left.bit_size); // i = left / right - self.convert_signed_division(left, right, scratch_var_i); + self.brillig_context.convert_signed_division(left, right, scratch_var_i); // j = i * right self.brillig_context.binary_instruction( @@ -1401,6 +1317,56 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.deallocate_single_addr(bias); } + fn convert_signed_shr( + &mut self, + left: SingleAddrVariable, + right: SingleAddrVariable, + result: SingleAddrVariable, + ) { + // Check if left is negative + let left_is_negative = SingleAddrVariable::new(self.brillig_context.allocate_register(), 1); + let max_positive = self + .brillig_context + .make_constant_instruction(((1_u128 << (left.bit_size - 1)) - 1).into(), left.bit_size); + self.brillig_context.binary_instruction( + max_positive, + left, + left_is_negative, + BrilligBinaryOp::LessThan, + ); + + self.brillig_context.codegen_branch(left_is_negative.address, |ctx, is_negative| { + if is_negative { + let one = ctx.make_constant_instruction(1_u128.into(), left.bit_size); + + // computes 2^right + let two = ctx.make_constant_instruction(2_u128.into(), left.bit_size); + let two_pow = ctx.make_constant_instruction(1_u128.into(), left.bit_size); + let right_u32 = SingleAddrVariable::new(ctx.allocate_register(), 32); + ctx.cast(right_u32, right); + let pow_body = |ctx: &mut BrilligContext<_, _>, _: SingleAddrVariable| { + ctx.binary_instruction(two_pow, two, two_pow, BrilligBinaryOp::Mul); + }; + ctx.codegen_for_loop(None, right_u32.address, None, pow_body); + + // Right shift using division on 1-complement + ctx.binary_instruction(left, one, result, BrilligBinaryOp::Add); + ctx.convert_signed_division(result, two_pow, result); + ctx.binary_instruction(result, one, result, BrilligBinaryOp::Sub); + + // Clean-up + ctx.deallocate_single_addr(one); + ctx.deallocate_single_addr(two); + ctx.deallocate_single_addr(two_pow); + ctx.deallocate_single_addr(right_u32); + } else { + ctx.binary_instruction(left, right, result, BrilligBinaryOp::Shr); + } + }); + + self.brillig_context.deallocate_single_addr(left_is_negative); + } + #[allow(clippy::too_many_arguments)] fn add_overflow_check( &mut self, diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 3633700a795..66d51b2accc 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -25,6 +25,7 @@ mod entry_point; mod instructions; use artifact::Label; +use brillig_variable::SingleAddrVariable; pub(crate) use instructions::BrilligBinaryOp; use registers::{RegisterAllocator, ScratchSpace}; @@ -109,6 +110,87 @@ impl BrilligContext { can_call_procedures: true, } } + + /// Splits a two's complement signed integer in the sign bit and the absolute value. + /// For example, -6 i8 (11111010) is split to 00000110 (6, absolute value) and 1 (is_negative). + pub(crate) fn absolute_value( + &mut self, + num: SingleAddrVariable, + absolute_value: SingleAddrVariable, + result_is_negative: SingleAddrVariable, + ) { + let max_positive = self + .make_constant_instruction(((1_u128 << (num.bit_size - 1)) - 1).into(), num.bit_size); + + // Compute if num is negative + self.binary_instruction(max_positive, num, result_is_negative, BrilligBinaryOp::LessThan); + + // Two's complement of num + let zero = self.make_constant_instruction(0_usize.into(), num.bit_size); + let twos_complement = SingleAddrVariable::new(self.allocate_register(), num.bit_size); + self.binary_instruction(zero, num, twos_complement, BrilligBinaryOp::Sub); + + // absolute_value = result_is_negative ? twos_complement : num + self.codegen_branch(result_is_negative.address, |ctx, is_negative| { + if is_negative { + ctx.mov_instruction(absolute_value.address, twos_complement.address); + } else { + ctx.mov_instruction(absolute_value.address, num.address); + } + }); + + self.deallocate_single_addr(zero); + self.deallocate_single_addr(max_positive); + self.deallocate_single_addr(twos_complement); + } + + pub(crate) fn convert_signed_division( + &mut self, + left: SingleAddrVariable, + right: SingleAddrVariable, + result: SingleAddrVariable, + ) { + let left_is_negative = SingleAddrVariable::new(self.allocate_register(), 1); + let left_abs_value = SingleAddrVariable::new(self.allocate_register(), left.bit_size); + + let right_is_negative = SingleAddrVariable::new(self.allocate_register(), 1); + let right_abs_value = SingleAddrVariable::new(self.allocate_register(), right.bit_size); + + let result_is_negative = SingleAddrVariable::new(self.allocate_register(), 1); + + // Compute both absolute values + self.absolute_value(left, left_abs_value, left_is_negative); + self.absolute_value(right, right_abs_value, right_is_negative); + + // Perform the division on the absolute values + self.binary_instruction( + left_abs_value, + right_abs_value, + result, + BrilligBinaryOp::UnsignedDiv, + ); + + // Compute result sign + self.binary_instruction( + left_is_negative, + right_is_negative, + result_is_negative, + BrilligBinaryOp::Xor, + ); + + // If result has to be negative, perform two's complement + self.codegen_if(result_is_negative.address, |ctx| { + let zero = ctx.make_constant_instruction(0_usize.into(), result.bit_size); + ctx.binary_instruction(zero, result, result, BrilligBinaryOp::Sub); + ctx.deallocate_single_addr(zero); + }); + + self.deallocate_single_addr(left_is_negative); + self.deallocate_single_addr(left_abs_value); + self.deallocate_single_addr(right_is_negative); + self.deallocate_single_addr(right_abs_value); + self.deallocate_single_addr(result_is_negative); + } } /// Special brillig context to codegen compiler intrinsic shared procedures diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs index 03262be0a06..487370488b9 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs @@ -281,15 +281,7 @@ impl Binary { let zero = dfg.make_constant(FieldElement::zero(), operand_type); return SimplifyResult::SimplifiedTo(zero); } - - // `two_pow_rhs` is limited to be at most `2 ^ {operand_bitsize - 1}` so it fits in `operand_type`. - let two_pow_rhs = FieldElement::from(2u128).pow(&rhs_const); - let two_pow_rhs = dfg.make_constant(two_pow_rhs, operand_type); - return SimplifyResult::SimplifiedToInstruction(Instruction::binary( - BinaryOp::Div, - self.lhs, - two_pow_rhs, - )); + return SimplifyResult::None; } } }; diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index 6f3f2fa14b7..cdbb1043232 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -145,6 +145,8 @@ impl Context<'_> { } /// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs + /// For negative signed integers, we do the division on the 1-complement representation of lhs, + /// before converting back the result to the 2-complement representation. pub(crate) fn insert_shift_right( &mut self, lhs: ValueId, @@ -153,16 +155,27 @@ impl Context<'_> { ) -> ValueId { let lhs_typ = self.function.dfg.type_of_value(lhs); let base = self.field_constant(FieldElement::from(2_u128)); - // we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value - let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size)); - let pow = self.pow(base, rhs_unsigned); - // We need at least one more bit for the case where rhs == bit_size - let div_type = Type::unsigned(bit_size + 1); - let casted_lhs = self.insert_cast(lhs, div_type.clone()); - let casted_pow = self.insert_cast(pow, div_type); - let div_result = self.insert_binary(casted_lhs, BinaryOp::Div, casted_pow); - // We have to cast back to the original type - self.insert_cast(div_result, lhs_typ) + let pow = self.pow(base, rhs); + if lhs_typ.is_unsigned() { + // unsigned right bit shift is just a normal division + self.insert_binary(lhs, BinaryOp::Div, pow) + } else { + // Get the sign of the operand; positive signed operand will just do a division as well + let zero = self.numeric_constant(FieldElement::zero(), Type::signed(bit_size)); + let lhs_sign = self.insert_binary(lhs, BinaryOp::Lt, zero); + let lhs_sign_as_field = self.insert_cast(lhs_sign, Type::field()); + let lhs_as_field = self.insert_cast(lhs, Type::field()); + // For negative numbers, convert to 1-complement using wrapping addition of a + 1 + let one_complement = self.insert_binary(lhs_sign_as_field, BinaryOp::Add, lhs_as_field); + let one_complement = self.insert_truncate(one_complement, bit_size, bit_size + 1); + let one_complement = self.insert_cast(one_complement, Type::signed(bit_size)); + // Performs the division on the 1-complement (or the operand if positive) + let shifted_complement = self.insert_binary(one_complement, BinaryOp::Div, pow); + // Convert back to 2-complement representation if operand is negative + let lhs_sign_as_int = self.insert_cast(lhs_sign, lhs_typ); + let shifted = self.insert_binary(shifted_complement, BinaryOp::Sub, lhs_sign_as_int); + self.insert_truncate(shifted, bit_size, bit_size + 1) + } } /// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs diff --git a/test_programs/execution_success/bit_shifts_comptime/src/main.nr b/test_programs/execution_success/bit_shifts_comptime/src/main.nr index 6d9736b6abb..a11dae1c716 100644 --- a/test_programs/execution_success/bit_shifts_comptime/src/main.nr +++ b/test_programs/execution_success/bit_shifts_comptime/src/main.nr @@ -15,6 +15,10 @@ fn main(x: u64) { assert(x << 63 == 0); assert_eq((1 as u64) << 32, 0x0100000000); + + //regression for 6201 + let a: i16 = -769; + assert_eq(a >> 3, -97); } fn regression_2250() { diff --git a/test_programs/execution_success/bit_shifts_runtime/Prover.toml b/test_programs/execution_success/bit_shifts_runtime/Prover.toml index 67bf6a6a234..7b56d36c9d5 100644 --- a/test_programs/execution_success/bit_shifts_runtime/Prover.toml +++ b/test_programs/execution_success/bit_shifts_runtime/Prover.toml @@ -1,2 +1,3 @@ x = 64 y = 1 +z = "-769" diff --git a/test_programs/execution_success/bit_shifts_runtime/src/main.nr b/test_programs/execution_success/bit_shifts_runtime/src/main.nr index 059bbe84dac..370bb699048 100644 --- a/test_programs/execution_success/bit_shifts_runtime/src/main.nr +++ b/test_programs/execution_success/bit_shifts_runtime/src/main.nr @@ -1,4 +1,4 @@ -fn main(x: u64, y: u8) { +fn main(x: u64, y: u8, z: i16) { // runtime shifts on compile-time known values assert(64 << y == 128); assert(64 >> y == 32); @@ -17,4 +17,6 @@ fn main(x: u64, y: u8) { assert(a << y == -2); assert(x >> (x as u8) == 0); + + assert_eq(z >> 3, -97); }