Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Nested array equality #4903

Merged
merged 7 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 5 additions & 103 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,22 +566,12 @@ impl<'a> FunctionContext<'a> {
mut rhs: ValueId,
location: Location,
) -> Values {
let result_type = self.builder.type_of_value(lhs);
let mut result = match operator {
BinaryOpKind::Equal | BinaryOpKind::NotEqual
if matches!(result_type, Type::Array(..)) =>
{
return self.insert_array_equality(lhs, operator, rhs, location)
}
_ => {
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
let op = convert_operator(operator);
if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}

self.builder.set_location(location).insert_binary(lhs, op, rhs)
}
};
let mut result = self.builder.set_location(location).insert_binary(lhs, op, rhs);

// Check for integer overflow
if matches!(
Expand All @@ -600,94 +590,6 @@ impl<'a> FunctionContext<'a> {
result.into()
}

/// The frontend claims to support equality (==) on arrays, so we must support it in SSA here.
/// The actual BinaryOp::Eq in SSA is meant only for primitive numeric types so we encode an
/// entire equality loop on each array element. The generated IR is as follows:
///
/// ...
/// result_alloc = allocate
/// store u1 1 in result_alloc
/// jmp loop_start(0)
/// loop_start(i: Field):
/// v0 = lt i, array_len
/// jmpif v0, then: loop_body, else: loop_end
/// loop_body():
/// v1 = array_get lhs, index i
/// v2 = array_get rhs, index i
/// v3 = eq v1, v2
/// v4 = load result_alloc
/// v5 = and v4, v3
/// store v5 in result_alloc
/// v6 = add i, Field 1
/// jmp loop_start(v6)
/// loop_end():
/// result = load result_alloc
fn insert_array_equality(
&mut self,
lhs: ValueId,
operator: BinaryOpKind,
rhs: ValueId,
location: Location,
) -> Values {
let lhs_type = self.builder.type_of_value(lhs);
let rhs_type = self.builder.type_of_value(rhs);

let (array_length, element_type) = match (lhs_type, rhs_type) {
(
Type::Array(lhs_composite_type, lhs_length),
Type::Array(rhs_composite_type, rhs_length),
) => {
assert!(
lhs_composite_type.len() == 1 && rhs_composite_type.len() == 1,
"== is unimplemented for arrays of structs"
);
assert_eq!(lhs_composite_type[0], rhs_composite_type[0]);
assert_eq!(lhs_length, rhs_length, "Expected two arrays of equal length");
(lhs_length, lhs_composite_type[0].clone())
}
_ => unreachable!("Expected two array values"),
};

let loop_start = self.builder.insert_block();
let loop_body = self.builder.insert_block();
let loop_end = self.builder.insert_block();

// pre-loop
let result_alloc = self.builder.set_location(location).insert_allocate(Type::bool());
let true_value = self.builder.numeric_constant(1u128, Type::bool());
self.builder.insert_store(result_alloc, true_value);
let zero = self.builder.length_constant(0u128);
self.builder.terminate_with_jmp(loop_start, vec![zero]);

// loop_start
self.builder.switch_to_block(loop_start);
let i = self.builder.add_block_parameter(loop_start, Type::length_type());
let array_length = self.builder.length_constant(array_length as u128);
let v0 = self.builder.insert_binary(i, BinaryOp::Lt, array_length);
self.builder.terminate_with_jmpif(v0, loop_body, loop_end);

// loop body
self.builder.switch_to_block(loop_body);
let v1 = self.builder.insert_array_get(lhs, i, element_type.clone());
let v2 = self.builder.insert_array_get(rhs, i, element_type);
let v3 = self.builder.insert_binary(v1, BinaryOp::Eq, v2);
let v4 = self.builder.insert_load(result_alloc, Type::bool());
let v5 = self.builder.insert_binary(v4, BinaryOp::And, v3);
self.builder.insert_store(result_alloc, v5);
let one = self.builder.length_constant(1u128);
let v6 = self.builder.insert_binary(i, BinaryOp::Add, one);
self.builder.terminate_with_jmp(loop_start, vec![v6]);

// loop end
self.builder.switch_to_block(loop_end);
let mut result = self.builder.insert_load(result_alloc, Type::bool());

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
result.into()
}

/// Inserts a call instruction at the end of the current block and returns the results
/// of the call.
///
Expand Down
30 changes: 0 additions & 30 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -890,36 +890,6 @@ impl<'interner> TypeChecker<'interner> {
// <= and friends are technically valid for booleans, just not very useful
(Bool, Bool) => Ok((Bool, false)),

// Special-case == and != for arrays
(Array(x_size, x_type), Array(y_size, y_type))
if matches!(op.kind, BinaryOpKind::Equal | BinaryOpKind::NotEqual) =>
{
self.unify(x_size, y_size, || TypeCheckError::TypeMismatchWithSource {
expected: lhs_type.clone(),
actual: rhs_type.clone(),
source: Source::ArrayLen,
span: op.location.span,
});

let (_, use_impl) = self.comparator_operand_type_rules(x_type, y_type, op, span)?;

// If the size is not constant, we must fall back to a user-provided impl for
// equality on slices.
let size = x_size.follow_bindings();
let use_impl = use_impl || size.evaluate_to_u64().is_none();
Ok((Bool, use_impl))
}

(String(x_size), String(y_size)) => {
self.unify(x_size, y_size, || TypeCheckError::TypeMismatchWithSource {
expected: *x_size.clone(),
actual: *y_size.clone(),
span: op.location.span,
source: Source::StringLen,
});

Ok((Bool, false))
}
(lhs, rhs) => {
self.unify(lhs, rhs, || TypeCheckError::TypeMismatchWithSource {
expected: lhs.clone(),
Expand Down
7 changes: 7 additions & 0 deletions test_programs/execution_success/regression_4383/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_4383"
type = "bin"
authors = [""]
compiler_version = ">=0.26.0"

[dependencies]
3 changes: 3 additions & 0 deletions test_programs/execution_success/regression_4383/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main() {
assert([[1]] == [[1]]);
}
Loading