Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

feat!: Handle result type of Binary Ops in Brillig #202

Merged
merged 5 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions brillig_bytecode/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
[dependencies]
acir_field.workspace = true
serde.workspace = true
num-bigint = "0.4"

[features]
bn254 = ["acir_field/bn254"]
Expand Down
5 changes: 3 additions & 2 deletions brillig_bytecode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod value;
use std::collections::BTreeMap;

use acir_field::FieldElement;
use num_bigint::{BigInt, Sign};
pub use opcodes::RegisterMemIndex;
pub use opcodes::{BinaryOp, Comparison, Opcode, OracleData, OracleInput};
pub use registers::{RegisterIndex, Registers};
Expand Down Expand Up @@ -210,12 +211,12 @@ impl VM {
lhs: RegisterMemIndex,
rhs: RegisterMemIndex,
result: RegisterIndex,
_result_type: Typ,
result_type: Typ,
) {
let lhs_value = self.registers.get(lhs);
let rhs_value = self.registers.get(rhs);

let result_value = op.function()(lhs_value, rhs_value);
let result_value = op.evaluate(lhs_value, rhs_value, result_type);

self.registers.set(result, result_value)
}
Expand Down
87 changes: 79 additions & 8 deletions brillig_bytecode/src/opcodes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::{Add, Mul, Sub};

use crate::{
memory::ArrayIndex,
value::{Typ, Value},
Expand Down Expand Up @@ -140,17 +142,86 @@ pub enum Comparison {
}

impl BinaryOp {
pub fn function(&self) -> fn(Value, Value) -> Value {
pub fn evaluate(&self, a: Value, b: Value, res_type: Typ) -> Value {
match self {
BinaryOp::Add => |a: Value, b: Value| a + b,
BinaryOp::Sub => |a: Value, b: Value| a - b,
BinaryOp::Mul => |a: Value, b: Value| a * b,
BinaryOp::Div => |a: Value, b: Value| a / b,
BinaryOp::Add => {
let res_inner = self.wrapping(a.inner, b.inner, res_type, u128::add, Add::add);
Value { typ: res_type, inner: res_inner }
}
BinaryOp::Sub => {
let res_inner =
self.wrapping(a.inner, b.inner, res_type, u128::wrapping_sub, Sub::sub);
Value { typ: res_type, inner: res_inner }
}
BinaryOp::Mul => {
let res_inner = self.wrapping(a.inner, b.inner, res_type, u128::mul, Mul::mul);
Value { typ: res_type, inner: res_inner }
}
BinaryOp::Div => match res_type {
Typ::Field => a / b,
Typ::Unsigned { bit_size } => {
let lhs = a.inner.to_u128() % (1_u128 << bit_size);
let rhs = b.inner.to_u128() % (1_u128 << bit_size);
Value { typ: res_type, inner: FieldElement::from(lhs / rhs) }
}
Typ::Signed { bit_size } => {
let a = field_to_signed(a.inner, bit_size);
let b = field_to_signed(b.inner, bit_size);
let res_inner = signed_to_field(a / b, bit_size);
Value { typ: res_type, inner: res_inner }
}
},
BinaryOp::Cmp(comparison) => match comparison {
Comparison::Eq => |a: Value, b: Value| (a == b).into(),
Comparison::Lt => |a: Value, b: Value| (a.inner < b.inner).into(),
Comparison::Lte => |a: Value, b: Value| (a.inner <= b.inner).into(),
Comparison::Eq => (a == b).into(),
Comparison::Lt => (a.inner < b.inner).into(),
Comparison::Lte => (a.inner <= b.inner).into(),
},
}
}

/// Perform the given numeric operation and modulo the result by the max value for the given bit count
/// if the res_type is not a FieldElement.
fn wrapping(
&self,
lhs: FieldElement,
rhs: FieldElement,
res_type: Typ,
u128_op: impl FnOnce(u128, u128) -> u128,
field_op: impl FnOnce(FieldElement, FieldElement) -> FieldElement,
) -> FieldElement {
match res_type {
Typ::Field => field_op(lhs, rhs),
Typ::Unsigned { bit_size } | Typ::Signed { bit_size } => {
let type_modulo = 1_u128 << bit_size;
let lhs = lhs.to_u128() % type_modulo;
let rhs = rhs.to_u128() % type_modulo;
let mut x = u128_op(lhs, rhs);
x %= type_modulo;
FieldElement::from(x)
}
}
}
}

fn field_to_signed(f: FieldElement, n: u32) -> i128 {
assert!(n < 127);
let a = f.to_u128();
let pow_2 = 2_u128.pow(n);
if a < pow_2 {
a as i128
} else {
(a - 2 * pow_2) as i128
}
}

fn signed_to_field(a: i128, n: u32) -> FieldElement {
if n >= 126 {
panic!("ICE: cannot convert signed {n} bit size into field");
}
if a >= 0 {
FieldElement::from(a)
} else {
let b = (a + 2_i128.pow(n + 1)) as u128;
FieldElement::from(b)
}
}