Skip to content

Commit

Permalink
Feat: burn-import implement ONNX ConstantOfShape
Browse files Browse the repository at this point in the history
  • Loading branch information
hexd0t committed May 26, 2024
1 parent c7ad25a commit c224e40
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 31 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ represent the corresponding Burn Op.
| [Concat][30] |||
| [ConcatFromSequence][31] |||
| [Constant][32] |||
| [ConstantOfShape][33] | | |
| [ConstantOfShape][33] | | |
| [Conv1d][34] |||
| [Conv2d][34] |||
| [ConvInteger][37] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ fn main() {
.input("tests/mask_where/mask_where.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
.out_dir("model/")
.run_from_script();

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3

# used to generate model: constant_of_shape.onnx

# torch simplifies simple usecases where it can statically determine the shape of the constant
# to use just ONNX constants instead of ConstantOfShape
# Hence this model is exported using onnx directly

import onnx
import onnx.helper


def build_model():
return onnx.helper.make_model(
ir_version=8,
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
graph=onnx.helper.make_graph(name="main_graph", nodes=[
onnx.helper.make_node(
"ConstantOfShape",
inputs=["input1"],
outputs=["output1"],
name="/ConstantOfShape",
value=onnx.helper.make_tensor("value", data_type=onnx.TensorProto.FLOAT, dims=[1], vals=[1.125])
),
],
inputs=[
onnx.helper.make_value_info(
name="input1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT64, shape=[2, 3, 2]
),
)
],
outputs=[
onnx.helper.make_value_info(
name="output1",
type_proto=onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=[2, 3, 2]
),
)
]),
)


def main():
onnx_model = build_model()
file_name = "constant_of_shape.onnx"

onnx.save(onnx_model, file_name)


if __name__ == "__main__":
main()
16 changes: 15 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ include_models!(
unsqueeze_opset16,
unsqueeze_opset11,
squeeze_opset16,
squeeze_opset13
squeeze_opset13,
constant_of_shape
);

#[cfg(test)]
Expand Down Expand Up @@ -1410,4 +1411,17 @@ mod tests {
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn constant_of_shape() {
let device = Default::default();
let model = constant_of_shape::Model::<Backend>::new(&device);
let shape = Shape::from([2, 3, 2]);
let expected = Tensor::<Backend, 3>::full(shape.clone(), 1.125f64, &device).to_data();

let input = Tensor::ones(shape, &device);
let output = model.forward(input);

output.to_data().assert_approx_eq(&expected, 3);
}
}
14 changes: 9 additions & 5 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::{
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode,
constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, squeeze::SqueezeNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -99,6 +100,7 @@ pub enum Node<PS: PrecisionSettings> {
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
ConstantOfShape(ConstantOfShapeNode),
}

macro_rules! match_all {
Expand Down Expand Up @@ -129,6 +131,7 @@ macro_rules! match_all {
Node::Unary(node) => $func(node),
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Node::ConstantOfShape(node) => $func(node),
}
}};
}
Expand Down Expand Up @@ -169,6 +172,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Unary(unary) => unary.kind.as_str(),
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Node::ConstantOfShape(_) => "constant_of_shape",
}
}
}
Expand Down
171 changes: 171 additions & 0 deletions crates/burn-import/src/burn/node/constant_of_shape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

/// Node for all unary operators.
#[derive(Debug, Clone)]
pub struct ConstantOfShapeNode {
pub input: Type,
pub output: Type,
pub value: ConstantValue,
}

#[derive(Debug, Clone)]
pub enum ConstantValue {
/// Float constant.
Float32(f32),
Float64(f64),

/// Integer constant.
Int32(i32),
Int64(i64),

// Boolean constant.
Bool(bool),
}

impl ConstantOfShapeNode {
pub fn new(input: Type, output: Type, value: ConstantValue) -> Self {
assert!(
matches!(input, Type::Tensor(_)),
"ConstantOfShape input needs to be a Tensor!"
);
assert!(
matches!(output, Type::Tensor(_)),
"ConstantOfShape output needs to be a Tensor!"
);
Self {
input,
output,
value,
}
}
}

impl ConstantValue {
pub fn val_tokens(&self) -> TokenStream {
match self {
Self::Float32(val) => quote! { #val },
Self::Float64(val) => quote! { #val },
Self::Int32(val) => quote! { #val },
Self::Int64(val) => quote! { #val },
Self::Bool(val) => quote! { #val },
}
}

pub fn from_vec<T: Into<Self> + Copy>(mut source: Vec<T>) -> Self {
assert_eq!(
source.len(),
1,
"ConstantOfShape value from a vec needs to have exactly 1 element!"
);
source.drain(..).next().unwrap().into()
}
}

impl From<f32> for ConstantValue {
fn from(value: f32) -> Self {
Self::Float32(value)
}
}
impl From<f64> for ConstantValue {
fn from(value: f64) -> Self {
Self::Float64(value)
}
}
impl From<i32> for ConstantValue {
fn from(value: i32) -> Self {
Self::Int32(value)
}
}
impl From<i64> for ConstantValue {
fn from(value: i64) -> Self {
Self::Int64(value)
}
}
impl From<bool> for ConstantValue {
fn from(value: bool) -> Self {
Self::Bool(value)
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantOfShapeNode {
fn input_types(&self) -> Vec<Type> {
vec![self.input.clone()]
}

fn output_types(&self) -> Vec<Type> {
vec![self.output.clone()]
}

fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let output = self.output.name();
let input = self.input.name();
let value = self.value.val_tokens();
quote! {
let #output = Tensor::full(#input.shape(), #value, &*self.device);
}
}

fn into_node(self) -> Node<PS> {
Node::ConstantOfShape(self)
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{constant_of_shape::ConstantOfShapeNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ConstantOfShapeNode::new(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
ConstantValue::Float32(1.25f32),
));

graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = Tensor::full(tensor1.shape(), 1.25f32, &*self.device);

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub(crate) mod binary;
pub(crate) mod clip;
pub(crate) mod concat;
pub(crate) mod constant;
pub(crate) mod constant_of_shape;
pub(crate) mod conv1d;
pub(crate) mod conv2d;
pub(crate) mod conv_transpose_2d;
Expand Down
19 changes: 19 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::PRelu => same_as_input(node),
NodeType::Where => where_update_outputs(node),
NodeType::Squeeze => squeeze_update_output(node),
NodeType::ConstantOfShape => constant_of_shape_update_output(node),
// Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated.
_ => temporary_pass_through_stub(node),
}
Expand Down Expand Up @@ -120,6 +121,24 @@ fn constant_update_outputs(node: &mut Node) {
};
}

fn constant_of_shape_update_output(node: &mut Node) {
let value_type = node
.attrs
.get("value")
.map(|v| v.clone().into_tensor().elem_type)
.unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32

if let ArgType::Tensor(input_type) = &node.inputs[0].ty {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: value_type,
dim: input_type.dim,
shape: None,
});
} else {
panic!("ConstantOfShape node must have a Tensor type input");
}
}

/// Infer the shape of the output tensor of a Conv2d node
fn linear_update_outputs(node: &mut Node) {
// Extract the configuration of the linear layer (inputs are known)
Expand Down
Loading

0 comments on commit c224e40

Please sign in to comment.