Skip to content

Commit

Permalink
chore: graph error bubbling (partial)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Jan 11, 2023
1 parent 35ba356 commit 0af1afb
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 94 deletions.
7 changes: 1 addition & 6 deletions src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod utils;

use thiserror::Error;

/// A wrapper for tensor related errors.
/// circuit related errors.
#[derive(Debug, Error)]
pub enum CircuitError {
/// Shape mismatch in circuit construction
Expand All @@ -19,16 +19,11 @@ pub enum CircuitError {
LookupInstantiation,
/// A lookup table was was already assigned
TableAlreadyAssigned,
/// A val/var tensor combination is not yet implemented
VariableComb,
}

impl std::fmt::Display for CircuitError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
CircuitError::VariableComb => {
write!(f, "var/val tensor combination is not yet implemented",)
}
CircuitError::DimMismatch(op) => {
write!(f, "dimension mismatch in circuit construction: {}", op)
}
Expand Down
63 changes: 63 additions & 0 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,71 @@ pub use model::*;
pub use node::*;
use std::cmp::max;
use std::marker::PhantomData;
use thiserror::Error;
pub use vars::*;

/// circuit related errors.
#[derive(Debug, Error)]
pub enum GraphError {
/// Shape mismatch in circuit construction
DimMismatch(String),
/// Wrong method was called to configure an op
WrongMethod(OpKind),
/// A requested node is missing in the graph
MissingNode(usize),
/// A requested node is missing in the graph
OpMismatch(OpKind),
/// A requested node is missing in the graph
UnsupportedOp,
/// A requested node is missing in the graph
MissingParams(String),
/// Error in the configuration of the visibility of variables
Visibility,
///
NonConstantDiv,
///
NonConstantPower,
///
RescalingError(OpKind),
}

impl std::fmt::Display for GraphError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
GraphError::DimMismatch(op) => {
write!(f, "dimension mismatch in circuit construction: {}", op)
}
GraphError::WrongMethod(op) => {
write!(f, "wrong method was called to configure: {}", op)
}
GraphError::MissingNode(id) => {
write!(f, "a requested node is missing in the graph: {}", id)
}
GraphError::OpMismatch(id) => {
write!(f, "a requested node is missing in the graph: {}", id)
}
GraphError::UnsupportedOp => {
write!(f, "unsupported operation in graph")
}
GraphError::MissingParams(id) => {
write!(f, "a requested node is missing in the graph: {}", id)
}
GraphError::Visibility => {
write!(f, "there should be at least 1 set of public variables")
}
GraphError::NonConstantDiv => {
write!(f, "ezkl currently only supports division by a constant")
}
GraphError::NonConstantPower => {
write!(f, "ezkl currently only supports constant exponents")
}
GraphError::RescalingError(op) => {
write!(f, "failed to rescale inputs to {}", op)
}
}
}
}

/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
#[derive(Clone, Debug)]
pub struct ModelCircuit<F: FieldExt> {
Expand Down
83 changes: 48 additions & 35 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::node::*;
use super::vars::*;
use super::GraphError;
use crate::circuit::lookup::Config as LookupConfig;
use crate::circuit::lookup::Op as LookupOp;
use crate::circuit::lookup::Table as LookupTable;
Expand Down Expand Up @@ -134,7 +135,7 @@ impl Model {

/// Creates a `Model` from parsed CLI arguments
pub fn from_ezkl_conf(args: Cli) -> Result<Self, Box<dyn Error>> {
let visibility = VarVisibility::from_args(args.clone());
let visibility = VarVisibility::from_args(args.clone())?;
match args.command {
Commands::Table { model } => Model::new(
model,
Expand Down Expand Up @@ -219,7 +220,7 @@ impl Model {
.collect();
if !non_op_nodes.is_empty() {
for (i, node) in non_op_nodes {
let config = self.conf_non_op_node(&node);
let config = self.conf_non_op_node(&node)?;
results.insert(*i, config);
}
}
Expand All @@ -231,7 +232,7 @@ impl Model {

if !lookup_ops.is_empty() {
for (i, node) in lookup_ops {
let config = self.conf_table(node, meta, vars, &mut tables);
let config = self.conf_table(node, meta, vars, &mut tables)?;
results.insert(*i, config);
}
}
Expand All @@ -243,7 +244,7 @@ impl Model {
.collect();
// preserves ordering
if !poly_ops.is_empty() {
let config = self.conf_poly_ops(&poly_ops, meta, vars);
let config = self.conf_poly_ops(&poly_ops, meta, vars)?;
results.insert(**poly_ops.keys().max().unwrap(), config);

let mut display: String = "Poly nodes: ".to_string();
Expand Down Expand Up @@ -296,23 +297,26 @@ impl Model {
configs
}
/// Configures non op related nodes (eg. representing an input or const value)
pub fn conf_non_op_node<F: FieldExt + TensorType>(&self, node: &Node) -> NodeConfig<F> {
pub fn conf_non_op_node<F: FieldExt + TensorType>(
&self,
node: &Node,
) -> Result<NodeConfig<F>, Box<dyn Error>> {
match &node.opkind {
OpKind::Const => {
// Typically parameters for one or more layers.
// Currently this is handled in the consuming node(s), but will be moved here.
NodeConfig::Const
Ok(NodeConfig::Const)
}
OpKind::Input => {
// This is the input to the model (e.g. the image).
// Currently this is handled in the consuming node(s), but will be moved here.
NodeConfig::Input
Ok(NodeConfig::Input)
}
OpKind::Unknown(_c) => {
unimplemented!()
}
c => {
panic!("wrong method called for {}", c)
return Err(Box::new(GraphError::WrongMethod(c.clone())));
}
}
}
Expand All @@ -330,25 +334,27 @@ impl Model {
nodes: &BTreeMap<&usize, &Node>,
meta: &mut ConstraintSystem<F>,
vars: &mut ModelVars<F>,
) -> NodeConfig<F> {
let input_nodes: BTreeMap<(&usize, &PolyOp), Vec<Node>> = nodes
.iter()
.map(|(i, e)| {
(
(
*i,
match &e.opkind {
OpKind::Poly(f) => f,
_ => panic!(),
},
),
e.inputs
.iter()
.map(|i| self.nodes.filter(i.node))
.collect_vec(),
)
})
.collect();
) -> Result<NodeConfig<F>, Box<dyn Error>> {
let mut input_nodes: BTreeMap<(&usize, &PolyOp), Vec<Node>> = BTreeMap::new();

for (i, e) in nodes.iter() {
let key = (
*i,
match &e.opkind {
OpKind::Poly(f) => f,
_ => {
return Err(Box::new(GraphError::WrongMethod(e.opkind.clone())));
}
},
);
let value = e
.inputs
.iter()
.map(|i| self.nodes.filter(i.node))
.collect_vec();
input_nodes.insert(key, value);
}

// This works because retain only keeps items for which the predicate returns true, and
// insert only returns true if the item was not previously present in the set.
// Since the vector is traversed in order, we end up keeping just the first occurrence of each item.
Expand Down Expand Up @@ -407,15 +413,16 @@ impl Model {

let inputs = inputs_to_layer.iter();

NodeConfig::Poly(
let config = NodeConfig::Poly(
PolyConfig::configure(
meta,
&inputs.clone().map(|x| x.1.clone()).collect_vec(),
output,
&fused_nodes,
),
inputs.map(|x| x.0).collect_vec(),
)
);
Ok(config)
}

/// Configures a lookup table based operation. These correspond to operations that are represented in
Expand All @@ -431,18 +438,20 @@ impl Model {
meta: &mut ConstraintSystem<F>,
vars: &mut ModelVars<F>,
tables: &mut BTreeMap<Vec<LookupOp>, Rc<RefCell<LookupTable<F>>>>,
) -> NodeConfig<F> {
) -> Result<NodeConfig<F>, Box<dyn Error>> {
let input_len = node.in_dims[0].iter().product();
let input = &vars.advices[0].reshape(&[input_len]);
let output = &vars.advices[1].reshape(&[input_len]);
let node_inputs = node.inputs.iter().map(|e| e.node).collect();

let op = match &node.opkind {
OpKind::Lookup(l) => l,
c => panic!("wrong method called for {}", c),
c => {
return Err(Box::new(GraphError::WrongMethod(c.clone())));
}
};

if tables.contains_key(&vec![op.clone()]) {
let config = if tables.contains_key(&vec![op.clone()]) {
let table = tables.get(&vec![op.clone()]).unwrap();
let conf: LookupConfig<F> =
LookupConfig::configure_with_table(meta, input, output, table.clone());
Expand All @@ -452,7 +461,8 @@ impl Model {
LookupConfig::configure(meta, input, output, self.bits, &[op.clone()]);
tables.insert(vec![op.clone()], conf.table.clone());
NodeConfig::Lookup(conf, node_inputs)
}
};
Ok(config)
}

/// Assigns values to the regions created when calling `configure`.
Expand Down Expand Up @@ -554,13 +564,16 @@ impl Model {
}
NodeConfig::Lookup(rc, idx) => {
assert_eq!(idx.len(), 1);
if idx.len() != 1 {
return Err(Box::new(GraphError::DimMismatch("lookup".to_string())));
}
// For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output.
Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap())?)
}
NodeConfig::Input => None,
NodeConfig::Const => None,
c => {
panic!("Not a configurable op {:?}", c)
_ => {
return Err(Box::new(GraphError::UnsupportedOp));
}
};
Ok(res)
Expand Down
Loading

0 comments on commit 0af1afb

Please sign in to comment.