Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MastForest::advice_map for the data required in the advice provider before execution #1574

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
- [BREAKING] `Process` no longer takes ownership of the `Host` (#1571)
- [BREAKING] `ProcessState` was converted from a trait to a struct (#1571)

#### Enhancements
- Added `miden_core::mast::MastForest::advice_map` to load it into the advice provider before the `MastForest` execution (#1574).

## 0.11.0 (2024-11-04)

#### Enhancements
Expand Down
53 changes: 48 additions & 5 deletions processor/src/host/advice/map.rs → core/src/advice/map.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use alloc::{
boxed::Box,
collections::{btree_map::IntoIter, BTreeMap},
vec::Vec,
};

use vm_core::{
use miden_crypto::{utils::collections::KvMap, Felt};

use crate::{
crypto::hash::RpoDigest,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};

use super::Felt;

// ADVICE MAP
// ================================================================================================

Expand Down Expand Up @@ -38,8 +39,18 @@ impl AdviceMap {
}

/// Removes the value associated with the key and returns the removed element.
pub fn remove(&mut self, key: RpoDigest) -> Option<Vec<Felt>> {
self.0.remove(&key)
pub fn remove(&mut self, key: &RpoDigest) -> Option<Vec<Felt>> {
self.0.remove(key)
}

/// Returns the number of key value pairs in the advice map.
pub fn len(&self) -> usize {
self.0.len()
}

/// Returns true if the advice map is empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

Expand All @@ -58,6 +69,38 @@ impl IntoIterator for AdviceMap {
}
}

impl FromIterator<(RpoDigest, Vec<Felt>)> for AdviceMap {
fn from_iter<T: IntoIterator<Item = (RpoDigest, Vec<Felt>)>>(iter: T) -> Self {
iter.into_iter().collect::<BTreeMap<RpoDigest, Vec<Felt>>>().into()
}
}

impl KvMap<RpoDigest, Vec<Felt>> for AdviceMap {
fn get(&self, key: &RpoDigest) -> Option<&Vec<Felt>> {
self.0.get(key)
}

fn contains_key(&self, key: &RpoDigest) -> bool {
self.0.contains_key(key)
}

fn len(&self) -> usize {
self.len()
}

fn insert(&mut self, key: RpoDigest, value: Vec<Felt>) -> Option<Vec<Felt>> {
self.insert(key, value)
}

fn remove(&mut self, key: &RpoDigest) -> Option<Vec<Felt>> {
self.remove(key)
}

fn iter(&self) -> Box<dyn Iterator<Item = (&RpoDigest, &Vec<Felt>)> + '_> {
Box::new(self.0.iter())
}
}

impl Extend<(RpoDigest, Vec<Felt>)> for AdviceMap {
fn extend<T: IntoIterator<Item = (RpoDigest, Vec<Felt>)>>(&mut self, iter: T) {
self.0.extend(iter)
Expand Down
1 change: 1 addition & 0 deletions core/src/advice/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(super) mod map;
3 changes: 3 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,7 @@ pub use operations::{
pub mod stack;
pub use stack::{StackInputs, StackOutputs};

mod advice;
pub use advice::map::AdviceMap;

pub mod utils;
25 changes: 21 additions & 4 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{collections::BTreeMap, vec::Vec};

use miden_crypto::hash::blake::Blake3Digest;
use miden_crypto::{hash::blake::Blake3Digest, utils::collections::KvMap};

use crate::mast::{
DecoratorId, MastForest, MastForestError, MastNode, MastNodeFingerprint, MastNodeId,
Expand Down Expand Up @@ -65,10 +65,11 @@ impl MastForestMerger {
///
/// It does this in three steps:
///
/// 1. Merge all decorators, which is a case of deduplication and creating a decorator id
/// 1. Merge all advice maps, checking for key collisions.
/// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
/// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
/// merged forest.
/// 2. Merge all nodes of forests.
/// 3. Merge all nodes of forests.
/// - Similar to decorators, node indices might move during merging, so the merger keeps a
/// node id mapping as it merges nodes.
/// - This is a depth-first traversal over all forests to ensure all children are processed
Expand All @@ -90,10 +91,13 @@ impl MastForestMerger {
/// `replacement` node. Now we can simply add a mapping from the external node to the
/// `replacement` node in our node id mapping which means all nodes that referenced the
/// external node will point to the `replacement` instead.
/// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to
/// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
/// their potentially new indices in the merged forest and add them to the forest,
/// deduplicating in the process, too.
fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
for other_forest in forests.iter() {
self.merge_advice_map(other_forest)?;
}
for other_forest in forests.iter() {
self.merge_decorators(other_forest)?;
}
Expand Down Expand Up @@ -163,6 +167,19 @@ impl MastForestMerger {
Ok(())
}

fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
for (digest, values) in other_forest.advice_map.iter() {
if let Some(stored_values) = self.mast_forest.advice_map().get(digest) {
if stored_values != values {
return Err(MastForestError::AdviceMapKeyCollisionOnMerge(*digest));
}
} else {
self.mast_forest.advice_map_mut().insert(*digest, values.clone());
}
}
Ok(())
}

fn merge_node(
&mut self,
forest_idx: usize,
Expand Down
53 changes: 52 additions & 1 deletion core/src/mast/merger/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use miden_crypto::{hash::rpo::RpoDigest, ONE};
use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE};

use super::*;
use crate::{Decorator, Operation};
Expand Down Expand Up @@ -794,3 +794,54 @@ fn mast_forest_merge_invalid_decorator_index() {
let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _));
}

/// Tests that forest's advice maps are merged correctly.
#[test]
fn mast_forest_merge_advice_maps_merged() {
let mut forest_a = MastForest::new();
let id_foo = forest_a.add_node(block_foo()).unwrap();
let id_call_a = forest_a.add_call(id_foo).unwrap();
forest_a.make_root(id_call_a);
let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]);
let value_a = vec![ONE, ONE];
forest_a.advice_map_mut().insert(key_a, value_a.clone());

let mut forest_b = MastForest::new();
let id_bar = forest_b.add_node(block_bar()).unwrap();
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);
let key_b = RpoDigest::new([Felt::new(1), Felt::new(3), Felt::new(2), Felt::new(1)]);
let value_b = vec![Felt::new(2), Felt::new(2)];
forest_b.advice_map_mut().insert(key_b, value_b.clone());

let (merged, _root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap();

let merged_advice_map = merged.advice_map();
assert_eq!(merged_advice_map.len(), 2);
assert_eq!(merged_advice_map.get(&key_a).unwrap(), &value_a);
assert_eq!(merged_advice_map.get(&key_b).unwrap(), &value_b);
}

/// Tests that an error is returned when advice maps have a key collision.
#[test]
fn mast_forest_merge_advice_maps_collision() {
let mut forest_a = MastForest::new();
let id_foo = forest_a.add_node(block_foo()).unwrap();
let id_call_a = forest_a.add_call(id_foo).unwrap();
forest_a.make_root(id_call_a);
let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]);
let value_a = vec![ONE, ONE];
forest_a.advice_map_mut().insert(key_a, value_a.clone());

let mut forest_b = MastForest::new();
let id_bar = forest_b.add_node(block_bar()).unwrap();
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);
// The key collides with key_a in the forest_a.
let key_b = key_a;
let value_b = vec![Felt::new(2), Felt::new(2)];
forest_b.advice_map_mut().insert(key_b, value_b.clone());

let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
assert_matches!(err, MastForestError::AdviceMapKeyCollisionOnMerge(_));
}
15 changes: 14 additions & 1 deletion core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use node::{
};
use winter_utils::{ByteWriter, DeserializationError, Serializable};

use crate::{Decorator, DecoratorList, Operation};
use crate::{AdviceMap, Decorator, DecoratorList, Operation};

mod serialization;

Expand Down Expand Up @@ -50,6 +50,9 @@ pub struct MastForest {

/// All the decorators included in the MAST forest.
decorators: Vec<Decorator>,

/// Advice map to be loaded into the VM prior to executing procedures from this MAST forest.
advice_map: AdviceMap,
}

// ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -463,6 +466,14 @@ impl MastForest {
pub fn nodes(&self) -> &[MastNode] {
&self.nodes
}

pub fn advice_map(&self) -> &AdviceMap {
&self.advice_map
}

pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
&mut self.advice_map
}
}

impl Index<MastNodeId> for MastForest {
Expand Down Expand Up @@ -689,4 +700,6 @@ pub enum MastForestError {
EmptyBasicBlock,
#[error("decorator root of child with node id {0} is missing but required for fingerprint computation")]
ChildFingerprintMissing(MastNodeId),
#[error("advice map key already exists when merging forests: {0}")]
AdviceMapKeyCollisionOnMerge(RpoDigest),
}
6 changes: 6 additions & 0 deletions core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use string_table::{StringTable, StringTableBuilder};
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

use super::{DecoratorId, MastForest, MastNode, MastNodeId};
use crate::AdviceMap;

mod decorator;

Expand Down Expand Up @@ -149,6 +150,8 @@ impl Serializable for MastForest {
node_data.write_into(target);
string_table.write_into(target);

self.advice_map.write_into(target);

// Write decorator and node infos
for decorator_info in decorator_infos {
decorator_info.write_into(target);
Expand Down Expand Up @@ -187,6 +190,7 @@ impl Deserializable for MastForest {
let decorator_data: Vec<u8> = Deserializable::read_from(source)?;
let node_data: Vec<u8> = Deserializable::read_from(source)?;
let string_table: StringTable = Deserializable::read_from(source)?;
let advice_map = AdviceMap::read_from(source)?;

let mut mast_forest = {
let mut mast_forest = MastForest::new();
Expand Down Expand Up @@ -229,6 +233,8 @@ impl Deserializable for MastForest {
mast_forest.make_root(root);
}

mast_forest.advice_map = advice_map;

mast_forest
};

Expand Down
21 changes: 20 additions & 1 deletion core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{string::ToString, sync::Arc};

use miden_crypto::{hash::rpo::RpoDigest, Felt};
use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE};

use super::*;
use crate::{
Expand Down Expand Up @@ -435,3 +435,22 @@ fn mast_forest_invalid_node_id() {
// Validate normal operations
forest.add_join(first, second).unwrap();
}

/// Test `MastForest::advice_map` serialization and deserialization.
#[test]
fn mast_forest_serialize_deserialize_advice_map() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

let key = RpoDigest::new([ONE, ONE, ONE, ONE]);
let value = vec![ONE, ONE];

forest.advice_map_mut().insert(key, value);

let parsed = MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
assert_eq!(forest.advice_map, parsed.advice_map);
}
2 changes: 1 addition & 1 deletion miden/benches/program_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn program_execution(c: &mut Criterion) {

let stdlib = StdLibrary::default();
let mut host = DefaultHost::default();
host.load_mast_forest(stdlib.as_ref().mast_forest().clone());
host.load_mast_forest(stdlib.as_ref().mast_forest().clone()).unwrap();

group.bench_function("sha256", |bench| {
let source = "
Expand Down
2 changes: 1 addition & 1 deletion miden/src/examples/blake3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub fn get_example(n: usize) -> Example<DefaultHost<MemAdviceProvider>> {
);

let mut host = DefaultHost::default();
host.load_mast_forest(StdLibrary::default().mast_forest().clone());
host.load_mast_forest(StdLibrary::default().mast_forest().clone()).unwrap();

let stack_inputs =
StackInputs::try_from_ints(INITIAL_HASH_VALUE.iter().map(|&v| v as u64)).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion miden/src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ fn execute(
let stack_inputs = StackInputs::default();
let mut host = DefaultHost::default();
for library in provided_libraries {
host.load_mast_forest(library.mast_forest().clone());
host.load_mast_forest(library.mast_forest().clone())
.map_err(|err| format!("{err}"))?;
}

let state_iter = processor::execute_iter(&program, stack_inputs, &mut host);
Expand Down
3 changes: 2 additions & 1 deletion miden/src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ impl Analyze {
// fetch the stack and program inputs from the arguments
let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?;
let mut host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?);
host.load_mast_forest(StdLibrary::default().mast_forest().clone());
host.load_mast_forest(StdLibrary::default().mast_forest().clone())
.into_diagnostic()?;

let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host)
.expect("Could not retrieve execution details");
Expand Down
Loading
Loading