Skip to content

Commit

Permalink
Implement ONNX pad
Browse files Browse the repository at this point in the history
  • Loading branch information
JC committed Jul 19, 2024
1 parent 35345de commit 2859976
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 4 deletions.
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ fn main() {
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/pad/pad.onnx")
.input("tests/expand/expand.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
Expand Down
21 changes: 21 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include_models!(
mul,
neg,
not,
pad,
greater,
greater_or_equal,
less,
Expand Down Expand Up @@ -1406,6 +1407,26 @@ mod tests {
output.assert_eq(&expected, true);
}

#[test]
fn pad() {
let device = Default::default();
let model: pad::Model<Backend> = pad::Model::new(&device);

let input = Tensor::<Backend, 2>::from_floats([[1., 2.], [3., 4.], [5., 6.]], &device);
let output = model.forward(input).to_data();
let expected = TensorData::from([
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 1., 2., 0., 0., 0., 0.],
[0.0_f32, 0., 3., 4., 0., 0., 0., 0.],
[0.0_f32, 0., 5., 6., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
[0.0_f32, 0., 0., 0., 0., 0., 0., 0.],
]);

output.assert_eq(&expected, true);
}

#[test]
fn greater() {
let device = Default::default();
Expand Down
Binary file added crates/burn-import/onnx-tests/tests/pad/pad.onnx
Binary file not shown.
158 changes: 158 additions & 0 deletions crates/burn-import/onnx-tests/tests/pad/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/pad/pad.onnx

### Helper Functions ###
from pathlib import Path
from typing import Any
import numpy
from numpy.core.multiarray import dtype
import onnx
from onnx import ModelProto, TensorProto, ValueInfoProto
from onnx.reference import ReferenceEvaluator
from onnx.checker import check_model
from onnx.helper import (
make_model,
make_node,
make_graph,
)


def build_test_save(
name: str,
inputs: list[ValueInfoProto],
outputs: list[ValueInfoProto],
initializers: list[TensorProto] = [],
attributes: dict[str, Any] = {},
) -> None:
node_inputs = [input.name for input in inputs + initializers]
node_outputs = [output.name for output in outputs]

node = make_node(
name.capitalize(),
inputs=node_inputs,
outputs=node_outputs,
**attributes,
)

graph = make_graph(
nodes=[node],
name=f"{name.capitalize()}Graph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)

onnx_model = make_model(graph)
check_model(onnx_model)

run_tests(onnx_model)

onnx.save(onnx_model, Path(__file__).with_name(f"{name}.onnx"))


class TestCase:
def __init__(
self, name: str, feeds: dict[str, numpy.ndarray], expected: numpy.ndarray
):
self.name = name
self.feeds = feeds
self.expected = expected

def test_model(self, model: ModelProto):
sess = ReferenceEvaluator(model)

result = numpy.array(sess.run(None, self.feeds))

if not numpy.array_equal(result, self.expected):
print(
f"""{self.name}
Expected result: {self.expected}
Got: {result}"""
)
raise Exception("Test failed")


def test_positive_pads(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 7, dtype="float32").reshape(3, 2)
pads = numpy.array([1, 2, 3, 4], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array(
[
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]
]
)

TestCase("test_positive_constant_pads", feeds, expected).test_model(model)


def test_1d_input(model: ModelProto) -> None:
input_tensor = numpy.arange(1, 5, dtype="float32")
pads = numpy.array([1, 2], dtype="int")
constant_value = 0.0
feeds = {
"input_tensor": input_tensor,
"pads": pads,
"constant_value": constant_value,
}
expected = numpy.array([[0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0]])

TestCase("test_1d_input", feeds, expected).test_model(model)


def run_tests(model: ModelProto) -> None:
test_positive_pads(model)
test_1d_input(model)
# TODO: test_negative_pads
# TODO: support other modes: reflect, edge, wrap


### Helper Functions End ###

import numpy
from onnx import TensorProto, numpy_helper
from onnx.helper import make_tensor_value_info


def get_initializers() -> list[TensorProto]:
pads = numpy_helper.from_array(
numpy.array([1, 2, 3, 4]).astype(numpy.int64), name="pads"
)
constant_value = numpy_helper.from_array(
numpy.array([0.0]).astype(numpy.float32), name="constant_value"
)

return [pads, constant_value]


def main() -> None:
name = "pad"

inputs = [make_tensor_value_info("input_tensor", TensorProto.FLOAT, [None, None])]
outputs = [make_tensor_value_info("output", TensorProto.FLOAT, [None, None])]
initializers = get_initializers()

build_test_save(
name=name,
inputs=inputs,
outputs=outputs,
initializers=initializers,
attributes={"mode": "constant"},
)


if __name__ == "__main__":
main()
5 changes: 4 additions & 1 deletion crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{
conv_transpose_3d::ConvTranspose3dNode, dropout::DropoutNode, expand::ExpandNode,
gather::GatherNode, gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, pad::PadNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
Expand Down Expand Up @@ -105,6 +105,7 @@ pub enum Node<PS: PrecisionSettings> {
Matmul(MatmulNode),
MaxPool1d(MaxPool1dNode),
MaxPool2d(MaxPool2dNode),
Pad(PadNode),
Range(RangeNode),
Reshape(ReshapeNode),
Resize(ResizeNode),
Expand Down Expand Up @@ -150,6 +151,7 @@ macro_rules! match_all {
Node::Matmul(node) => $func(node),
Node::MaxPool1d(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Node::Pad(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Resize(node) => $func(node),
Expand Down Expand Up @@ -203,6 +205,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Matmul(_) => "matmul",
Node::MaxPool1d(_) => "max_pool1d",
Node::MaxPool2d(_) => "max_pool2d",
Node::Pad(_) => "pad",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Resize(_) => "resize",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod mask_where;
pub(crate) mod matmul;
pub(crate) mod max_pool1d;
pub(crate) mod max_pool2d;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_uniform;
Expand Down
104 changes: 104 additions & 0 deletions crates/burn-import/src/burn/node/pad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use std::str::FromStr;

use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use burn::config::Config;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Config, Debug)]
pub struct PadConfig {
pub pads: Vec<usize>,
pub constant_value: f32,
}

#[derive(Debug, Clone, new)]
pub struct PadNode {
pub input: TensorType,
pub output: TensorType,
pub config: PadConfig,
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for PadNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;

let pads = self.config.pads.iter().map(|p| p.to_tokens());
let constant_value_string = format!("{}_f32.elem()", self.config.constant_value);
let constant_value = TokenStream::from_str(&constant_value_string).unwrap();

quote! {
let #output = #input.pad((#(#pads),*), #constant_value);
}
}
fn into_node(self) -> Node<PS> {
Node::Pad(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{pad::PadNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_pad() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
let config = PadConfig::new(vec![1, 2, 3, 4], -1.0);
graph.register(PadNode::new(
TensorType::new_float("input", 2),
TensorType::new_float("output", 2),
config,
));
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::tensor::ElementConversion;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.pad((1, 2, 3, 4), -1_f32.elem());
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
Loading

0 comments on commit 2859976

Please sign in to comment.