diff --git a/brillig_bytecode/src/lib.rs b/brillig_bytecode/src/lib.rs index 1c3267cef..70f5b6673 100644 --- a/brillig_bytecode/src/lib.rs +++ b/brillig_bytecode/src/lib.rs @@ -170,7 +170,7 @@ fn test_jmpif_opcode() { let equal_cmp_opcode = Opcode::BinaryOp { result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Equal), + op: BinaryOp::Cmp(Comparison::Eq), lhs: RegisterMemIndex::Register(RegisterIndex(0)), rhs: RegisterMemIndex::Register(RegisterIndex(1)), result: RegisterIndex(2), @@ -207,7 +207,7 @@ fn test_jmpifnot_opcode() { let not_equal_cmp_opcode = Opcode::BinaryOp { result_type: Typ::Field, - op: BinaryOp::Cmp(Comparison::Equal), + op: BinaryOp::Cmp(Comparison::Eq), lhs: RegisterMemIndex::Register(RegisterIndex(0)), rhs: RegisterMemIndex::Register(RegisterIndex(1)), result: RegisterIndex(2), @@ -277,3 +277,77 @@ fn test_mov_opcode() { let source_value = registers.get(RegisterMemIndex::Register(RegisterIndex(0))); assert_eq!(source_value, Value::from(1u128)); } + +#[test] +fn test_cmp_binary_ops() { + let input_registers = Registers::load(vec![ + Value::from(2u128), + Value::from(2u128), + Value::from(0u128), + Value::from(5u128), + Value::from(6u128), + ]); + + let equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(1)), + result: RegisterIndex(2), + }; + + let not_equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Eq), + lhs: RegisterMemIndex::Register(RegisterIndex(0)), + rhs: RegisterMemIndex::Register(RegisterIndex(3)), + result: RegisterIndex(2), + }; + + let less_than_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Lt), + lhs: RegisterMemIndex::Register(RegisterIndex(3)), + rhs: RegisterMemIndex::Register(RegisterIndex(4)), + result: RegisterIndex(2), + }; + + let less_than_equal_opcode = Opcode::BinaryOp { + result_type: Typ::Field, + op: BinaryOp::Cmp(Comparison::Lte), + lhs: RegisterMemIndex::Register(RegisterIndex(3)), + rhs: RegisterMemIndex::Register(RegisterIndex(4)), + result: RegisterIndex(2), + }; + + let mut vm = VM::new( + input_registers, + vec![equal_opcode, not_equal_opcode, less_than_opcode, less_than_equal_opcode], + ); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_eq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_eq_value, Value::from(true)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let output_neq_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(output_neq_value, Value::from(false)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::InProgress); + + let lt_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(lt_value, Value::from(true)); + + let status = vm.process_opcode(); + assert_eq!(status, VMStatus::Halted); + + let lte_value = vm.registers.get(RegisterMemIndex::Register(RegisterIndex(2))); + assert_eq!(lte_value, Value::from(true)); + + vm.finish(); +} diff --git a/brillig_bytecode/src/opcodes.rs b/brillig_bytecode/src/opcodes.rs index e8e054401..fca46f4ee 100644 --- a/brillig_bytecode/src/opcodes.rs +++ b/brillig_bytecode/src/opcodes.rs @@ -79,8 +79,9 @@ pub enum BinaryOp { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Comparison { - NotEqual, - Equal, + Eq, //(==) equal + Lt, //(<) field less + Lte, //(<=) field less or equal } impl BinaryOp { @@ -90,8 +91,11 @@ impl BinaryOp { BinaryOp::Sub => |a: Value, b: Value| a - b, BinaryOp::Mul => |a: Value, b: Value| a * b, BinaryOp::Div => |a: Value, b: Value| a / b, - // TODO: only support equal and not equal, need less than, greater than, etc. - BinaryOp::Cmp(_) => |a: Value, b: Value| (a == b).into(), + BinaryOp::Cmp(comparison) => match comparison { + Comparison::Eq => |a: Value, b: Value| (a == b).into(), + Comparison::Lt => |a: Value, b: Value| (a.inner < b.inner).into(), + Comparison::Lte => |a: Value, b: Value| (a.inner <= b.inner).into(), + }, } } } diff --git a/brillig_bytecode/src/value.rs b/brillig_bytecode/src/value.rs index 84484dd86..d5bd3d9fb 100644 --- a/brillig_bytecode/src/value.rs +++ b/brillig_bytecode/src/value.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::ops::{Add, Div, Mul, Neg, Sub}; /// Types of values allowed in the VM -#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub enum Typ { Field, Unsigned { bit_size: u32 }, @@ -11,7 +11,7 @@ pub enum Typ { } /// Value represents a Value in the VM -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Value { pub typ: Typ, pub inner: FieldElement,