From 8a7c1330868d14c34fdcc619a48ee88c14d2eb35 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Fri, 23 Dec 2022 12:27:03 +0000 Subject: [PATCH 01/13] chore: the great documentation push of 2022 --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index ac6c92634..e0a242fe0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![deny(warnings, unsafe_code)] +#![deny(missing_docs, warnings, unsafe_code)] #![feature(slice_flatten)] /// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit. From 553a2a910c672b3d9d2f8e903cb299f369ac3aa2 Mon Sep 17 00:00:00 2001 From: jason Date: Fri, 23 Dec 2022 09:39:44 -0500 Subject: [PATCH 02/13] Docs for var.rs, rename v1 to max_rot --- src/tensor/var.rs | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/tensor/var.rs b/src/tensor/var.rs index 88db9ac21..25d3a0d89 100644 --- a/src/tensor/var.rs +++ b/src/tensor/var.rs @@ -12,31 +12,51 @@ use std::cmp::min; /// using the `assign` method called on a [ValTensor]. #[derive(Clone, Debug)] pub enum VarTensor { + /// A VarTensor for holding Advice values, which are assigned at proving time. Advice { + /// Vec of Advice columns inner: Vec>, + /// Number of rows available to be used in each column of the storage col_size: usize, + /// Total capacity (number of advice cells), usually inner.len()*col_size capacity: usize, + /// Vector of dimensions of the tensor we are representing using this storage. Note that the shape of the storage and this shape can differ. dims: Vec, }, + /// A VarTensor for holding Fixed values, which are assigned at circuit definition time. Fixed { + /// Vec of Fixed columns inner: Vec>, + /// Number of rows available to be used in each column of the storage col_size: usize, + /// Total capacity (number of advice cells), usually inner.len()*col_size capacity: usize, + /// Vector of dimensions of the tensor we are representing using this storage. Note that the shape of the storage and this shape can differ. dims: Vec, }, } impl VarTensor { + /// Create a new VarTensor::Advice + /// `cs` is the `ConstraintSystem` from which the columns will be allocated. + /// `k` is the log2 number of rows in the matrix, including any system and blinding rows. + /// `capacity` is the number of advice cells for this tensor + /// `dims` is the `Vec` of dimensions of the tensor we are representing. Note that the shape of the storage and this shape can differ. + /// `equality` should be true if we want to enable equality constraints for the columns involved. + /// `max_rot` is the maximum number of rotations that we allow for this VarTensor. Rotations affect performance. pub fn new_advice( cs: &mut ConstraintSystem, k: usize, capacity: usize, dims: Vec, equality: bool, - v1: usize, + max_rot: usize, ) -> Self { let base = 2u32; - let max_rows = min(v1, base.pow(k as u32) as usize - cs.blinding_factors() - 1); + let max_rows = min( + max_rot, + base.pow(k as u32) as usize - cs.blinding_factors() - 1, + ); let modulo = (capacity / max_rows) + 1; let mut advices = vec![]; for _ in 0..modulo { @@ -55,16 +75,26 @@ impl VarTensor { } } + /// Create a new VarTensor::Fixed + /// `cs` is the `ConstraintSystem` from which the columns will be allocated. + /// `k` is the log2 number of rows in the matrix, including any system and blinding rows. + /// `capacity` is the number of fixed cells for this tensor + /// `dims` is the `Vec` of dimensions of the tensor we are representing. Note that the shape of the storage and this shape can differ. + /// `equality` should be true if we want to enable equality constraints for the columns involved. + /// `max_rot` is the maximum number of rotations that we allow for this VarTensor. Rotations affect performance. pub fn new_fixed( cs: &mut ConstraintSystem, k: usize, capacity: usize, dims: Vec, equality: bool, - v1: usize, + max_rot: usize, ) -> Self { let base = 2u32; - let max_rows = min(v1, base.pow(k as u32) as usize - cs.blinding_factors() - 1); + let max_rows = min( + max_rot, + base.pow(k as u32) as usize - cs.blinding_factors() - 1, + ); let modulo = (capacity / max_rows) + 1; let mut fixed = vec![]; for _ in 0..modulo { @@ -126,6 +156,7 @@ impl VarTensor { } } + /// Take a linear coordinate and output the (column, row) position in the storage block. pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize) { match self { VarTensor::Advice { col_size, .. } | VarTensor::Fixed { col_size, .. } => { From 68af5c57477b956bebb1d79faa31926c6ca938b2 Mon Sep 17 00:00:00 2001 From: jason Date: Fri, 23 Dec 2022 09:53:40 -0500 Subject: [PATCH 03/13] Docs for val.rs --- src/tensor/val.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 73f0741b8..07305d679 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -5,20 +5,32 @@ use halo2_proofs::plonk::Instance; /// For instance can represent pre-trained neural network weights; or a known input to a network. #[derive(Debug, Clone)] pub enum ValTensor { + /// A tensor of Values, each containing a field element Value { + /// Underlying Tensor. inner: Tensor>, + /// Vector of dimensions of the tensor. dims: Vec, }, + /// A tensor of Values, each containing a ratio of field elements, which may be evaluated to produce plain field elements. AssignedValue { + /// Underlying Tensor. inner: Tensor>>, + /// Vector of dimensions of the tensor. dims: Vec, }, + /// A tensor of AssignedCells, with data both a value and the matrix cell to which it is assigned. PrevAssigned { + /// Underlying Tensor. inner: Tensor>, + /// Vector of dimensions of the tensor. dims: Vec, }, + /// A tensor backed by an Instance column Instance { + /// Underlying Tensor. inner: Column, + /// Vector of dimensions of the tensor. dims: Vec, }, } @@ -51,6 +63,7 @@ impl From>> for ValTensor } impl ValTensor { + /// Allocate a new ValTensor::Instance from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. pub fn new_instance(cs: &mut ConstraintSystem, dims: Vec, equality: bool) -> Self { let col = cs.instance_column(); if equality { @@ -142,6 +155,7 @@ impl ValTensor { | ValTensor::Instance { dims: d, .. } => d, } } + /// A `String` representation of the `ValTensor` for display, for example in showing intermediate values in a computational graph. pub fn show(&self) -> String { match self.clone() { ValTensor::PrevAssigned { inner: v, dims: _ } => { From 715e49db35ddc6725c418be095485a7c8dc84f70 Mon Sep 17 00:00:00 2001 From: jason Date: Fri, 23 Dec 2022 10:18:42 -0500 Subject: [PATCH 04/13] add sumpool doc and doctest --- src/tensor/ops.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 55ea2e417..eb1260ea2 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -529,6 +529,29 @@ pub fn convolution + Add>( output } +/// Applies 2D sum pooling over a 3D tensor of shape C x H x W. +/// # Arguments +/// +/// * `image` - Tensor. +/// * `padding` - Tuple of padding values in x and y directions. +/// * `stride` - Tuple of stride values in x and y directions. +/// * `pool_dims` - Tuple of pooling window size in x and y directions. +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::sumpool; +/// use halo2_proofs::circuit::Value; +/// use halo2_proofs::plonk::Assigned; +/// use halo2curves::pasta::Fp as F; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 3, 3], +/// ).unwrap(); +/// let pooled = sumpool::(&x, (0, 0), (1, 1), (2, 2)); +/// let expected: Tensor = Tensor::::new(Some(&[11, 8, 8, 10]), &[1, 2, 2]).unwrap(); +/// assert_eq!(pooled, expected); +/// ``` pub fn sumpool + Add>( image: &Tensor, padding: (usize, usize), From 0a3a02831af6bba75f027e26131c324e0b4325c1 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Fri, 23 Dec 2022 20:49:27 +0000 Subject: [PATCH 05/13] `src/circuit` docs --- src/circuit/eltwise.rs | 24 +++++++++++++++++++++++- src/circuit/fused.rs | 2 ++ src/circuit/range.rs | 1 + src/commands.rs | 6 ++++++ src/fieldutils.rs | 3 +++ src/graph/mod.rs | 3 +++ src/graph/model.rs | 11 ++++++++++- src/graph/utilities.rs | 9 +++++++++ src/pfsys/mod.rs | 12 +++++++++--- src/tensor/mod.rs | 3 ++- src/tensor/ops.rs | 20 ++++++++++++++++++++ src/tensor/val.rs | 30 +++++++++++++++--------------- src/tensor/var.rs | 14 ++++++++------ 13 files changed, 111 insertions(+), 27 deletions(-) diff --git a/src/circuit/eltwise.rs b/src/circuit/eltwise.rs index 52b954b33..27decea39 100644 --- a/src/circuit/eltwise.rs +++ b/src/circuit/eltwise.rs @@ -10,7 +10,13 @@ use halo2_proofs::{ use log::error; use std::{cell::RefCell, marker::PhantomData, rc::Rc}; +/// Defines a non-linear layer. pub trait Nonlinearity { + /// Function that defines the non-linearity. + /// Arguments + /// + /// * `x` - input to function + /// * `scales` - additional parameters that may parametrize the function fn nonlinearity(x: i32, scales: &[usize]) -> F; /// a value which is always in the table fn default_pair(scales: &[usize]) -> (F, F) { @@ -18,26 +24,35 @@ pub trait Nonlinearity { } } +/// A 1D non-linearity. #[derive(Clone, Debug)] pub struct Nonlin1d> { + /// Input to the layer as a [ValTensor]. pub input: ValTensor, + /// Input to the layer as a [ValTensor]. pub output: ValTensor, - pub _marker: PhantomData<(F, NL)>, + _marker: PhantomData<(F, NL)>, } /// Halo2 lookup table for element wise non-linearities. // Table that should be reused across all lookups (so no Clone) #[derive(Clone, Debug)] pub struct EltwiseTable> { + /// Input to table. pub table_input: TableColumn, + /// Output of table pub table_output: TableColumn, + /// Flags if table has been previously assigned to. pub is_assigned: bool, + /// Number of bits used in lookup table. pub scaling_params: Vec, + /// Number of bits used in lookup table. pub bits: usize, _marker: PhantomData<(F, NL)>, } impl> EltwiseTable { + /// Configures the table. pub fn configure( cs: &mut ConstraintSystem, bits: usize, @@ -52,6 +67,7 @@ impl> EltwiseTable { _marker: PhantomData, } } + /// Assigns values to the constraints generated when calling `configure`. pub fn layout(&mut self, layouter: &mut impl Layouter) { assert!(!self.is_assigned); let base = 2i32; @@ -100,8 +116,11 @@ impl> EltwiseTable { /// Configuration for element-wise non-linearities. #[derive(Clone, Debug)] pub struct EltwiseConfig> { + /// [VarTensor] input to non-linearity. pub input: VarTensor, + /// [VarTensor] input to non-linearity. pub output: VarTensor, + /// Lookup table used to represent the non-linearity pub table: Rc>>, qlookup: Selector, _marker: PhantomData<(NL, F)>, @@ -260,6 +279,7 @@ impl> EltwiseConfig { @@ -277,6 +297,7 @@ impl Nonlinearity for ReLu { } } +#[allow(missing_docs)] #[derive(Clone, Debug)] pub struct Sigmoid { _marker: PhantomData, @@ -292,6 +313,7 @@ impl Nonlinearity for Sigmoid { } } +#[allow(missing_docs)] #[derive(Clone, Debug)] pub struct DivideBy { _marker: PhantomData, diff --git a/src/circuit/fused.rs b/src/circuit/fused.rs index d66c17381..02aaab1be 100644 --- a/src/circuit/fused.rs +++ b/src/circuit/fused.rs @@ -12,6 +12,7 @@ use log::error; use std::fmt; use std::marker::PhantomData; +#[allow(missing_docs)] /// An enum representing the operations that can be merged into a single circuit gate. #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum FusedOp { @@ -100,6 +101,7 @@ pub struct FusedConfig { nodes: Vec, /// the (currently singular) output of the fused operations. pub output: VarTensor, + /// [Selector] generated when configuring the layer. pub selector: Selector, _marker: PhantomData, } diff --git a/src/circuit/range.rs b/src/circuit/range.rs index 225d1b356..342a4838b 100644 --- a/src/circuit/range.rs +++ b/src/circuit/range.rs @@ -13,6 +13,7 @@ use std::marker::PhantomData; #[derive(Debug, Clone)] pub struct RangeCheckConfig { input: VarTensor, + /// The value we are expecting the output of the circuit to match (within a range) pub expected: VarTensor, selector: Selector, _marker: PhantomData, diff --git a/src/commands.rs b/src/commands.rs index 803dda97b..fc7ef2f65 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -4,10 +4,12 @@ use log::info; use std::io::{stdin, stdout, Write}; use std::path::PathBuf; +#[allow(missing_docs)] #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Cli { #[command(subcommand)] + #[allow(missing_docs)] pub command: Commands, /// The tolerance for error on model outputs #[arg(short = 'T', long, default_value = "0")] @@ -35,6 +37,7 @@ pub struct Cli { pub max_rotations: usize, } +#[allow(missing_docs)] #[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)] pub enum ProofSystem { IPA, @@ -49,6 +52,7 @@ impl std::fmt::Display for ProofSystem { } } +#[allow(missing_docs)] #[derive(Debug, Subcommand)] pub enum Commands { /// Loads model and prints model table @@ -121,6 +125,7 @@ pub enum Commands { default_missing_value = "always", value_enum )] + /// The [ProofSystem] we'll be using. pfsys: ProofSystem, // todo, optionally allow supplying proving key }, @@ -156,6 +161,7 @@ pub enum Commands { }, } +/// Loads the path to a path `data` represented as a [String]. If empty queries the user for an input. pub fn data_path(data: String) -> PathBuf { let mut s = String::new(); match data.is_empty() { diff --git a/src/fieldutils.rs b/src/fieldutils.rs index 171152933..5f09beb3e 100644 --- a/src/fieldutils.rs +++ b/src/fieldutils.rs @@ -1,5 +1,7 @@ /// Utilities for converting from Halo2 Field types to integers (and vice-versa). use halo2_proofs::arithmetic::FieldExt; + +/// Converts an i32 to a Field element. pub fn i32_to_felt(x: i32) -> F { if x >= 0 { F::from(x as u64) @@ -12,6 +14,7 @@ fn felt_to_u32(x: F) -> u32 { x.get_lower_32() } +/// Converts a Field element to an i32. pub fn felt_to_i32(x: F) -> i32 { if x > F::from(65536) { -(felt_to_u32(-x) as i32) diff --git a/src/graph/mod.rs b/src/graph/mod.rs index cf9de7dcb..f6e8d76e7 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -18,9 +18,12 @@ pub use node::*; use std::cmp::max; pub use vars::*; +/// Defines the circuit for a computational graph / model loaded from a `.onnx` file. #[derive(Clone, Debug)] pub struct ModelCircuit { + /// Vector of input tensors to the model / graph of computations. pub inputs: Vec>, + /// Represents the Field we are using. pub _marker: PhantomData, } diff --git a/src/graph/model.rs b/src/graph/model.rs index 8f3dd63d3..5f115e397 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -46,13 +46,22 @@ pub struct ModelConfig { /// A struct for loading from an Onnx file and converting a computational graph to a circuit. #[derive(Clone, Debug)] pub struct Model { - pub model: Graph>, // The raw Tract data structure + /// The raw tract [Graph] data structure. + pub model: Graph>, + /// Graph of nodes we are loading from Onnx. pub nodes: NodeGraph, // Wrapped nodes with additional methods and data (e.g. inferred shape, quantization) + /// bits used in lookup tables pub bits: usize, + /// Log rows available in circuit. pub logrows: u32, + /// Exponent used in the fixed point representation. pub scale: i32, + /// The divergence from the expected output (if using public outputs) we can tolerate. This is in absolute value across each dimension. + /// eg. for a tolerance of 1 and for a 2D output we could tolerate at most off by 1 errors for each of the 2 outputs. pub tolerance: usize, + /// The [Mode] we're using the model in. pub mode: Mode, + /// Defines which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility]. pub visibility: VarVisibility, } diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 01b168e47..75a49e899 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -4,6 +4,13 @@ use tract_onnx::prelude::{InferenceFact, Node}; use tract_onnx::tract_hir::internal::InferenceOp; // Warning: currently ignores stride information +/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation. +/// Arguments +/// +/// * `vec` - the vector to quantize. +/// * `dims` - the dimensionality of the resulting [Tensor]. +/// * `shift` - offset used in the fixed point representation. +/// * `scale` - `2^scale` used in the fixed point representation. pub fn vector_to_quantized( vec: &[f32], dims: &[usize], @@ -18,10 +25,12 @@ pub fn vector_to_quantized( Tensor::new(Some(&scaled), dims) } +/// Converts a scale (log base 2) to a fixed point multiplier. pub fn scale_to_multiplier(scale: i32) -> f32 { i32::pow(2, scale as u32) as f32 } +/// Gets the shape of a onnx node's outlets. pub fn node_output_shapes( node: &Node>, ) -> Result>>> { diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 1c55f5b1d..96e6855d3 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -27,21 +27,28 @@ use std::ops::Deref; use std::path::PathBuf; use std::time::Instant; +/// Defines the input to a model. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelInput { + /// Inputs to the model / computational graph. pub input_data: Vec>, + /// The shape of said inputs. pub input_shapes: Vec>, + /// The expected output of the model (can be empty vectors if outputs are not being constrained). pub output_data: Vec>, } +/// Defines the proof generated by a model / circuit. #[derive(Debug, Deserialize, Serialize)] pub struct Proof { - pub input_shapes: Vec>, + /// Public inputs to the model. pub public_inputs: Vec>, + /// The generated proof, as a vector of bytes. pub proof: Vec, } impl Proof { + /// Saves the Proof to a specified `proof_path`. pub fn save(&self, proof_path: &PathBuf) { let serialized = match serde_json::to_string(&self) { Ok(s) => s, @@ -278,7 +285,6 @@ where info!("Proof took {}", now.elapsed().as_secs()); let checkable_pf = Proof { - input_shapes: circuit.inputs.iter().map(|i| i.dims().to_vec()).collect(), public_inputs: public_inputs .iter() .map(|i| i.clone().into_iter().collect()) @@ -289,7 +295,7 @@ where (checkable_pf, dims) } -/// a wrapper around halo2's verify_proof +/// A wrapper around halo2's verify_proof pub fn verify_proof_model< 'params, F: FieldExt, diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 7427c693b..ed69bac78 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -30,7 +30,7 @@ pub trait TensorType: Clone + Debug + 'static { fn zero() -> Option { None } - + /// Max operator for ordering values. fn tmax(&self, _: &Self) -> Option { None } @@ -173,6 +173,7 @@ impl TensorType for halo2curves::bn256::Fr { } } +/// A wrapper for tensor related errors. #[derive(Debug)] pub struct TensorError(String); diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index eb1260ea2..286b3823d 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -317,6 +317,26 @@ pub fn mult>(t: &Vec>) -> Tensor { output } +/// Elementwise divide a tensor with another tensor. +/// # Arguments +/// +/// * `t` - Tensor +/// * `d` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::div; +/// let x = Tensor::::new( +/// Some(&[4, 1, 4, 1, 1, 4]), +/// &[2, 3], +/// ).unwrap(); +/// let y = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// let result = div(&x, y); +/// let expected = Tensor::::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` pub fn div>(t: Tensor, d: Tensor) -> Tensor { assert_eq!(t.dims(), d.dims()); // calculate value of output diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 07305d679..fec0db994 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -1,34 +1,34 @@ use super::*; use halo2_proofs::plonk::Instance; -/// A wrapper around a tensor where the inner type is one of Halo2's `Value`, `Value>`, `AssignedCell, F>`. +/// A wrapper around a [Tensor] where the inner type is one of Halo2's [Value], [Value>], [AssignedCell, F>]. /// This enum is generally used to assign values to variables / advices already configured in a Halo2 circuit (usually represented as a [VarTensor]). /// For instance can represent pre-trained neural network weights; or a known input to a network. #[derive(Debug, Clone)] pub enum ValTensor { - /// A tensor of Values, each containing a field element + /// A tensor of [Value], each containing a field element Value { - /// Underlying Tensor. + /// Underlying [Tensor]. inner: Tensor>, /// Vector of dimensions of the tensor. dims: Vec, }, - /// A tensor of Values, each containing a ratio of field elements, which may be evaluated to produce plain field elements. + /// A tensor of [Value], each containing a ratio of field elements, which may be evaluated to produce plain field elements. AssignedValue { - /// Underlying Tensor. + /// Underlying [Tensor]. inner: Tensor>>, - /// Vector of dimensions of the tensor. + /// Vector of dimensions of the [Tensor]. dims: Vec, }, /// A tensor of AssignedCells, with data both a value and the matrix cell to which it is assigned. PrevAssigned { - /// Underlying Tensor. + /// Underlying [Tensor]. inner: Tensor>, - /// Vector of dimensions of the tensor. + /// Vector of dimensions of the [Tensor]. dims: Vec, }, - /// A tensor backed by an Instance column + /// A tensor backed by an [Instance] column Instance { - /// Underlying Tensor. + /// [Instance] inner: Column, /// Vector of dimensions of the tensor. dims: Vec, @@ -63,7 +63,7 @@ impl From>> for ValTensor } impl ValTensor { - /// Allocate a new ValTensor::Instance from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. + /// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. pub fn new_instance(cs: &mut ConstraintSystem, dims: Vec, equality: bool) -> Self { let col = cs.instance_column(); if equality { @@ -100,7 +100,7 @@ impl ValTensor { } } - /// Sets the `ValTensor`'s shape. + /// Sets the [ValTensor]'s shape. pub fn reshape(&mut self, new_dims: &[usize]) { match self { ValTensor::Value { inner: v, dims: d } => { @@ -125,7 +125,7 @@ impl ValTensor { } } - /// Calls `flatten` on the inner tensor. + /// Calls `flatten` on the inner [Tensor]. pub fn flatten(&mut self) { match self { ValTensor::Value { inner: v, dims: d } => { @@ -146,7 +146,7 @@ impl ValTensor { } } - /// Returns the `dims` attribute of the `ValTensor`. + /// Returns the `dims` attribute of the [ValTensor]. pub fn dims(&self) -> &[usize] { match self { ValTensor::Value { dims: d, .. } @@ -155,7 +155,7 @@ impl ValTensor { | ValTensor::Instance { dims: d, .. } => d, } } - /// A `String` representation of the `ValTensor` for display, for example in showing intermediate values in a computational graph. + /// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph. pub fn show(&self) -> String { match self.clone() { ValTensor::PrevAssigned { inner: v, dims: _ } => { diff --git a/src/tensor/var.rs b/src/tensor/var.rs index 25d3a0d89..6e8f80709 100644 --- a/src/tensor/var.rs +++ b/src/tensor/var.rs @@ -38,12 +38,14 @@ pub enum VarTensor { impl VarTensor { /// Create a new VarTensor::Advice - /// `cs` is the `ConstraintSystem` from which the columns will be allocated. - /// `k` is the log2 number of rows in the matrix, including any system and blinding rows. - /// `capacity` is the number of advice cells for this tensor - /// `dims` is the `Vec` of dimensions of the tensor we are representing. Note that the shape of the storage and this shape can differ. - /// `equality` should be true if we want to enable equality constraints for the columns involved. - /// `max_rot` is the maximum number of rotations that we allow for this VarTensor. Rotations affect performance. + /// Arguments + /// + /// * `cs` - `ConstraintSystem` from which the columns will be allocated. + /// * `k` - log2 number of rows in the matrix, including any system and blinding rows. + /// * `capacity` - number of advice cells for this tensor + /// * `dims` - `Vec` of dimensions of the tensor we are representing. Note that the shape of the storage and this shape can differ. + /// * `equality` - true if we want to enable equality constraints for the columns involved. + /// * `max_rot` - maximum number of rotations that we allow for this VarTensor. Rotations affect performance. pub fn new_advice( cs: &mut ConstraintSystem, k: usize, From 8035267519a2463fdd73c3ec69ea549f03329d2e Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Fri, 23 Dec 2022 20:55:57 +0000 Subject: [PATCH 06/13] some of pfsys --- src/pfsys/mod.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 96e6855d3..b452c9f15 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -197,6 +197,7 @@ pub fn prepare_circuit(data: &ModelInput) -> ModelCircuit { } } +/// Deserializes the required inputs to a model at path `datapath` to a [ModelInput] struct. pub fn prepare_data(datapath: String) -> ModelInput { let mut file = match File::open(data_path(datapath)) { Ok(t) => t, @@ -216,6 +217,7 @@ pub fn prepare_data(datapath: String) -> ModelInput { data } +/// Creates a [VerifyingKey] and [ProvingKey] for a [ModelCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`). pub fn create_keys( circuit: &ModelCircuit, params: &'_ Scheme::ParamsProver, @@ -337,6 +339,7 @@ where result } +/// Loads a [VerifyingKey] at `path`. pub fn load_vk( path: PathBuf, params: &'_ Scheme::ParamsVerifier, @@ -355,6 +358,7 @@ where VerifyingKey::::read::<_, ModelCircuit>(&mut reader, params).unwrap() } +/// Loads the [CommitmentScheme::ParamsVerifier] at `path`. pub fn load_params(path: PathBuf) -> Scheme::ParamsVerifier { info!("loading params from {:?}", path); let f = match File::open(path) { @@ -367,20 +371,19 @@ pub fn load_params(path: PathBuf) -> Scheme::ParamsVer Params::<'_, Scheme::Curve>::read(&mut reader).unwrap() } -pub fn save_vk(vk_path: &PathBuf, vk: &VerifyingKey) { +/// Saves a [VerifyingKey] to `path`. +pub fn save_vk(path: &PathBuf, vk: &VerifyingKey) { info!("saving verification key 💾"); - let f = File::create(vk_path).unwrap(); + let f = File::create(path).unwrap(); let mut writer = BufWriter::new(f); vk.write(&mut writer).unwrap(); writer.flush().unwrap(); } -pub fn save_params( - params_path: &PathBuf, - params: &'_ Scheme::ParamsVerifier, -) { +/// Saves [CommitmentScheme] parameters to `path`. +pub fn save_params(path: &PathBuf, params: &'_ Scheme::ParamsVerifier) { info!("saving parameters 💾"); - let f = File::create(params_path).unwrap(); + let f = File::create(path).unwrap(); let mut writer = BufWriter::new(f); params.write(&mut writer).unwrap(); writer.flush().unwrap(); From 00cb9dbeac4903ccdbfa397b42f69f427d94dbb0 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 24 Dec 2022 11:10:34 -0500 Subject: [PATCH 07/13] More pfsys --- src/pfsys/mod.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index b452c9f15..a25f56527 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -27,7 +27,8 @@ use std::ops::Deref; use std::path::PathBuf; use std::time::Instant; -/// Defines the input to a model. +/// The input tensor data and shape, and output data for the computational graph (model) as floats. +/// For example, the input might be the image data for a neural network, and the output class scores. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelInput { /// Inputs to the model / computational graph. @@ -38,7 +39,7 @@ pub struct ModelInput { pub output_data: Vec>, } -/// Defines the proof generated by a model / circuit. +/// Defines the proof generated by a model / circuit suitably for serialization/deserialization. #[derive(Debug, Deserialize, Serialize)] pub struct Proof { /// Public inputs to the model. @@ -61,6 +62,7 @@ impl Proof { file.write_all(serialized.as_bytes()).expect("write failed"); } + /// Load a json serialized proof from the provided path. pub fn load(proof_path: &PathBuf) -> Self { let mut file = match File::open(proof_path) { Ok(f) => f, @@ -121,6 +123,7 @@ pub fn parse_prover_errors(f: &VerifyFailure) { } } +/// Initialize the model circuit and quantize the provided float inputs from the provided `ModelInput`. pub fn prepare_circuit_and_public_input( data: &ModelInput, ) -> (ModelCircuit, Vec>) { @@ -175,6 +178,7 @@ pub fn prepare_circuit_and_public_input( (circuit, public_inputs) } +/// Initialize the model circuit pub fn prepare_circuit(data: &ModelInput) -> ModelCircuit { let args = Cli::parse(); From 3e78c571c8c3d1a92e6e55aef99e88c535458dc6 Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 24 Dec 2022 12:00:50 -0500 Subject: [PATCH 08/13] docs for graph/vars.rs --- src/graph/vars.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/graph/vars.rs b/src/graph/vars.rs index 35d0dfb2f..c962d0e5c 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -10,12 +10,16 @@ use log::error; use serde::Deserialize; use std::{cell::RefCell, rc::Rc}; +/// Label Enum to track whether model input, model parameters, and model output are public or private #[derive(Clone, Debug, Deserialize)] pub enum Visibility { + /// Mark an item as private to the prover (not in the proof submitted for verification) Private, + /// Mark an item as public (sent in the proof submitted for verification) Public, } impl Visibility { + #[allow(missing_docs)] pub fn is_public(&self) -> bool { matches!(&self, Visibility::Public) } @@ -29,10 +33,14 @@ impl std::fmt::Display for Visibility { } } +/// Whether the model input, model parameters, and model output are Public or Private to the prover. #[derive(Clone, Debug, Deserialize)] pub struct VarVisibility { + /// Input to the model or computational graph pub input: Visibility, + /// Parameters, such as weights and biases, in the model pub params: Visibility, + /// Output of the model or computational graph pub output: Visibility, } impl std::fmt::Display for VarVisibility { @@ -46,6 +54,8 @@ impl std::fmt::Display for VarVisibility { } impl VarVisibility { + /// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover. + /// Place in [VarVIsibility] struct. pub fn from_args() -> Self { let args = Cli::parse(); @@ -75,12 +85,17 @@ impl VarVisibility { } } +/// Lookup tables that will be available for reuse. pub enum TableTypes { + /// Reference to a ReLU table ReLu(Rc>>>), + /// Reference to a DivideBy table DivideBy(Rc>>>), + /// Reference to a Sigmoid table Sigmoid(Rc>>>), } impl TableTypes { + /// Get a reference to a reused ReLU lookup table pub fn get_relu(&self) -> Rc>>> { match self { TableTypes::ReLu(inner) => inner.clone(), @@ -89,6 +104,7 @@ impl TableTypes { } } } + /// Get a reference to a reused DivideBy lookup table pub fn get_div(&self) -> Rc>>> { match self { TableTypes::DivideBy(inner) => inner.clone(), @@ -97,6 +113,7 @@ impl TableTypes { } } } + /// Get a reference to a reused Sigmoid lookup table pub fn get_sig(&self) -> Rc>>> { match self { TableTypes::Sigmoid(inner) => inner.clone(), @@ -107,14 +124,19 @@ impl TableTypes { } } +/// A wrapper for holding all columns that will be assigned to by a model. #[derive(Clone)] pub struct ModelVars { + #[allow(missing_docs)] pub advices: Vec, + #[allow(missing_docs)] pub fixed: Vec, + #[allow(missing_docs)] pub instances: Vec>, } -/// A wrapper for holding all columns that will be assigned to by a model. + impl ModelVars { + /// Allocate all columns that will be assigned to by a model. pub fn new( cs: &mut ConstraintSystem, logrows: usize, From 471f6f49fb79f7c2ab6b470ed122c8f31059593b Mon Sep 17 00:00:00 2001 From: jason Date: Sat, 24 Dec 2022 14:15:18 -0500 Subject: [PATCH 09/13] Partial graph/node.rs docs --- src/graph/node.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/graph/node.rs b/src/graph/node.rs index bfa9f2ee3..d73e5fc10 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -35,18 +35,27 @@ use tract_onnx::tract_hir::{ /// Enum of the different kinds of operations `ezkl` can support. #[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)] pub enum OpKind { + /// A ReLU nonlinearity ReLU(usize), + /// A Sigmoid nonlinearity Sigmoid(usize), + /// A DivideBy nonlinearity Div(usize), + /// A Constant tensor, such as a parameter or hyperparameter Const, + /// An input to the model Input, + /// A fused op, combining affine layers or other arithmetic Fused(FusedOp), + /// Unable to parse the node type Unknown(String), + #[allow(missing_docs)] #[default] None, } impl OpKind { + /// Produce an OpKind from a `&str` onnx name pub fn new(name: &str) -> Self { match name { "Clip" => OpKind::ReLU(1), @@ -77,10 +86,12 @@ impl OpKind { } } } + /// Identify fused OpKind pub fn is_fused(&self) -> bool { matches!(self, OpKind::Fused(_)) } + /// Identify constant OpKind pub fn is_const(&self) -> bool { matches!(self, OpKind::Const) } @@ -102,6 +113,7 @@ impl fmt::Display for OpKind { } /// Enum of the different kinds of node configurations `ezkl` can support. +#[allow(missing_docs)] #[derive(Clone, Default, Debug)] pub enum NodeConfigTypes { ReLU(EltwiseConfig>, Vec), @@ -119,10 +131,12 @@ pub enum NodeConfigTypes { pub struct NodeGraph(pub BTreeMap, BTreeMap>); impl NodeGraph { + /// Create an empty NodeGraph pub fn new() -> Self { NodeGraph(BTreeMap::new()) } + /// Insert the node with given tract `node_idx` and config at `idx` pub fn insert(&mut self, idx: Option, node_idx: usize, config: Node) { match self.0.entry(idx) { Entry::Vacant(e) => { From e7e1e6443012f985465b97d0865805da99038983 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Tue, 27 Dec 2022 11:48:00 +0000 Subject: [PATCH 10/13] final docs --- src/circuit/eltwise.rs | 3 +- src/graph/mod.rs | 17 +++++---- src/graph/model.rs | 78 ++++++++++++++++++++++++------------------ src/graph/node.rs | 40 +++++++++++++++------- src/lib.rs | 4 ++- 5 files changed, 88 insertions(+), 54 deletions(-) diff --git a/src/circuit/eltwise.rs b/src/circuit/eltwise.rs index 27decea39..ac5a110bf 100644 --- a/src/circuit/eltwise.rs +++ b/src/circuit/eltwise.rs @@ -31,7 +31,8 @@ pub struct Nonlin1d> { pub input: ValTensor, /// Input to the layer as a [ValTensor]. pub output: ValTensor, - _marker: PhantomData<(F, NL)>, + #[allow(missing_docs)] + pub _marker: PhantomData<(F, NL)>, } /// Halo2 lookup table for element wise non-linearities. diff --git a/src/graph/mod.rs b/src/graph/mod.rs index f6e8d76e7..0c745fe77 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,3 +1,13 @@ +/// Helper functions +pub mod utilities; +pub use utilities::*; +/// Crate for defining a computational graph and building a ZK-circuit from it. +pub mod model; +/// Inner elements of a computational graph that represent a single operation / constraints. +pub mod node; +/// Representations of a computational graph's variables. +pub mod vars; + use crate::tensor::TensorType; use crate::tensor::{Tensor, ValTensor}; use anyhow::Result; @@ -6,16 +16,11 @@ use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, plonk::{Circuit, ConstraintSystem, Error}, }; -use std::marker::PhantomData; -pub mod utilities; -pub use utilities::*; -pub mod model; -pub mod node; -pub mod vars; use log::{info, trace}; pub use model::*; pub use node::*; use std::cmp::max; +use std::marker::PhantomData; pub use vars::*; /// Defines the circuit for a computational graph / model loaded from a `.onnx` file. diff --git a/src/graph/model.rs b/src/graph/model.rs index 5f115e397..6158fceb1 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -27,19 +27,27 @@ use tract_onnx::tract_hir::internal::InferenceOp; /// Mode we're using the model in. #[derive(Clone, Debug)] pub enum Mode { + /// Initialize the model and display the operations table / graph Table, + /// Initialize the model and generate a mock proof Mock, + /// Initialize the model and generate a proof Prove, + /// Initialize the model, generate a proof, and verify FullProve, + /// Initialize the model and verify an already generated proof Verify, } -/// A circuit configuration for the entirety of a model loaded from an Onnx file. +/// A circuit configuration for a model loaded from an Onnx file. #[derive(Clone)] pub struct ModelConfig { - configs: BTreeMap>, + configs: BTreeMap, Vec)>, + /// The model struct pub model: Model, + /// (optional) range checked outputs of the model graph pub public_outputs: Vec>, + /// A wrapper for holding all columns that will be assigned to by the model pub vars: ModelVars, } @@ -190,13 +198,7 @@ impl Model { if !non_fused_ops.is_empty() { for (i, n) in non_fused_ops.iter() { let config = self.configure_table(n, meta, vars, &mut tables); - results.insert( - **i, - NodeConfig { - config, - onnx_idx: vec![**i], - }, - ); + results.insert(**i, (config, vec![**i])); } } @@ -210,10 +212,7 @@ impl Model { let config = self.fuse_ops(&fused_ops, meta, vars); results.insert( **fused_ops.keys().max().unwrap(), - NodeConfig { - config, - onnx_idx: fused_ops.keys().map(|k| **k).sorted().collect_vec(), - }, + (config, fused_ops.keys().map(|k| **k).sorted().collect_vec()), ); } } @@ -272,7 +271,7 @@ impl Model { nodes: &BTreeMap<&usize, &Node>, meta: &mut ConstraintSystem, vars: &mut ModelVars, - ) -> NodeConfigTypes { + ) -> NodeConfig { let input_nodes: BTreeMap<(&usize, &FusedOp), Vec> = nodes .iter() .map(|(i, e)| { @@ -349,7 +348,7 @@ impl Model { let inputs = inputs_to_layer.iter(); - NodeConfigTypes::Fused( + NodeConfig::Fused( FusedConfig::configure( meta, &inputs.clone().map(|x| x.1.clone()).collect_vec(), @@ -373,7 +372,7 @@ impl Model { meta: &mut ConstraintSystem, vars: &mut ModelVars, tables: &mut BTreeMap>, - ) -> NodeConfigTypes { + ) -> NodeConfig { let input_len = node.in_dims[0].iter().product(); let input = &vars.advices[0].reshape(&[input_len]); let output = &vars.advices[1].reshape(&[input_len]); @@ -385,7 +384,7 @@ impl Model { let table = tables.get(&node.opkind).unwrap(); let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_div()); - NodeConfigTypes::Divide(conf, node_inputs) + NodeConfig::Divide(conf, node_inputs) } else { let conf: EltwiseConfig> = EltwiseConfig::configure(meta, input, output, Some(&[self.bits, *s])); @@ -393,7 +392,7 @@ impl Model { node.opkind.clone(), TableTypes::DivideBy(conf.table.clone()), ); - NodeConfigTypes::Divide(conf, node_inputs) + NodeConfig::Divide(conf, node_inputs) } } OpKind::ReLU(s) => { @@ -401,12 +400,12 @@ impl Model { let table = tables.get(&node.opkind).unwrap(); let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_relu()); - NodeConfigTypes::ReLU(conf, node_inputs) + NodeConfig::ReLU(conf, node_inputs) } else { let conf: EltwiseConfig> = EltwiseConfig::configure(meta, input, output, Some(&[self.bits, *s])); tables.insert(node.opkind.clone(), TableTypes::ReLu(conf.table.clone())); - NodeConfigTypes::ReLU(conf, node_inputs) + NodeConfig::ReLU(conf, node_inputs) } } OpKind::Sigmoid(s) => { @@ -414,7 +413,7 @@ impl Model { let table = tables.get(&node.opkind).unwrap(); let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_sig()); - NodeConfigTypes::Sigmoid(conf, node_inputs) + NodeConfig::Sigmoid(conf, node_inputs) } else { let conf: EltwiseConfig> = EltwiseConfig::configure( meta, @@ -423,18 +422,18 @@ impl Model { Some(&[self.bits, *s, scale_to_multiplier(self.scale) as usize]), ); tables.insert(node.opkind.clone(), TableTypes::Sigmoid(conf.table.clone())); - NodeConfigTypes::Sigmoid(conf, node_inputs) + NodeConfig::Sigmoid(conf, node_inputs) } } OpKind::Const => { // Typically parameters for one or more layers. // Currently this is handled in the consuming node(s), but will be moved here. - NodeConfigTypes::Const + NodeConfig::Const } OpKind::Input => { // This is the input to the model (e.g. the image). // Currently this is handled in the consuming node(s), but will be moved here. - NodeConfigTypes::Input + NodeConfig::Input } OpKind::Fused(s) => { error!("For {:?} call fuse_fused_ops instead", s); @@ -470,9 +469,9 @@ impl Model { results.insert(i.0, i.1.clone()); } } - for (idx, c) in config.configs.iter() { + for (idx, (config, node_idx)) in config.configs.iter() { let mut display: String = "".to_string(); - for (i, idx) in c.onnx_idx[0..].iter().enumerate() { + for (i, idx) in node_idx.iter().enumerate() { let node = &self.nodes.filter(*idx); if i > 0 { display.push_str(&format!( @@ -489,7 +488,7 @@ impl Model { info!("{}", display); - if let Some(vt) = self.layout_config(layouter, &mut results, c)? { + if let Some(vt) = self.layout_config(layouter, &mut results, config)? { // we get the max as for fused nodes this corresponds to the node output results.insert(*idx, vt); //only use with mock prover @@ -541,8 +540,8 @@ impl Model { config: &NodeConfig, ) -> Result>> { // The node kind and the config should be the same. - let res = match config.config.clone() { - NodeConfigTypes::Fused(mut ac, idx) => { + let res = match config.clone() { + NodeConfig::Fused(mut ac, idx) => { let values: Vec> = idx .iter() .map(|i| { @@ -563,21 +562,21 @@ impl Model { Some(ac.layout(layouter, &values)) } - NodeConfigTypes::ReLU(rc, idx) => { + NodeConfig::ReLU(rc, idx) => { assert_eq!(idx.len(), 1); // For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output. Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) } - NodeConfigTypes::Sigmoid(sc, idx) => { + NodeConfig::Sigmoid(sc, idx) => { assert_eq!(idx.len(), 1); Some(sc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) } - NodeConfigTypes::Divide(dc, idx) => { + NodeConfig::Divide(dc, idx) => { assert_eq!(idx.len(), 1); Some(dc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) } - NodeConfigTypes::Input => None, - NodeConfigTypes::Const => None, + NodeConfig::Input => None, + NodeConfig::Const => None, c => { panic!("Not a configurable op {:?}", c) } @@ -635,19 +634,23 @@ impl Model { self.model.nodes().to_vec() } + /// Returns the ID of the computational graph's inputs pub fn input_outlets(&self) -> Result> { Ok(self.model.input_outlets()?.to_vec()) } + /// Returns the ID of the computational graph's outputs pub fn output_outlets(&self) -> Result> { Ok(self.model.output_outlets()?.to_vec()) } + /// Returns the number of the computational graph's inputs pub fn num_inputs(&self) -> usize { let input_nodes = self.model.inputs.iter(); input_nodes.len() } + /// Returns shapes of the computational graph's inputs pub fn input_shapes(&self) -> Vec> { self.model .inputs @@ -656,11 +659,13 @@ impl Model { .collect_vec() } + /// Returns the number of the computational graph's outputs pub fn num_outputs(&self) -> usize { let output_nodes = self.model.outputs.iter(); output_nodes.len() } + /// Returns shapes of the computational graph's outputs pub fn output_shapes(&self) -> Vec> { self.model .outputs @@ -669,6 +674,7 @@ impl Model { .collect_vec() } + /// Returns the fixed point scale of the computational graph's outputs pub fn get_output_scales(&self) -> Vec { let output_nodes = self.model.outputs.iter(); output_nodes @@ -676,6 +682,7 @@ impl Model { .collect_vec() } + /// Max number of inlets or outlets to a node pub fn max_node_size(&self) -> usize { max( self.nodes @@ -699,6 +706,7 @@ impl Model { ) } + /// Max number of parameters (i.e trainable weights) across the computational graph pub fn max_node_params(&self) -> usize { let mut maximum_number_inputs = 0; for (_, bucket_nodes) in self.nodes.0.iter() { @@ -722,6 +730,7 @@ impl Model { maximum_number_inputs + 1 } + /// Maximum number of input variables in fused layers pub fn max_node_vars_fused(&self) -> usize { let mut maximum_number_inputs = 0; for (_, bucket_nodes) in self.nodes.0.iter() { @@ -745,6 +754,7 @@ impl Model { maximum_number_inputs + 1 } + /// Maximum number of input variables in non-fused layers pub fn max_node_vars_non_fused(&self) -> usize { let mut maximum_number_inputs = 0; for (_, bucket_nodes) in self.nodes.0.iter() { diff --git a/src/graph/node.rs b/src/graph/node.rs index d73e5fc10..d8b1f3c6e 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -115,7 +115,7 @@ impl fmt::Display for OpKind { /// Enum of the different kinds of node configurations `ezkl` can support. #[allow(missing_docs)] #[derive(Clone, Default, Debug)] -pub enum NodeConfigTypes { +pub enum NodeConfig { ReLU(EltwiseConfig>, Vec), Sigmoid(EltwiseConfig>, Vec), Divide(EltwiseConfig>, Vec), @@ -148,6 +148,7 @@ impl NodeGraph { } } + /// Flattens the inner [BTreeMap] into a [Vec] of [Node]s. pub fn flatten(&self) -> Vec { let a = self .0 @@ -167,6 +168,7 @@ impl NodeGraph { c } + /// Retrieves a node, as specified by idx, from the Graph of bucketed nodes. pub fn filter(&self, idx: usize) -> Node { let a = self.flatten(); let c = &a @@ -178,12 +180,9 @@ impl NodeGraph { } } -/// A circuit configuration for a single self. -#[derive(Clone, Default, Debug)] -pub struct NodeConfig { - pub config: NodeConfigTypes, - pub onnx_idx: Vec, -} +// /// A circuit configuration for a single self. +// #[derive(Clone, Default, Debug)] +// pub struct NodeConfig(pub NodeConfigTypes); fn display_option(o: &Option) -> String { match o { @@ -234,28 +233,45 @@ fn display_tensorf32(o: &Option>) -> String { /// * `bucket` - The execution bucket this node has been assigned to. #[derive(Clone, Debug, Default, Tabled)] pub struct Node { + /// [OpKind] enum, i.e what operation this node represents. pub opkind: OpKind, + /// The inferred maximum value that can appear in the output tensor given previous quantization choices. pub output_max: f32, + /// The denominator in the fixed point representation for the node's input. Tensors of differing scales should not be combined. pub in_scale: i32, + /// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined. pub out_scale: i32, #[tabled(display_with = "display_tensor")] - pub const_value: Option>, // float value * 2^qscale if applicable. + /// The quantized constants potentially associated with this self. + pub const_value: Option>, #[tabled(display_with = "display_tensorf32")] + /// The un-quantized constants potentially associated with this self. pub raw_const_value: Option>, // Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias), // but in_dim is [in], out_dim is [out] #[tabled(display_with = "display_inputs")] + /// The indices of the node's inputs. pub inputs: Vec, #[tabled(display_with = "display_vector")] + /// Dimensions of input. pub in_dims: Vec>, #[tabled(display_with = "display_vector")] + /// Dimensions of output. pub out_dims: Vec, + /// The node's unique identifier. pub idx: usize, #[tabled(display_with = "display_option")] + /// The execution bucket this node has been assigned to. pub bucket: Option, } impl Node { + /// Converts a tract [OnnxNode] into an ezkl [Node]. + /// # Arguments: + /// * `node` - [OnnxNode] + /// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph. + /// * `scale` - The denominator in the fixed point representation. Tensors of differing scales should not be combined. + /// * `idx` - The node's unique identifier. pub fn new( mut node: OnnxNode>, other_nodes: &mut BTreeMap, @@ -955,8 +971,8 @@ impl Node { mn } - /// Ensures all inputs to a node have the same floating point denominator. - pub fn homogenize_input_scales(opkind: OpKind, inputs: Vec) -> OpKind { + /// Ensures all inputs to a node have the same fixed point denominator. + fn homogenize_input_scales(opkind: OpKind, inputs: Vec) -> OpKind { let mut multipliers = vec![1; inputs.len()]; let out_scales = inputs.windows(1).map(|w| w[0].out_scale).collect_vec(); if !out_scales.windows(2).all(|w| w[0] == w[1]) { @@ -990,7 +1006,7 @@ impl Node { } } - pub fn quantize_const_to_scale(&mut self, scale: i32) { + fn quantize_const_to_scale(&mut self, scale: i32) { assert!(matches!(self.opkind, OpKind::Const)); let raw = self.raw_const_value.as_ref().unwrap(); self.out_scale = scale; @@ -1000,7 +1016,7 @@ impl Node { } /// Re-quantizes a constant value node to a new scale. - pub fn scale_up_const_node(node: &mut Node, scale_diff: i32) -> &mut Node { + fn scale_up_const_node(node: &mut Node, scale_diff: i32) -> &mut Node { assert!(matches!(node.opkind, OpKind::Const)); if scale_diff > 0 { if let Some(val) = &node.const_value { diff --git a/src/lib.rs b/src/lib.rs index e0a242fe0..7b83cc258 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ #![deny(missing_docs, warnings, unsafe_code)] #![feature(slice_flatten)] +//! A library for turning computational graphs, such as neural networks, into ZK-circuits. +//! /// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit. pub mod circuit; -/// Commands +/// CLI commands. pub mod commands; /// Utilities for converting from Halo2 Field types to integers (and vice-versa). pub mod fieldutils; From 798b9520418cedf1c78f2b5b0abb7cce3205379e Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Tue, 27 Dec 2022 11:57:36 +0000 Subject: [PATCH 11/13] cleanup print --- src/graph/model.rs | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/src/graph/model.rs b/src/graph/model.rs index 6158fceb1..dd303b6d7 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -42,7 +42,7 @@ pub enum Mode { /// A circuit configuration for a model loaded from an Onnx file. #[derive(Clone)] pub struct ModelConfig { - configs: BTreeMap, Vec)>, + configs: BTreeMap>, /// The model struct pub model: Model, /// (optional) range checked outputs of the model graph @@ -198,7 +198,7 @@ impl Model { if !non_fused_ops.is_empty() { for (i, n) in non_fused_ops.iter() { let config = self.configure_table(n, meta, vars, &mut tables); - results.insert(**i, (config, vec![**i])); + results.insert(**i, config); } } @@ -210,10 +210,14 @@ impl Model { // preserves ordering if !fused_ops.is_empty() { let config = self.fuse_ops(&fused_ops, meta, vars); - results.insert( - **fused_ops.keys().max().unwrap(), - (config, fused_ops.keys().map(|k| **k).sorted().collect_vec()), - ); + results.insert(**fused_ops.keys().max().unwrap(), config); + + let mut display: String = "Fused nodes: ".to_string(); + for idx in fused_ops.keys().map(|k| **k).sorted() { + let node = &self.nodes.filter(idx); + display.push_str(&format!("| {} ({:?}) | ", idx, node.opkind)); + } + info!("{}", display); } } @@ -469,25 +473,7 @@ impl Model { results.insert(i.0, i.1.clone()); } } - for (idx, (config, node_idx)) in config.configs.iter() { - let mut display: String = "".to_string(); - for (i, idx) in node_idx.iter().enumerate() { - let node = &self.nodes.filter(*idx); - if i > 0 { - display.push_str(&format!( - "| combined with node {} ({:?}) ", - idx, node.opkind - )); - } else { - display.push_str(&format!( - "------ laying out node {} ({:?}) ", - idx, node.opkind - )); - } - } - - info!("{}", display); - + for (idx, config) in config.configs.iter() { if let Some(vt) = self.layout_config(layouter, &mut results, config)? { // we get the max as for fused nodes this corresponds to the node output results.insert(*idx, vt); From 96491c99238918ee46bea49c396ed391051f0639 Mon Sep 17 00:00:00 2001 From: jason Date: Tue, 27 Dec 2022 15:58:27 -0500 Subject: [PATCH 12/13] fix paren in doctest --- src/tensor/ops.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 286b3823d..29b44e9cb 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -333,7 +333,8 @@ pub fn mult>(t: &Vec>) -> Tensor { /// let y = Tensor::::new( /// Some(&[2, 1, 2, 1, 1, 1]), /// &[2, 3], -/// let result = div(&x, y); +/// ).unwrap(); +/// let result = div(x, y); /// let expected = Tensor::::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` From 5f58e99fd62048cfbca97df9145cf2c575bc16f3 Mon Sep 17 00:00:00 2001 From: jason Date: Tue, 27 Dec 2022 17:15:49 -0500 Subject: [PATCH 13/13] docs for evm aggregator --- src/pfsys/aggregation.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/pfsys/aggregation.rs b/src/pfsys/aggregation.rs index b57ad6ee2..435428c70 100644 --- a/src/pfsys/aggregation.rs +++ b/src/pfsys/aggregation.rs @@ -60,6 +60,7 @@ const LIMBS: usize = 4; const BITS: usize = 68; type Pcs = Kzg; type As = KzgAs; +/// Type for aggregator verification pub type Plonk = verifier::Plonk>; const T: usize = 5; @@ -69,10 +70,13 @@ const R_P: usize = 60; type Svk = KzgSuccinctVerifyingKey; type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; +/// The loader type used in the transcript definition type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +/// Application snark transcript pub type PoseidonTranscript = system::halo2::transcript::halo2::PoseidonTranscript; +/// An application snark with proof and instance variables ready for aggregation (raw field element) #[derive(Debug)] pub struct Snark { protocol: Protocol, @@ -81,6 +85,7 @@ pub struct Snark { } impl Snark { + /// Create a new application snark from proof and instance variables ready for aggregation pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { Self { protocol, @@ -104,6 +109,7 @@ impl From for SnarkWitness { } } +/// An application snark with proof and instance variables ready for aggregation (wrapped field element) #[derive(Clone)] pub struct SnarkWitness { protocol: Protocol, @@ -129,6 +135,7 @@ impl SnarkWitness { } } +/// Aggregate one or more application snarks of the same shape into a KzgAccumulator pub fn aggregate<'a>( svk: &Svk, loader: &Rc>, @@ -168,6 +175,7 @@ pub fn aggregate<'a>( accumulator } +/// The Halo2 Config for the aggregation circuit #[derive(Clone)] pub struct AggregationConfig { main_gate_config: MainGateConfig, @@ -175,6 +183,7 @@ pub struct AggregationConfig { } impl AggregationConfig { + /// Configure the aggregation circuit pub fn configure( meta: &mut ConstraintSystem, composition_bits: Vec, @@ -189,14 +198,17 @@ impl AggregationConfig { } } + /// Create a MainGate from the aggregation approach pub fn main_gate(&self) -> MainGate { MainGate::new(self.main_gate_config.clone()) } + /// Create a range chip to decompose and range check inputs pub fn range_chip(&self) -> RangeChip { RangeChip::new(self.range_config.clone()) } + /// Create an ecc chip for ec ops pub fn ecc_chip(&self) -> BaseFieldEccChip { BaseFieldEccChip::new(EccConfig::new( self.range_config.clone(), @@ -205,6 +217,7 @@ impl AggregationConfig { } } +/// Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof. #[derive(Clone)] pub struct AggregationCircuit { svk: Svk, @@ -214,6 +227,7 @@ pub struct AggregationCircuit { } impl AggregationCircuit { + /// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof. pub fn new(params: &ParamsKZG, snarks: impl IntoIterator) -> Self { let svk = params.get_g()[0].into(); let snarks = snarks.into_iter().collect_vec(); @@ -253,19 +267,22 @@ impl AggregationCircuit { } } + /// Accumulator indices used in generating verifier. pub fn accumulator_indices() -> Vec<(usize, usize)> { (0..4 * LIMBS).map(|idx| (0, idx)).collect() } + /// Number of instance variables for the aggregation circuit, used in generating verifier. pub fn num_instance() -> Vec { vec![4 * LIMBS] } + /// Instance variables for the aggregation circuit, fed to verifier. pub fn instances(&self) -> Vec> { vec![self.instances.clone()] } - pub fn as_proof(&self) -> Value<&[u8]> { + fn as_proof(&self) -> Value<&[u8]> { self.as_proof.as_ref().map(Vec::as_slice) } } @@ -336,6 +353,7 @@ impl Circuit for AggregationCircuit { } } +/// Create proof and instance variables for the application snark pub fn gen_application_snark(params: &ParamsKZG, data: &ModelInput) -> Snark { let (circuit, public_inputs) = prepare_circuit_and_public_input::(data); @@ -362,6 +380,7 @@ pub fn gen_application_snark(params: &ParamsKZG, data: &ModelInput) -> Sn Snark::new(protocol, pi_inner, proof) } +/// Create aggregation EVM verifier bytecode pub fn gen_aggregation_evm_verifier( params: &ParamsKZG, vk: &VerifyingKey, @@ -389,6 +408,7 @@ pub fn gen_aggregation_evm_verifier( evm::compile_yul(&loader.yul_code()) } +/// Verify by executing bytecode with instance variables and proof as input pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { let calldata = encode_calldata(&instances, &proof); let success = { @@ -412,10 +432,12 @@ pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec< assert!(success); } +/// Generate a structured reference string for testing. Not secure, do not use in production. pub fn gen_srs(k: u32) -> ParamsKZG { ParamsKZG::::setup(k, OsRng) } +/// Generate the proving key pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { let vk = keygen_vk(params, circuit).unwrap(); keygen_pk(params, vk, circuit).unwrap()