Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

feat: Comparison Binary Ops #167

Merged
merged 6 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions brillig_bytecode/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ fn test_jmpif_opcode() {

let equal_cmp_opcode = Opcode::BinaryOp {
result_type: Typ::Field,
op: BinaryOp::Cmp(Comparison::Equal),
op: BinaryOp::Cmp(Comparison::Eq),
lhs: RegisterMemIndex::Register(RegisterIndex(0)),
rhs: RegisterMemIndex::Register(RegisterIndex(1)),
result: RegisterIndex(2),
Expand Down Expand Up @@ -207,7 +207,7 @@ fn test_jmpifnot_opcode() {

let not_equal_cmp_opcode = Opcode::BinaryOp {
result_type: Typ::Field,
op: BinaryOp::Cmp(Comparison::Equal),
op: BinaryOp::Cmp(Comparison::Eq),
lhs: RegisterMemIndex::Register(RegisterIndex(0)),
rhs: RegisterMemIndex::Register(RegisterIndex(1)),
result: RegisterIndex(2),
Expand Down Expand Up @@ -277,3 +277,77 @@ fn test_mov_opcode() {
let source_value = registers.get(RegisterMemIndex::Register(RegisterIndex(0)));
assert_eq!(source_value, Value::from(1u128));
}

#[test]
fn test_cmp_binary_ops() {
let input_registers = Registers::load(vec![
Value::from(2u128),
Value::from(2u128),
Value::from(0u128),
Value::from(5u128),
Value::from(6u128),
]);

let equal_opcode = Opcode::BinaryOp {
result_type: Typ::Field,
op: BinaryOp::Cmp(Comparison::Eq),
lhs: RegisterMemIndex::Register(RegisterIndex(0)),
rhs: RegisterMemIndex::Register(RegisterIndex(1)),
result: RegisterIndex(2),
};

let not_equal_opcode = Opcode::BinaryOp {
result_type: Typ::Field,
op: BinaryOp::Cmp(Comparison::Eq),
lhs: RegisterMemIndex::Register(RegisterIndex(0)),
rhs: RegisterMemIndex::Register(RegisterIndex(3)),
result: RegisterIndex(2),
};

let less_than_opcode = Opcode::BinaryOp {
result_type: Typ::Field,
op: BinaryOp::Cmp(Comparison::Lt),
lhs: RegisterMemIndex::Register(RegisterIndex(3)),
rhs: RegisterMemIndex::Register(RegisterIndex(4)),
result: RegisterIndex(2),
};

let less_than_equal_opcode = Opcode::BinaryOp {
result_type: Typ::Field,
op: BinaryOp::Cmp(Comparison::Lte),
lhs: RegisterMemIndex::Register(RegisterIndex(3)),
rhs: RegisterMemIndex::Register(RegisterIndex(4)),
result: RegisterIndex(2),
};

let mut vm = VM::new(
input_registers,
vec![equal_opcode, not_equal_opcode, less_than_opcode, less_than_equal_opcode],
);

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let output_eq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2)));
assert_eq!(output_eq_value, Value::from(true));

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let output_neq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2)));
assert_eq!(output_neq_value, Value::from(false));

let status = vm.process_opcode();
assert_eq!(status, VMStatus::InProgress);

let lt_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2)));
assert_eq!(lt_value, Value::from(true));

let status = vm.process_opcode();
assert_eq!(status, VMStatus::Halted);

let lte_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2)));
assert_eq!(lte_value, Value::from(true));

vm.finish();
}
12 changes: 8 additions & 4 deletions brillig_bytecode/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ pub enum BinaryOp {

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Comparison {
NotEqual,
Equal,
Eq, //(==) equal
Lt, //(<) field less
Lte, //(<=) field less or equal
}

impl BinaryOp {
Expand All @@ -90,8 +91,11 @@ impl BinaryOp {
BinaryOp::Sub => |a: Value, b: Value| a - b,
BinaryOp::Mul => |a: Value, b: Value| a * b,
BinaryOp::Div => |a: Value, b: Value| a / b,
// TODO: only support equal and not equal, need less than, greater than, etc.
BinaryOp::Cmp(_) => |a: Value, b: Value| (a == b).into(),
BinaryOp::Cmp(comparison) => match comparison {
Comparison::Eq => |a: Value, b: Value| (a == b).into(),
Comparison::Lt => |a: Value, b: Value| (a.inner < b.inner).into(),
Comparison::Lte => |a: Value, b: Value| (a.inner <= b.inner).into(),
},
}
}
}
4 changes: 2 additions & 2 deletions brillig_bytecode/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use serde::{Deserialize, Serialize};
use std::ops::{Add, Div, Mul, Neg, Sub};

/// Types of values allowed in the VM
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub enum Typ {
Field,
Unsigned { bit_size: u32 },
Signed { bit_size: u32 },
}

/// Value represents a Value in the VM
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Value {
pub typ: Typ,
pub inner: FieldElement,
Expand Down