From d0b7ae26f0121835670ea07e70cad2a8adc7f4bd Mon Sep 17 00:00:00 2001 From: "dcbuilder.eth" Date: Thu, 29 Dec 2022 15:02:57 +0100 Subject: [PATCH] add: leaky_relu non-linearity (#72) Co-authored-by: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com> --- .gitignore | 1 + Cargo.lock | 7 +++ Cargo.toml | 2 +- benches/relu.rs | 8 +-- examples/conv2d_mnist/main.rs | 10 +-- examples/mlp_4d.rs | 12 ++-- examples/onnx | 2 +- src/circuit/eltwise.rs | 113 +++++++++++++++++++++++----------- src/graph/model.rs | 51 ++++++++++++--- src/graph/node.rs | 67 +++++++++++++++----- src/graph/vars.rs | 20 ++++-- tests/integration_tests.rs | 10 +-- 12 files changed, 220 insertions(+), 83 deletions(-) diff --git a/.gitignore b/.gitignore index f578af220..66874259c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ data *.params *~ \#*\# +.DS_Store diff --git a/Cargo.lock b/Cargo.lock index dd0c21132..ec00ba5ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1464,6 +1464,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "eq-float" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c02b5d1d1e6ba431b960d4bf971c8b0e2d2942b8cbc577d9bdf9c60fca5d41d" + [[package]] name = "errno" version = "0.2.8" @@ -1815,6 +1821,7 @@ dependencies = [ "criterion", "ctor", "ecc", + "eq-float", "ethereum-types", "foundry-evm", "halo2_proofs", diff --git a/Cargo.toml b/Cargo.toml index 8aa4cb9c8..1a92811f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundr halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", package = "ecc", branch = "master", optional=true} plonk_verifier = { git = "https://github.com/zkonduit/plonk-verifier", branch = "main"} colog = { version = "1.1.0", optional = true } - +eq-float = "0.1.0" [dev-dependencies] criterion = {version = "0.3", features = ["html_reports"]} diff --git a/benches/relu.rs b/benches/relu.rs index 432ea20f9..d7b2251ca 100644 --- a/benches/relu.rs +++ b/benches/relu.rs @@ -1,5 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use ezkl::circuit::eltwise::{EltwiseConfig, Nonlin1d, Nonlinearity, ReLu}; +use ezkl::circuit::eltwise::{EltwiseConfig, Nonlin1d, Nonlinearity, ReLU}; use ezkl::tensor::*; use halo2_proofs::dev::MockProver; use halo2_proofs::{ @@ -37,7 +37,7 @@ impl + Clone> Circuit .map(|_| VarTensor::new_advice(cs, K, LEN, vec![LEN], true, 512)) .collect::>(); - Self::Config::configure(cs, &advices[0], &advices[1], Some(&[BITS, 128])) + Self::Config::configure(cs, &advices[0], &advices[1], BITS, &[128], None) } } @@ -66,13 +66,13 @@ fn runrelu(c: &mut Criterion) { let input: Tensor> = Tensor::::from((0..len).map(|_| rng.gen_range(0..10))).into(); - let assigned: Nonlin1d> = Nonlin1d { + let assigned: Nonlin1d> = Nonlin1d { input: ValTensor::from(input.clone()), output: ValTensor::from(input), _marker: PhantomData, }; - let circuit = NLCircuit::> { + let circuit = NLCircuit::> { assigned, _marker: PhantomData, }; diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index f034a95fd..53b88e645 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -1,4 +1,4 @@ -use ezkl::circuit::eltwise::{EltwiseConfig, ReLu}; +use ezkl::circuit::eltwise::{EltwiseConfig, ReLU}; use ezkl::circuit::fused::*; use ezkl::fieldutils; use ezkl::fieldutils::i32_to_felt; @@ -54,7 +54,7 @@ struct Config< { // this will be a conv layer l0: FusedConfig, - l1: EltwiseConfig>, + l1: EltwiseConfig>, // this will be an affine layer l2: FusedConfig, public_output: Column, @@ -200,8 +200,8 @@ where let input = input.reshape(&[LEN]); let output = output.reshape(&[LEN]); - let l1: EltwiseConfig> = - EltwiseConfig::configure(cs, &input, &output, Some(&[BITS, 32])); + let l1: EltwiseConfig> = + EltwiseConfig::configure(cs, &input, &output, BITS, &[32], None); // tells the config layer to add an affine op to the circuit gate let affine_node = FusedNode { @@ -384,7 +384,7 @@ pub fn runconv() { let root = BitMapBackend::new("conv2dmnist-layout.png", (2048, 7680)).into_drawing_area(); root.fill(&WHITE).unwrap(); let root = root - .titled("Conv -> ReLu -> Affine -> Relu", ("sans-serif", 60)) + .titled("Conv -> ReLU -> Affine -> ReLU", ("sans-serif", 60)) .unwrap(); halo2_proofs::dev::CircuitLayout::default() diff --git a/examples/mlp_4d.rs b/examples/mlp_4d.rs index 7fde20598..f41366d06 100644 --- a/examples/mlp_4d.rs +++ b/examples/mlp_4d.rs @@ -1,4 +1,4 @@ -use ezkl::circuit::eltwise::{DivideBy, EltwiseConfig, ReLu}; +use ezkl::circuit::eltwise::{DivideBy, EltwiseConfig, ReLU}; use ezkl::circuit::fused::*; use ezkl::fieldutils::i32_to_felt; use ezkl::tensor::*; @@ -16,9 +16,9 @@ const K: usize = 15; #[derive(Clone)] struct MyConfig { l0: FusedConfig, - l1: EltwiseConfig>, + l1: EltwiseConfig>, l2: FusedConfig, - l3: EltwiseConfig>, + l3: EltwiseConfig>, l4: EltwiseConfig>, public_output: Column, } @@ -79,12 +79,12 @@ impl Circuit ); // sets up a new ReLU table and resuses it for l1 and l3 non linearities - let [l1, l3]: [EltwiseConfig>; 2] = - EltwiseConfig::configure_multiple(cs, &input, &output, Some(&[BITS, 1])); + let [l1, l3]: [EltwiseConfig>; 2] = + EltwiseConfig::configure_multiple(cs, &input, &output, BITS, &[1], None); // sets up a new Divide by table let l4: EltwiseConfig> = - EltwiseConfig::configure(cs, &input, &output, Some(&[BITS, 128])); + EltwiseConfig::configure(cs, &input, &output, BITS, &[128], None); let public_output: Column = cs.instance_column(); cs.enable_equality(public_output); diff --git a/examples/onnx b/examples/onnx index 051260d89..5209851de 160000 --- a/examples/onnx +++ b/examples/onnx @@ -1 +1 @@ -Subproject commit 051260d898a1e340c9da76749494c2a96019c001 +Subproject commit 5209851defdf1df6399efbe4f7ab30a5d5cbdb1d diff --git a/src/circuit/eltwise.rs b/src/circuit/eltwise.rs index ac5a110bf..4317d9730 100644 --- a/src/circuit/eltwise.rs +++ b/src/circuit/eltwise.rs @@ -17,10 +17,10 @@ pub trait Nonlinearity { /// /// * `x` - input to function /// * `scales` - additional parameters that may parametrize the function - fn nonlinearity(x: i32, scales: &[usize]) -> F; + fn nonlinearity(x: i32, scales: &[usize], param: Option) -> F; /// a value which is always in the table - fn default_pair(scales: &[usize]) -> (F, F) { - (F::zero(), Self::nonlinearity(0, scales)) + fn default_pair(scales: &[usize], param: Option) -> (F, F) { + (F::zero(), Self::nonlinearity(0, scales, param)) } } @@ -45,8 +45,10 @@ pub struct EltwiseTable> { pub table_output: TableColumn, /// Flags if table has been previously assigned to. pub is_assigned: bool, - /// Number of bits used in lookup table. + /// Scaling of the table's inputs. pub scaling_params: Vec, + /// Parameters related to the eltwise function being represented + pub eltwise_params: Option, /// Number of bits used in lookup table. pub bits: usize, _marker: PhantomData<(F, NL)>, @@ -58,12 +60,14 @@ impl> EltwiseTable { cs: &mut ConstraintSystem, bits: usize, scaling_params: &[usize], + eltwise_params: Option, ) -> EltwiseTable { EltwiseTable { table_input: cs.lookup_table_column(), table_output: cs.lookup_table_column(), is_assigned: false, scaling_params: scaling_params.to_vec(), + eltwise_params, bits, _marker: PhantomData, } @@ -94,7 +98,13 @@ impl> EltwiseTable { || format!("nl_o_col row {}", row_offset), self.table_output, row_offset, - || Value::known(NL::nonlinearity(int_input, &self.scaling_params)), + || { + Value::known(NL::nonlinearity( + int_input, + &self.scaling_params, + self.eltwise_params, + )) + }, ) { Ok(a) => a, Err(e) => { @@ -133,13 +143,17 @@ impl> EltwiseConfig, input: &VarTensor, output: &VarTensor, - eltwise_params: Option<&[usize]>, + bits: usize, + scaling_params: &[usize], + eltwise_params: Option, ) -> [Self; NUM] { let mut table: Option>>> = None; let configs = (0..NUM) .map(|_| { let l = match &table { - None => Self::configure(cs, input, output, eltwise_params), + None => { + Self::configure(cs, input, output, bits, scaling_params, eltwise_params) + } Some(t) => Self::configure_with_table(cs, input, output, t.clone()), }; table = Some(l.table.clone()); @@ -168,8 +182,10 @@ impl> EltwiseConfig> EltwiseConfig, input: &VarTensor, output: &VarTensor, - eltwise_params: Option<&[usize]>, + bits: usize, + scaling_params: &[usize], + eltwise_params: Option, ) -> Self { - // will fail if not supplied - let params = match eltwise_params { - Some(p) => p, - None => { - panic!("failed to supply eltwise parameters") - } - }; - let bits = params[0]; let table = Rc::new(RefCell::new(EltwiseTable::::configure( cs, bits, - ¶ms[1..], + scaling_params, + eltwise_params, ))); Self::configure_with_table(cs, input, output, table) } @@ -259,6 +270,7 @@ impl> EltwiseConfig>::nonlinearity( felt_to_i32(f.evaluate()), &self.table.borrow().scaling_params, + self.table.borrow().eltwise_params, ) }) })); @@ -283,11 +295,11 @@ impl> EltwiseConfig { +pub struct ReLU { _marker: PhantomData, } -impl Nonlinearity for ReLu { - fn nonlinearity(x: i32, scale: &[usize]) -> F { +impl Nonlinearity for ReLU { + fn nonlinearity(x: i32, scale: &[usize], _: Option) -> F { if x < 0 { F::zero() } else { @@ -298,6 +310,24 @@ impl Nonlinearity for ReLu { } } +#[allow(missing_docs)] +#[derive(Clone, Debug)] +pub struct LeakyReLU { + _marker: PhantomData, +} + +impl Nonlinearity for LeakyReLU { + fn nonlinearity(x: i32, scale: &[usize], slope: Option) -> F { + if x < 0 { + let d_inv_x = slope.unwrap() * (x as f32) / (scale[0] as f32); + let rounded = d_inv_x.round(); + fieldutils::i32_to_felt(rounded as i32) + } else { + fieldutils::i32_to_felt(x) + } + } +} + #[allow(missing_docs)] #[derive(Clone, Debug)] pub struct Sigmoid { @@ -306,7 +336,7 @@ pub struct Sigmoid { // L is our implicit or explicit denominator (fixed point d) // Usually want K=L impl Nonlinearity for Sigmoid { - fn nonlinearity(x: i32, scale: &[usize]) -> F { + fn nonlinearity(x: i32, scale: &[usize], _: Option) -> F { let kix = (x as f32) / (scale[0] as f32); let fout = (scale[1] as f32) / (1.0 + (-kix).exp()); let rounded = fout.round(); @@ -320,7 +350,7 @@ pub struct DivideBy { _marker: PhantomData, } impl Nonlinearity for DivideBy { - fn nonlinearity(x: i32, scale: &[usize]) -> F { + fn nonlinearity(x: i32, scale: &[usize], _: Option) -> F { let d_inv_x = (x as f32) / (scale[0] as f32); let rounded = d_inv_x.round(); fieldutils::i32_to_felt(rounded as i32) @@ -359,7 +389,7 @@ mod tests { .map(|_| VarTensor::new_advice(cs, 4, 1, vec![1], true, 512)) .collect::>(); - Self::Config::configure(cs, &advices[0], &advices[1], Some(&[2, 1])) + Self::Config::configure(cs, &advices[0], &advices[1], 2, &[1], None) } fn synthesize( @@ -376,11 +406,24 @@ mod tests { #[test] fn test_eltrelunl() { for i in -127..127 { - let r = as Nonlinearity>::nonlinearity(i, &[1]); + let r = as Nonlinearity>::nonlinearity(i, &[1], None); + if i <= 0 { + assert_eq!(r, F::from(0_u64)) + } else { + assert_eq!(r, F::from(i as u64)) + } + } + } + + #[test] + fn test_eltleakyrelunl() { + for i in -127..127 { + let r = as Nonlinearity>::nonlinearity(i, &[1], Some(0.05)); if i <= 0 { - assert!(r == F::from(0_u64)) + println!("{:?}", (0.05 * i as f32)); + assert_eq!(r, -F::from(-(0.05 * i as f32).round() as u64)) } else { - assert!(r == F::from(i as u64)) + assert_eq!(r, F::from(i as u64)) } } } @@ -388,21 +431,21 @@ mod tests { #[test] fn test_eltsigmoid() { for i in -127..127 { - let r = as Nonlinearity>::nonlinearity(i, &[1, 1]); + let r = as Nonlinearity>::nonlinearity(i, &[1, 1], None); let exp_sig = (1.0 / (1.0 + (-i as f32).exp())).round(); - assert!(r == F::from(exp_sig as u64)) + assert_eq!(r, F::from(exp_sig as u64)) } } #[test] fn test_eltdivide() { for i in -127..127 { - let r = as Nonlinearity>::nonlinearity(i, &[1]); + let r = as Nonlinearity>::nonlinearity(i, &[1], None); println!("{:?}, {:?}, {:?}", i, r, F::from(-i as u64)); if i <= 0 { - assert!(r == -F::from(-i as u64)) + assert_eq!(r, -F::from(-i as u64)) } else { - assert!(r == F::from(i as u64)) + assert_eq!(r, F::from(i as u64)) } } } @@ -411,13 +454,13 @@ mod tests { fn relucircuit() { let input: Tensor> = Tensor::new(Some(&[Value::::known(F::from(1_u64))]), &[1]).unwrap(); - let assigned: Nonlin1d> = Nonlin1d { + let assigned: Nonlin1d> = Nonlin1d { input: ValTensor::from(input.clone()), output: ValTensor::from(input), _marker: PhantomData, }; - let circuit = NLCircuit::> { + let circuit = NLCircuit::> { assigned, _marker: PhantomData, }; diff --git a/src/graph/model.rs b/src/graph/model.rs index dd303b6d7..32dff0394 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -1,7 +1,7 @@ use super::node::*; use super::utilities::scale_to_multiplier; use super::vars::*; -use crate::circuit::eltwise::{DivideBy, EltwiseConfig, ReLu, Sigmoid}; +use crate::circuit::eltwise::{DivideBy, EltwiseConfig, LeakyReLU, ReLU, Sigmoid}; use crate::circuit::fused::*; use crate::circuit::range::*; use crate::commands::{Cli, Commands}; @@ -39,7 +39,7 @@ pub enum Mode { Verify, } -/// A circuit configuration for a model loaded from an Onnx file. +/// A circuit configuration for the entirety of a model loaded from an Onnx file. #[derive(Clone)] pub struct ModelConfig { configs: BTreeMap>, @@ -391,7 +391,7 @@ impl Model { NodeConfig::Divide(conf, node_inputs) } else { let conf: EltwiseConfig> = - EltwiseConfig::configure(meta, input, output, Some(&[self.bits, *s])); + EltwiseConfig::configure(meta, input, output, self.bits, &[*s], None); tables.insert( node.opkind.clone(), TableTypes::DivideBy(conf.table.clone()), @@ -401,17 +401,43 @@ impl Model { } OpKind::ReLU(s) => { if tables.contains_key(&node.opkind) { - let table = tables.get(&node.opkind).unwrap(); - let conf: EltwiseConfig> = + let table = tables.get(&node.opkind).unwrap().clone(); + let conf: EltwiseConfig> = EltwiseConfig::configure_with_table(meta, input, output, table.get_relu()); NodeConfig::ReLU(conf, node_inputs) } else { - let conf: EltwiseConfig> = - EltwiseConfig::configure(meta, input, output, Some(&[self.bits, *s])); - tables.insert(node.opkind.clone(), TableTypes::ReLu(conf.table.clone())); + let conf: EltwiseConfig> = + EltwiseConfig::configure(meta, input, output, self.bits, &[*s], None); + tables.insert(node.opkind.clone(), TableTypes::ReLU(conf.table.clone())); NodeConfig::ReLU(conf, node_inputs) } } + OpKind::LeakyReLU((scale, slope)) => { + if tables.contains_key(&node.opkind) { + let table = tables.get(&node.opkind).unwrap().clone(); + let conf: EltwiseConfig> = EltwiseConfig::configure_with_table( + meta, + input, + output, + table.get_leakyrelu(), + ); + NodeConfig::LeakyReLU(conf, node_inputs) + } else { + let conf: EltwiseConfig> = EltwiseConfig::configure( + meta, + input, + output, + self.bits, + &[*scale], + Some(slope.0), + ); + tables.insert( + node.opkind.clone(), + TableTypes::LeakyReLU(conf.table.clone()), + ); + NodeConfig::LeakyReLU(conf, node_inputs) + } + } OpKind::Sigmoid(s) => { if tables.contains_key(&node.opkind) { let table = tables.get(&node.opkind).unwrap(); @@ -423,7 +449,9 @@ impl Model { meta, input, output, - Some(&[self.bits, *s, scale_to_multiplier(self.scale) as usize]), + self.bits, + &[self.bits, *s, scale_to_multiplier(self.scale) as usize], + None, ); tables.insert(node.opkind.clone(), TableTypes::Sigmoid(conf.table.clone())); NodeConfig::Sigmoid(conf, node_inputs) @@ -548,6 +576,11 @@ impl Model { Some(ac.layout(layouter, &values)) } + NodeConfig::LeakyReLU(rc, idx) => { + assert_eq!(idx.len(), 1); + // For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output. + Some(rc.layout(layouter, inputs.get(&idx[0]).unwrap().clone())) + } NodeConfig::ReLU(rc, idx) => { assert_eq!(idx.len(), 1); // For activations and elementwise operations, the dimensions are sometimes only in one or the other of input and output. diff --git a/src/graph/node.rs b/src/graph/node.rs index 0ebade307..370a6a452 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -1,13 +1,11 @@ use super::utilities::{node_output_shapes, scale_to_multiplier, vector_to_quantized}; -use crate::circuit::eltwise::{DivideBy, EltwiseConfig, ReLu, Sigmoid}; -use crate::circuit::fused::*; - use crate::abort; +use crate::circuit::eltwise::{DivideBy, EltwiseConfig, LeakyReLU, ReLU, Sigmoid}; +use crate::circuit::fused::*; use crate::tensor::ops::{add, const_mult, div, mult}; use crate::tensor::Tensor; use crate::tensor::TensorType; use anyhow::Result; - use halo2_proofs::arithmetic::FieldExt; use itertools::Itertools; use log::{error, info, trace, warn}; @@ -19,6 +17,7 @@ use tract_onnx::prelude::{DatumType, InferenceFact, Node as OnnxNode, OutletId}; use tract_onnx::tract_hir::{ infer::Factoid, internal::InferenceOp, + ops::activations::LeakyRelu, ops::cnn::{Conv, PoolSpec, SumPool}, //MaxPool,}, ops::expandable::Expansion, ops::nn::DataFormat, @@ -37,6 +36,8 @@ use tract_onnx::tract_hir::{ pub enum OpKind { /// A ReLU nonlinearity ReLU(usize), + /// A Leaky ReLU nonlinearity with slope parameter + LeakyReLU((usize, eq_float::F32)), /// A Sigmoid nonlinearity Sigmoid(usize), /// A DivideBy nonlinearity @@ -59,7 +60,8 @@ impl OpKind { pub fn new(name: &str) -> Self { match name { "Clip" => OpKind::ReLU(1), - "Prelu" => OpKind::ReLU(1), + "Prelu" => OpKind::LeakyReLU((1, eq_float::F32(0.0))), + "LeakyRelu" => OpKind::LeakyReLU((1, eq_float::F32(0.0))), "Sigmoid" => OpKind::Sigmoid(1), "Div" => OpKind::Div(1), "Const" => OpKind::Const, @@ -101,6 +103,7 @@ impl fmt::Display for OpKind { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { OpKind::ReLU(s) => write!(f, "relu w/ scaling: {}", s), + OpKind::LeakyReLU(s) => write!(f, "leaky relu w/ scaling: {} and slope {}", s.0, s.1), OpKind::Div(s) => write!(f, "div w/ scaling: {}", s), OpKind::Sigmoid(s) => write!(f, "sigmoid w/ scaling: {}", s), OpKind::Const => write!(f, "const"), @@ -116,7 +119,8 @@ impl fmt::Display for OpKind { #[allow(missing_docs)] #[derive(Clone, Default, Debug)] pub enum NodeConfig { - ReLU(EltwiseConfig>, Vec), + ReLU(EltwiseConfig>, Vec), + LeakyReLU(EltwiseConfig>, Vec), Sigmoid(EltwiseConfig>, Vec), Divide(EltwiseConfig>, Vec), Fused(FusedConfig, Vec), @@ -180,10 +184,6 @@ impl NodeGraph { } } -// /// A circuit configuration for a single self. -// #[derive(Clone, Default, Debug)] -// pub struct NodeConfig(pub NodeConfigTypes); - fn display_option(o: &Option) -> String { match o { Some(s) => format!("{:?}", s), @@ -279,6 +279,7 @@ impl Node { idx: usize, ) -> Self { trace!("Create {:?}", node); + trace!("Create op {:?}", node.op); let output_shapes = match node_output_shapes(&node) { Ok(s) => Some(s), _ => None, @@ -344,6 +345,46 @@ impl Node { ..Default::default() } } + OpKind::LeakyReLU((mut layer_scale, _)) => { + let input_node = &inputs[0]; + + // Extract the slope layer hyperparams + let op = Box::new(node.op()); + + let leaky_op: &LeakyRelu = match op.downcast_ref::>() { + Some(b) => match (*b).as_any().downcast_ref() { + Some(b) => b, + None => { + panic!("not a leaky relu!"); + } + }, + None => { + panic!("op is not a Tract Expansion!"); + } + }; + + let scale_diff = input_node.out_scale - scale; + // We can also consider adjusting the scale of all inputs and the output in a more custom way. + let mut output_max = input_node.output_max; + if scale_diff > 0 { + layer_scale = scale_to_multiplier(scale_diff) as usize; + output_max = input_node.output_max / (layer_scale as f32); + } + + opkind = OpKind::LeakyReLU((layer_scale as usize, eq_float::F32(leaky_op.0))); // now the input will be scaled down to match + + Node { + idx, + opkind, + inputs: node.inputs.clone(), + in_dims: vec![input_node.out_dims.clone()], + out_dims: input_node.out_dims.clone(), + in_scale: input_node.out_scale, + out_scale: scale, + output_max, + ..Default::default() + } + } OpKind::Div(_) => { if inputs[1].out_dims.clone() != [1] { abort!("ezkl currently only supports division by a constant"); @@ -394,13 +435,11 @@ impl Node { Some(b) => match (*b).as_any().downcast_ref() { Some(b) => b, None => { - error!("not a conv!"); - panic!() + panic!("not a conv!"); } }, None => { - error!("op is not a Tract Expansion!"); - panic!() + panic!("op is not a Tract Expansion!"); } }; diff --git a/src/graph/vars.rs b/src/graph/vars.rs index 785228816..6a4f7daff 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -1,5 +1,5 @@ use crate::abort; -use crate::circuit::eltwise::{DivideBy, EltwiseTable, ReLu, Sigmoid}; +use crate::circuit::eltwise::{DivideBy, EltwiseTable, LeakyReLU, ReLU, Sigmoid}; use crate::commands::Cli; use crate::tensor::TensorType; use crate::tensor::{ValTensor, VarTensor}; @@ -88,7 +88,9 @@ impl VarVisibility { /// Lookup tables that will be available for reuse. pub enum TableTypes { /// Reference to a ReLU table - ReLu(Rc>>>), + ReLU(Rc>>>), + /// Reference to a leaky ReLU table + LeakyReLU(Rc>>>), /// Reference to a DivideBy table DivideBy(Rc>>>), /// Reference to a Sigmoid table @@ -96,9 +98,9 @@ pub enum TableTypes { } impl TableTypes { /// Get a reference to a reused ReLU lookup table - pub fn get_relu(&self) -> Rc>>> { + pub fn get_relu(&self) -> Rc>>> { match self { - TableTypes::ReLu(inner) => inner.clone(), + TableTypes::ReLU(inner) => inner.clone(), _ => { abort!("fetching wrong table type"); } @@ -122,6 +124,16 @@ impl TableTypes { } } } + + /// Get a reference to a reused Sigmoid lookup table + pub fn get_leakyrelu(&self) -> Rc>>> { + match self { + TableTypes::LeakyReLU(inner) => inner.clone(), + _ => { + abort!("fetching wrong table type"); + } + } + } } /// A wrapper for holding all columns that will be assigned to by a model. diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index f0fa09e91..26c1e919e 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -13,12 +13,13 @@ fn init() { build_ezkl(); } -const TESTS: [&str; 11] = [ +const TESTS: [&str; 12] = [ "1l_mlp", "1l_flatten", "1l_average", "1l_reshape", "1l_sigmoid", + "1l_leakyrelu", "1l_relu", "2l_relu_sigmoid_small", "2l_relu_small", @@ -27,12 +28,13 @@ const TESTS: [&str; 11] = [ "2l_relu_sigmoid_conv", ]; -const TESTS_EVM: [&str; 8] = [ +const TESTS_EVM: [&str; 9] = [ "1l_mlp", "1l_flatten", "1l_average", "1l_reshape", "1l_sigmoid", + "1l_leakyrelu", "1l_relu", "2l_relu_sigmoid_small", "2l_relu_small", @@ -52,7 +54,7 @@ macro_rules! test_func { use crate::ipa_prove_and_verify; use crate::kzg_fullprove; use crate::kzg_prove_and_verify; - seq!(N in 0..=10 { + seq!(N in 0..=11 { #(#[test_case(TESTS[N])])* fn mock_public_outputs_(test: &str) { mock(test.to_string()); @@ -100,7 +102,7 @@ macro_rules! test_func_evm { use crate::TESTS_EVM; use test_case::test_case; use crate::kzg_evm_fullprove; - seq!(N in 0..=7 { + seq!(N in 0..=8 { // these take a particularly long time to run #(#[test_case(TESTS_EVM[N])])* fn kzg_evm_fullprove_(test: &str) {