Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make circuit serializable / deserializable #186

Merged
merged 9 commits into from
Apr 13, 2023
Merged
8 changes: 1 addition & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.64", optional = true }
log = { version = "0.4.17", optional = true }
tabled = { version = "0.9.0", optional = true}
eq-float = "0.1.0"
thiserror = "1.0.38"
hex = "0.4.3"
ethereum_types = { package = "ethereum-types", version = "0.14.1", default-features = false, features = ["std"]}
Expand All @@ -37,6 +36,7 @@ colored = { version = "2.0.0", optional = true}
env_logger = { version = "0.10.0", optional = true}
colored_json = { version = "3.0.1", optional = true}
tokio = { version = "1.26.0", features = ["macros", "rt"] }
bincode = "*"

# python binding related deps
pyo3 = { version = "0.18.2", features = ["extension-module", "abi3-py37"], optional = true }
Expand Down
5 changes: 2 additions & 3 deletions examples/mlp_4d.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use eq_float::F32;
use ezkl_lib::circuit::{BaseConfig as PolyConfig, CheckMode, LookupOp, Op as PolyOp};
use ezkl_lib::fieldutils::i32_to_felt;
use ezkl_lib::tensor::*;
Expand Down Expand Up @@ -67,7 +66,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
&output,
BITS,
&LookupOp::Div {
denom: F32::from(128.),
denom: ezkl_lib::circuit::utils::F32::from(128.),
},
)
.unwrap();
Expand Down Expand Up @@ -152,7 +151,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
&[x.unwrap()],
&mut offset,
LookupOp::Div {
denom: F32::from(128.),
denom: ezkl_lib::circuit::utils::F32::from(128.),
}
.into(),
)
Expand Down
16 changes: 8 additions & 8 deletions src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ impl fmt::Display for BaseOp {

#[allow(missing_docs)]
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Div {
denom: eq_float::F32,
denom: utils::F32,
},
ReLU {
scale: usize,
Expand All @@ -185,11 +185,11 @@ pub enum LookupOp {
},
LeakyReLU {
scale: usize,
slope: eq_float::F32,
slope: utils::F32,
},
PReLU {
scale: usize,
slopes: Vec<eq_float::F32>,
slopes: Vec<utils::F32>,
},
Sigmoid {
scales: (usize, usize),
Expand Down Expand Up @@ -251,7 +251,7 @@ impl LookupOp {

#[allow(missing_docs)]
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum Op {
Dot,
Matmul,
Expand Down Expand Up @@ -504,7 +504,7 @@ impl fmt::Display for Op {
// Eventually, though, we probably want to keep them and treat them directly (layouting and configuring
// at each type of node)
/// Enum of the different kinds of operations `ezkl` can support.
#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)]
#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Deserialize, Serialize)]
pub enum OpKind {
/// A nonlinearity
Lookup(LookupOp),
Expand Down Expand Up @@ -544,13 +544,13 @@ impl OpKind {
}),
"LeakyRelu" => OpKind::Lookup(LookupOp::LeakyReLU {
scale: 1,
slope: eq_float::F32(0.0),
slope: utils::F32(0.0),
}),
"Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }),
"Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }),
"Tanh" => OpKind::Lookup(LookupOp::Tanh { scales: (1, 1) }),
"Div" => OpKind::Lookup(LookupOp::Div {
denom: eq_float::F32(1.0),
denom: utils::F32(1.0),
}),
"Const" => OpKind::Const,
"Source" => OpKind::Input,
Expand Down
138 changes: 138 additions & 0 deletions src/circuit/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,141 @@ pub fn value_muxer<F: FieldExt + TensorType>(
_ => unimplemented!(),
}
}

// --------------------------------------------------------------------------------------------
//
// Float Utils to enable the usage of f32s as the keys of HashMaps
// This section is taken from the `eq_float` crate verbatim -- but we also implement deserialization methods
//
//

use std::cmp::Ordering;
use std::fmt;
use std::hash::{Hash, Hasher};

#[derive(Debug, Default, Clone, Copy)]
/// f32 wrapper
pub struct F32(pub f32);

impl<'de> Deserialize<'de> for F32 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let float = f32::deserialize(deserializer)?;
Ok(F32(float))
}
}

impl Serialize for F32 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
f32::serialize(&self.0, serializer)
}
}

/// This works like `PartialEq` on `f32`, except that `NAN == NAN` is true.
impl PartialEq for F32 {
fn eq(&self, other: &Self) -> bool {
if self.0.is_nan() && other.0.is_nan() {
true
} else {
self.0 == other.0
}
}
}

impl Eq for F32 {}

/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats
/// (and is equal to another NAN). This always returns a `Some`.
impl PartialOrd for F32 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

/// This works like `PartialOrd` on `f32`, except that `NAN` sorts below all other floats
/// (and is equal to another NAN).
impl Ord for F32 {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap_or_else(|| {
if self.0.is_nan() && !other.0.is_nan() {
Ordering::Less
} else if !self.0.is_nan() && other.0.is_nan() {
Ordering::Greater
} else {
Ordering::Equal
}
})
}
}

impl Hash for F32 {
fn hash<H: Hasher>(&self, state: &mut H) {
if self.0.is_nan() {
0x7fc00000u32.hash(state); // a particular bit representation for NAN
} else if self.0 == 0.0 {
// catches both positive and negative zero
0u32.hash(state);
} else {
self.0.to_bits().hash(state);
}
}
}

impl From<F32> for f32 {
fn from(f: F32) -> Self {
f.0
}
}

impl From<f32> for F32 {
fn from(f: f32) -> Self {
F32(f)
}
}

impl fmt::Display for F32 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

#[cfg(test)]
mod tests {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use super::F32;

fn calculate_hash<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}

#[test]
fn f32_eq() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) != F32(5.0));
assert!(F32(5.0) != F32(std::f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}

#[test]
fn f32_cmp() {
assert!(F32(std::f32::NAN) == F32(std::f32::NAN));
assert!(F32(std::f32::NAN) < F32(5.0));
assert!(F32(5.0) > F32(std::f32::NAN));
assert!(F32(0.0) == F32(-0.0));
}

#[test]
fn f32_hash() {
assert!(calculate_hash(&F32(0.0)) == calculate_hash(&F32(-0.0)));
assert!(calculate_hash(&F32(std::f32::NAN)) == calculate_hash(&F32(-std::f32::NAN)));
}
}
25 changes: 14 additions & 11 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,8 @@ use crate::pfsys::evm::aggregation::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::evm::{aggregation::gen_aggregation_evm_verifier, single::gen_evm_verifier};
#[cfg(not(target_arch = "wasm32"))]
use crate::pfsys::evm::{evm_verify, DeploymentCode};
#[cfg(feature = "render")]
use crate::pfsys::prepare_model_circuit;
use crate::pfsys::{create_keys, load_params, load_vk, save_params, Snark};
use crate::pfsys::{
create_proof_circuit, gen_srs, prepare_data, prepare_model_circuit_and_public_input, save_vk,
verify_proof_circuit,
};
use crate::pfsys::{create_proof_circuit, gen_srs, prepare_data, save_vk, verify_proof_circuit};
#[cfg(not(target_arch = "wasm32"))]
use ethers::providers::Middleware;
use halo2_proofs::dev::VerifyFailure;
Expand Down Expand Up @@ -217,7 +212,8 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
ref output,
} => {
let data = prepare_data(data.to_string())?;
let circuit = prepare_model_circuit::<Fr>(&data, &cli.args)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
info!("Rendering circuit");

// Create the area we want to draw on.
Expand Down Expand Up @@ -256,8 +252,10 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
}
Commands::Mock { ref data, model: _ } => {
let data = prepare_data(data.to_string())?;
let (circuit, public_inputs) =
prepare_model_circuit_and_public_input::<Fr>(&data, &cli)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;

info!("Mock proof");

let prover = MockProver::run(cli.args.logrows, &circuit, public_inputs)
Expand All @@ -278,7 +276,9 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
} => {
let data = prepare_data(data.to_string())?;

let (_, public_inputs) = prepare_model_circuit_and_public_input::<Fr>(&data, &cli)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;
let num_instance = public_inputs.iter().map(|x| x.len()).collect();
let mut params: ParamsKZG<Bn256> =
load_params::<KZGCommitmentScheme<Bn256>>(params_path.to_path_buf())?;
Expand Down Expand Up @@ -333,7 +333,10 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
} => {
let data = prepare_data(data.to_string())?;

let (circuit, public_inputs) = prepare_model_circuit_and_public_input(&data, &cli)?;
let model = Model::from_arg()?;
let circuit = ModelCircuit::<Fr>::new(&data, model)?;
let public_inputs = circuit.prepare_public_inputs(&data)?;

let mut params: ParamsKZG<Bn256> =
load_params::<KZGCommitmentScheme<Bn256>>(params_path.to_path_buf())?;
info!("downsizing params to {} logrows", cli.args.logrows);
Expand Down
Loading