Skip to content

Commit

Permalink
feat(Assembler): add support for vendoring compiled libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
paracetamolo committed Jan 30, 2025
1 parent 2afc2be commit f73998a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
28 changes: 27 additions & 1 deletion assembly/src/assembler/mast_forest_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use vm_core::{
crypto::hash::RpoDigest,
mast::{
DecoratorFingerprint, DecoratorId, MastForest, MastNode, MastNodeFingerprint, MastNodeId,
Remapping, RootIterator,
},
Decorator, DecoratorList, Operation,
};
Expand Down Expand Up @@ -59,6 +60,8 @@ pub struct MastForestBuilder {
/// used as a candidate set of nodes that may be eliminated if the are not referenced by any
/// other node in the forest and are not a root of any procedure.
merged_basic_block_ids: BTreeSet<MastNodeId>,
vendored_mast: MastForest,
vendored_remapping: Remapping,
}

impl MastForestBuilder {
Expand All @@ -74,6 +77,10 @@ impl MastForestBuilder {

(self.mast_forest, id_remappings)
}

pub fn vendor(&mut self, mast_forest: MastForest) {
self.vendored_mast = mast_forest
}
}

/// Takes the set of MAST node ids (all basic blocks) that were merged as part of the assembly
Expand Down Expand Up @@ -334,6 +341,24 @@ impl MastForestBuilder {

Ok(merged_basic_blocks)
}

fn take_subtree_from_vendored_if_present(
&mut self,
mast_root: &RpoDigest,
) -> Result<MastNodeId, AssemblyError> {
if let Some(root_id) = self.vendored_mast.find_procedure_root(*mast_root) {
for old_id in RootIterator::new(&root_id, &self.vendored_mast.clone()) {
let mut node = self.vendored_mast[old_id].clone();
node.remap(&self.vendored_remapping);
let new_id = self.ensure_node(node)?;
self.vendored_remapping.insert(old_id, new_id);
}
let new_root_id = root_id.remap(&self.vendored_remapping);
Ok(new_root_id)
} else {
Err(AssemblyError::Empty)
}
}
}

// ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -452,7 +477,8 @@ impl MastForestBuilder {

/// Adds an external node to the forest, and returns the [`MastNodeId`] associated with it.
pub fn ensure_external(&mut self, mast_root: RpoDigest) -> Result<MastNodeId, AssemblyError> {
self.ensure_node(MastNode::new_external(mast_root))
self.take_subtree_from_vendored_if_present(&mast_root)
.or_else(|_| self.ensure_node(MastNode::new_external(mast_root)))
}

pub fn set_before_enter(&mut self, node_id: MastNodeId, decorator_ids: Vec<DecoratorId>) {
Expand Down
40 changes: 36 additions & 4 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use module_graph::{ProcedureWrapper, WrappedModule};
use vm_core::{
crypto::hash::RpoDigest,
debuginfo::SourceSpan,
mast::{DecoratorId, MastNodeId},
mast::{DecoratorId, MastForest, MastNodeId},
DecoratorList, Felt, Kernel, Operation, Program, WORD_SIZE,
};

Expand Down Expand Up @@ -71,6 +71,7 @@ pub struct Assembler {
warnings_as_errors: bool,
/// Whether the assembler enables extra debugging information.
in_debug_mode: bool,
vendored_libraries: BTreeMap<RpoDigest, Library>,
}

impl Default for Assembler {
Expand All @@ -82,6 +83,7 @@ impl Default for Assembler {
module_graph,
warnings_as_errors: false,
in_debug_mode: false,
vendored_libraries: BTreeMap::new(),
}
}
}
Expand All @@ -97,6 +99,7 @@ impl Assembler {
module_graph,
warnings_as_errors: false,
in_debug_mode: false,
vendored_libraries: BTreeMap::new(),
}
}

Expand Down Expand Up @@ -251,6 +254,29 @@ impl Assembler {
self.add_library(library)?;
Ok(self)
}

pub fn add_vendored_library(&mut self, library: impl AsRef<Library>) -> Result<(), Report> {
self.add_library(&library)?;
self.vendored_libraries
.insert(*library.as_ref().digest(), library.as_ref().clone());
Ok(())
}
}

fn vendor_mast(vendored_libraries: &BTreeMap<RpoDigest, Library>) -> Result<MastForest, Report> {
use miette::IntoDiagnostic;
// MastForest::merge when receiveing a single forest does not always return
// it unchanged. So we avoid it entirely.
if vendored_libraries.is_empty() {
return Ok(MastForest::default());
}

let forests: Vec<MastForest> = vendored_libraries
.iter()
.map(|(_, lib)| (**lib.mast_forest()).clone())
.collect::<Vec<MastForest>>();
let (mast_forest, _remapping) = vm_core::mast::MastForest::merge(&forests).into_diagnostic()?;
Ok(mast_forest)
}

// ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -298,10 +324,13 @@ impl Assembler {
modules: impl IntoIterator<Item = impl Compile>,
options: CompileOptions,
) -> Result<Library, Report> {
let ast_module_indices = self.add_modules_with_options(modules, options)?;

let mut mast_forest_builder = MastForestBuilder::default();

let vendored_mast = vendor_mast(&self.vendored_libraries)?;
mast_forest_builder.vendor(vendored_mast);

let ast_module_indices = self.add_modules_with_options(modules, options)?;

let mut exports = {
let mut exports = BTreeMap::new();

Expand Down Expand Up @@ -392,6 +421,9 @@ impl Assembler {

// Compile the module graph rooted at the entrypoint
let mut mast_forest_builder = MastForestBuilder::default();
let vendored_mast = vendor_mast(&self.vendored_libraries)?;
mast_forest_builder.vendor(vendored_mast);

self.compile_subgraph(entrypoint, &mut mast_forest_builder)?;
let entry_node_id = mast_forest_builder
.get_procedure(entrypoint)
Expand All @@ -400,7 +432,7 @@ impl Assembler {

// in case the node IDs changed, update the entrypoint ID to the new value
let (mast_forest, id_remappings) = mast_forest_builder.build();
let entry_node_id = id_remappings.get(&entry_node_id).unwrap_or(&entry_node_id);
let entry_node_id = *id_remappings.get(&entry_node_id).unwrap_or(&entry_node_id);

Ok(Program::with_kernel(
mast_forest.into(),
Expand Down
28 changes: 28 additions & 0 deletions assembly/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3029,3 +3029,31 @@ fn test_program_serde_with_decorators() {

assert_eq!(original_program, deserialized_program);
}

#[test]
fn vendoring() -> TestResult {
let context = TestContext::new();
let mut mod_parser = ModuleParser::new(ModuleKind::Library);
let vendor_lib = {
let source = source_file!(&context, "export.bar push.1 end export.prune push.2 end");
let mod1 = mod_parser.parse(LibraryPath::new("test::mod1").unwrap(), source).unwrap();
Assembler::default().assemble_library([mod1]).unwrap()
};

let lib = {
let source = source_file!(&context, "export.foo exec.::test::mod1::bar end");
let mod2 = mod_parser.parse(LibraryPath::new("test::mod2").unwrap(), source).unwrap();

let mut assembler = Assembler::default();
assembler.add_vendored_library(vendor_lib)?;
assembler.assemble_library([mod2]).unwrap()
};

let expected_lib = {
let source = source_file!(&context, "export.foo push.1 end");
let mod2 = mod_parser.parse(LibraryPath::new("test::mod2").unwrap(), source).unwrap();
Assembler::default().assemble_library([mod2]).unwrap()
};
assert!(lib == expected_lib);
Ok(())
}

0 comments on commit f73998a

Please sign in to comment.