From 2421b59bebbea4a7703e85725faa5885391bf537 Mon Sep 17 00:00:00 2001 From: Tom French Date: Sat, 11 Jan 2025 20:23:47 +0000 Subject: [PATCH] feat: add `ConstrainNotEqual` instruction --- .../noirc_evaluator/src/acir/acir_variable.rs | 20 ++++++++ compiler/noirc_evaluator/src/acir/mod.rs | 41 +++++++++++++++ .../src/brillig/brillig_gen/brillig_block.rs | 4 ++ .../check_for_underconstrained_values.rs | 4 +- .../noirc_evaluator/src/ssa/ir/instruction.rs | 50 +++++++++++++++++-- .../noirc_evaluator/src/ssa/ir/printer.rs | 8 +++ .../src/ssa/opt/remove_enable_side_effects.rs | 1 + 7 files changed, 123 insertions(+), 5 deletions(-) diff --git a/compiler/noirc_evaluator/src/acir/acir_variable.rs b/compiler/noirc_evaluator/src/acir/acir_variable.rs index cf6b1fcc7f..7878cc806e 100644 --- a/compiler/noirc_evaluator/src/acir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/acir/acir_variable.rs @@ -541,6 +541,26 @@ impl> AcirContext { Ok(()) } + /// Constrains the `lhs` and `rhs` to be non-equal. + pub(crate) fn assert_neq_var( + &mut self, + lhs: AcirVar, + rhs: AcirVar, + assert_message: Option>, + ) -> Result<(), RuntimeError> { + let diff_var = self.sub_var(lhs, rhs)?; + + let one = self.add_constant(F::one()); + let _ = self.inv_var(diff_var, one)?; + if let Some(payload) = assert_message { + self.acir_ir + .assertion_payloads + .insert(self.acir_ir.last_acir_opcode_location(), payload); + } + + Ok(()) + } + pub(crate) fn vars_to_expressions_or_memory( &self, values: &[AcirValue], diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index a250189d3f..137d0f3c28 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -723,6 +723,47 @@ impl<'a> Context<'a> { self.acir_context.assert_eq_var(lhs, rhs, assert_payload)?; } + Instruction::ConstrainNotEqual(lhs, rhs, assert_message) => { + let lhs = self.convert_numeric_value(*lhs, dfg)?; + let rhs = self.convert_numeric_value(*rhs, dfg)?; + + let assert_payload = if let Some(error) = assert_message { + match error { + ConstrainError::StaticString(string) => Some( + self.acir_context.generate_assertion_message_payload(string.clone()), + ), + ConstrainError::Dynamic(error_selector, is_string_type, values) => { + if let Some(constant_string) = try_to_extract_string_from_error_payload( + *is_string_type, + values, + dfg, + ) { + Some( + self.acir_context + .generate_assertion_message_payload(constant_string), + ) + } else { + let acir_vars: Vec<_> = values + .iter() + .map(|value| self.convert_value(*value, dfg)) + .collect(); + + let expressions_or_memory = + self.acir_context.vars_to_expressions_or_memory(&acir_vars)?; + + Some(AssertionPayload { + error_selector: error_selector.as_u64(), + payload: expressions_or_memory, + }) + } + } + } + } else { + None + }; + + self.acir_context.assert_neq_var(lhs, rhs, assert_payload)?; + } Instruction::Cast(value_id, _) => { let acir_var = self.convert_numeric_value(*value_id, dfg)?; self.define_result_var(dfg, instruction_id, acir_var); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index e9bc6b127f..ec918c51ff 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -279,6 +279,10 @@ impl<'block> BrilligBlock<'block> { self.brillig_context.deallocate_single_addr(condition); } } + Instruction::ConstrainNotEqual(..) => { + unreachable!("only implemented in ACIR") + } + Instruction::Allocate => { let result_value = dfg.instruction_results(instruction_id)[0]; let pointer = self.variables.define_single_addr_variable( diff --git a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs index 80dde5e27f..f44f726bfc 100644 --- a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs +++ b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs @@ -267,7 +267,8 @@ impl DependencyContext { } // Check the constrain instruction arguments against those // involved in Brillig calls, remove covered calls - Instruction::Constrain(value_id1, value_id2, _) => { + Instruction::Constrain(value_id1, value_id2, _) + | Instruction::ConstrainNotEqual(value_id1, value_id2, _) => { self.clear_constrained( &[function.dfg.resolve(*value_id1), function.dfg.resolve(*value_id2)], function, @@ -555,6 +556,7 @@ impl Context { | Instruction::Binary(..) | Instruction::Cast(..) | Instruction::Constrain(..) + | Instruction::ConstrainNotEqual(..) | Instruction::IfElse { .. } | Instruction::Load { .. } | Instruction::Not(..) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 786c3671d3..3af1746103 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -256,6 +256,9 @@ pub(crate) enum Instruction { /// Constrains two values to be equal to one another. Constrain(ValueId, ValueId, Option), + /// Constrains two values to not be equal to one another. + ConstrainNotEqual(ValueId, ValueId, Option), + /// Range constrain `value` to `max_bit_size` RangeCheck { value: ValueId, max_bit_size: u32, assert_message: Option }, @@ -364,6 +367,7 @@ impl Instruction { InstructionResultType::Operand(*value) } Instruction::Constrain(..) + | Instruction::ConstrainNotEqual(..) | Instruction::Store { .. } | Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } @@ -405,7 +409,7 @@ impl Instruction { }, // These can fail. - Constrain(..) | RangeCheck { .. } => true, + Constrain(..) | ConstrainNotEqual(..) | RangeCheck { .. } => true, // This should never be side-effectful MakeArray { .. } | Noop => false, @@ -472,7 +476,7 @@ impl Instruction { }, // We can deduplicate these instructions if we know the predicate is also the same. - Constrain(..) | RangeCheck { .. } => deduplicate_with_predicate, + Constrain(..) | ConstrainNotEqual(..) | RangeCheck { .. } => deduplicate_with_predicate, // Noop instructions can always be deduplicated, although they're more likely to be // removed entirely. @@ -540,6 +544,7 @@ impl Instruction { } Constrain(..) + | ConstrainNotEqual(..) | EnableSideEffectsIf { .. } | IncrementRc { .. } | DecrementRc { .. } @@ -610,6 +615,7 @@ impl Instruction { Instruction::Cast(_, _) | Instruction::Not(_) | Instruction::Truncate { .. } + | Instruction::ConstrainNotEqual(..) | Instruction::Constrain(_, _, _) | Instruction::RangeCheck { .. } | Instruction::Allocate @@ -656,6 +662,22 @@ impl Instruction { }); Instruction::Constrain(lhs, rhs, assert_message) } + Instruction::ConstrainNotEqual(lhs, rhs, assert_message) => { + // Must map the `lhs` and `rhs` first as the value `f` is moved with the closure + let lhs = f(*lhs); + let rhs = f(*rhs); + let assert_message = assert_message.as_ref().map(|error| match error { + ConstrainError::Dynamic(selector, is_string, payload_values) => { + ConstrainError::Dynamic( + *selector, + *is_string, + payload_values.iter().map(|&value| f(value)).collect(), + ) + } + _ => error.clone(), + }); + Instruction::ConstrainNotEqual(lhs, rhs, assert_message) + } Instruction::Call { func, arguments } => Instruction::Call { func: f(*func), arguments: vecmap(arguments.iter().copied(), f), @@ -714,7 +736,8 @@ impl Instruction { Instruction::Truncate { value, bit_size: _, max_bit_size: _ } => { *value = f(*value); } - Instruction::Constrain(lhs, rhs, assert_message) => { + Instruction::Constrain(lhs, rhs, assert_message) + | Instruction::ConstrainNotEqual(lhs, rhs, assert_message) => { *lhs = f(*lhs); *rhs = f(*rhs); if let Some(ConstrainError::Dynamic(_, _, payload_values)) = assert_message { @@ -786,7 +809,8 @@ impl Instruction { | Instruction::Load { address: value } => { f(*value); } - Instruction::Constrain(lhs, rhs, assert_error) => { + Instruction::Constrain(lhs, rhs, assert_error) + | Instruction::ConstrainNotEqual(lhs, rhs, assert_error) => { f(*lhs); f(*rhs); if let Some(ConstrainError::Dynamic(_, _, values)) = assert_error.as_ref() { @@ -871,6 +895,23 @@ impl Instruction { } } Instruction::Constrain(lhs, rhs, msg) => { + if dfg.get_numeric_constant(*rhs).map_or(false, |constant| constant.is_zero()) { + if let Value::Instruction { instruction, .. } = &dfg[dfg.resolve(*lhs)] { + if let Instruction::Binary(Binary { + lhs, + rhs, + operator: BinaryOp::Eq, + .. + }) = &dfg[*instruction] + { + return SimplifiedToInstruction(Instruction::ConstrainNotEqual( + *lhs, + *rhs, + msg.clone(), + )); + } + } + } let constraints = decompose_constrain(*lhs, *rhs, msg, dfg); if constraints.is_empty() { Remove @@ -878,6 +919,7 @@ impl Instruction { SimplifiedToInstructionMultiple(constraints) } } + Instruction::ConstrainNotEqual(..) => None, Instruction::ArrayGet { array, index } => { if let Some(index) = dfg.get_numeric_constant(*index) { try_optimize_array_get_from_previous_set(dfg, *array, index) diff --git a/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 7fe12b83ea..ae0f23ad58 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -192,6 +192,14 @@ fn display_instruction_inner( writeln!(f) } } + Instruction::ConstrainNotEqual(lhs, rhs, error) => { + write!(f, "constrain {} != {}", show(*lhs), show(*rhs))?; + if let Some(error) = error { + display_constrain_error(dfg, error, f) + } else { + writeln!(f) + } + } Instruction::Call { func, arguments } => { let arguments = value_list(dfg, arguments); writeln!(f, "call {}({}){}", show(*func), arguments, result_types(dfg, results)) diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs index ca9b75643b..942fe67b5d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs @@ -143,6 +143,7 @@ impl Context { | Not(_) | Truncate { .. } | Constrain(..) + | ConstrainNotEqual(..) | RangeCheck { .. } | IfElse { .. } | IncrementRc { .. }