Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: insert_subgraph just return HashMap, make InsertionResult new_root compulsory #609

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr(hugr)?.new_root.unwrap();
let node = self.add_hugr(hugr)?.new_root;

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand All @@ -252,7 +252,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr_view(hugr)?.new_root.unwrap();
let node = self.add_hugr_view(hugr)?.new_root;

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand Down
50 changes: 20 additions & 30 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ pub trait HugrMut: HugrMutInternals {
///
/// Sibling order is not preserved.
///
/// The returned `InsertionResult` does not contain a `new_root` value, since
/// a subgraph may not have a defined root.
/// The return value is a map from indices in `other` to the indices of the
/// corresponding new nodes in `self`.
//
// TODO: Try to preserve the order when possible? We cannot always ensure
// it, since the subgraph may have arbitrary nodes without including their
Expand All @@ -175,7 +175,7 @@ pub trait HugrMut: HugrMutInternals {
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> Result<InsertionResult, HugrError> {
) -> Result<HashMap<Node, Node>, HugrError> {
self.valid_node(root)?;
self.hugr_mut().insert_subgraph(root, other, subgraph)
}
Expand All @@ -195,24 +195,14 @@ pub struct InsertionResult {
/// The node, after insertion, that was the root of the inserted Hugr.
///
/// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root]
///
/// When inserting a subgraph, this value is `None`.
pub new_root: Option<Node>,
pub new_root: Node,
/// Map from nodes in the Hugr/view that was inserted, to their new
/// positions in the Hugr into which said was inserted.
pub node_map: HashMap<Node, Node>,
}

impl InsertionResult {
fn translating_indices(
new_root: Option<Node>,
node_map: HashMap<NodeIndex, NodeIndex>,
) -> Self {
Self {
new_root,
node_map: HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into()))),
}
}
fn translate_indices(node_map: HashMap<NodeIndex, NodeIndex>) -> HashMap<Node, Node> {
HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into())))
}

/// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr.
Expand Down Expand Up @@ -298,27 +288,27 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
}

fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> Result<InsertionResult, HugrError> {
let (other_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other)?;
let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other)?;
// Update the optypes and metadata, taking them from the other graph.
for (&node, &new_node) in node_map.iter() {
let optype = other.op_types.take(node);
self.as_mut().op_types.set(new_node, optype);
let meta = other.metadata.take(node);
self.as_mut().set_metadata(new_node.into(), meta).unwrap();
}
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(
Some(other_root),
node_map,
))
debug_assert_eq!(Some(&new_root.index), node_map.get(&other.root().index));
Ok(InsertionResult {
new_root,
node_map: translate_indices(node_map),
})
}

fn insert_from_view(
&mut self,
root: Node,
other: &impl HugrView,
) -> Result<InsertionResult, HugrError> {
let (other_root, node_map) = insert_hugr_internal(self.as_mut(), root, other)?;
let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other)?;
// Update the optypes and metadata, copying them from the other graph.
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_nodetype(node.into());
Expand All @@ -328,19 +318,19 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
.set_metadata(new_node.into(), meta.clone())
.unwrap();
}
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(
Some(other_root),
node_map,
))
debug_assert_eq!(Some(&new_root.index), node_map.get(&other.root().index));
Ok(InsertionResult {
new_root,
node_map: translate_indices(node_map),
})
}

fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> Result<InsertionResult, HugrError> {
) -> Result<HashMap<Node, Node>, HugrError> {
// Create a portgraph view with the explicit list of nodes defined by the subgraph.
let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> =
NodeFiltered::new_node_filtered(
Expand All @@ -358,7 +348,7 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
.set_metadata(new_node.into(), meta.clone())
.unwrap();
}
Ok(InsertionResult::translating_indices(None, node_map))
Ok(translate_indices(node_map))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ impl Rewrite for OutlineCfg {
.insert_hugr(outer_cfg, new_block_bldr.hugr().clone())
.unwrap();
(
ins_res.new_root.unwrap(),
ins_res.new_root,
*ins_res.node_map.get(&cfg.node()).unwrap(),
)
};
Expand Down
4 changes: 1 addition & 3 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,7 @@ impl SiblingSubgraph {
// Take the unfinished Hugr from the builder, to avoid unnecessary
// validation checks that require connecting the inputs and outputs.
let mut extracted = mem::take(builder.hugr_mut());
let node_map = extracted
.insert_subgraph(extracted.root(), hugr, self)?
.node_map;
let node_map = extracted.insert_subgraph(extracted.root(), hugr, self)?;

// Connect the inserted nodes in-between the input and output nodes.
let [inp, out] = extracted.get_io(extracted.root()).unwrap();
Expand Down