From 4d87b66136efca53ad28560ffa643ff0dba23f99 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 27 Dec 2022 19:50:04 -0500 Subject: [PATCH] chore: the great documentation push of 2022 (#84) Co-authored-by: jason --- src/circuit/eltwise.rs | 23 ++++++++ src/circuit/fused.rs | 2 + src/circuit/range.rs | 1 + src/commands.rs | 6 +++ src/fieldutils.rs | 3 ++ src/graph/mod.rs | 20 ++++--- src/graph/model.rs | 113 ++++++++++++++++++++------------------- src/graph/node.rs | 54 ++++++++++++++----- src/graph/utilities.rs | 9 ++++ src/graph/vars.rs | 24 ++++++++- src/lib.rs | 6 ++- src/pfsys/aggregation.rs | 24 ++++++++- src/pfsys/mod.rs | 33 ++++++++---- src/tensor/mod.rs | 3 +- src/tensor/ops.rs | 44 +++++++++++++++ src/tensor/val.rs | 22 ++++++-- src/tensor/var.rs | 41 ++++++++++++-- 17 files changed, 333 insertions(+), 95 deletions(-) diff --git a/src/circuit/eltwise.rs b/src/circuit/eltwise.rs index 52b954b33..ac5a110bf 100644 --- a/src/circuit/eltwise.rs +++ b/src/circuit/eltwise.rs @@ -10,7 +10,13 @@ use halo2_proofs::{ use log::error; use std::{cell::RefCell, marker::PhantomData, rc::Rc}; +/// Defines a non-linear layer. pub trait Nonlinearity { + /// Function that defines the non-linearity. + /// Arguments + /// + /// * `x` - input to function + /// * `scales` - additional parameters that may parametrize the function fn nonlinearity(x: i32, scales: &[usize]) -> F; /// a value which is always in the table fn default_pair(scales: &[usize]) -> (F, F) { @@ -18,10 +24,14 @@ pub trait Nonlinearity { } } +/// A 1D non-linearity. #[derive(Clone, Debug)] pub struct Nonlin1d> { + /// Input to the layer as a [ValTensor]. pub input: ValTensor, + /// Input to the layer as a [ValTensor]. pub output: ValTensor, + #[allow(missing_docs)] pub _marker: PhantomData<(F, NL)>, } @@ -29,15 +39,21 @@ pub struct Nonlin1d> { // Table that should be reused across all lookups (so no Clone) #[derive(Clone, Debug)] pub struct EltwiseTable> { + /// Input to table. pub table_input: TableColumn, + /// Output of table pub table_output: TableColumn, + /// Flags if table has been previously assigned to. pub is_assigned: bool, + /// Number of bits used in lookup table. pub scaling_params: Vec, + /// Number of bits used in lookup table. pub bits: usize, _marker: PhantomData<(F, NL)>, } impl> EltwiseTable { + /// Configures the table. pub fn configure( cs: &mut ConstraintSystem, bits: usize, @@ -52,6 +68,7 @@ impl> EltwiseTable { _marker: PhantomData, } } + /// Assigns values to the constraints generated when calling `configure`. pub fn layout(&mut self, layouter: &mut impl Layouter) { assert!(!self.is_assigned); let base = 2i32; @@ -100,8 +117,11 @@ impl> EltwiseTable { /// Configuration for element-wise non-linearities. #[derive(Clone, Debug)] pub struct EltwiseConfig> { + /// [VarTensor] input to non-linearity. pub input: VarTensor, + /// [VarTensor] input to non-linearity. pub output: VarTensor, + /// Lookup table used to represent the non-linearity pub table: Rc>>, qlookup: Selector, _marker: PhantomData<(NL, F)>, @@ -260,6 +280,7 @@ impl> EltwiseConfig { @@ -277,6 +298,7 @@ impl Nonlinearity for ReLu { } } +#[allow(missing_docs)] #[derive(Clone, Debug)] pub struct Sigmoid { _marker: PhantomData, @@ -292,6 +314,7 @@ impl Nonlinearity for Sigmoid { } } +#[allow(missing_docs)] #[derive(Clone, Debug)] pub struct DivideBy { _marker: PhantomData, diff --git a/src/circuit/fused.rs b/src/circuit/fused.rs index d66c17381..02aaab1be 100644 --- a/src/circuit/fused.rs +++ b/src/circuit/fused.rs @@ -12,6 +12,7 @@ use log::error; use std::fmt; use std::marker::PhantomData; +#[allow(missing_docs)] /// An enum representing the operations that can be merged into a single circuit gate. #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum FusedOp { @@ -100,6 +101,7 @@ pub struct FusedConfig { nodes: Vec, /// the (currently singular) output of the fused operations. pub output: VarTensor, + /// [Selector] generated when configuring the layer. pub selector: Selector, _marker: PhantomData, } diff --git a/src/circuit/range.rs b/src/circuit/range.rs index 225d1b356..342a4838b 100644 --- a/src/circuit/range.rs +++ b/src/circuit/range.rs @@ -13,6 +13,7 @@ use std::marker::PhantomData; #[derive(Debug, Clone)] pub struct RangeCheckConfig { input: VarTensor, + /// The value we are expecting the output of the circuit to match (within a range) pub expected: VarTensor, selector: Selector, _marker: PhantomData, diff --git a/src/commands.rs b/src/commands.rs index 803dda97b..fc7ef2f65 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -4,10 +4,12 @@ use log::info; use std::io::{stdin, stdout, Write}; use std::path::PathBuf; +#[allow(missing_docs)] #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Cli { #[command(subcommand)] + #[allow(missing_docs)] pub command: Commands, /// The tolerance for error on model outputs #[arg(short = 'T', long, default_value = "0")] @@ -35,6 +37,7 @@ pub struct Cli { pub max_rotations: usize, } +#[allow(missing_docs)] #[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq)] pub enum ProofSystem { IPA, @@ -49,6 +52,7 @@ impl std::fmt::Display for ProofSystem { } } +#[allow(missing_docs)] #[derive(Debug, Subcommand)] pub enum Commands { /// Loads model and prints model table @@ -121,6 +125,7 @@ pub enum Commands { default_missing_value = "always", value_enum )] + /// The [ProofSystem] we'll be using. pfsys: ProofSystem, // todo, optionally allow supplying proving key }, @@ -156,6 +161,7 @@ pub enum Commands { }, } +/// Loads the path to a path `data` represented as a [String]. If empty queries the user for an input. pub fn data_path(data: String) -> PathBuf { let mut s = String::new(); match data.is_empty() { diff --git a/src/fieldutils.rs b/src/fieldutils.rs index 171152933..5f09beb3e 100644 --- a/src/fieldutils.rs +++ b/src/fieldutils.rs @@ -1,5 +1,7 @@ /// Utilities for converting from Halo2 Field types to integers (and vice-versa). use halo2_proofs::arithmetic::FieldExt; + +/// Converts an i32 to a Field element. pub fn i32_to_felt(x: i32) -> F { if x >= 0 { F::from(x as u64) @@ -12,6 +14,7 @@ fn felt_to_u32(x: F) -> u32 { x.get_lower_32() } +/// Converts a Field element to an i32. pub fn felt_to_i32(x: F) -> i32 { if x > F::from(65536) { -(felt_to_u32(-x) as i32) diff --git a/src/graph/mod.rs b/src/graph/mod.rs index cf9de7dcb..0c745fe77 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,3 +1,13 @@ +/// Helper functions +pub mod utilities; +pub use utilities::*; +/// Crate for defining a computational graph and building a ZK-circuit from it. +pub mod model; +/// Inner elements of a computational graph that represent a single operation / constraints. +pub mod node; +/// Representations of a computational graph's variables. +pub mod vars; + use crate::tensor::TensorType; use crate::tensor::{Tensor, ValTensor}; use anyhow::Result; @@ -6,21 +16,19 @@ use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, plonk::{Circuit, ConstraintSystem, Error}, }; -use std::marker::PhantomData; -pub mod utilities; -pub use utilities::*; -pub mod model; -pub mod node; -pub mod vars; use log::{info, trace}; pub use model::*; pub use node::*; use std::cmp::max; +use std::marker::PhantomData; pub use vars::*; +/// Defines the circuit for a computational graph / model loaded from a `.onnx` file. #[derive(Clone, Debug)] pub struct ModelCircuit { + /// Vector of input tensors to the model / graph of computations. pub inputs: Vec>, + /// Represents the Field we are using. pub _marker: PhantomData, } diff --git a/src/graph/model.rs b/src/graph/model.rs index 8f3dd63d3..dd303b6d7 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -27,32 +27,49 @@ use tract_onnx::tract_hir::internal::InferenceOp; /// Mode we're using the model in. #[derive(Clone, Debug)] pub enum Mode { + /// Initialize the model and display the operations table / graph Table, + /// Initialize the model and generate a mock proof Mock, + /// Initialize the model and generate a proof Prove, + /// Initialize the model, generate a proof, and verify FullProve, + /// Initialize the model and verify an already generated proof Verify, } -/// A circuit configuration for the entirety of a model loaded from an Onnx file. +/// A circuit configuration for a model loaded from an Onnx file. #[derive(Clone)] pub struct ModelConfig { configs: BTreeMap>, + /// The model struct pub model: Model, + /// (optional) range checked outputs of the model graph pub public_outputs: Vec>, + /// A wrapper for holding all columns that will be assigned to by the model pub vars: ModelVars, } /// A struct for loading from an Onnx file and converting a computational graph to a circuit. #[derive(Clone, Debug)] pub struct Model { - pub model: Graph>, // The raw Tract data structure + /// The raw tract [Graph] data structure. + pub model: Graph>, + /// Graph of nodes we are loading from Onnx. pub nodes: NodeGraph, // Wrapped nodes with additional methods and data (e.g. inferred shape, quantization) + /// bits used in lookup tables pub bits: usize, + /// Log rows available in circuit. pub logrows: u32, + /// Exponent used in the fixed point representation. pub scale: i32, + /// The divergence from the expected output (if using public outputs) we can tolerate. This is in absolute value across each dimension. + /// eg. for a tolerance of 1 and for a 2D output we could tolerate at most off by 1 errors for each of the 2 outputs. pub tolerance: usize, + /// The [Mode] we're using the model in. pub mode: Mode, + /// Defines which inputs to the model are public and private (params, inputs, outputs) using [VarVisibility]. pub visibility: VarVisibility, } @@ -181,13 +198,7 @@ impl Model { if !non_fused_ops.is_empty() { for (i, n) in non_fused_ops.iter() { let config = self.configure_table(n, meta, vars, &mut tables); - results.insert( - **i, - NodeConfig { - config, - onnx_idx: vec![**i], - }, - ); + results.insert(**i, config); } } @@ -199,13 +210,14 @@ impl Model { // preserves ordering if !fused_ops.is_empty() { let config = self.fuse_ops(&fused_ops, meta, vars); - results.insert( - **fused_ops.keys().max().unwrap(), - NodeConfig { - config, - onnx_idx: fused_ops.keys().map(|k| **k).sorted().collect_vec(), - }, - ); + results.insert(**fused_ops.keys().max().unwrap(), config); + + let mut display: String = "Fused nodes: ".to_string(); + for idx in fused_ops.keys().map(|k| **k).sorted() { + let node = &self.nodes.filter(idx); + display.push_str(&format!("| {} ({:?}) | ", idx, node.opkind)); + } + info!("{}", display); } } @@ -263,7 +275,7 @@ impl Model { nodes: &BTreeMap<&usize, &Node>, meta: &mut ConstraintSystem, vars: &mut ModelVars, - ) -> NodeConfigTypes { + ) -> NodeConfig { let input_nodes: BTreeMap<(&usize, &FusedOp), Vec> = nodes .iter() .map(|(i, e)| { @@ -340,7 +352,7 @@ impl Model { let inputs = inputs_to_layer.iter(); - NodeConfigTypes::Fused( + NodeConfig::Fused( FusedConfig::configure( meta, &inputs.clone().map(|x| x.1.clone()).collect_vec(), @@ -364,7 +376,7 @@ impl Model { meta: &mut ConstraintSystem, vars: &mut ModelVars, tables: &mut BTreeMap>, - ) -> NodeConfigTypes { + ) -> NodeConfig { let input_len = node.in_dims[0].iter().product(); let input = &vars.advices[0].reshape(&[input_len]); let output = &vars.advices[1].reshape(&[input_len]); @@ -376,7 +388,7 @@ impl Model { let table = tables.get(&node.opkind).unwrap(); let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_div()); - NodeConfigTypes::Divide(conf, node_inputs) + NodeConfig::Divide(conf, node_inputs) } else { let conf: EltwiseConfig> = EltwiseConfig::configure(meta, input, output, Some(&[self.bits, *s])); @@ -384,7 +396,7 @@ impl Model { node.opkind.clone(), TableTypes::DivideBy(conf.table.clone()), ); - NodeConfigTypes::Divide(conf, node_inputs) + NodeConfig::Divide(conf, node_inputs) } } OpKind::ReLU(s) => { @@ -392,12 +404,12 @@ impl Model { let table = tables.get(&node.opkind).unwrap(); let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_relu()); - NodeConfigTypes::ReLU(conf, node_inputs) + NodeConfig::ReLU(conf, node_inputs) } else { let conf: EltwiseConfig> = EltwiseConfig::configure(meta, input, output, Some(&[self.bits, *s])); tables.insert(node.opkind.clone(), TableTypes::ReLu(conf.table.clone())); - NodeConfigTypes::ReLU(conf, node_inputs) + NodeConfig::ReLU(conf, node_inputs) } } OpKind::Sigmoid(s) => { @@ -405,7 +417,7 @@ impl Model { let table = tables.get(&node.opkind).unwrap(); let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_sig()); - NodeConfigTypes::Sigmoid(conf, node_inputs) + NodeConfig::Sigmoid(conf, node_inputs) } else { let conf: EltwiseConfig> = EltwiseConfig::configure( meta, @@ -414,18 +426,18 @@ impl Model { Some(&[self.bits, *s, scale_to_multiplier(self.scale) as usize]), ); tables.insert(node.opkind.clone(), TableTypes::Sigmoid(conf.table.clone())); - NodeConfigTypes::Sigmoid(conf, node_inputs) + NodeConfig::Sigmoid(conf, node_inputs) } } OpKind::Const => { // Typically parameters for one or more layers. // Currently this is handled in the consuming node(s), but will be moved here. - NodeConfigTypes::Const + NodeConfig::Const } OpKind::Input => { // This is the input to the model (e.g. the image). // Currently this is handled in the consuming node(s), but will be moved here. - NodeConfigTypes::Input + NodeConfig::Input } OpKind::Fused(s) => { error!("For {:?} call fuse_fused_ops instead", s); @@ -461,26 +473,8 @@ impl Model { results.insert(i.0, i.1.clone()); } } - for (idx, c) in config.configs.iter() { - let mut display: String = "".to_string(); - for (i, idx) in c.onnx_idx[0..].iter().enumerate() { - let node = &self.nodes.filter(*idx); - if i > 0 { - display.push_str(&format!( - "| combined with node {} ({:?}) ", - idx, node.opkind - )); - } else { - display.push_str(&format!( - "------ laying out node {} ({:?}) ", - idx, node.opkind - )); - } - } - - info!("{}", display); - - if let Some(vt) = self.layout_config(layouter, &mut results, c)? { + for (idx, config) in config.configs.iter() { + if let Some(vt) = self.layout_config(layouter, &mut results, config)? { // we get the max as for fused nodes this corresponds to the node output results.insert(*idx, vt); //only use with mock prover @@ -532,8 +526,8 @@ impl Model { config: &NodeConfig, ) -> Result>> { // The node kind and the config should be the same. - let res = match config.config.clone() { - NodeConfigTypes::Fused(mut ac, idx) => { + let res = match config.clone() { + NodeConfig::Fused(mut ac, idx) => { let values: Vec> = idx .iter() .map(|i| { @@ -554,21 +548,21 @@ impl Model { Some(ac.layout(layouter, &values)) } - NodeConfigTypes::ReLU(rc, idx) => { + NodeConfig::ReLU(rc, idx) => { assert_eq!(idx.len(), 1); // For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output. Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) } - NodeConfigTypes::Sigmoid(sc, idx) => { + NodeConfig::Sigmoid(sc, idx) => { assert_eq!(idx.len(), 1); Some(sc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) } - NodeConfigTypes::Divide(dc, idx) => { + NodeConfig::Divide(dc, idx) => { assert_eq!(idx.len(), 1); Some(dc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) } - NodeConfigTypes::Input => None, - NodeConfigTypes::Const => None, + NodeConfig::Input => None, + NodeConfig::Const => None, c => { panic!("Not a configurable op {:?}", c) } @@ -626,19 +620,23 @@ impl Model { self.model.nodes().to_vec() } + /// Returns the ID of the computational graph's inputs pub fn input_outlets(&self) -> Result> { Ok(self.model.input_outlets()?.to_vec()) } + /// Returns the ID of the computational graph's outputs pub fn output_outlets(&self) -> Result> { Ok(self.model.output_outlets()?.to_vec()) } + /// Returns the number of the computational graph's inputs pub fn num_inputs(&self) -> usize { let input_nodes = self.model.inputs.iter(); input_nodes.len() } + /// Returns shapes of the computational graph's inputs pub fn input_shapes(&self) -> Vec> { self.model .inputs @@ -647,11 +645,13 @@ impl Model { .collect_vec() } + /// Returns the number of the computational graph's outputs pub fn num_outputs(&self) -> usize { let output_nodes = self.model.outputs.iter(); output_nodes.len() } + /// Returns shapes of the computational graph's outputs pub fn output_shapes(&self) -> Vec> { self.model .outputs @@ -660,6 +660,7 @@ impl Model { .collect_vec() } + /// Returns the fixed point scale of the computational graph's outputs pub fn get_output_scales(&self) -> Vec { let output_nodes = self.model.outputs.iter(); output_nodes @@ -667,6 +668,7 @@ impl Model { .collect_vec() } + /// Max number of inlets or outlets to a node pub fn max_node_size(&self) -> usize { max( self.nodes @@ -690,6 +692,7 @@ impl Model { ) } + /// Max number of parameters (i.e trainable weights) across the computational graph pub fn max_node_params(&self) -> usize { let mut maximum_number_inputs = 0; for (_, bucket_nodes) in self.nodes.0.iter() { @@ -713,6 +716,7 @@ impl Model { maximum_number_inputs + 1 } + /// Maximum number of input variables in fused layers pub fn max_node_vars_fused(&self) -> usize { let mut maximum_number_inputs = 0; for (_, bucket_nodes) in self.nodes.0.iter() { @@ -736,6 +740,7 @@ impl Model { maximum_number_inputs + 1 } + /// Maximum number of input variables in non-fused layers pub fn max_node_vars_non_fused(&self) -> usize { let mut maximum_number_inputs = 0; for (_, bucket_nodes) in self.nodes.0.iter() { diff --git a/src/graph/node.rs b/src/graph/node.rs index bfa9f2ee3..d8b1f3c6e 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -35,18 +35,27 @@ use tract_onnx::tract_hir::{ /// Enum of the different kinds of operations `ezkl` can support. #[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)] pub enum OpKind { + /// A ReLU nonlinearity ReLU(usize), + /// A Sigmoid nonlinearity Sigmoid(usize), + /// A DivideBy nonlinearity Div(usize), + /// A Constant tensor, such as a parameter or hyperparameter Const, + /// An input to the model Input, + /// A fused op, combining affine layers or other arithmetic Fused(FusedOp), + /// Unable to parse the node type Unknown(String), + #[allow(missing_docs)] #[default] None, } impl OpKind { + /// Produce an OpKind from a `&str` onnx name pub fn new(name: &str) -> Self { match name { "Clip" => OpKind::ReLU(1), @@ -77,10 +86,12 @@ impl OpKind { } } } + /// Identify fused OpKind pub fn is_fused(&self) -> bool { matches!(self, OpKind::Fused(_)) } + /// Identify constant OpKind pub fn is_const(&self) -> bool { matches!(self, OpKind::Const) } @@ -102,8 +113,9 @@ impl fmt::Display for OpKind { } /// Enum of the different kinds of node configurations `ezkl` can support. +#[allow(missing_docs)] #[derive(Clone, Default, Debug)] -pub enum NodeConfigTypes { +pub enum NodeConfig { ReLU(EltwiseConfig>, Vec), Sigmoid(EltwiseConfig>, Vec), Divide(EltwiseConfig>, Vec), @@ -119,10 +131,12 @@ pub enum NodeConfigTypes { pub struct NodeGraph(pub BTreeMap, BTreeMap>); impl NodeGraph { + /// Create an empty NodeGraph pub fn new() -> Self { NodeGraph(BTreeMap::new()) } + /// Insert the node with given tract `node_idx` and config at `idx` pub fn insert(&mut self, idx: Option, node_idx: usize, config: Node) { match self.0.entry(idx) { Entry::Vacant(e) => { @@ -134,6 +148,7 @@ impl NodeGraph { } } + /// Flattens the inner [BTreeMap] into a [Vec] of [Node]s. pub fn flatten(&self) -> Vec { let a = self .0 @@ -153,6 +168,7 @@ impl NodeGraph { c } + /// Retrieves a node, as specified by idx, from the Graph of bucketed nodes. pub fn filter(&self, idx: usize) -> Node { let a = self.flatten(); let c = &a @@ -164,12 +180,9 @@ impl NodeGraph { } } -/// A circuit configuration for a single self. -#[derive(Clone, Default, Debug)] -pub struct NodeConfig { - pub config: NodeConfigTypes, - pub onnx_idx: Vec, -} +// /// A circuit configuration for a single self. +// #[derive(Clone, Default, Debug)] +// pub struct NodeConfig(pub NodeConfigTypes); fn display_option(o: &Option) -> String { match o { @@ -220,28 +233,45 @@ fn display_tensorf32(o: &Option>) -> String { /// * `bucket` - The execution bucket this node has been assigned to. #[derive(Clone, Debug, Default, Tabled)] pub struct Node { + /// [OpKind] enum, i.e what operation this node represents. pub opkind: OpKind, + /// The inferred maximum value that can appear in the output tensor given previous quantization choices. pub output_max: f32, + /// The denominator in the fixed point representation for the node's input. Tensors of differing scales should not be combined. pub in_scale: i32, + /// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined. pub out_scale: i32, #[tabled(display_with = "display_tensor")] - pub const_value: Option>, // float value * 2^qscale if applicable. + /// The quantized constants potentially associated with this self. + pub const_value: Option>, #[tabled(display_with = "display_tensorf32")] + /// The un-quantized constants potentially associated with this self. pub raw_const_value: Option>, // Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias), // but in_dim is [in], out_dim is [out] #[tabled(display_with = "display_inputs")] + /// The indices of the node's inputs. pub inputs: Vec, #[tabled(display_with = "display_vector")] + /// Dimensions of input. pub in_dims: Vec>, #[tabled(display_with = "display_vector")] + /// Dimensions of output. pub out_dims: Vec, + /// The node's unique identifier. pub idx: usize, #[tabled(display_with = "display_option")] + /// The execution bucket this node has been assigned to. pub bucket: Option, } impl Node { + /// Converts a tract [OnnxNode] into an ezkl [Node]. + /// # Arguments: + /// * `node` - [OnnxNode] + /// * `other_nodes` - [BTreeMap] of other previously initialized [Node]s in the computational graph. + /// * `scale` - The denominator in the fixed point representation. Tensors of differing scales should not be combined. + /// * `idx` - The node's unique identifier. pub fn new( mut node: OnnxNode>, other_nodes: &mut BTreeMap, @@ -941,8 +971,8 @@ impl Node { mn } - /// Ensures all inputs to a node have the same floating point denominator. - pub fn homogenize_input_scales(opkind: OpKind, inputs: Vec) -> OpKind { + /// Ensures all inputs to a node have the same fixed point denominator. + fn homogenize_input_scales(opkind: OpKind, inputs: Vec) -> OpKind { let mut multipliers = vec![1; inputs.len()]; let out_scales = inputs.windows(1).map(|w| w[0].out_scale).collect_vec(); if !out_scales.windows(2).all(|w| w[0] == w[1]) { @@ -976,7 +1006,7 @@ impl Node { } } - pub fn quantize_const_to_scale(&mut self, scale: i32) { + fn quantize_const_to_scale(&mut self, scale: i32) { assert!(matches!(self.opkind, OpKind::Const)); let raw = self.raw_const_value.as_ref().unwrap(); self.out_scale = scale; @@ -986,7 +1016,7 @@ impl Node { } /// Re-quantizes a constant value node to a new scale. - pub fn scale_up_const_node(node: &mut Node, scale_diff: i32) -> &mut Node { + fn scale_up_const_node(node: &mut Node, scale_diff: i32) -> &mut Node { assert!(matches!(node.opkind, OpKind::Const)); if scale_diff > 0 { if let Some(val) = &node.const_value { diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 01b168e47..75a49e899 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -4,6 +4,13 @@ use tract_onnx::prelude::{InferenceFact, Node}; use tract_onnx::tract_hir::internal::InferenceOp; // Warning: currently ignores stride information +/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation. +/// Arguments +/// +/// * `vec` - the vector to quantize. +/// * `dims` - the dimensionality of the resulting [Tensor]. +/// * `shift` - offset used in the fixed point representation. +/// * `scale` - `2^scale` used in the fixed point representation. pub fn vector_to_quantized( vec: &[f32], dims: &[usize], @@ -18,10 +25,12 @@ pub fn vector_to_quantized( Tensor::new(Some(&scaled), dims) } +/// Converts a scale (log base 2) to a fixed point multiplier. pub fn scale_to_multiplier(scale: i32) -> f32 { i32::pow(2, scale as u32) as f32 } +/// Gets the shape of a onnx node's outlets. pub fn node_output_shapes( node: &Node>, ) -> Result>>> { diff --git a/src/graph/vars.rs b/src/graph/vars.rs index 35d0dfb2f..c962d0e5c 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -10,12 +10,16 @@ use log::error; use serde::Deserialize; use std::{cell::RefCell, rc::Rc}; +/// Label Enum to track whether model input, model parameters, and model output are public or private #[derive(Clone, Debug, Deserialize)] pub enum Visibility { + /// Mark an item as private to the prover (not in the proof submitted for verification) Private, + /// Mark an item as public (sent in the proof submitted for verification) Public, } impl Visibility { + #[allow(missing_docs)] pub fn is_public(&self) -> bool { matches!(&self, Visibility::Public) } @@ -29,10 +33,14 @@ impl std::fmt::Display for Visibility { } } +/// Whether the model input, model parameters, and model output are Public or Private to the prover. #[derive(Clone, Debug, Deserialize)] pub struct VarVisibility { + /// Input to the model or computational graph pub input: Visibility, + /// Parameters, such as weights and biases, in the model pub params: Visibility, + /// Output of the model or computational graph pub output: Visibility, } impl std::fmt::Display for VarVisibility { @@ -46,6 +54,8 @@ impl std::fmt::Display for VarVisibility { } impl VarVisibility { + /// Read from cli args whether the model input, model parameters, and model output are Public or Private to the prover. + /// Place in [VarVIsibility] struct. pub fn from_args() -> Self { let args = Cli::parse(); @@ -75,12 +85,17 @@ impl VarVisibility { } } +/// Lookup tables that will be available for reuse. pub enum TableTypes { + /// Reference to a ReLU table ReLu(Rc>>>), + /// Reference to a DivideBy table DivideBy(Rc>>>), + /// Reference to a Sigmoid table Sigmoid(Rc>>>), } impl TableTypes { + /// Get a reference to a reused ReLU lookup table pub fn get_relu(&self) -> Rc>>> { match self { TableTypes::ReLu(inner) => inner.clone(), @@ -89,6 +104,7 @@ impl TableTypes { } } } + /// Get a reference to a reused DivideBy lookup table pub fn get_div(&self) -> Rc>>> { match self { TableTypes::DivideBy(inner) => inner.clone(), @@ -97,6 +113,7 @@ impl TableTypes { } } } + /// Get a reference to a reused Sigmoid lookup table pub fn get_sig(&self) -> Rc>>> { match self { TableTypes::Sigmoid(inner) => inner.clone(), @@ -107,14 +124,19 @@ impl TableTypes { } } +/// A wrapper for holding all columns that will be assigned to by a model. #[derive(Clone)] pub struct ModelVars { + #[allow(missing_docs)] pub advices: Vec, + #[allow(missing_docs)] pub fixed: Vec, + #[allow(missing_docs)] pub instances: Vec>, } -/// A wrapper for holding all columns that will be assigned to by a model. + impl ModelVars { + /// Allocate all columns that will be assigned to by a model. pub fn new( cs: &mut ConstraintSystem, logrows: usize, diff --git a/src/lib.rs b/src/lib.rs index ac6c92634..7b83cc258 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ -#![deny(warnings, unsafe_code)] +#![deny(missing_docs, warnings, unsafe_code)] #![feature(slice_flatten)] +//! A library for turning computational graphs, such as neural networks, into ZK-circuits. +//! /// Methods for configuring tensor operations and assigning values to them in a Halo2 circuit. pub mod circuit; -/// Commands +/// CLI commands. pub mod commands; /// Utilities for converting from Halo2 Field types to integers (and vice-versa). pub mod fieldutils; diff --git a/src/pfsys/aggregation.rs b/src/pfsys/aggregation.rs index b57ad6ee2..435428c70 100644 --- a/src/pfsys/aggregation.rs +++ b/src/pfsys/aggregation.rs @@ -60,6 +60,7 @@ const LIMBS: usize = 4; const BITS: usize = 68; type Pcs = Kzg; type As = KzgAs; +/// Type for aggregator verification pub type Plonk = verifier::Plonk>; const T: usize = 5; @@ -69,10 +70,13 @@ const R_P: usize = 60; type Svk = KzgSuccinctVerifyingKey; type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; +/// The loader type used in the transcript definition type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +/// Application snark transcript pub type PoseidonTranscript = system::halo2::transcript::halo2::PoseidonTranscript; +/// An application snark with proof and instance variables ready for aggregation (raw field element) #[derive(Debug)] pub struct Snark { protocol: Protocol, @@ -81,6 +85,7 @@ pub struct Snark { } impl Snark { + /// Create a new application snark from proof and instance variables ready for aggregation pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { Self { protocol, @@ -104,6 +109,7 @@ impl From for SnarkWitness { } } +/// An application snark with proof and instance variables ready for aggregation (wrapped field element) #[derive(Clone)] pub struct SnarkWitness { protocol: Protocol, @@ -129,6 +135,7 @@ impl SnarkWitness { } } +/// Aggregate one or more application snarks of the same shape into a KzgAccumulator pub fn aggregate<'a>( svk: &Svk, loader: &Rc>, @@ -168,6 +175,7 @@ pub fn aggregate<'a>( accumulator } +/// The Halo2 Config for the aggregation circuit #[derive(Clone)] pub struct AggregationConfig { main_gate_config: MainGateConfig, @@ -175,6 +183,7 @@ pub struct AggregationConfig { } impl AggregationConfig { + /// Configure the aggregation circuit pub fn configure( meta: &mut ConstraintSystem, composition_bits: Vec, @@ -189,14 +198,17 @@ impl AggregationConfig { } } + /// Create a MainGate from the aggregation approach pub fn main_gate(&self) -> MainGate { MainGate::new(self.main_gate_config.clone()) } + /// Create a range chip to decompose and range check inputs pub fn range_chip(&self) -> RangeChip { RangeChip::new(self.range_config.clone()) } + /// Create an ecc chip for ec ops pub fn ecc_chip(&self) -> BaseFieldEccChip { BaseFieldEccChip::new(EccConfig::new( self.range_config.clone(), @@ -205,6 +217,7 @@ impl AggregationConfig { } } +/// Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof. #[derive(Clone)] pub struct AggregationCircuit { svk: Svk, @@ -214,6 +227,7 @@ pub struct AggregationCircuit { } impl AggregationCircuit { + /// Create a new Aggregation Circuit with a SuccinctVerifyingKey, application snark witnesses (each with a proof and instance variables), and the instance variables and the resulting aggregation circuit proof. pub fn new(params: &ParamsKZG, snarks: impl IntoIterator) -> Self { let svk = params.get_g()[0].into(); let snarks = snarks.into_iter().collect_vec(); @@ -253,19 +267,22 @@ impl AggregationCircuit { } } + /// Accumulator indices used in generating verifier. pub fn accumulator_indices() -> Vec<(usize, usize)> { (0..4 * LIMBS).map(|idx| (0, idx)).collect() } + /// Number of instance variables for the aggregation circuit, used in generating verifier. pub fn num_instance() -> Vec { vec![4 * LIMBS] } + /// Instance variables for the aggregation circuit, fed to verifier. pub fn instances(&self) -> Vec> { vec![self.instances.clone()] } - pub fn as_proof(&self) -> Value<&[u8]> { + fn as_proof(&self) -> Value<&[u8]> { self.as_proof.as_ref().map(Vec::as_slice) } } @@ -336,6 +353,7 @@ impl Circuit for AggregationCircuit { } } +/// Create proof and instance variables for the application snark pub fn gen_application_snark(params: &ParamsKZG, data: &ModelInput) -> Snark { let (circuit, public_inputs) = prepare_circuit_and_public_input::(data); @@ -362,6 +380,7 @@ pub fn gen_application_snark(params: &ParamsKZG, data: &ModelInput) -> Sn Snark::new(protocol, pi_inner, proof) } +/// Create aggregation EVM verifier bytecode pub fn gen_aggregation_evm_verifier( params: &ParamsKZG, vk: &VerifyingKey, @@ -389,6 +408,7 @@ pub fn gen_aggregation_evm_verifier( evm::compile_yul(&loader.yul_code()) } +/// Verify by executing bytecode with instance variables and proof as input pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { let calldata = encode_calldata(&instances, &proof); let success = { @@ -412,10 +432,12 @@ pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec< assert!(success); } +/// Generate a structured reference string for testing. Not secure, do not use in production. pub fn gen_srs(k: u32) -> ParamsKZG { ParamsKZG::::setup(k, OsRng) } +/// Generate the proving key pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { let vk = keygen_vk(params, circuit).unwrap(); keygen_pk(params, vk, circuit).unwrap() diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 1c55f5b1d..a25f56527 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -27,21 +27,29 @@ use std::ops::Deref; use std::path::PathBuf; use std::time::Instant; +/// The input tensor data and shape, and output data for the computational graph (model) as floats. +/// For example, the input might be the image data for a neural network, and the output class scores. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelInput { + /// Inputs to the model / computational graph. pub input_data: Vec>, + /// The shape of said inputs. pub input_shapes: Vec>, + /// The expected output of the model (can be empty vectors if outputs are not being constrained). pub output_data: Vec>, } +/// Defines the proof generated by a model / circuit suitably for serialization/deserialization. #[derive(Debug, Deserialize, Serialize)] pub struct Proof { - pub input_shapes: Vec>, + /// Public inputs to the model. pub public_inputs: Vec>, + /// The generated proof, as a vector of bytes. pub proof: Vec, } impl Proof { + /// Saves the Proof to a specified `proof_path`. pub fn save(&self, proof_path: &PathBuf) { let serialized = match serde_json::to_string(&self) { Ok(s) => s, @@ -54,6 +62,7 @@ impl Proof { file.write_all(serialized.as_bytes()).expect("write failed"); } + /// Load a json serialized proof from the provided path. pub fn load(proof_path: &PathBuf) -> Self { let mut file = match File::open(proof_path) { Ok(f) => f, @@ -114,6 +123,7 @@ pub fn parse_prover_errors(f: &VerifyFailure) { } } +/// Initialize the model circuit and quantize the provided float inputs from the provided `ModelInput`. pub fn prepare_circuit_and_public_input( data: &ModelInput, ) -> (ModelCircuit, Vec>) { @@ -168,6 +178,7 @@ pub fn prepare_circuit_and_public_input( (circuit, public_inputs) } +/// Initialize the model circuit pub fn prepare_circuit(data: &ModelInput) -> ModelCircuit { let args = Cli::parse(); @@ -190,6 +201,7 @@ pub fn prepare_circuit(data: &ModelInput) -> ModelCircuit { } } +/// Deserializes the required inputs to a model at path `datapath` to a [ModelInput] struct. pub fn prepare_data(datapath: String) -> ModelInput { let mut file = match File::open(data_path(datapath)) { Ok(t) => t, @@ -209,6 +221,7 @@ pub fn prepare_data(datapath: String) -> ModelInput { data } +/// Creates a [VerifyingKey] and [ProvingKey] for a [ModelCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`). pub fn create_keys( circuit: &ModelCircuit, params: &'_ Scheme::ParamsProver, @@ -278,7 +291,6 @@ where info!("Proof took {}", now.elapsed().as_secs()); let checkable_pf = Proof { - input_shapes: circuit.inputs.iter().map(|i| i.dims().to_vec()).collect(), public_inputs: public_inputs .iter() .map(|i| i.clone().into_iter().collect()) @@ -289,7 +301,7 @@ where (checkable_pf, dims) } -/// a wrapper around halo2's verify_proof +/// A wrapper around halo2's verify_proof pub fn verify_proof_model< 'params, F: FieldExt, @@ -331,6 +343,7 @@ where result } +/// Loads a [VerifyingKey] at `path`. pub fn load_vk( path: PathBuf, params: &'_ Scheme::ParamsVerifier, @@ -349,6 +362,7 @@ where VerifyingKey::::read::<_, ModelCircuit>(&mut reader, params).unwrap() } +/// Loads the [CommitmentScheme::ParamsVerifier] at `path`. pub fn load_params(path: PathBuf) -> Scheme::ParamsVerifier { info!("loading params from {:?}", path); let f = match File::open(path) { @@ -361,20 +375,19 @@ pub fn load_params(path: PathBuf) -> Scheme::ParamsVer Params::<'_, Scheme::Curve>::read(&mut reader).unwrap() } -pub fn save_vk(vk_path: &PathBuf, vk: &VerifyingKey) { +/// Saves a [VerifyingKey] to `path`. +pub fn save_vk(path: &PathBuf, vk: &VerifyingKey) { info!("saving verification key 💾"); - let f = File::create(vk_path).unwrap(); + let f = File::create(path).unwrap(); let mut writer = BufWriter::new(f); vk.write(&mut writer).unwrap(); writer.flush().unwrap(); } -pub fn save_params( - params_path: &PathBuf, - params: &'_ Scheme::ParamsVerifier, -) { +/// Saves [CommitmentScheme] parameters to `path`. +pub fn save_params(path: &PathBuf, params: &'_ Scheme::ParamsVerifier) { info!("saving parameters 💾"); - let f = File::create(params_path).unwrap(); + let f = File::create(path).unwrap(); let mut writer = BufWriter::new(f); params.write(&mut writer).unwrap(); writer.flush().unwrap(); diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 7427c693b..ed69bac78 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -30,7 +30,7 @@ pub trait TensorType: Clone + Debug + 'static { fn zero() -> Option { None } - + /// Max operator for ordering values. fn tmax(&self, _: &Self) -> Option { None } @@ -173,6 +173,7 @@ impl TensorType for halo2curves::bn256::Fr { } } +/// A wrapper for tensor related errors. #[derive(Debug)] pub struct TensorError(String); diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 55ea2e417..29b44e9cb 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -317,6 +317,27 @@ pub fn mult>(t: &Vec>) -> Tensor { output } +/// Elementwise divide a tensor with another tensor. +/// # Arguments +/// +/// * `t` - Tensor +/// * `d` - Tensor +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::div; +/// let x = Tensor::::new( +/// Some(&[4, 1, 4, 1, 1, 4]), +/// &[2, 3], +/// ).unwrap(); +/// let y = Tensor::::new( +/// Some(&[2, 1, 2, 1, 1, 1]), +/// &[2, 3], +/// ).unwrap(); +/// let result = div(x, y); +/// let expected = Tensor::::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` pub fn div>(t: Tensor, d: Tensor) -> Tensor { assert_eq!(t.dims(), d.dims()); // calculate value of output @@ -529,6 +550,29 @@ pub fn convolution + Add>( output } +/// Applies 2D sum pooling over a 3D tensor of shape C x H x W. +/// # Arguments +/// +/// * `image` - Tensor. +/// * `padding` - Tuple of padding values in x and y directions. +/// * `stride` - Tuple of stride values in x and y directions. +/// * `pool_dims` - Tuple of pooling window size in x and y directions. +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::sumpool; +/// use halo2_proofs::circuit::Value; +/// use halo2_proofs::plonk::Assigned; +/// use halo2curves::pasta::Fp as F; +/// +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[1, 3, 3], +/// ).unwrap(); +/// let pooled = sumpool::(&x, (0, 0), (1, 1), (2, 2)); +/// let expected: Tensor = Tensor::::new(Some(&[11, 8, 8, 10]), &[1, 2, 2]).unwrap(); +/// assert_eq!(pooled, expected); +/// ``` pub fn sumpool + Add>( image: &Tensor, padding: (usize, usize), diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 73f0741b8..fec0db994 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -1,24 +1,36 @@ use super::*; use halo2_proofs::plonk::Instance; -/// A wrapper around a tensor where the inner type is one of Halo2's `Value`, `Value>`, `AssignedCell, F>`. +/// A wrapper around a [Tensor] where the inner type is one of Halo2's [Value], [Value>], [AssignedCell, F>]. /// This enum is generally used to assign values to variables / advices already configured in a Halo2 circuit (usually represented as a [VarTensor]). /// For instance can represent pre-trained neural network weights; or a known input to a network. #[derive(Debug, Clone)] pub enum ValTensor { + /// A tensor of [Value], each containing a field element Value { + /// Underlying [Tensor]. inner: Tensor>, + /// Vector of dimensions of the tensor. dims: Vec, }, + /// A tensor of [Value], each containing a ratio of field elements, which may be evaluated to produce plain field elements. AssignedValue { + /// Underlying [Tensor]. inner: Tensor>>, + /// Vector of dimensions of the [Tensor]. dims: Vec, }, + /// A tensor of AssignedCells, with data both a value and the matrix cell to which it is assigned. PrevAssigned { + /// Underlying [Tensor]. inner: Tensor>, + /// Vector of dimensions of the [Tensor]. dims: Vec, }, + /// A tensor backed by an [Instance] column Instance { + /// [Instance] inner: Column, + /// Vector of dimensions of the tensor. dims: Vec, }, } @@ -51,6 +63,7 @@ impl From>> for ValTensor } impl ValTensor { + /// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`. pub fn new_instance(cs: &mut ConstraintSystem, dims: Vec, equality: bool) -> Self { let col = cs.instance_column(); if equality { @@ -87,7 +100,7 @@ impl ValTensor { } } - /// Sets the `ValTensor`'s shape. + /// Sets the [ValTensor]'s shape. pub fn reshape(&mut self, new_dims: &[usize]) { match self { ValTensor::Value { inner: v, dims: d } => { @@ -112,7 +125,7 @@ impl ValTensor { } } - /// Calls `flatten` on the inner tensor. + /// Calls `flatten` on the inner [Tensor]. pub fn flatten(&mut self) { match self { ValTensor::Value { inner: v, dims: d } => { @@ -133,7 +146,7 @@ impl ValTensor { } } - /// Returns the `dims` attribute of the `ValTensor`. + /// Returns the `dims` attribute of the [ValTensor]. pub fn dims(&self) -> &[usize] { match self { ValTensor::Value { dims: d, .. } @@ -142,6 +155,7 @@ impl ValTensor { | ValTensor::Instance { dims: d, .. } => d, } } + /// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph. pub fn show(&self) -> String { match self.clone() { ValTensor::PrevAssigned { inner: v, dims: _ } => { diff --git a/src/tensor/var.rs b/src/tensor/var.rs index 88db9ac21..6e8f80709 100644 --- a/src/tensor/var.rs +++ b/src/tensor/var.rs @@ -12,31 +12,53 @@ use std::cmp::min; /// using the `assign` method called on a [ValTensor]. #[derive(Clone, Debug)] pub enum VarTensor { + /// A VarTensor for holding Advice values, which are assigned at proving time. Advice { + /// Vec of Advice columns inner: Vec>, + /// Number of rows available to be used in each column of the storage col_size: usize, + /// Total capacity (number of advice cells), usually inner.len()*col_size capacity: usize, + /// Vector of dimensions of the tensor we are representing using this storage. Note that the shape of the storage and this shape can differ. dims: Vec, }, + /// A VarTensor for holding Fixed values, which are assigned at circuit definition time. Fixed { + /// Vec of Fixed columns inner: Vec>, + /// Number of rows available to be used in each column of the storage col_size: usize, + /// Total capacity (number of advice cells), usually inner.len()*col_size capacity: usize, + /// Vector of dimensions of the tensor we are representing using this storage. Note that the shape of the storage and this shape can differ. dims: Vec, }, } impl VarTensor { + /// Create a new VarTensor::Advice + /// Arguments + /// + /// * `cs` - `ConstraintSystem` from which the columns will be allocated. + /// * `k` - log2 number of rows in the matrix, including any system and blinding rows. + /// * `capacity` - number of advice cells for this tensor + /// * `dims` - `Vec` of dimensions of the tensor we are representing. Note that the shape of the storage and this shape can differ. + /// * `equality` - true if we want to enable equality constraints for the columns involved. + /// * `max_rot` - maximum number of rotations that we allow for this VarTensor. Rotations affect performance. pub fn new_advice( cs: &mut ConstraintSystem, k: usize, capacity: usize, dims: Vec, equality: bool, - v1: usize, + max_rot: usize, ) -> Self { let base = 2u32; - let max_rows = min(v1, base.pow(k as u32) as usize - cs.blinding_factors() - 1); + let max_rows = min( + max_rot, + base.pow(k as u32) as usize - cs.blinding_factors() - 1, + ); let modulo = (capacity / max_rows) + 1; let mut advices = vec![]; for _ in 0..modulo { @@ -55,16 +77,26 @@ impl VarTensor { } } + /// Create a new VarTensor::Fixed + /// `cs` is the `ConstraintSystem` from which the columns will be allocated. + /// `k` is the log2 number of rows in the matrix, including any system and blinding rows. + /// `capacity` is the number of fixed cells for this tensor + /// `dims` is the `Vec` of dimensions of the tensor we are representing. Note that the shape of the storage and this shape can differ. + /// `equality` should be true if we want to enable equality constraints for the columns involved. + /// `max_rot` is the maximum number of rotations that we allow for this VarTensor. Rotations affect performance. pub fn new_fixed( cs: &mut ConstraintSystem, k: usize, capacity: usize, dims: Vec, equality: bool, - v1: usize, + max_rot: usize, ) -> Self { let base = 2u32; - let max_rows = min(v1, base.pow(k as u32) as usize - cs.blinding_factors() - 1); + let max_rows = min( + max_rot, + base.pow(k as u32) as usize - cs.blinding_factors() - 1, + ); let modulo = (capacity / max_rows) + 1; let mut fixed = vec![]; for _ in 0..modulo { @@ -126,6 +158,7 @@ impl VarTensor { } } + /// Take a linear coordinate and output the (column, row) position in the storage block. pub fn cartesian_coord(&self, linear_coord: usize) -> (usize, usize) { match self { VarTensor::Advice { col_size, .. } | VarTensor::Fixed { col_size, .. } => {