Skip to content

Commit

Permalink
refactor: ExtensionSolution only consists of input extensions (#480)
Browse files Browse the repository at this point in the history
Only input extensions need to be inferred. From the values of the input
extensions, the outputs can always be calculated
  • Loading branch information
croyzor authored Sep 1, 2023
1 parent 7a43175 commit 4099995
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 42 deletions.
44 changes: 19 additions & 25 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use std::collections::{HashMap, HashSet};

use thiserror::Error;

/// A mapping from locations on the hugr to extension requirement sets which
/// have been inferred for them
pub type ExtensionSolution = HashMap<(Node, Direction), ExtensionSet>;
/// A mapping from nodes on the hugr to extension requirement sets which have
/// been inferred for their inputs.
pub type ExtensionSolution = HashMap<Node, ExtensionSet>;

/// Infer extensions for a hugr. This is the main API exposed by this module
///
Expand All @@ -38,9 +38,9 @@ pub fn infer_extensions(
let solution = ctx.main_loop()?;
ctx.instantiate_variables();
let closed_solution = ctx.main_loop()?;
let closure: HashMap<(Node, Direction), ExtensionSet> = closed_solution
let closure: ExtensionSolution = closed_solution
.into_iter()
.filter(|(loc, _)| !solution.contains_key(loc))
.filter(|(node, _)| !solution.contains_key(node))
.collect();
Ok((solution, closure))
}
Expand Down Expand Up @@ -536,7 +536,9 @@ impl UnificationContext {
}
}
}?;
results.insert(*loc, rs);
if loc.1 == Direction::Incoming {
results.insert(loc.0, rs);
}
}
debug_assert!(self.live_metas().is_empty());
Ok(results)
Expand Down Expand Up @@ -735,22 +737,11 @@ mod test {
let (_, closure) = infer_extensions(&hugr)?;
let empty = ExtensionSet::new();
let ab = ExtensionSet::from_iter(["A".into(), "B".into()]);
let abc = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]);
assert_eq!(*closure.get(&(hugr.root())).unwrap(), empty);
assert_eq!(*closure.get(&(mult_c)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab)).unwrap(), empty);
assert_eq!(
*closure.get(&(hugr.root(), Direction::Incoming)).unwrap(),
empty
);
assert_eq!(
*closure.get(&(hugr.root(), Direction::Outgoing)).unwrap(),
abc
);
assert_eq!(*closure.get(&(mult_c, Direction::Incoming)).unwrap(), ab);
assert_eq!(*closure.get(&(mult_c, Direction::Outgoing)).unwrap(), abc);
assert_eq!(*closure.get(&(add_ab, Direction::Incoming)).unwrap(), empty);
assert_eq!(*closure.get(&(add_ab, Direction::Outgoing)).unwrap(), ab);
assert_eq!(*closure.get(&(add_ab, Direction::Incoming)).unwrap(), empty);
assert_eq!(
*closure.get(&(add_b, Direction::Incoming)).unwrap(),
*closure.get(&add_b).unwrap(),
ExtensionSet::singleton(&"A".into())
);
Ok(())
Expand Down Expand Up @@ -837,9 +828,9 @@ mod test {
ctx.add_constraint(ab, Constraint::Plus("A".into(), b));
ctx.add_constraint(ab, Constraint::Plus("B".into(), a));
let solution = ctx.main_loop()?;
// We'll only find concrete solutions for the Incoming/Outgoing sides of
// We'll only find concrete solutions for the Incoming extension reqs of
// the main node created by `Hugr::default`
assert_eq!(solution.len(), 2);
assert_eq!(solution.len(), 1);
Ok(())
}

Expand Down Expand Up @@ -983,11 +974,14 @@ mod test {
hugr.connect(lift_node, 0, ochild, 0)?;
hugr.connect(child, 0, output, 0)?;

let (sol, _) = infer_extensions(&hugr)?;
hugr.infer_extensions()?;

// The solution for the const node should be {A, B}!
assert_eq!(
*sol.get(&(const_node, Direction::Outgoing)).unwrap(),
hugr.get_nodetype(const_node)
.signature()
.unwrap()
.output_extensions(),
ExtensionSet::from_iter(["A".into(), "B".into()])
);

Expand Down
18 changes: 12 additions & 6 deletions src/extension/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ use std::collections::HashMap;

use thiserror::Error;

use super::{ExtensionSet, ExtensionSolution};
use crate::hugr::NodeType;
use crate::{Direction, Hugr, HugrView, Node, Port};

use super::ExtensionSet;

/// Context for validating the extension requirements defined in a Hugr.
#[derive(Debug, Clone, Default)]
pub struct ExtensionValidator {
Expand All @@ -23,10 +22,17 @@ impl ExtensionValidator {
///
/// The `closure` argument is a set of extensions which doesn't actually
/// live on the graph, but is used to close the graph for validation
pub fn new(hugr: &Hugr, closure: HashMap<(Node, Direction), ExtensionSet>) -> Self {
let mut validator = ExtensionValidator {
extensions: closure,
};
pub fn new(hugr: &Hugr, closure: ExtensionSolution) -> Self {
let mut extensions: HashMap<(Node, Direction), ExtensionSet> = HashMap::new();
for (node, incoming_sol) in closure.into_iter() {
let op_signature = hugr.get_nodetype(node).op_signature();
let outgoing_sol = op_signature.extension_reqs.union(&incoming_sol);

extensions.insert((node, Direction::Incoming), incoming_sol);
extensions.insert((node, Direction::Outgoing), outgoing_sol);
}

let mut validator = ExtensionValidator { extensions };

for node in hugr.nodes() {
validator.gather_extensions(&node, hugr.get_nodetype(node));
Expand Down
11 changes: 3 additions & 8 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,7 @@ impl Hugr {
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for ((node, _), input_extensions) in solution
.iter()
.filter(|((_, dir), _)| *dir == Direction::Incoming)
{
for (node, input_extensions) in solution.iter() {
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
match &nodetype.input_extensions {
None => nodetype.input_extensions = Some(input_extensions.clone()),
Expand Down Expand Up @@ -528,16 +525,14 @@ mod test {
hugr.infer_extensions()?;

assert_eq!(
hugr.op_types
.get(lift.index)
hugr.get_nodetype(lift)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.op_types
.get(output.index)
hugr.get_nodetype(output)
.signature()
.unwrap()
.input_extensions,
Expand Down
6 changes: 3 additions & 3 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use pyo3::prelude::*;
use crate::extension::SignatureError;
use crate::extension::{
validate::{ExtensionError, ExtensionValidator},
ExtensionRegistry, ExtensionSet, InferExtensionError,
ExtensionRegistry, ExtensionSolution, InferExtensionError,
};
use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError};
use crate::ops::{OpTag, OpTrait, OpType, ValidateOp};
Expand Down Expand Up @@ -53,7 +53,7 @@ impl Hugr {
/// free extension variables
pub fn validate_with_extension_closure(
&self,
closure: HashMap<(Node, Direction), ExtensionSet>,
closure: ExtensionSolution,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let mut validator = ValidationContext::new(self, closure, extension_registry);
Expand All @@ -65,7 +65,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
/// Create a new validation context.
pub fn new(
hugr: &'a Hugr,
extension_closure: HashMap<(Node, Direction), ExtensionSet>,
extension_closure: ExtensionSolution,
extension_registry: &'b ExtensionRegistry,
) -> Self {
Self {
Expand Down

0 comments on commit 4099995

Please sign in to comment.