From cc3c2c22644f0b5d8369bad2362ea6e9112a0713 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 10 Jan 2024 14:18:57 +0000 Subject: [PATCH] feat: remove truncations which can be seen to be noops using type information (#3953) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … # Description ## Problem\* Resolves ## Summary\* `Truncate`s which occur on outputs of `Binary(BinaryOp::Div)` can sometimes be determined to be noops based on type information. Put simply, if we have `(numerator / constant_denom) as T` then if `type_of(numerator)::max / constant_denom` fits inside `T` then we don't need to truncate the value. ## Additional Context Testing this on noir-rsa takes us from ``` +---------+----------------------+--------------+----------------------+ | Package | Expression Width | ACIR Opcodes | Backend Circuit Size | +---------+----------------------+--------------+----------------------+ | dkim | Bounded { width: 3 } | 1649234 | 3049613 | +---------+----------------------+--------------+----------------------+ ``` to ``` +---------+----------------------+--------------+----------------------+ | Package | Expression Width | ACIR Opcodes | Backend Circuit Size | +---------+----------------------+--------------+----------------------+ | dkim | Bounded { width: 3 } | 1486175 | 2764246 | +---------+----------------------+--------------+----------------------+ ``` ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- .../noirc_evaluator/src/ssa/ir/instruction.rs | 45 +++++-- .../src/ssa/opt/constant_folding.rs | 115 +++++++++++++++++- 2 files changed, 149 insertions(+), 11 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 9691017f04b..25b406dc858 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -533,16 +533,43 @@ impl Instruction { let truncated = numeric_constant.to_u128() % integer_modulus; SimplifiedTo(dfg.make_constant(truncated.into(), typ)) } else if let Value::Instruction { instruction, .. } = &dfg[dfg.resolve(*value)] { - if let Instruction::Truncate { bit_size: src_bit_size, .. } = &dfg[*instruction] - { - // If we're truncating the value to fit into the same or larger bit size then this is a noop. - if src_bit_size <= bit_size && src_bit_size <= max_bit_size { - SimplifiedTo(*value) - } else { - None + match &dfg[*instruction] { + Instruction::Truncate { bit_size: src_bit_size, .. } => { + // If we're truncating the value to fit into the same or larger bit size then this is a noop. + if src_bit_size <= bit_size && src_bit_size <= max_bit_size { + SimplifiedTo(*value) + } else { + None + } } - } else { - None + + Instruction::Binary(Binary { + lhs, rhs, operator: BinaryOp::Div, .. + }) if dfg.is_constant(*rhs) => { + // If we're truncating the result of a division by a constant denominator, we can + // reason about the maximum bit size of the result and whether a truncation is necessary. + + let numerator_type = dfg.type_of_value(*lhs); + let max_numerator_bits = numerator_type.bit_size(); + + let divisor = dfg + .get_numeric_constant(*rhs) + .expect("rhs is checked to be constant."); + let divisor_bits = divisor.num_bits(); + + // 2^{max_quotient_bits} = 2^{max_numerator_bits} / 2^{divisor_bits} + // => max_quotient_bits = max_numerator_bits - divisor_bits + // + // In order for the truncation to be a noop, we then require `max_quotient_bits < bit_size`. + let max_quotient_bits = max_numerator_bits - divisor_bits; + if max_quotient_bits < *bit_size { + SimplifiedTo(*value) + } else { + None + } + } + + _ => None, } } else { None diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 7d345a9a4ab..7e6d841adc1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -184,10 +184,10 @@ mod test { function_builder::FunctionBuilder, ir::{ function::RuntimeType, - instruction::{BinaryOp, Instruction, TerminatorInstruction}, + instruction::{Binary, BinaryOp, Instruction, TerminatorInstruction}, map::Id, types::Type, - value::Value, + value::{Value, ValueId}, }, }; @@ -247,6 +247,117 @@ mod test { } } + #[test] + fn redundant_truncation() { + // fn main f0 { + // b0(v0: u16, v1: u16): + // v2 = div v0, v1 + // v3 = truncate v2 to 8 bits, max_bit_size: 16 + // return v3 + // } + // + // After constructing this IR, we set the value of v1 to 2^8. + // The expected return afterwards should be v2. + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + let v0 = builder.add_parameter(Type::unsigned(16)); + let v1 = builder.add_parameter(Type::unsigned(16)); + + // Note that this constant guarantees that `v0/constant < 2^8`. We then do not need to truncate the result. + let constant = 2_u128.pow(8); + let constant = builder.numeric_constant(constant, Type::field()); + + let v2 = builder.insert_binary(v0, BinaryOp::Div, v1); + let v3 = builder.insert_truncate(v2, 8, 16); + builder.terminate_with_return(vec![v3]); + + let mut ssa = builder.finish(); + let main = ssa.main_mut(); + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 2); // The final return is not counted + + // Expected output: + // + // fn main f0 { + // b0(Field 2: Field): + // return Field 9 + // } + main.dfg.set_value_from_id(v1, constant); + + let ssa = ssa.fold_constants(); + let main = ssa.main(); + + println!("{ssa}"); + + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 1); + let instruction = &main.dfg[instructions[0]]; + + assert_eq!( + instruction, + &Instruction::Binary(Binary { lhs: v0, operator: BinaryOp::Div, rhs: constant }) + ); + } + + #[test] + fn non_redundant_truncation() { + // fn main f0 { + // b0(v0: u16, v1: u16): + // v2 = div v0, v1 + // v3 = truncate v2 to 8 bits, max_bit_size: 16 + // return v3 + // } + // + // After constructing this IR, we set the value of v1 to 2^8 - 1. + // This should not result in the truncation being removed. + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + let v0 = builder.add_parameter(Type::unsigned(16)); + let v1 = builder.add_parameter(Type::unsigned(16)); + + // Note that this constant does not guarantee that `v0/constant < 2^8`. We must then truncate the result. + let constant = 2_u128.pow(8) - 1; + let constant = builder.numeric_constant(constant, Type::field()); + + let v2 = builder.insert_binary(v0, BinaryOp::Div, v1); + let v3 = builder.insert_truncate(v2, 8, 16); + builder.terminate_with_return(vec![v3]); + + let mut ssa = builder.finish(); + let main = ssa.main_mut(); + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 2); // The final return is not counted + + // Expected output: + // + // fn main f0 { + // b0(v0: u16, Field 255: Field): + // v5 = div v0, Field 255 + // v6 = truncate v5 to 8 bits, max_bit_size: 16 + // return v6 + // } + main.dfg.set_value_from_id(v1, constant); + + let ssa = ssa.fold_constants(); + let main = ssa.main(); + + let instructions = main.dfg[main.entry_block()].instructions(); + assert_eq!(instructions.len(), 2); + + assert_eq!( + &main.dfg[instructions[0]], + &Instruction::Binary(Binary { lhs: v0, operator: BinaryOp::Div, rhs: constant }) + ); + assert_eq!( + &main.dfg[instructions[1]], + &Instruction::Truncate { value: ValueId::test_new(5), bit_size: 8, max_bit_size: 16 } + ); + } + #[test] fn arrays_elements_are_updated() { // fn main f0 {