diff --git a/src/extension/infer.rs b/src/extension/infer.rs index ffab42585..1b0b7f890 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -21,9 +21,9 @@ use std::collections::{HashMap, HashSet}; use thiserror::Error; -/// A mapping from locations on the hugr to extension requirement sets which -/// have been inferred for them -pub type ExtensionSolution = HashMap<(Node, Direction), ExtensionSet>; +/// A mapping from nodes on the hugr to extension requirement sets which have +/// been inferred for their inputs. +pub type ExtensionSolution = HashMap; /// Infer extensions for a hugr. This is the main API exposed by this module /// @@ -38,9 +38,9 @@ pub fn infer_extensions( let solution = ctx.main_loop()?; ctx.instantiate_variables(); let closed_solution = ctx.main_loop()?; - let closure: HashMap<(Node, Direction), ExtensionSet> = closed_solution + let closure: ExtensionSolution = closed_solution .into_iter() - .filter(|(loc, _)| !solution.contains_key(loc)) + .filter(|(node, _)| !solution.contains_key(node)) .collect(); Ok((solution, closure)) } @@ -536,7 +536,9 @@ impl UnificationContext { } } }?; - results.insert(*loc, rs); + if loc.1 == Direction::Incoming { + results.insert(loc.0, rs); + } } debug_assert!(self.live_metas().is_empty()); Ok(results) @@ -735,22 +737,11 @@ mod test { let (_, closure) = infer_extensions(&hugr)?; let empty = ExtensionSet::new(); let ab = ExtensionSet::from_iter(["A".into(), "B".into()]); - let abc = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]); + assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty); + assert_eq!(*closure.get(&(mult_c)).unwrap(), ab); + assert_eq!(*closure.get(&(add_ab)).unwrap(), empty); assert_eq!( - *closure.get(&(hugr.root(), Direction::Incoming)).unwrap(), - empty - ); - assert_eq!( - *closure.get(&(hugr.root(), Direction::Outgoing)).unwrap(), - abc - ); - assert_eq!(*closure.get(&(mult_c, Direction::Incoming)).unwrap(), ab); - assert_eq!(*closure.get(&(mult_c, Direction::Outgoing)).unwrap(), abc); - assert_eq!(*closure.get(&(add_ab, Direction::Incoming)).unwrap(), empty); - assert_eq!(*closure.get(&(add_ab, Direction::Outgoing)).unwrap(), ab); - assert_eq!(*closure.get(&(add_ab, Direction::Incoming)).unwrap(), empty); - assert_eq!( - *closure.get(&(add_b, Direction::Incoming)).unwrap(), + *closure.get(&add_b).unwrap(), ExtensionSet::singleton(&"A".into()) ); Ok(()) @@ -837,9 +828,9 @@ mod test { ctx.add_constraint(ab, Constraint::Plus("A".into(), b)); ctx.add_constraint(ab, Constraint::Plus("B".into(), a)); let solution = ctx.main_loop()?; - // We'll only find concrete solutions for the Incoming/Outgoing sides of + // We'll only find concrete solutions for the Incoming extension reqs of // the main node created by `Hugr::default` - assert_eq!(solution.len(), 2); + assert_eq!(solution.len(), 1); Ok(()) } @@ -983,11 +974,14 @@ mod test { hugr.connect(lift_node, 0, ochild, 0)?; hugr.connect(child, 0, output, 0)?; - let (sol, _) = infer_extensions(&hugr)?; + hugr.infer_extensions()?; // The solution for the const node should be {A, B}! assert_eq!( - *sol.get(&(const_node, Direction::Outgoing)).unwrap(), + hugr.get_nodetype(const_node) + .signature() + .unwrap() + .output_extensions(), ExtensionSet::from_iter(["A".into(), "B".into()]) ); diff --git a/src/extension/validate.rs b/src/extension/validate.rs index a7322f4fb..c026c2519 100644 --- a/src/extension/validate.rs +++ b/src/extension/validate.rs @@ -5,11 +5,10 @@ use std::collections::HashMap; use thiserror::Error; +use super::{ExtensionSet, ExtensionSolution}; use crate::hugr::NodeType; use crate::{Direction, Hugr, HugrView, Node, Port}; -use super::ExtensionSet; - /// Context for validating the extension requirements defined in a Hugr. #[derive(Debug, Clone, Default)] pub struct ExtensionValidator { @@ -23,10 +22,17 @@ impl ExtensionValidator { /// /// The `closure` argument is a set of extensions which doesn't actually /// live on the graph, but is used to close the graph for validation - pub fn new(hugr: &Hugr, closure: HashMap<(Node, Direction), ExtensionSet>) -> Self { - let mut validator = ExtensionValidator { - extensions: closure, - }; + pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self { + let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new(); + for (node, incoming_sol) in closure.into_iter() { + let op_signature = hugr.get_nodetype(node).op_signature(); + let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol); + + extensions.insert((node, Direction::Incoming), incoming_sol); + extensions.insert((node, Direction::Outgoing), outgoing_sol); + } + + let mut validator = ExtensionValidator { extensions }; for node in hugr.nodes() { validator.gather_extensions(&node, hugr.get_nodetype(node)); diff --git a/src/hugr.rs b/src/hugr.rs index 3c0c970a6..374077529 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -221,10 +221,7 @@ impl Hugr { fn instantiate_extensions(&mut self, solution: ExtensionSolution) { // We only care about inferred _input_ extensions, because `NodeType` // uses those to infer the output extensions - for ((node, _), input_extensions) in solution - .iter() - .filter(|((_, dir), _)| *dir == Direction::Incoming) - { + for (node, input_extensions) in solution.iter() { let nodetype = self.op_types.try_get_mut(node.index).unwrap(); match &nodetype.input_extensions { None => nodetype.input_extensions = Some(input_extensions.clone()), @@ -528,16 +525,14 @@ mod test { hugr.infer_extensions()?; assert_eq!( - hugr.op_types - .get(lift.index) + hugr.get_nodetype(lift) .signature() .unwrap() .input_extensions, ExtensionSet::new() ); assert_eq!( - hugr.op_types - .get(output.index) + hugr.get_nodetype(output) .signature() .unwrap() .input_extensions, diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 83b087ff3..061bad2a4 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -15,7 +15,7 @@ use pyo3::prelude::*; use crate::extension::SignatureError; use crate::extension::{ validate::{ExtensionError, ExtensionValidator}, - ExtensionRegistry, ExtensionSet, InferExtensionError, + ExtensionRegistry, ExtensionSolution, InferExtensionError, }; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; use crate::ops::{OpTag, OpTrait, OpType, ValidateOp}; @@ -53,7 +53,7 @@ impl Hugr { /// free extension variables pub fn validate_with_extension_closure( &self, - closure: HashMap<(Node, Direction), ExtensionSet>, + closure: ExtensionSolution, extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { let mut validator = ValidationContext::new(self, closure, extension_registry); @@ -65,7 +65,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { /// Create a new validation context. pub fn new( hugr: &'a Hugr, - extension_closure: HashMap<(Node, Direction), ExtensionSet>, + extension_closure: ExtensionSolution, extension_registry: &'b ExtensionRegistry, ) -> Self { Self {