Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement hint on uint256_mul_div_mod #957

Merged
merged 17 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@

#### Upcoming Changes

* Implement hint on `uint256_mul_div_mod`[#957](https://github.com/lambdaclass/cairo-rs/pull/957)

`BuiltinHintProcessor` now supports the following hint:

```python
a = (ids.a.high << 128) + ids.a.low
b = (ids.b.high << 128) + ids.b.low
div = (ids.div.high << 128) + ids.div.low
quotient, remainder = divmod(a * b, div)

ids.quotient_low.low = quotient & ((1 << 128) - 1)
ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
ids.quotient_high.high = quotient >> 384
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128"
```

Used by the common library function `uint256_mul_div_mod`

* Move `Memory` into `MemorySegmentManager` [#830](https://github.com/lambdaclass/cairo-rs/pull/830)
* Structural changes:
* Remove `memory: Memory` field from `VirtualMachine`
Expand Down
19 changes: 19 additions & 0 deletions cairo_programs/uint256.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from starkware.cairo.common.uint256 import (
uint256_signed_nn,
uint256_unsigned_div_rem,
uint256_mul,
uint256_mul_div_mod
)
from starkware.cairo.common.alloc import alloc

Expand Down Expand Up @@ -57,6 +58,24 @@ func main{range_check_ptr: felt}() {
assert b_quotient = Uint256(1, 0);
assert b_remainder = Uint256(340282366920938463463374607431768211377, 0);

let (a_quotient_low, a_quotient_high, a_remainder) = uint256_mul_div_mod(
Uint256(89, 72),
Uint256(3, 7),
Uint256(107, 114),
);
assert a_quotient_low = Uint256(143276786071974089879315624181797141668, 4);
assert a_quotient_high = Uint256(0, 0);
assert a_remainder = Uint256(322372768661941702228460154409043568767, 101);

let (b_quotient_low, b_quotient_high, b_remainder) = uint256_mul_div_mod(
Uint256(-3618502788666131213697322783095070105282824848410658236509717448704103809099, 2),
Uint256(1, 1),
Uint256(5, 2),
);
assert b_quotient_low = Uint256(170141183460469231731687303715884105688, 1);
assert b_quotient_high = Uint256(0, 0);
assert b_remainder = Uint256(170141183460469231731687303715884105854, 1);

let (mult_low_a, mult_high_a) = uint256_mul(Uint256(59, 2), Uint256(10, 0));
assert mult_low_a = Uint256(590, 20);
assert mult_high_a = Uint256(0, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ use crate::{
squash_dict_inner_used_accesses_assert,
},
uint256_utils::{
split_64, uint256_add, uint256_signed_nn, uint256_sqrt, uint256_unsigned_div_rem,
split_64, uint256_add, uint256_mul_div_mod, uint256_signed_nn, uint256_sqrt,
uint256_unsigned_div_rem,
},
usort::{
usort_body, usort_enter_scope, verify_multiplicity_assert,
Expand Down Expand Up @@ -452,6 +453,9 @@ impl HintProcessor for BuiltinHintProcessor {
chained_ec_op_random_ec_point_hint(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::RECOVER_Y => recover_y_hint(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT256_MUL_DIV_MOD => {
uint256_mul_div_mod(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
#[cfg(feature = "skip_next_instruction_hint")]
hint_code::SKIP_NEXT_INSTRUCTION => skip_next_instruction(vm),
code => Err(HintError::UnknownHint(code.to_string())),
Expand Down
12 changes: 12 additions & 0 deletions src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,18 @@ ids.quotient.high = quotient >> 128
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128"#;

pub(crate) const UINT256_MUL_DIV_MOD: &str = r#"a = (ids.a.high << 128) + ids.a.low
b = (ids.b.high << 128) + ids.b.low
div = (ids.div.high << 128) + ids.div.low
quotient, remainder = divmod(a * b, div)

ids.quotient_low.low = quotient & ((1 << 128) - 1)
ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
ids.quotient_high.high = quotient >> 384
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128"#;

pub(crate) const USORT_ENTER_SCOPE: &str =
"vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))";
pub(crate) const USORT_BODY: &str = r#"from collections import defaultdict
Expand Down
197 changes: 196 additions & 1 deletion src/hint_processor/builtin_hint_processor/uint256_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use crate::{
vm::{errors::hint_errors::HintError, vm_core::VirtualMachine},
};
use felt::Felt252;
use num_integer::div_rem;
use num_bigint::BigUint;
use num_integer::{div_rem, Integer};
use num_traits::{One, Signed, Zero};
/*
Implements hint:
Expand Down Expand Up @@ -217,9 +218,90 @@ pub fn uint256_unsigned_div_rem(
Ok(())
}

/* Implements Hint:
%{
a = (ids.a.high << 128) + ids.a.low
b = (ids.b.high << 128) + ids.b.low
div = (ids.div.high << 128) + ids.div.low
quotient, remainder = divmod(a * b, div)

ids.quotient_low.low = quotient & ((1 << 128) - 1)
ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)
ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)
ids.quotient_high.high = quotient >> 384
ids.remainder.low = remainder & ((1 << 128) - 1)
ids.remainder.high = remainder >> 128
%}
*/
pub fn uint256_mul_div_mod(
vm: &mut VirtualMachine,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
) -> Result<(), HintError> {
// Extract variables
let a_addr = get_relocatable_from_var_name("a", vm, ids_data, ap_tracking)?;
let b_addr = get_relocatable_from_var_name("b", vm, ids_data, ap_tracking)?;
let div_addr = get_relocatable_from_var_name("div", vm, ids_data, ap_tracking)?;
let quotient_low_addr =
get_relocatable_from_var_name("quotient_low", vm, ids_data, ap_tracking)?;
let quotient_high_addr =
get_relocatable_from_var_name("quotient_high", vm, ids_data, ap_tracking)?;
let remainder_addr = get_relocatable_from_var_name("remainder", vm, ids_data, ap_tracking)?;

let a_low = vm.get_integer(a_addr)?;
let a_high = vm.get_integer((a_addr + 1_usize)?)?;
let b_low = vm.get_integer(b_addr)?;
let b_high = vm.get_integer((b_addr + 1_usize)?)?;
let div_low = vm.get_integer(div_addr)?;
let div_high = vm.get_integer((div_addr + 1_usize)?)?;
Comment on lines +242 to +256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need a getter for this kind of structured data eventually, if possible generic so they're still refs.

let a_low = a_low.as_ref();
let a_high = a_high.as_ref();
let b_low = b_low.as_ref();
let b_high = b_high.as_ref();
let div_low = div_low.as_ref();
let div_high = div_high.as_ref();

// Main Logic
let a = a_high.shl(128_usize) + a_low;
let b = b_high.shl(128_usize) + b_low;
let div = div_high.shl(128_usize) + div_low;
let (quotient, remainder) = (a.to_biguint() * b.to_biguint()).div_mod_floor(&div.to_biguint());

// ids.quotient_low.low
vm.insert_value(
quotient_low_addr,
Felt252::from(&quotient & &BigUint::from(u128::MAX)),
)?;
// ids.quotient_low.high
vm.insert_value(
(quotient_low_addr + 1)?,
Felt252::from((&quotient).shr(128_u32) & &BigUint::from(u128::MAX)),
)?;
// ids.quotient_high.low
vm.insert_value(
quotient_high_addr,
Felt252::from((&quotient).shr(256_u32) & &BigUint::from(u128::MAX)),
)?;
// ids.quotient_high.high
vm.insert_value(
(quotient_high_addr + 1)?,
Felt252::from((&quotient).shr(384_u32)),
)?;
//ids.remainder.low
vm.insert_value(
remainder_addr,
Felt252::from(&remainder & &BigUint::from(u128::MAX)),
)?;
//ids.remainder.high
vm.insert_value((remainder_addr + 1)?, Felt252::from(remainder.shr(128_u32)))?;

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::hint_processor::builtin_hint_processor::hint_code;
use crate::vm::vm_memory::memory_segments::MemorySegmentManager;
use crate::{
any_box,
Expand Down Expand Up @@ -573,4 +655,117 @@ mod tests {
z == MaybeRelocatable::from(Felt252::new(10))
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_unsigned_div_rem_invalid_memory_insert_2() {
let hint_code = "a = (ids.a.high << 128) + ids.a.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a, div)\n\nids.quotient.low = quotient & ((1 << 128) - 1)\nids.quotient.high = quotient >> 128\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128";
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data =
non_continuous_ids_data![("a", -6), ("div", -4), ("quotient", 0), ("remainder", 2)];
//Insert ids into memory
vm.segments = segments![
((1, 4), 89),
((1, 5), 72),
((1, 6), 3),
((1, 7), 7),
((1, 11), 1)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code),
Err(HintError::Memory(
MemoryError::InconsistentMemory(
x,
y,
z,
)
)) if x == Relocatable::from((1, 11)) &&
y == MaybeRelocatable::from(Felt252::one()) &&
z == MaybeRelocatable::from(Felt252::zero())
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_mul_div_mod_ok() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data = non_continuous_ids_data![
("a", -8),
("b", -6),
("div", -4),
("quotient_low", 0),
("quotient_high", 2),
("remainder", 4)
];
//Insert ids into memory
vm.segments = segments![
((1, 2), 89),
((1, 3), 72),
((1, 4), 3),
((1, 5), 7),
((1, 6), 107),
((1, 7), 114)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code::UINT256_MUL_DIV_MOD),
Ok(())
);
//Check hint memory inserts
//ids.quotient.low, ids.quotient.high, ids.remainder.low, ids.remainder.high
check_memory![
vm.segments.memory,
((1, 10), 143276786071974089879315624181797141668),
((1, 11), 4),
((1, 12), 0),
((1, 13), 0),
//((1, 14), 322372768661941702228460154409043568767),
((1, 15), 101)
];
assert_eq!(
vm.segments
.memory
.get_integer((1, 14).into())
.unwrap()
.as_ref(),
&felt_str!("322372768661941702228460154409043568767")
)
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_mul_div_mod_missing_ids() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data = non_continuous_ids_data![
("a", -8),
("b", -6),
("div", -4),
("quotient", 0),
("remainder", 2)
];
//Insert ids into memory
vm.segments = segments![
((1, 2), 89),
((1, 3), 72),
((1, 4), 3),
((1, 5), 7),
((1, 6), 107),
((1, 7), 114)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code::UINT256_MUL_DIV_MOD),
Err(HintError::UnknownIdentifier(s)) if s == "quotient_low"
);
}
}