Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: combine Program.hints and Program.hints_ranges into custom collection #1366

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vm/src/serde/deserialize_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,11 @@ pub fn parse_program_json(
}
}

let (hints, hints_ranges) =
Program::flatten_hints(&program_json.hints, program_json.data.len())?;
let hints_collection = Program::flatten_hints(&program_json.hints, program_json.data.len())?;

let shared_program_data = SharedProgramData {
data: program_json.data,
hints,
hints_ranges,
hints_collection,
main: entrypoint_pc,
start,
end,
Expand Down
42 changes: 27 additions & 15 deletions vm/src/types/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ use arbitrary::{Arbitrary, Unstructured};
#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub(crate) struct SharedProgramData {
pub(crate) data: Vec<MaybeRelocatable>,
pub(crate) hints: Vec<HintParams>,
/// This maps a PC to the range of hints in `hints` that correspond to it.
pub(crate) hints_ranges: Vec<HintRange>,
pub(crate) hints_collection: HintsCollection,
pub(crate) main: Option<usize>,
//start and end labels will only be used in proof-mode
pub(crate) start: Option<usize>,
Expand Down Expand Up @@ -81,12 +79,11 @@ impl<'a> Arbitrary<'a> for SharedProgramData {
}

let raw_hints = BTreeMap::<usize, Vec<HintParams>>::arbitrary(u)?;
let (hints, hints_ranges) = Program::flatten_hints(&raw_hints, data.len())
let hints_collection = Program::flatten_hints(&raw_hints, data.len())
.map_err(|_| arbitrary::Error::IncorrectFormat)?;
Ok(SharedProgramData {
data,
hints,
hints_ranges,
hints_collection,
main: Option::<usize>::arbitrary(u)?,
start: Option::<usize>::arbitrary(u)?,
end: Option::<usize>::arbitrary(u)?,
Expand All @@ -98,6 +95,12 @@ impl<'a> Arbitrary<'a> for SharedProgramData {
}
}

#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub struct HintsCollection {
pub hints: Vec<HintParams>,
pub hints_ranges: Vec<HintRange>,
}
Copy link
Contributor

@Oppen Oppen Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest the following:

  • The fields to be private;
  • The struct to be pub(crate);
  • Use a constructor taking a HashMap;
  • Use a getter taking a pc and returning a slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thanks for the suggestions! Just a short question.
The getter you are describing works fine for this case:

let hint_data = self 
     .program 
     .shared_program_data 
     .hints_ranges 
     .get(vm.run_context.pc.offset) 
     .and_then(|r| r.and_then(|(s, l)| hint_data.get(s..s + l.get()))) 
     .unwrap_or(&[]); 

But cases like this:

self.program
            .shared_program_data
            .hints_collection
            .hints
            .iter()
            .map(|hint| {
                hint_executor
                    .compile_hint(
                        &hint.code,
                        &hint.flow_tracking_data.ap_tracking,
                        &hint.flow_tracking_data.reference_ids,
                        references,
                    )
                    .map_err(|_| VirtualMachineError::CompileHintFail(hint.code.clone().into()))
            })
            .collect()

It seems a getter that returns all the hints is required. Let me know if I am wrong :)

Copy link
Contributor

@Oppen Oppen Aug 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make a method that returns the iterator instead:

fn iter(&self) -> impl Iterator<Item = Option<(pc, &[HintData])>> {
    self.hint_ranges.iter().enumerate().filter_map(|(pc, range)| {
        let Some(range) = range else {
              return None;
        }
        Some((pc, self.hint_data[range.start..range.start+range.len]))
    })
}

Consider this pseudo-code, you'll probably need to make some changes to it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also implement some fn iter_hints(&self) -> impl Iterator<Item = &HintData> if it makes this easier.

Copy link
Contributor Author

@PanGan21 PanGan21 Aug 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunatelly I was forced to implement both iter_hints and iter. The iter function is used only for testing in the helper get_hints_as_map thus it has #[allow(dead_code)]. Let me know your thoughts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can avoid the #[allow(dead_code)] by using a separate impl block that is gated by the #[cfg(test)] annotation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the new impl block! Thanks!


/// Represents a range of hints corresponding to a PC.
///
/// Is [`None`] if the range is empty, and it is [`Some`] tuple `(start, length)` otherwise.
Expand Down Expand Up @@ -135,15 +138,14 @@ impl Program {
}
let hints: BTreeMap<_, _> = hints.into_iter().collect();

let (hints, hints_ranges) = Self::flatten_hints(&hints, data.len())?;
let hints_collection = Self::flatten_hints(&hints, data.len())?;

let shared_program_data = SharedProgramData {
data,
main,
start: None,
end: None,
hints,
hints_ranges,
hints_collection,
error_message_attributes,
instruction_locations,
identifiers,
Expand All @@ -159,14 +161,15 @@ impl Program {
pub(crate) fn flatten_hints(
hints: &BTreeMap<usize, Vec<HintParams>>,
program_length: usize,
) -> Result<(Vec<HintParams>, Vec<HintRange>), ProgramError> {
) -> Result<HintsCollection, ProgramError> {
let bounds = hints
.iter()
.map(|(pc, hs)| (*pc, hs.len()))
.reduce(|(max_hint_pc, full_len), (pc, len)| (max_hint_pc.max(pc), full_len + len));

let Some((max_hint_pc, full_len)) = bounds else {
return Ok((Vec::new(), Vec::new()));
let (max_hint_pc, full_len) = match bounds {
Some(bounds) => bounds,
None => return Ok(HintsCollection::default()),
};

if max_hint_pc >= program_length {
Expand All @@ -185,7 +188,10 @@ impl Program {
hints_values.extend_from_slice(&hs[..]);
}

Ok((hints_values, hints_ranges))
Ok(HintsCollection {
hints: hints_values,
hints_ranges,
})
}

#[cfg(feature = "std")]
Expand Down Expand Up @@ -312,6 +318,8 @@ impl TryFrom<CasmContractClass> for Program {

#[cfg(test)]
mod tests {
use core::hint;

use super::*;
use crate::serde::deserialize_program::{ApTracking, FlowTrackingData};
use crate::utils::test_utils::*;
Expand Down Expand Up @@ -1020,10 +1028,14 @@ mod tests {
#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn default_program() {
let shared_program_data = SharedProgramData {
data: Vec::new(),
let hints_collection = HintsCollection {
hints: Vec::new(),
hints_ranges: Vec::new(),
};

let shared_program_data = SharedProgramData {
data: Vec::new(),
hints_collection,
main: None,
start: None,
end: None,
Expand Down
29 changes: 18 additions & 11 deletions vm/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,10 @@ pub mod test_utils {
( $( $builtin_name: expr ),* ) => {{
let shared_program_data = SharedProgramData {
data: crate::stdlib::vec::Vec::new(),
hints: crate::stdlib::vec::Vec::new(),
hints_ranges: crate::stdlib::vec::Vec::new(),
hints_collection: HintsCollection{
hints: crate::stdlib::vec::Vec::new(),
hints_ranges: crate::stdlib::vec::Vec::new(),
},
main: None,
start: None,
end: None,
Expand Down Expand Up @@ -344,13 +346,12 @@ pub mod test_utils {
impl From<ProgramFlat> for Program {
fn from(val: ProgramFlat) -> Self {
// NOTE: panics if hints have PCs higher than the program length
let (hints, hints_ranges) =
let hints_collection =
Program::flatten_hints(&val.hints, val.data.len()).expect("hints are valid");
Program {
shared_program_data: Arc::new(SharedProgramData {
data: val.data,
hints,
hints_ranges,
hints_collection,
main: val.main,
start: val.start,
end: val.end,
Expand Down Expand Up @@ -925,8 +926,10 @@ mod test {
fn program_macro() {
let shared_data = SharedProgramData {
data: Vec::new(),
hints: Vec::new(),
hints_ranges: Vec::new(),
hints_collection: HintsCollection {
hints: Vec::new(),
hints_ranges: Vec::new(),
},
main: None,
start: None,
end: None,
Expand All @@ -950,8 +953,10 @@ mod test {
fn program_macro_with_builtin() {
let shared_data = SharedProgramData {
data: Vec::new(),
hints: Vec::new(),
hints_ranges: Vec::new(),
hints_collection: HintsCollection {
hints: Vec::new(),
hints_ranges: Vec::new(),
},
main: None,
start: None,
end: None,
Expand All @@ -976,8 +981,10 @@ mod test {
fn program_macro_custom_definition() {
let shared_data = SharedProgramData {
data: Vec::new(),
hints: Vec::new(),
hints_ranges: Vec::new(),
hints_collection: HintsCollection {
hints: Vec::new(),
hints_ranges: Vec::new(),
},
main: Some(2),
start: None,
end: None,
Expand Down
3 changes: 3 additions & 0 deletions vm/src/vm/runners/cairo_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ impl CairoRunner {
) -> Result<Vec<Box<dyn Any>>, VirtualMachineError> {
self.program
.shared_program_data
.hints_collection
.hints
.iter()
.map(|hint| {
Expand Down Expand Up @@ -552,6 +553,7 @@ impl CairoRunner {
let hint_data = self
.program
.shared_program_data
.hints_collection
.hints_ranges
.get(vm.run_context.pc.offset)
.and_then(|r| r.and_then(|(s, l)| hint_data.get(s..s + l.get())))
Expand Down Expand Up @@ -590,6 +592,7 @@ impl CairoRunner {
let hint_data = self
.program
.shared_program_data
.hints_collection
.hints_ranges
.get(vm.run_context.pc.offset)
.and_then(|r| r.and_then(|(s, l)| hint_data.get(s..s + l.get())))
Expand Down