diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index b3c6ff13d..0ffc6eb3e 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -293,8 +293,7 @@ pub trait Dataflow: Container { signature: signature.clone(), }; let nodetype = match &input_extensions { - // TODO: Make this NodeType::open_extensions - None => NodeType::pure(op), + None => NodeType::open_extensions(op), Some(rs) => NodeType::new(op, rs.clone()), }; let (dfg_n, _) = add_node_with_wires(self, nodetype, input_wires.into_iter().collect())?; diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 657afd8c2..bb17940af 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -71,10 +71,10 @@ impl CFGBuilder { impl HugrBuilder for CFGBuilder { fn finish_hugr( - self, + mut self, extension_registry: &ExtensionRegistry, ) -> Result { - self.base.validate(extension_registry)?; + self.base.infer_and_validate(extension_registry)?; Ok(self.base) } } diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index 993240bec..8943ed7ee 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -151,8 +151,7 @@ impl HugrBuilder for ConditionalBuilder { mut self, extension_registry: &ExtensionRegistry, ) -> Result { - self.base.infer_extensions()?; - self.base.validate(extension_registry)?; + self.base.infer_and_validate(extension_registry)?; Ok(self.base) } } diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 8250532bd..b536b9e23 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -52,8 +52,7 @@ impl + AsRef> DFGBuilder { base.as_mut().add_node_with_parent( parent, match &input_extensions { - // TODO: Make this NodeType::open_extensions - None => NodeType::pure(input), + None => NodeType::open_extensions(input), Some(rs) => NodeType::new(input, rs.clone()), }, )?; @@ -61,7 +60,7 @@ impl + AsRef> DFGBuilder { parent, match input_extensions.map(|inp| inp.union(&signature.extension_reqs)) { // TODO: Make this NodeType::open_extensions - None => NodeType::new(output, signature.extension_reqs), + None => NodeType::open_extensions(output), Some(rs) => NodeType::new(output, rs), }, )?; @@ -100,9 +99,7 @@ impl HugrBuilder for DFGBuilder { mut self, extension_registry: &ExtensionRegistry, ) -> Result { - let closure = self.base.infer_extensions()?; - self.base - .validate_with_extension_closure(closure, extension_registry)?; + self.base.infer_and_validate(extension_registry)?; Ok(self.base) } } diff --git a/src/builder/module.rs b/src/builder/module.rs index 4cec8a4fd..d2fab41ca 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -56,8 +56,11 @@ impl Default for ModuleBuilder { } impl HugrBuilder for ModuleBuilder { - fn finish_hugr(self, extension_registry: &ExtensionRegistry) -> Result { - self.0.validate(extension_registry)?; + fn finish_hugr( + mut self, + extension_registry: &ExtensionRegistry, + ) -> Result { + self.0.infer_and_validate(extension_registry)?; Ok(self.0) } } diff --git a/src/hugr.rs b/src/hugr.rs index 441217112..92b6713d9 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -24,7 +24,9 @@ use thiserror::Error; use pyo3::prelude::*; pub use self::views::HugrView; -use crate::extension::{infer_extensions, ExtensionSet, ExtensionSolution, InferExtensionError}; +use crate::extension::{ + infer_extensions, ExtensionRegistry, ExtensionSet, ExtensionSolution, InferExtensionError, +}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::types::{FunctionType, Signature}; @@ -196,6 +198,16 @@ impl Hugr { rw.apply(self) } + /// Run resource inference and pass the closure into validation + pub fn infer_and_validate( + &mut self, + extension_registry: &ExtensionRegistry, + ) -> Result<(), ValidationError> { + let closure = self.infer_extensions()?; + self.validate_with_extension_closure(closure, extension_registry)?; + Ok(()) + } + /// Infer extension requirements and add new information to `op_types` field /// /// See [`infer_extensions`] for details on the "closure" value diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 708c7aba6..a4338c380 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -695,10 +695,7 @@ mod test { use cool_asserts::assert_matches; use super::*; - use crate::builder::{ - BuildError, Container, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, - ModuleBuilder, - }; + use crate::builder::{BuildError, Container, Dataflow, DataflowSubContainer, ModuleBuilder}; use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; use crate::extension::{prelude_registry, Extension, ExtensionSet, TypeDefBound, EMPTY_REG}; use crate::hugr::hugrmut::sealed::HugrMutInternals; @@ -1134,7 +1131,7 @@ mod test { let f_handle = f_builder.finish_with_outputs(f_inputs)?; let [f_output] = f_handle.outputs_arr(); main.finish_with_outputs([f_output])?; - let handle = module_builder.finish_prelude_hugr(); + let handle = module_builder.hugr().validate(&prelude_registry()); assert_matches!( handle, @@ -1171,7 +1168,7 @@ mod test { let f_handle = f_builder.finish_with_outputs(f_inputs)?; let [f_output] = f_handle.outputs_arr(); main.finish_with_outputs([f_output])?; - let handle = module_builder.finish_prelude_hugr(); + let handle = module_builder.hugr().validate(&prelude_registry()); assert_matches!( handle, Err(ValidationError::ExtensionError( @@ -1233,7 +1230,7 @@ mod test { let [output] = builder.finish_with_outputs([])?.outputs_arr(); main.finish_with_outputs([output])?; - let handle = module_builder.finish_prelude_hugr(); + let handle = module_builder.hugr().validate(&prelude_registry()); assert_matches!( handle, Err(ValidationError::ExtensionError( @@ -1245,16 +1242,33 @@ mod test { #[test] fn parent_signature_mismatch() -> Result<(), BuildError> { - let main_signature = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&"R".into())); + let rs = ExtensionSet::singleton(&"R".into()); - let mut builder = DFGBuilder::new(main_signature)?; - let [w] = builder.input_wires_arr(); - builder.set_outputs([w])?; - let hugr = builder.base.validate(&prelude_registry()); // finish_hugr_with_outputs([w]); + let main_signature = + FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); + + let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + signature: main_signature, + })); + let input = hugr.add_node_with_parent( + hugr.root(), + NodeType::pure(ops::Input { + types: type_row![NAT], + }), + )?; + let output = hugr.add_node_with_parent( + hugr.root(), + NodeType::new( + ops::Output { + types: type_row![NAT], + }, + rs, + ), + )?; + hugr.connect(input, 0, output, 0)?; assert_matches!( - hugr, + hugr.validate(&prelude_registry()), Err(ValidationError::ExtensionError( ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } ))