diff --git a/Cargo.lock b/Cargo.lock index f3f67275f..8263723d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1353,12 +1353,6 @@ dependencies = [ "termcolor", ] -[[package]] -name = "eq-float" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c02b5d1d1e6ba431b960d4bf971c8b0e2d2942b8cbc577d9bdf9c60fca5d41d" - [[package]] name = "errno" version = "0.2.8" @@ -1722,6 +1716,7 @@ name = "ezkl-lib" version = "0.1.0" dependencies = [ "anyhow", + "bincode", "clap 4.1.13", "colored", "colored_json", @@ -1729,7 +1724,6 @@ dependencies = [ "ctor", "ecc", "env_logger", - "eq-float", "ethereum-types", "ethers", "ethers-solc", diff --git a/Cargo.toml b/Cargo.toml index 178c0c669..c44884d7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ serde = { version = "1.0.126", features = ["derive"], optional = true } serde_json = { version = "1.0.64", optional = true } log = { version = "0.4.17", optional = true } tabled = { version = "0.9.0", optional = true} -eq-float = "0.1.0" thiserror = "1.0.38" hex = "0.4.3" ethereum_types = { package = "ethereum-types", version = "0.14.1", default-features = false, features = ["std"]} @@ -37,6 +36,7 @@ colored = { version = "2.0.0", optional = true} env_logger = { version = "0.10.0", optional = true} colored_json = { version = "3.0.1", optional = true} tokio = { version = "1.26.0", features = ["macros", "rt"] } +bincode = "*" # python binding related deps pyo3 = { version = "0.18.2", features = ["extension-module", "abi3-py37"], optional = true } diff --git a/examples/mlp_4d.rs b/examples/mlp_4d.rs index 9dca5675e..dfe9a0b41 100644 --- a/examples/mlp_4d.rs +++ b/examples/mlp_4d.rs @@ -1,4 +1,3 @@ -use eq_float::F32; use ezkl_lib::circuit::{BaseConfig as PolyConfig, CheckMode, LookupOp, Op as PolyOp}; use ezkl_lib::fieldutils::i32_to_felt; use ezkl_lib::tensor::*; @@ -67,7 +66,7 @@ impl Circuit &output, BITS, &LookupOp::Div { - denom: F32::from(128.), + denom: ezkl_lib::circuit::utils::F32::from(128.), }, ) .unwrap(); @@ -152,7 +151,7 @@ impl Circuit &[x.unwrap()], &mut offset, LookupOp::Div { - denom: F32::from(128.), + denom: ezkl_lib::circuit::utils::F32::from(128.), } .into(), ) diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index 2672673a4..b1be4a3c0 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -172,10 +172,10 @@ impl fmt::Display for BaseOp { #[allow(missing_docs)] /// An enum representing the operations that can be used to express more complex operations via accumulation -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] pub enum LookupOp { Div { - denom: eq_float::F32, + denom: utils::F32, }, ReLU { scale: usize, @@ -185,11 +185,11 @@ pub enum LookupOp { }, LeakyReLU { scale: usize, - slope: eq_float::F32, + slope: utils::F32, }, PReLU { scale: usize, - slopes: Vec, + slopes: Vec, }, Sigmoid { scales: (usize, usize), @@ -251,7 +251,7 @@ impl LookupOp { #[allow(missing_docs)] /// An enum representing the operations that can be used to express more complex operations via accumulation -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] pub enum Op { Dot, Matmul, @@ -504,7 +504,7 @@ impl fmt::Display for Op { // Eventually, though, we probably want to keep them and treat them directly (layouting and configuring // at each type of node) /// Enum of the different kinds of operations `ezkl` can support. -#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Deserialize, Serialize)] pub enum OpKind { /// A nonlinearity Lookup(LookupOp), @@ -544,13 +544,13 @@ impl OpKind { }), "LeakyRelu" => OpKind::Lookup(LookupOp::LeakyReLU { scale: 1, - slope: eq_float::F32(0.0), + slope: utils::F32(0.0), }), "Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }), "Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }), "Tanh" => OpKind::Lookup(LookupOp::Tanh { scales: (1, 1) }), "Div" => OpKind::Lookup(LookupOp::Div { - denom: eq_float::F32(1.0), + denom: utils::F32(1.0), }), "Const" => OpKind::Const, "Source" => OpKind::Input, diff --git a/src/circuit/utils.rs b/src/circuit/utils.rs index 885492eb3..e880d5caf 100644 --- a/src/circuit/utils.rs +++ b/src/circuit/utils.rs @@ -24,3 +24,141 @@ pub fn value_muxer( _ => unimplemented!(), } } + +// -------------------------------------------------------------------------------------------- +// +// Float Utils to enable the usage of f32s as the keys of HashMaps +// This section is taken from the `eq_float` crate verbatim -- but we also implement deserialization methods +// +// + +use std::cmp::Ordering; +use std::fmt; +use std::hash::{Hash, Hasher}; + +#[derive(Debug, Default, Clone, Copy)] +/// f32 wrapper +pub struct F32(pub f32); + +impl<'de> Deserialize<'de> for F32 { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let float = f32::deserialize(deserializer)?; + Ok(F32(float)) + } +} + +impl Serialize for F32 { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + f32::serialize(&self.0, serializer) + } +} + +/// This works like `PartialEq` on `f32`, except that `NAN == NAN` is true. +impl PartialEq for F32 { + fn eq(&self, other: &Self) -> bool { + if self.0.is_nan() && other.0.is_nan() { + true + } else { + self.0 == other.0 + } + } +} + +impl Eq for F32 {} + +/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats +/// (and is equal to another NAN). This always returns a `Some`. +impl PartialOrd for F32 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats +/// (and is equal to another NAN). +impl Ord for F32 { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap_or_else(|| { + if self.0.is_nan() && !other.0.is_nan() { + Ordering::Less + } else if !self.0.is_nan() && other.0.is_nan() { + Ordering::Greater + } else { + Ordering::Equal + } + }) + } +} + +impl Hash for F32 { + fn hash(&self, state: &mut H) { + if self.0.is_nan() { + 0x7fc00000u32.hash(state); // a particular bit representation for NAN + } else if self.0 == 0.0 { + // catches both positive and negative zero + 0u32.hash(state); + } else { + self.0.to_bits().hash(state); + } + } +} + +impl From for f32 { + fn from(f: F32) -> Self { + f.0 + } +} + +impl From for F32 { + fn from(f: f32) -> Self { + F32(f) + } +} + +impl fmt::Display for F32 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod tests { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + use super::F32; + + fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() + } + + #[test] + fn f32_eq() { + assert!(F32(std::f32::NAN) == F32(std::f32::NAN)); + assert!(F32(std::f32::NAN) != F32(5.0)); + assert!(F32(5.0) != F32(std::f32::NAN)); + assert!(F32(0.0) == F32(-0.0)); + } + + #[test] + fn f32_cmp() { + assert!(F32(std::f32::NAN) == F32(std::f32::NAN)); + assert!(F32(std::f32::NAN) < F32(5.0)); + assert!(F32(5.0) > F32(std::f32::NAN)); + assert!(F32(0.0) == F32(-0.0)); + } + + #[test] + fn f32_hash() { + assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0))); + assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN))); + } +} diff --git a/src/execute.rs b/src/execute.rs index f67c86020..1b7a85cec 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -11,13 +11,8 @@ use crate::pfsys::evm::aggregation::{AggregationCircuit, PoseidonTranscript}; use crate::pfsys::evm::{aggregation::gen_aggregation_evm_verifier, single::gen_evm_verifier}; #[cfg(not(target_arch = "wasm32"))] use crate::pfsys::evm::{evm_verify, DeploymentCode}; -#[cfg(feature = "render")] -use crate::pfsys::prepare_model_circuit; use crate::pfsys::{create_keys, load_params, load_vk, save_params, Snark}; -use crate::pfsys::{ - create_proof_circuit, gen_srs, prepare_data, prepare_model_circuit_and_public_input, save_vk, - verify_proof_circuit, -}; +use crate::pfsys::{create_proof_circuit, gen_srs, prepare_data, save_vk, verify_proof_circuit}; #[cfg(not(target_arch = "wasm32"))] use ethers::providers::Middleware; use halo2_proofs::dev::VerifyFailure; @@ -217,7 +212,8 @@ pub async fn run(cli: Cli) -> Result<(), Box> { ref output, } => { let data = prepare_data(data.to_string())?; - let circuit = prepare_model_circuit::(&data, &cli.args)?; + let model = Model::from_arg()?; + let circuit = ModelCircuit::::new(&data, model)?; info!("Rendering circuit"); // Create the area we want to draw on. @@ -256,8 +252,10 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } Commands::Mock { ref data, model: _ } => { let data = prepare_data(data.to_string())?; - let (circuit, public_inputs) = - prepare_model_circuit_and_public_input::(&data, &cli)?; + let model = Model::from_arg()?; + let circuit = ModelCircuit::::new(&data, model)?; + let public_inputs = circuit.prepare_public_inputs(&data)?; + info!("Mock proof"); let prover = MockProver::run(cli.args.logrows, &circuit, public_inputs) @@ -278,7 +276,9 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } => { let data = prepare_data(data.to_string())?; - let (_, public_inputs) = prepare_model_circuit_and_public_input::(&data, &cli)?; + let model = Model::from_arg()?; + let circuit = ModelCircuit::::new(&data, model)?; + let public_inputs = circuit.prepare_public_inputs(&data)?; let num_instance = public_inputs.iter().map(|x| x.len()).collect(); let mut params: ParamsKZG = load_params::>(params_path.to_path_buf())?; @@ -333,7 +333,10 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } => { let data = prepare_data(data.to_string())?; - let (circuit, public_inputs) = prepare_model_circuit_and_public_input(&data, &cli)?; + let model = Model::from_arg()?; + let circuit = ModelCircuit::::new(&data, model)?; + let public_inputs = circuit.prepare_public_inputs(&data)?; + let mut params: ParamsKZG = load_params::>(params_path.to_path_buf())?; info!("downsizing params to {} logrows", cli.args.logrows); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 9272b2eaf..23531a4c1 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,5 +1,6 @@ /// Helper functions pub mod utilities; +use serde::{Deserialize, Serialize}; pub use utilities::*; /// Crate for defining a computational graph and building a ZK-circuit from it. pub mod model; @@ -9,6 +10,10 @@ pub mod node; pub mod vars; use crate::circuit::OpKind; +use crate::commands::Cli; +use crate::fieldutils::i128_to_felt; +use crate::pfsys::ModelInput; +use crate::tensor::ops::pack; use crate::tensor::TensorType; use crate::tensor::{Tensor, ValTensor}; use anyhow::Result; @@ -20,7 +25,10 @@ use halo2_proofs::{ use log::{info, trace}; pub use model::*; pub use node::*; +use std::fs::File; +use std::io::{BufReader, BufWriter, Read, Write}; use std::marker::PhantomData; +use std::path::PathBuf; use thiserror::Error; pub use vars::*; @@ -66,17 +74,137 @@ pub enum GraphError { /// Error when attempting to load a model #[error("failed to load model")] ModelLoad, + /// Packing exponent is too large + #[error("largest packing exponent exceeds max. try reducing the scale")] + PackingExponent, } /// Defines the circuit for a computational graph / model loaded from a `.onnx` file. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ModelCircuit { /// Vector of input tensors to the model / graph of computations. pub inputs: Vec>, + /// + pub model: Model, /// Represents the Field we are using. pub _marker: PhantomData, } +impl ModelCircuit { + /// + pub fn new( + data: &ModelInput, + model: Model, + ) -> Result, Box> { + // quantize the supplied data using the provided scale. + let mut inputs: Vec> = vec![]; + for (input, shape) in data.input_data.iter().zip(data.input_shapes.clone()) { + let t = vector_to_quantized(input, &shape, 0.0, model.run_args.scale)?; + inputs.push(t); + } + + Ok(ModelCircuit:: { + inputs, + model, + _marker: PhantomData, + }) + } + + /// + pub fn write( + &self, + mut writer: BufWriter, + ) -> Result<(), Box> { + let circuit_bytes = bincode::serialize(&self)?; + writer.write(&circuit_bytes)?; + writer.flush()?; + Ok(()) + } + + /// + pub fn write_to_file(&self, path: PathBuf) -> Result<(), Box> { + let fs = File::create(path)?; + let buffer = BufWriter::new(fs); + self.write(buffer) + } + + /// + pub fn read(mut reader: BufReader) -> Result> { + let buffer: &mut Vec = &mut vec![]; + reader.read_to_end(buffer)?; + + let circuit = bincode::deserialize(&buffer)?; + Ok(circuit) + } + /// + pub fn read_from_file(path: PathBuf) -> Result> { + let f = File::open(path)?; + let reader = BufReader::new(f); + Self::read(reader) + } + + /// + pub fn from_arg(data: &ModelInput) -> Result> { + let cli = Cli::create()?; + let model = Model::from_ezkl_conf(cli)?; + Self::new(data, model) + } + + /// + pub fn prepare_public_inputs( + &self, + data: &ModelInput, + ) -> Result>, Box> { + let out_scales = self.model.get_output_scales(); + + // quantize the supplied data using the provided scale. + // the ordering here is important, we want the inputs to come before the outputs + // as they are configured in that order as Column + let mut public_inputs = vec![]; + if self.model.visibility.input.is_public() { + for v in data.input_data.iter() { + let t = + vector_to_quantized(v, &Vec::from([v.len()]), 0.0, self.model.run_args.scale)?; + public_inputs.push(t); + } + } + if self.model.visibility.output.is_public() { + for (idx, v) in data.output_data.iter().enumerate() { + let mut t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx])?; + let len = t.len(); + if self.model.run_args.pack_base > 1 { + let max_exponent = + (((len - 1) as u32) * (self.model.run_args.scale + 1)) as f64; + if max_exponent > (i128::MAX as f64).log(self.model.run_args.pack_base as f64) { + return Err(Box::new(GraphError::PackingExponent)); + } + t = pack( + &t, + self.model.run_args.pack_base as i128, + self.model.run_args.scale, + )?; + } + public_inputs.push(t); + } + } + info!( + "public inputs lengths: {:?}", + public_inputs + .iter() + .map(|i| i.len()) + .collect::>() + ); + trace!("{:?}", public_inputs); + + let pi_inner: Vec> = public_inputs + .iter() + .map(|i| i.iter().map(|e| i128_to_felt::(*e)).collect::>()) + .collect::>>(); + + Ok(pi_inner) + } +} + impl Circuit for ModelCircuit { type Config = ModelConfig; type FloorPlanner = SimpleFloorPlanner; @@ -86,7 +214,7 @@ impl Circuit for ModelCircuit { } fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let model = Model::from_arg().expect("model should load from args"); + let model = Model::from_arg().expect("model should load"); // for now the number of instances corresponds to the number of graph / model outputs let instance_shapes = model.instance_shapes(); @@ -133,3 +261,5 @@ impl Circuit for ModelCircuit { Ok(()) } } + +//////////////////////// diff --git a/src/graph/model.rs b/src/graph/model.rs index 27b7f5d9d..529001fe3 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -12,8 +12,9 @@ use crate::graph::scale_to_multiplier; use crate::tensor::TensorType; use crate::tensor::{Tensor, ValTensor, VarTensor}; use anyhow::Context; +use serde::Deserialize; +use serde::Serialize; //use clap::Parser; -use anyhow::Error as AnyError; use core::panic; use halo2_proofs::circuit::Region; use halo2_proofs::{ @@ -28,14 +29,16 @@ use std::cell::RefCell; use std::cmp::max; use std::collections::BTreeMap; use std::error::Error; +use std::fs::File; +use std::io::{BufReader, BufWriter, Read, Write}; use std::path::Path; +use std::path::PathBuf; use std::rc::Rc; use tabled::Table; use tract_onnx; -use tract_onnx::prelude::{Framework, Graph, InferenceFact, Node as OnnxNode, OutletId}; -use tract_onnx::tract_hir::internal::InferenceOp; +use tract_onnx::prelude::Framework; /// Mode we're using the model in. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum Mode { /// Initialize the model and display the operations table / graph Table, @@ -64,10 +67,12 @@ pub struct ModelConfig { } /// A struct for loading from an Onnx file and converting a computational graph to a circuit. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Model { - /// The raw tract [Graph] data structure. - pub model: Graph>, + /// input indices + pub inputs: Vec, + /// output indices + pub outputs: Vec, /// Graph of nodes we are loading from Onnx. pub nodes: NodeGraph, // Wrapped nodes with additional methods and data (e.g. inferred shape, quantization) /// The [RunArgs] being used @@ -103,7 +108,8 @@ impl Model { nodes.insert(i, n); } let om = Model { - model: model.clone(), + inputs: model.inputs.iter().map(|o| o.node).collect(), + outputs: model.outputs.iter().map(|o| o.node).collect(), run_args, nodes, mode, @@ -142,9 +148,9 @@ impl Model { for (i, n) in nodes.iter() { let mut inputs = vec![]; for i in n.inputs.iter() { - match results.get(&i.node) { + match results.get(&i) { Some(value) => inputs.push(value.clone()), - None => return Err(Box::new(GraphError::MissingNode(i.node))), + None => return Err(Box::new(GraphError::MissingNode(*i))), } } match &n.opkind { @@ -213,10 +219,40 @@ impl Model { } } + /// + pub fn write(&self, mut writer: BufWriter) -> Result<(), Box> { + let circuit_bytes = bincode::serialize(&self)?; + writer.write(&circuit_bytes)?; + writer.flush()?; + Ok(()) + } + + /// + pub fn write_to_file(&self, path: PathBuf) -> Result<(), Box> { + let fs = File::create(path)?; + let buffer = BufWriter::new(fs); + self.write(buffer) + } + + /// + pub fn read(mut reader: BufReader) -> Result> { + let buffer: &mut Vec = &mut vec![]; + reader.read_to_end(buffer)?; + + let circuit = bincode::deserialize(&buffer)?; + Ok(circuit) + } + /// + pub fn read_from_file(path: PathBuf) -> Result> { + let f = File::open(path)?; + let reader = BufReader::new(f); + Self::read(reader) + } + /// Creates a `Model` based on CLI arguments pub fn from_arg() -> Result> { - let args = Cli::create()?; - Self::from_ezkl_conf(args) + let conf = Cli::create()?; + Self::from_ezkl_conf(conf) } /// Configures an `Model`. Does so one execution `bucket` at a time. Each bucket holds either: @@ -385,7 +421,7 @@ impl Model { let input_nodes = node .inputs .iter() - .map(|i| self.nodes.get(&i.node).unwrap()) + .map(|i| self.nodes.get(&i).unwrap()) .collect_vec(); let input_idx = input_nodes.iter().map(|f| f.idx).collect_vec(); @@ -448,7 +484,7 @@ impl Model { let input_nodes = node .inputs .iter() - .map(|i| self.nodes.get(&i.node).unwrap()) + .map(|i| self.nodes.get(&i).unwrap()) .collect_vec(); let input_idx = input_nodes.iter().map(|f| f.idx).collect_vec(); @@ -539,13 +575,10 @@ impl Model { } } - let output_nodes = self.model.outputs.iter(); - info!( - "model outputs are nodes: {:?}", - output_nodes.clone().map(|o| o.node).collect_vec() - ); + let output_nodes = self.outputs.iter(); + info!("model outputs are nodes: {:?}", output_nodes); let mut outputs = output_nodes - .map(|o| results.get(&o.node).unwrap().clone()) + .map(|o| results.get(&o).unwrap().clone()) .collect_vec(); // pack outputs if need be @@ -646,64 +679,39 @@ impl Model { Ok(res) } - /// Get a linear extension of the model (an evaluation order), for example to feed to circuit construction. - /// Note that this order is not stable over multiple reloads of the model. For example, it will freely - /// interchange the order of evaluation of fixed parameters. For example weight could have id 1 on one load, - /// and bias id 2, and vice versa on the next load of the same file. The ids are also not stable. - pub fn eval_order(&self) -> Result, AnyError> { - self.model.eval_order() - } - - /// Note that this order is not stable. - pub fn nodes(&self) -> Vec>> { - self.model.nodes().to_vec() - } - - /// Returns the ID of the computational graph's inputs - pub fn input_outlets(&self) -> Result, Box> { - Ok(self.model.input_outlets()?.to_vec()) - } - - /// Returns the ID of the computational graph's outputs - pub fn output_outlets(&self) -> Result, Box> { - Ok(self.model.output_outlets()?.to_vec()) - } - /// Returns the number of the computational graph's inputs pub fn num_inputs(&self) -> usize { - let input_nodes = self.model.inputs.iter(); + let input_nodes = self.inputs.iter(); input_nodes.len() } /// Returns shapes of the computational graph's inputs pub fn input_shapes(&self) -> Vec> { - self.model - .inputs + self.inputs .iter() - .map(|o| self.nodes.get(&o.node).unwrap().out_dims.clone()) + .map(|o| self.nodes.get(&o).unwrap().out_dims.clone()) .collect_vec() } /// Returns the number of the computational graph's outputs pub fn num_outputs(&self) -> usize { - let output_nodes = self.model.outputs.iter(); + let output_nodes = self.outputs.iter(); output_nodes.len() } /// Returns shapes of the computational graph's outputs pub fn output_shapes(&self) -> Vec> { - self.model - .outputs + self.outputs .iter() - .map(|o| self.nodes.get(&o.node).unwrap().out_dims.clone()) + .map(|o| self.nodes.get(&o).unwrap().out_dims.clone()) .collect_vec() } /// Returns the fixed point scale of the computational graph's outputs pub fn get_output_scales(&self) -> Vec { - let output_nodes = self.model.outputs.iter(); + let output_nodes = self.outputs.iter(); output_nodes - .map(|o| self.nodes.get(&o.node).unwrap().out_scale) + .map(|o| self.nodes.get(&o).unwrap().out_scale) .collect_vec() } @@ -759,7 +767,7 @@ impl Model { let in_dims = n .inputs .iter() - .map(|i| self.nodes.get(&i.node).unwrap().out_dims.clone()); + .map(|i| self.nodes.get(&i).unwrap().out_dims.clone()); let layout_shape = p.circuit_shapes(in_dims.collect_vec()); maximum_var_len += layout_shape.last().unwrap(); } diff --git a/src/graph/node.rs b/src/graph/node.rs index 0fc92ab3f..76474adf7 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -7,10 +7,10 @@ use crate::graph::GraphError; use crate::tensor::Tensor; use crate::tensor::TensorType; use anyhow::Result; -use eq_float::F32; use halo2_proofs::arithmetic::FieldExt; use itertools::Itertools; use log::{info, trace, warn}; +use serde::{Deserialize, Serialize}; use std::cell::RefCell; use std::collections::BTreeMap; use std::error::Error; @@ -19,7 +19,7 @@ use std::ops::Deref; use std::rc::Rc; use tabled::Tabled; use tract_onnx; -use tract_onnx::prelude::{DatumType, InferenceFact, Node as OnnxNode, OutletId}; +use tract_onnx::prelude::{DatumType, InferenceFact, Node as OnnxNode}; use tract_onnx::tract_hir::{ infer::Factoid, internal::InferenceOp, @@ -60,18 +60,10 @@ fn display_option(o: &Option) -> String { } fn display_vector(v: &Vec) -> String { - format!("{:?}", v) -} - -fn display_inputs(o: &Vec) -> String { - if !o.is_empty() { - let mut nodes = vec![]; - for id in o.iter() { - nodes.push(id.node); - } - format!("{:?}", nodes) + if v.len() > 0 { + format!("{:?}", v) } else { - String::new() + format!("") } } @@ -99,7 +91,7 @@ fn display_tensorf32(o: &Option>) -> String { /// * `const_value` - The constants potentially associated with this self. /// * `idx` - The node's unique identifier. /// * `bucket` - The execution bucket this node has been assigned to. -#[derive(Clone, Debug, Default, Tabled)] +#[derive(Clone, Debug, Default, Tabled, Serialize, Deserialize)] pub struct Node { /// [OpKind] enum, i.e what operation this node represents. pub opkind: OpKind, @@ -117,9 +109,9 @@ pub struct Node { pub raw_const_value: Option>, // Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias), // but in_dim is [in], out_dim is [out] - #[tabled(display_with = "display_inputs")] + #[tabled(display_with = "display_vector")] /// The indices of the node's inputs. - pub inputs: Vec, + pub inputs: Vec, #[tabled(display_with = "display_vector")] /// Dimensions of input. pub in_dims: Vec>, @@ -183,7 +175,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -210,7 +202,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -237,7 +229,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -262,7 +254,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -302,13 +294,13 @@ impl Node { opkind = OpKind::Lookup(LookupOp::LeakyReLU { scale: layer_scale, - slope: F32(leaky_op.0), + slope: crate::circuit::utils::F32(leaky_op.0), }); // now the input will be scaled down to match Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -329,7 +321,7 @@ impl Node { .unwrap() .deref() .iter() - .map(|value| F32(*value)) + .map(|value| crate::circuit::utils::F32(*value)) .collect_vec(); // node.inputs.pop(); @@ -349,7 +341,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs, + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -377,18 +369,20 @@ impl Node { if scale_diff > 0 { let mult = scale_to_multiplier(scale_diff); opkind = OpKind::Lookup(LookupOp::Div { - denom: F32(denom * mult), + denom: crate::circuit::utils::F32(denom * mult), }); // now the input will be scaled down to match output_max = input_node.output_max / (denom * mult); } else { - opkind = OpKind::Lookup(LookupOp::Div { denom: F32(denom) }); // now the input will be scaled down to match + opkind = OpKind::Lookup(LookupOp::Div { + denom: crate::circuit::utils::F32(denom), + }); // now the input will be scaled down to match output_max = input_node.output_max / (denom); } Node { idx, opkind, - inputs: input_outlets, + inputs: input_outlets.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: input_node.out_dims.clone(), // in scale is the same as the input @@ -460,7 +454,7 @@ impl Node { Node { idx, opkind: OpKind::Poly(PolyOp::Pad(padding_h, padding_w)), - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: vec![input_channels, out_height, out_width], in_scale: input_node.out_scale, @@ -548,7 +542,7 @@ impl Node { padding: (padding_h, padding_w), stride: (stride_h, stride_w), }), - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: vec![out_channels, out_height, out_width], in_scale: input_node.out_scale, @@ -615,7 +609,7 @@ impl Node { stride: (stride_h, stride_w), kernel_shape: (kernel_height, kernel_width), }), - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: vec![input_channels, out_height, out_width], in_scale: input_node.out_scale, @@ -651,7 +645,7 @@ impl Node { stride: (stride_h, stride_w), kernel_shape: (kernel_height, kernel_width), }), - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![input_node.out_dims.clone()], out_dims: vec![input_channels, out_height, out_width], in_scale: input_node.out_scale, @@ -676,7 +670,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![vec![in_dim]], out_dims: dims.clone(), in_scale: a_node.out_scale, @@ -703,7 +697,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: vec![out_dim], in_scale: input_node.out_scale, @@ -744,7 +738,7 @@ impl Node { Node { idx, opkind: OpKind::Poly(PolyOp::ScaleAndShift), - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: inputs[0].out_dims.clone(), in_scale, @@ -776,7 +770,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: inputs[0].out_dims.clone(), in_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), @@ -793,7 +787,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: vec![1], in_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), @@ -823,7 +817,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: inputs[0].out_dims.clone(), in_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), @@ -838,7 +832,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: inputs[0].out_dims.clone(), in_scale: input_node.out_scale, @@ -867,7 +861,7 @@ impl Node { Node { idx, opkind: OpKind::Poly(PolyOp::Pow(pow as u32)), - inputs: node.inputs, + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -891,7 +885,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: input_node.out_dims.clone(), in_scale: input_node.out_scale, @@ -907,7 +901,7 @@ impl Node { Node { idx, opkind: OpKind::Poly(PolyOp::Flatten(new_dims.clone())), - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: new_dims, in_scale: input_node.out_scale, @@ -969,7 +963,7 @@ impl Node { Node { idx, opkind: OpKind::Poly(PolyOp::Reshape(new_dims.clone())), - inputs: node.inputs[0..1].to_vec(), + inputs: node.inputs[0..1].iter().map(|i| i.node).collect(), in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), out_dims: new_dims, in_scale: input_node.out_scale, @@ -1004,7 +998,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![dims.clone()], out_dims: dims, in_scale: scale, @@ -1025,7 +1019,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![dims.clone()], out_dims: dims, in_scale: scale, @@ -1068,7 +1062,7 @@ impl Node { Node { idx, opkind, - inputs: node.inputs.clone(), + inputs: node.inputs.iter().map(|i| i.node).collect(), in_dims: vec![out_dims.clone()], out_dims, in_scale: scale, diff --git a/src/graph/vars.rs b/src/graph/vars.rs index f43b54eff..a357d086f 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -5,12 +5,12 @@ use crate::tensor::TensorType; use crate::tensor::{ValTensor, VarTensor}; use halo2_proofs::{arithmetic::FieldExt, plonk::ConstraintSystem}; use itertools::Itertools; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use super::GraphError; /// Label Enum to track whether model input, model parameters, and model output are public or private -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] pub enum Visibility { /// Mark an item as private to the prover (not in the proof submitted for verification) Private, @@ -33,7 +33,7 @@ impl std::fmt::Display for Visibility { } /// Whether the model input, model parameters, and model output are Public or Private to the prover. -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct VarVisibility { /// Input to the model or computational graph pub input: Visibility, diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 39b8a2086..0fe8d3b32 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -2,12 +2,9 @@ pub mod evm; use crate::circuit::CheckMode; -use crate::commands::{data_path, Cli, RunArgs}; +use crate::commands::data_path; use crate::execute::ExecutionError; -use crate::fieldutils::i128_to_felt; -use crate::graph::{utilities::vector_to_quantized, Model, ModelCircuit}; -use crate::tensor::ops::pack; -use crate::tensor::{Tensor, TensorType}; +use crate::tensor::TensorType; use halo2_proofs::arithmetic::FieldExt; use halo2_proofs::circuit::Value; use halo2_proofs::dev::MockProver; @@ -28,7 +25,6 @@ use snark_verifier::verifier::plonk::PlonkProtocol; use std::error::Error; use std::fs::File; use std::io::{self, BufReader, BufWriter, Cursor, Read, Write}; -use std::marker::PhantomData; use std::ops::Deref; use std::path::PathBuf; use std::time::Instant; @@ -120,9 +116,9 @@ impl Snark { } } - /// Saves the Proof to a specified `proof_path`. - pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box> { - let self_i128 = Snarkbytes { + /// Converts to Snarkbytes + pub fn to_bytes(&self) -> Snarkbytes { + Snarkbytes { num_instance: self.protocol.as_ref().unwrap().num_instance.clone(), instances: self .instances @@ -130,7 +126,12 @@ impl Snark { .map(|i| i.iter().map(|e| e.to_raw_bytes()).collect::>>()) .collect::>>>(), proof: self.proof.clone(), - }; + } + } + + /// Saves the Proof to a specified `proof_path`. + pub fn save(&self, proof_path: &PathBuf) -> Result<(), Box> { + let self_i128 = self.to_bytes(); let serialized = serde_json::to_string(&self_i128).map_err(Box::::from)?; @@ -226,76 +227,6 @@ impl From> for SnarkWitne } } -type CircuitInputs = (ModelCircuit, Vec>); - -/// Initialize the model circuit and quantize the provided float inputs from the provided `ModelInput`. -pub fn prepare_model_circuit_and_public_input( - data: &ModelInput, - cli: &Cli, -) -> Result, Box> { - let model = Model::from_ezkl_conf(cli.clone())?; - let out_scales = model.get_output_scales(); - let circuit = prepare_model_circuit(data, &cli.args)?; - - // quantize the supplied data using the provided scale. - // the ordering here is important, we want the inputs to come before the outputs - // as they are configured in that order as Column - let mut public_inputs = vec![]; - if model.visibility.input.is_public() { - for v in data.input_data.iter() { - let t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, model.run_args.scale)?; - public_inputs.push(t); - } - } - if model.visibility.output.is_public() { - for (idx, v) in data.output_data.iter().enumerate() { - let mut t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx])?; - let len = t.len(); - if cli.args.pack_base > 1 { - let max_exponent = (((len - 1) as u32) * (cli.args.scale + 1)) as f64; - if max_exponent > (i128::MAX as f64).log(cli.args.pack_base as f64) { - return Err(Box::new(PfSysError::PackingExponent)); - } - t = pack(&t, cli.args.pack_base as i128, cli.args.scale)?; - } - public_inputs.push(t); - } - } - info!( - "public inputs lengths: {:?}", - public_inputs - .iter() - .map(|i| i.len()) - .collect::>() - ); - trace!("{:?}", public_inputs); - - let pi_inner: Vec> = public_inputs - .iter() - .map(|i| i.iter().map(|e| i128_to_felt::(*e)).collect::>()) - .collect::>>(); - - Ok((circuit, pi_inner)) -} - -/// Initialize the model circuit -pub fn prepare_model_circuit( - data: &ModelInput, - args: &RunArgs, -) -> Result, Box> { - // quantize the supplied data using the provided scale. - let mut inputs: Vec> = vec![]; - for (input, shape) in data.input_data.iter().zip(data.input_shapes.clone()) { - let t = vector_to_quantized(input, &shape, 0.0, args.scale)?; - inputs.push(t); - } - - Ok(ModelCircuit:: { - inputs, - _marker: PhantomData, - }) -} - /// Deserializes the required inputs to a model at path `datapath` to a [ModelInput] struct. pub fn prepare_data(datapath: String) -> Result> { let mut file = File::open(data_path(datapath)).map_err(Box::::from)?; diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 52610b060..aa9310864 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -5,6 +5,7 @@ pub mod val; /// A wrapper around a tensor of Halo2 Value types. pub mod var; +use serde::{Deserialize, Serialize}; pub use val::*; pub use var::*; @@ -222,7 +223,7 @@ impl TensorType for halo2curves::bn256::Fr { /// A generic multi-dimensional array representation of a Tensor. /// The `inner` attribute contains a vector of values whereas `dims` corresponds to the dimensionality of the array /// and as such determines how we index, query for values, or slice a Tensor. -#[derive(Clone, Debug, Eq)] +#[derive(Clone, Debug, Eq, Serialize, Deserialize)] pub struct Tensor { inner: Vec, dims: Vec,