Skip to content

Commit

Permalink
feat: allow for single channel conv kernels (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Apr 24, 2023
1 parent b5e05d3 commit 2e1e756
Show file tree
Hide file tree
Showing 35 changed files with 1,618 additions and 850 deletions.
305 changes: 259 additions & 46 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ env_logger = { version = "0.10.0", optional = true}
colored_json = { version = "3.0.1", optional = true}
tokio = { version = "1.26.0", features = ["macros", "rt"] }
puruspe = "0.2.0"
rayon = "*"
serde_traitobject = "0.2.8"
bincode = "*"


# python binding related deps
pyo3 = { version = "0.18.2", features = ["extension-module", "abi3-py37"], optional = true }
pyo3-log = { version = "0.8.1", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion benches/accum_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Sum),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();
Ok(())
Expand Down
18 changes: 1 addition & 17 deletions examples/onnx/1l_div/input.json
Original file line number Diff line number Diff line change
@@ -1,17 +1 @@
{
"input_data": [
[
0.05301234
]
],
"input_shapes": [
[
1
]
],
"output_data": [
[
0.0048828125
]
]
}
{"input_data":[[0.05301234]],"input_shapes":[[1]],"output_data":[[0.005554199]]}
14 changes: 14 additions & 0 deletions examples/onnx/1l_eltwise_div/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch import nn
from ezkl import export


class Circuit(nn.Module):
def __init__(self, inplace=False):
super(Circuit, self).__init__()

def forward(self, x):
return x/(2*x)


circuit = Circuit()
export(circuit, input_shape=[1])
1 change: 1 addition & 0 deletions examples/onnx/1l_eltwise_div/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_shapes": [[1]], "input_data": [[0.009540927596390247]], "output_data": [[0.5]]}
Binary file added examples/onnx/1l_eltwise_div/network.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/onnx/1l_gelu_noappx/input.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"input_data":[[0.61017877,0.21496391,0.8960367]],"input_shapes":[[3]],"output_data":[[0.44274902,0.12817383,0.72998047]]}
{"input_data":[[0.61017877,0.21496391,0.8960367]],"input_shapes":[[3]],"output_data":[[0.44274902,0.12817383,0.73349]]}
9 changes: 4 additions & 5 deletions examples/onnx/1l_instance_norm/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
from ezkl import export
import torch


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()

self.layer = nn.InstanceNorm2d(3).eval()

def forward(self, x):
return [self.layer(x)]

circuit = MyModel()
export(circuit, input_shape = [3,2,2])


circuit = MyModel()
export(circuit, input_shape=[3, 2, 2])
2 changes: 1 addition & 1 deletion examples/onnx/1l_instance_norm/input.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"input_data":[[0.012968451,0.07738311,0.026791424,0.030387182,0.0008226812,0.0008844912,0.050431013,0.0018652737,0.022439813,0.058267362,0.032021005,0.015621644]],"input_shapes":[[3,2,2]],"output_data":[[-0.70703125,0.0,-0.61865234,-0.53027344,-0.88378906,-0.88378906,-0.35351562,-0.88378906,-0.61865234,-0.26513672,-0.53027344,-0.70703125]]}
{"input_data":[[0.0008132875,0.061658032,0.06964847,0.0367831]],"input_shapes":[[1,2,2]],"output_data":[[-1.5609741,0.7095337,0.99334717,-0.14190674]]}
Binary file modified examples/onnx/1l_instance_norm/network.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/onnx/1l_var/input.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"input_data":[[0.049622517,0.09307936,0.07327823,0.03837678,0.04325245,0.06161573,0.07465642,0.00059551,0.058175515]],"input_shapes":[[3,3]],"output_data":[[0.0]]}
{"input_data":[[0.05369061,0.053044915,0.04497401,0.05943901,0.09750504,0.056985468,0.08049235,0.064583704,0.07241523]],"input_shapes":[[3,3]],"output_data":[[0.00018494483]]}
Binary file modified examples/onnx/1l_var/network.onnx
Binary file not shown.
1 change: 1 addition & 0 deletions examples/onnx/mobilenet/input.json

Large diffs are not rendered by default.

Binary file added examples/onnx/mobilenet/network.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/onnx/tutorial/input.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"input_data":[[0.01252346,0.044559505,0.03341365,0.09370565,0.045316238,0.08961995,0.017141312,0.026875867,0.0531207,0.021451747,0.06615644,0.047222037],[0.08351613,0.022880394,0.017713035,0.0059509515,0.017137295,0.081684746,0.0034027814,0.09778482,0.06083488,0.029793901,0.06421044,0.0028769136],[0.018726094,0.02259075,0.033056613,0.06125168,0.08233156,0.08846317,0.04373407,0.07020334,0.022839189,0.0063982666,0.084016055,0.0008106947]],"input_shapes":[[3,2,2],[3,2,2],[3,2,2]],"output_data":[[2.515625,2.515625,2.515625,2.515625,2.515625,2.515625,2.515625,2.5234375,2.5234375,2.515625,2.515625,2.515625,2.5234375,2.53125,2.515625,2.515625,2.515625,2.515625,2.5234375,2.515625,2.515625,2.515625,2.515625,2.515625,2.515625,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5625,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5625,2.5546875,2.5546875,2.5546875,2.546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.5546875,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4375,2.4375,2.4453125,2.4453125,2.4453125,2.4453125,2.4375,2.4375,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125,2.4453125],[0.0078125,0.0078125,0.0078125,0.0234375,0.03125,0.03125,0.015625,0.0234375,0.0078125,0.0,0.03125,0.0]]}
{"input_data":[[0.01252346,0.044559505,0.03341365,0.09370565,0.045316238,0.08961995,0.017141312,0.026875867,0.0531207,0.021451747,0.06615644,0.047222037],[0.08351613,0.022880394,0.017713035,0.0059509515,0.017137295,0.081684746,0.0034027814,0.09778482,0.06083488,0.029793901,0.06421044,0.0028769136],[0.018726094,0.02259075,0.033056613,0.06125168,0.08233156,0.08846317,0.04373407,0.07020334,0.022839189,0.0063982666,0.084016055,0.0008106947]],"input_shapes":[[3,2,2],[3,2,2],[3,2,2]],"output_data":[[2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5625,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.5625,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375,2.4375],[0.0,0.0,0.01953125,0.01953125,0.01953125,0.01953125,0.01953125,0.01953125,0.0,0.0,0.01953125,0.0]]}
23 changes: 20 additions & 3 deletions src/bin/ezkl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ pub fn level_color(level: &log::Level, msg: &str) -> String {
.to_string()
}

pub fn level_text_color(level: &log::Level, msg: &str) -> String {
match level {
Level::Error => msg.red(),
Level::Warn => msg.yellow(),
Level::Info => msg.white(),
Level::Debug => msg.white(),
Level::Trace => msg.white(),
}
.bold()
.to_string()
}

fn level_token(level: &Level) -> &str {
match *level {
Level::Error => "E",
Expand All @@ -45,11 +57,12 @@ fn prefix_token(level: &Level) -> String {

pub fn format(buf: &mut Formatter, record: &Record<'_>) -> Result<(), std::fmt::Error> {
let sep = format!("\n{} ", " | ".white().bold());
let level = record.level();
writeln!(
buf,
"{} {}",
prefix_token(&record.level()),
format!("{}", record.args()).replace('\n', &sep),
prefix_token(&level),
format!("{}", level_color(&level, record.args().as_str().unwrap())).replace('\n', &sep),
)
}

Expand All @@ -64,7 +77,11 @@ pub fn init_logger() {
prefix_token(&record.level()),
start.elapsed().as_secs(),
record.metadata().target(),
format!("{}", record.args()).replace('\n', &format!("\n{} ", " | ".white().bold())),
format!(
"{}",
level_text_color(&record.level(), &format!("{}", record.args()))
)
.replace('\n', &format!("\n{} ", " | ".white().bold()))
)
});
builder.target(env_logger::Target::Stdout);
Expand Down
137 changes: 24 additions & 113 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use halo2_proofs::circuit::Region;
use halo2curves::FieldExt;
use itertools::Itertools;

use crate::{
circuit::{layouts, utils},
graph::scale_to_multiplier,
circuit::layouts,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};

Expand All @@ -12,118 +11,52 @@ use super::{lookup::LookupOp, Op};
#[allow(missing_docs)]
/// An enum representing the operations that can be used to express more complex operations via accumulation
#[derive(Clone, Debug)]
pub enum HybridOp<F: FieldExt + TensorType> {
Mean {
scale: usize,
num_inputs: usize,
},
Max,
EltWiseMax {
a: Option<ValTensor<F>>,
pub enum HybridOp {
Max {
axes: Vec<usize>,
},
MaxPool2d {
padding: (usize, usize),
stride: (usize, usize),
pool_dims: (usize, usize),
},
Min,
PReLU {
scale: usize,
slopes: Vec<crate::circuit::utils::F32>,
},
Greater {
a: Option<ValTensor<F>>,
Min {
axes: Vec<usize>,
},
}

impl<F: FieldExt + TensorType> Op<F> for HybridOp<F> {
impl<F: FieldExt + TensorType> Op<F> for HybridOp {
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<i128>]) -> Result<Tensor<i128>, TensorError> {
match &self {
HybridOp::Mean { scale, .. } => {
Ok(tensor::ops::nonlinearities::mean(&inputs[0], *scale))
}
HybridOp::Greater { .. } => Ok(inputs[0]
.iter()
.zip(inputs[1].iter())
.map(|(a, b)| if a > b { 1 } else { 0 })
.collect_vec()
.into_iter()
.into()),
HybridOp::Max => Ok(Tensor::new(
Some(&[inputs[0].clone().into_iter().max().unwrap()]),
&[1],
)?),

HybridOp::EltWiseMax { .. } => Ok(Tensor::new(
Some(&[inputs[0].clone().into_iter().max().unwrap()]),
&inputs[0].dims(),
)?),
HybridOp::Max { axes, .. } => Ok(tensor::ops::max_axes(&inputs[0], axes)?),

HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => tensor::ops::max_pool2d(&inputs[0], padding, stride, pool_dims),
HybridOp::Min => Ok(Tensor::new(
Some(&[inputs[0].clone().into_iter().min().unwrap()]),
&[1],
)?),
HybridOp::PReLU { scale, slopes } => Ok(tensor::ops::nonlinearities::prelu(
&inputs[0],
*scale,
&slopes.iter().map(|e| e.0).collect_vec(),
)),
HybridOp::Min { axes, .. } => Ok(tensor::ops::min_axes(&inputs[0], axes)?),
}
}

fn as_str(&self) -> &'static str {
match &self {
HybridOp::EltWiseMax { .. } => "ELTWISEMAX",
HybridOp::Mean { .. } => "MEAN",
HybridOp::Max => "MAX",
HybridOp::Greater { .. } => "GREATER",
HybridOp::Max { .. } => "MAX",
HybridOp::MaxPool2d { .. } => "MAXPOOL2D",
HybridOp::Min => "MIN",
HybridOp::PReLU { .. } => "PRELU",
HybridOp::Min { .. } => "MIN",
}
}

fn layout(
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: Option<&mut halo2_proofs::circuit::Region<F>>,
region: Option<&mut Region<F>>,
values: &[ValTensor<F>],
offset: &mut usize,
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
let mut values = values.to_vec();
Ok(match self {
HybridOp::PReLU { scale, .. } => Some(layouts::prelu(
config,
region,
values[..].try_into()?,
*scale,
offset,
)?),
HybridOp::EltWiseMax { a } => {
if let Some(a) = a {
values.push(a.clone());
}
todo!("EltWiseMax")
}
HybridOp::Greater { a } => {
if let Some(a) = a {
values.push(a.clone());
}
todo!()
}
HybridOp::Mean { scale, .. } => Some(layouts::mean(
config,
region,
values[..].try_into()?,
*scale,
offset,
)?),
HybridOp::MaxPool2d {
padding,
stride,
Expand All @@ -137,16 +70,18 @@ impl<F: FieldExt + TensorType> Op<F> for HybridOp<F> {
*pool_dims,
offset,
)?),
HybridOp::Max => Some(layouts::max(
HybridOp::Max { axes } => Some(layouts::max_axes(
config,
region,
values[..].try_into()?,
axes,
offset,
)?),
HybridOp::Min => Some(layouts::min(
HybridOp::Min { axes } => Some(layouts::min_axes(
config,
region,
values[..].try_into()?,
axes,
offset,
)?),
})
Expand All @@ -156,39 +91,15 @@ impl<F: FieldExt + TensorType> Op<F> for HybridOp<F> {
in_scales[0]
}

fn has_3d_input(&self) -> bool {
matches!(self, HybridOp::MaxPool2d { .. })
}

fn rescale(&self, inputs_scale: Vec<u32>, global_scale: u32) -> Box<dyn Op<F>> {
let mult = scale_to_multiplier(inputs_scale[0] - global_scale);
match self {
HybridOp::PReLU { scale: _, slopes } => Box::new(HybridOp::PReLU {
scale: mult as usize,
slopes: slopes.to_vec(),
}),
HybridOp::Mean {
scale: _,
num_inputs,
} => Box::new(HybridOp::Mean {
scale: mult as usize,
num_inputs: *num_inputs,
}),
_ => Box::new(self.clone()),
}
fn rescale(&self, _: Vec<u32>, _: u32) -> Box<dyn Op<F>> {
Box::new(self.clone())
}

fn required_lookup(&self) -> Option<LookupOp> {
fn required_lookups(&self) -> Vec<LookupOp> {
match self {
HybridOp::PReLU { scale, .. } => Some(LookupOp::ReLU { scale: *scale }),
HybridOp::Max
| HybridOp::Min
| HybridOp::MaxPool2d { .. }
| HybridOp::Greater { .. }
| HybridOp::EltWiseMax { .. } => Some(LookupOp::ReLU { scale: 1 }),
HybridOp::Mean { scale, num_inputs } => Some(LookupOp::Div {
denom: utils::F32((*scale * *num_inputs) as f32),
}),
HybridOp::Max { .. } | HybridOp::Min { .. } | HybridOp::MaxPool2d { .. } => {
Op::<F>::required_lookups(&LookupOp::ReLU { scale: 1 })
}
}
}

Expand Down
Loading

0 comments on commit 2e1e756

Please sign in to comment.