diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index d52d4ca8c71..85f83dd5216 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -61,6 +61,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "as_slice" => as_slice(interner, arguments, location), "expr_as_array" => expr_as_array(interner, arguments, return_type, location), "expr_as_assert" => expr_as_assert(interner, arguments, return_type, location), + "expr_as_assert_eq" => expr_as_assert_eq(interner, arguments, return_type, location), "expr_as_assign" => expr_as_assign(interner, arguments, return_type, location), "expr_as_binary_op" => expr_as_binary_op(interner, arguments, return_type, location), "expr_as_block" => expr_as_block(interner, arguments, return_type, location), @@ -953,6 +954,43 @@ fn expr_as_assert( }) } +// fn as_assert_eq(self) -> Option<(Expr, Expr, Option)> +fn expr_as_assert_eq( + interner: &NodeInterner, + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(interner, arguments, return_type.clone(), location, |expr| { + if let ExprValue::Statement(StatementKind::Constrain(constrain)) = expr { + if constrain.2 == ConstrainKind::AssertEq { + let ExpressionKind::Infix(infix) = constrain.0.kind else { + panic!("Expected AssertEq constrain statement to have an infix expression"); + }; + + let lhs = Value::expression(infix.lhs.kind); + let rhs = Value::expression(infix.rhs.kind); + + let option_type = extract_option_generic_type(return_type); + let Type::Tuple(mut tuple_types) = option_type else { + panic!("Expected the return type option generic arg to be a tuple"); + }; + assert_eq!(tuple_types.len(), 3); + + let option_type = tuple_types.pop().unwrap(); + let message = constrain.1.map(|message| Value::expression(message.kind)); + let message = option(option_type, message).ok()?; + + Some(Value::Tuple(vec![lhs, rhs, message])) + } else { + None + } + } else { + None + } + }) +} + // fn as_assign(self) -> Option<(Expr, Expr)> fn expr_as_assign( interner: &NodeInterner, diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md index 57f0fce24c1..3a3c61b41f5 100644 --- a/docs/docs/noir/standard_library/meta/expr.md +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -18,6 +18,13 @@ If this expression is an array, this returns a slice of each element in the arra If this expression is an assert, this returns the assert expression and the optional message. +### as_assert_eq + +#include_code as_assert_eq noir_stdlib/src/meta/expr.nr rust + +If this expression is an assert_eq, this returns the left-hand-side and right-hand-side +expressions, together with the optional message. + ### as_assign #include_code as_assign noir_stdlib/src/meta/expr.nr rust diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index 9b5eee03229..43638ad791b 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -13,6 +13,11 @@ impl Expr { fn as_assert(self) -> Option<(Expr, Option)> {} // docs:end:as_assert + #[builtin(expr_as_assert_eq)] + // docs:start:as_assert_eq + fn as_assert_eq(self) -> Option<(Expr, Expr, Option)> {} + // docs:end:as_assert_eq + #[builtin(expr_as_assign)] // docs:start:as_assign fn as_assign(self) -> Option<(Expr, Expr)> {} @@ -121,6 +126,7 @@ impl Expr { // docs:end:modify let result = modify_array(self, f); let result = result.or_else(|| modify_assert(self, f)); + let result = result.or_else(|| modify_assert_eq(self, f)); let result = result.or_else(|| modify_assign(self, f)); let result = result.or_else(|| modify_binary_op(self, f)); let result = result.or_else(|| modify_block(self, f)); @@ -178,6 +184,18 @@ fn modify_assert(expr: Expr, f: fn[Env](Expr) -> Option) -> Option(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { + expr.as_assert_eq().map( + |expr: (Expr, Expr, Option)| { + let (lhs, rhs, msg) = expr; + let lhs = lhs.modify(f); + let rhs = rhs.modify(f); + let msg = msg.map(|msg: Expr| msg.modify(f)); + new_assert_eq(lhs, rhs, msg) + } + ) +} + fn modify_assign(expr: Expr, f: fn[Env](Expr) -> Option) -> Option { expr.as_assign().map( |expr: (Expr, Expr)| { @@ -360,6 +378,15 @@ fn new_assert(predicate: Expr, msg: Option) -> Expr { } } +fn new_assert_eq(lhs: Expr, rhs: Expr, msg: Option) -> Expr { + if msg.is_some() { + let msg = msg.unwrap(); + quote { assert_eq($lhs, $rhs, $msg) }.as_expr().unwrap() + } else { + quote { assert_eq($lhs, $rhs) }.as_expr().unwrap() + } +} + fn new_assign(lhs: Expr, rhs: Expr) -> Expr { quote { $lhs = $rhs }.as_expr().unwrap() } diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr index 1488783c72c..c082c1dde33 100644 --- a/test_programs/noir_test_success/comptime_expr/src/main.nr +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -63,6 +63,44 @@ mod tests { } } + #[test] + fn test_expr_as_assert_eq() { + comptime + { + let expr = quote { assert_eq(true, false) }.as_expr().unwrap(); + let (lhs, rhs, msg) = expr.as_assert_eq().unwrap(); + assert_eq(lhs.as_bool().unwrap(), true); + assert_eq(rhs.as_bool().unwrap(), false); + assert(msg.is_none()); + + let expr = quote { assert_eq(false, true, "oops") }.as_expr().unwrap(); + let (lhs, rhs, msg) = expr.as_assert_eq().unwrap(); + assert_eq(lhs.as_bool().unwrap(), false); + assert_eq(rhs.as_bool().unwrap(), true); + assert(msg.is_some()); + } + } + + #[test] + fn test_expr_mutate_for_assert_eq() { + comptime + { + let expr = quote { assert_eq(1, 2) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (lhs, rhs, msg) = expr.as_assert_eq().unwrap(); + assert_eq(lhs.as_integer().unwrap(), (2, false)); + assert_eq(rhs.as_integer().unwrap(), (4, false)); + assert(msg.is_none()); + + let expr = quote { assert_eq(1, 2, 3) }.as_expr().unwrap(); + let expr = expr.modify(times_two); + let (lhs, rhs, msg) = expr.as_assert_eq().unwrap(); + assert_eq(lhs.as_integer().unwrap(), (2, false)); + assert_eq(rhs.as_integer().unwrap(), (4, false)); + assert_eq(msg.unwrap().as_integer().unwrap(), (6, false)); + } + } + #[test] fn test_expr_as_assign() { comptime