Skip to content

Commit

Permalink
feat: make circuit serializable / deserializable (zkonduit#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Apr 13, 2023
1 parent c03a242 commit bba10d1
Show file tree
Hide file tree
Showing 12 changed files with 413 additions and 215 deletions.
8 changes: 1 addition & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.64", optional = true }
log = { version = "0.4.17", optional = true }
tabled = { version = "0.9.0", optional = true}
eq-float = "0.1.0"
thiserror = "1.0.38"
hex = "0.4.3"
ethereum_types = { package = "ethereum-types", version = "0.14.1", default-features = false, features = ["std"]}
Expand All @@ -37,6 +36,7 @@ colored = { version = "2.0.0", optional = true}
env_logger = { version = "0.10.0", optional = true}
colored_json = { version = "3.0.1", optional = true}
tokio = { version = "1.26.0", features = ["macros", "rt"] }
bincode = "*"

# python binding related deps
pyo3 = { version = "0.18.2", features = ["extension-module", "abi3-py37"], optional = true }
Expand Down
5 changes: 2 additions & 3 deletions examples/mlp_4d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use eq_float::F32;
use ezkl_lib::circuit::{BaseConfig as PolyConfig, CheckMode, LookupOp, Op as PolyOp};
use ezkl_lib::fieldutils::i32_to_felt;
use ezkl_lib::tensor::*;
Expand Down Expand Up @@ -67,7 +66,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
&output,
BITS,
&LookupOp::Div {
denom: F32::from(128.),
denom: ezkl_lib::circuit::utils::F32::from(128.),
},
)
.unwrap();
Expand Down Expand Up @@ -152,7 +151,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
&[x.unwrap()],
&mut offset,
LookupOp::Div {
denom: F32::from(128.),
denom: ezkl_lib::circuit::utils::F32::from(128.),
}
.into(),
)
Expand Down
16 changes: 8 additions & 8 deletions src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ impl fmt::Display for BaseOp {

#[allow(missing_docs)]
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Div {
denom: eq_float::F32,
denom: utils::F32,
},
ReLU {
scale: usize,
Expand All @@ -185,11 +185,11 @@ pub enum LookupOp {
},
LeakyReLU {
scale: usize,
slope: eq_float::F32,
slope: utils::F32,
},
PReLU {
scale: usize,
slopes: Vec<eq_float::F32>,
slopes: Vec<utils::F32>,
},
Sigmoid {
scales: (usize, usize),
Expand Down Expand Up @@ -251,7 +251,7 @@ impl LookupOp {

#[allow(missing_docs)]
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum Op {
Dot,
Matmul,
Expand Down Expand Up @@ -504,7 +504,7 @@ impl fmt::Display for Op {
// Eventually, though, we probably want to keep them and treat them directly (layouting and configuring
// at each type of node)
/// Enum of the different kinds of operations `ezkl` can support.
#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)]
#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Deserialize, Serialize)]
pub enum OpKind {
/// A nonlinearity
Lookup(LookupOp),
Expand Down Expand Up @@ -544,13 +544,13 @@ impl OpKind {
}),
"LeakyRelu" => OpKind::Lookup(LookupOp::LeakyReLU {
scale: 1,
slope: eq_float::F32(0.0),
slope: utils::F32(0.0),
}),
"Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }),
"Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }),
"Tanh" => OpKind::Lookup(LookupOp::Tanh { scales: (1, 1) }),
"Div" => OpKind::Lookup(LookupOp::Div {
denom: eq_float::F32(1.0),
denom: utils::F32(1.0),
}),
"Const" => OpKind::Const,
"Source" => OpKind::Input,
Expand Down
138 changes: 138 additions & 0 deletions src/circuit/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,141 @@ pub fn value_muxer<F: FieldExt + TensorType>(
_ => unimplemented!(),
}
}

// --------------------------------------------------------------------------------------------
//
// Float Utils to enable the usage of f32s as the keys of HashMaps
// This section is taken from the `eq_float` crate verbatim -- but we also implement deserialization methods
//
//

use std::cmp::Ordering;
use std::fmt;
use std::hash::{Hash, Hasher};

#[derive(Debug, Default, Clone, Copy)]
/// f32 wrapper
pub struct F32(pub f32);

impl<'de> Deserialize<'de> for F32 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let float = f32::deserialize(deserializer)?;
Ok(F32(float))
}
}

impl Serialize for F32 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
f32::serialize(&self.0, serializer)
}
}

/// This works like `PartialEq` on `f32`, except that `NAN == NAN` is true.
impl PartialEq for F32 {
fn eq(&self, other: &Self) -> bool {
if self.0.is_nan() && other.0.is_nan() {
true
} else {
self.0 == other.0
}
}
}

impl Eq for F32 {}

/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats
/// (and is equal to another NAN). This always returns a `Some`.
impl PartialOrd for F32 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats
/// (and is equal to another NAN).
impl Ord for F32 {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap_or_else(|| {
if self.0.is_nan() && !other.0.is_nan() {
Ordering::Less
} else if !self.0.is_nan() && other.0.is_nan() {
Ordering::Greater
} else {
Ordering::Equal
}
})
}
}

impl Hash for F32 {
fn hash<H: Hasher>(&self, state: &mut H) {
if self.0.is_nan() {
0x7fc00000u32.hash(state); // a particular bit representation for NAN
} else if self.0 == 0.0 {
// catches both positive and negative zero
0u32.hash(state);
} else {
self.0.to_bits().hash(state);
}
}
}

impl From<F32> for f32 {
fn from(f: F32) -> Self {
f.0
}
}

impl From<f32> for F32 {
fn from(f: f32) -> Self {
F32(f)
}
}

impl fmt::Display for F32 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

#[cfg(test)]
mod tests {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use super::F32;

fn calculate_hash<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}

#[test]
fn f32_eq() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) != F32(5.0));
assert!(F32(5.0) != F32(std::f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}

#[test]
fn f32_cmp() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) < F32(5.0));
assert!(F32(5.0) > F32(std::f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}

#[test]
fn f32_hash() {
assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0)));
assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN)));
}
}
25 changes: 14 additions & 11 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,8 @@ use crate::pfsys::evm::aggregation::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::evm::{aggregation::gen_aggregation_evm_verifier, single::gen_evm_verifier};
#[cfg(not(target_arch = "wasm32"))]
use crate::pfsys::evm::{evm_verify, DeploymentCode};
#[cfg(feature = "render")]
use crate::pfsys::prepare_model_circuit;
use crate::pfsys::{create_keys, load_params, load_vk, save_params, Snark};
use crate::pfsys::{
create_proof_circuit, gen_srs, prepare_data, prepare_model_circuit_and_public_input, save_vk,
verify_proof_circuit,
};
use crate::pfsys::{create_proof_circuit, gen_srs, prepare_data, save_vk, verify_proof_circuit};
#[cfg(not(target_arch = "wasm32"))]
use ethers::providers::Middleware;
use halo2_proofs::dev::VerifyFailure;
Expand Down Expand Up @@ -217,7 +212,8 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
ref output,
} => {
let data = prepare_data(data.to_string())?;
let circuit = prepare_model_circuit::<Fr>(&data, &cli.args)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
info!("Rendering circuit");

// Create the area we want to draw on.
Expand Down Expand Up @@ -256,8 +252,10 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
}
Commands::Mock { ref data, model: _ } => {
let data = prepare_data(data.to_string())?;
let (circuit, public_inputs) =
prepare_model_circuit_and_public_input::<Fr>(&data, &cli)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;

info!("Mock proof");

let prover = MockProver::run(cli.args.logrows, &circuit, public_inputs)
Expand All @@ -278,7 +276,9 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
} => {
let data = prepare_data(data.to_string())?;

let (_, public_inputs) = prepare_model_circuit_and_public_input::<Fr>(&data, &cli)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;
let num_instance = public_inputs.iter().map(|x| x.len()).collect();
let mut params: ParamsKZG<Bn256> =
load_params::<KZGCommitmentScheme<Bn256>>(params_path.to_path_buf())?;
Expand Down Expand Up @@ -333,7 +333,10 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
} => {
let data = prepare_data(data.to_string())?;

let (circuit, public_inputs) = prepare_model_circuit_and_public_input(&data, &cli)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;

let mut params: ParamsKZG<Bn256> =
load_params::<KZGCommitmentScheme<Bn256>>(params_path.to_path_buf())?;
info!("downsizing params to {} logrows", cli.args.logrows);
Expand Down
Loading

0 comments on commit bba10d1

Please sign in to comment.