From c224e4025435ec888bc2ee04fcda38b22dbd0bbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20M=C3=BCller?= Date: Mon, 27 May 2024 00:34:18 +0200 Subject: [PATCH] Feat: burn-import implement ONNX ConstantOfShape --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../constant_of_shape/constant_of_shape.onnx | Bin 0 -> 169 bytes .../constant_of_shape/constant_of_shape.py | 53 ++++++ .../onnx-tests/tests/onnx_tests.rs | 16 +- crates/burn-import/src/burn/node/base.rs | 14 +- .../src/burn/node/constant_of_shape.rs | 171 ++++++++++++++++++ crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/onnx/dim_inference.rs | 19 ++ crates/burn-import/src/onnx/to_burn.rs | 71 +++++--- 10 files changed, 317 insertions(+), 31 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx create mode 100644 crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.py create mode 100644 crates/burn-import/src/burn/node/constant_of_shape.rs diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index ed4115fbd1c..7172a261503 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -38,7 +38,7 @@ represent the corresponding Burn Op. | [Concat][30] | ✅ | ✅ | | [ConcatFromSequence][31] | ❌ | ❌ | | [Constant][32] | ✅ | ✅ | -| [ConstantOfShape][33] | ❌ | ❌ | +| [ConstantOfShape][33] | ✅ | ✅ | | [Conv1d][34] | ✅ | ✅ | | [Conv2d][34] | ✅ | ✅ | | [ConvInteger][37] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 07eb6a20264..f512a5d0135 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -72,6 +72,7 @@ fn main() { .input("tests/mask_where/mask_where.onnx") .input("tests/squeeze/squeeze_opset16.onnx") .input("tests/squeeze/squeeze_opset13.onnx") + .input("tests/constant_of_shape/constant_of_shape.onnx") .out_dir("model/") .run_from_script(); diff --git a/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx b/crates/burn-import/onnx-tests/tests/constant_of_shape/constant_of_shape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1ee68a8fd98f2cdef5c1c5c82beceff9507ea01b GIT binary patch literal 169 zcmd::new(&device); + let shape = Shape::from([2, 3, 2]); + let expected = Tensor::::full(shape.clone(), 1.125f64, &device).to_data(); + + let input = Tensor::ones(shape, &device); + let output = model.forward(input); + + output.to_data().assert_approx_eq(&expected, 3); + } } diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index b90bec69606..afe8f65514b 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -1,11 +1,12 @@ use super::{ avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode, - conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode, - dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, - layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, - squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode, conv2d::Conv2dNode, + conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode, + global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, + mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, + max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, squeeze::SqueezeNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -99,6 +100,7 @@ pub enum Node { Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), + ConstantOfShape(ConstantOfShapeNode), } macro_rules! match_all { @@ -129,6 +131,7 @@ macro_rules! match_all { Node::Unary(node) => $func(node), Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), + Node::ConstantOfShape(node) => $func(node), } }}; } @@ -169,6 +172,7 @@ impl Node { Node::Unary(unary) => unary.kind.as_str(), Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", + Node::ConstantOfShape(_) => "constant_of_shape", } } } diff --git a/crates/burn-import/src/burn/node/constant_of_shape.rs b/crates/burn-import/src/burn/node/constant_of_shape.rs new file mode 100644 index 00000000000..a86b6f2f9cb --- /dev/null +++ b/crates/burn-import/src/burn/node/constant_of_shape.rs @@ -0,0 +1,171 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +/// Node for all unary operators. +#[derive(Debug, Clone)] +pub struct ConstantOfShapeNode { + pub input: Type, + pub output: Type, + pub value: ConstantValue, +} + +#[derive(Debug, Clone)] +pub enum ConstantValue { + /// Float constant. + Float32(f32), + Float64(f64), + + /// Integer constant. + Int32(i32), + Int64(i64), + + // Boolean constant. + Bool(bool), +} + +impl ConstantOfShapeNode { + pub fn new(input: Type, output: Type, value: ConstantValue) -> Self { + assert!( + matches!(input, Type::Tensor(_)), + "ConstantOfShape input needs to be a Tensor!" + ); + assert!( + matches!(output, Type::Tensor(_)), + "ConstantOfShape output needs to be a Tensor!" + ); + Self { + input, + output, + value, + } + } +} + +impl ConstantValue { + pub fn val_tokens(&self) -> TokenStream { + match self { + Self::Float32(val) => quote! { #val }, + Self::Float64(val) => quote! { #val }, + Self::Int32(val) => quote! { #val }, + Self::Int64(val) => quote! { #val }, + Self::Bool(val) => quote! { #val }, + } + } + + pub fn from_vec + Copy>(mut source: Vec) -> Self { + assert_eq!( + source.len(), + 1, + "ConstantOfShape value from a vec needs to have exactly 1 element!" + ); + source.drain(..).next().unwrap().into() + } +} + +impl From for ConstantValue { + fn from(value: f32) -> Self { + Self::Float32(value) + } +} +impl From for ConstantValue { + fn from(value: f64) -> Self { + Self::Float64(value) + } +} +impl From for ConstantValue { + fn from(value: i32) -> Self { + Self::Int32(value) + } +} +impl From for ConstantValue { + fn from(value: i64) -> Self { + Self::Int64(value) + } +} +impl From for ConstantValue { + fn from(value: bool) -> Self { + Self::Bool(value) + } +} + +impl NodeCodegen for ConstantOfShapeNode { + fn input_types(&self) -> Vec { + vec![self.input.clone()] + } + + fn output_types(&self) -> Vec { + vec![self.output.clone()] + } + + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { + let output = self.output.name(); + let input = self.input.name(); + let value = self.value.val_tokens(); + quote! { + let #output = Tensor::full(#input.shape(), #value, &*self.device); + } + } + + fn into_node(self) -> Node { + Node::ConstantOfShape(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{constant_of_shape::ConstantOfShapeNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(ConstantOfShapeNode::new( + Type::Tensor(TensorType::new_float("tensor1", 4)), + Type::Tensor(TensorType::new_float("tensor2", 4)), + ConstantValue::Float32(1.25f32), + )); + + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = Tensor::full(tensor1.shape(), 1.25f32, &*self.device); + + tensor2 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index b22876d8fd8..9baef5fcafc 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -7,6 +7,7 @@ pub(crate) mod binary; pub(crate) mod clip; pub(crate) mod concat; pub(crate) mod constant; +pub(crate) mod constant_of_shape; pub(crate) mod conv1d; pub(crate) mod conv2d; pub(crate) mod conv_transpose_2d; diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 331868227e1..ac03972554e 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -71,6 +71,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::PRelu => same_as_input(node), NodeType::Where => where_update_outputs(node), NodeType::Squeeze => squeeze_update_output(node), + NodeType::ConstantOfShape => constant_of_shape_update_output(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), } @@ -120,6 +121,24 @@ fn constant_update_outputs(node: &mut Node) { }; } +fn constant_of_shape_update_output(node: &mut Node) { + let value_type = node + .attrs + .get("value") + .map(|v| v.clone().into_tensor().elem_type) + .unwrap_or(ElementType::Float32); // If not given, defaults to 0 as float32 + + if let ArgType::Tensor(input_type) = &node.inputs[0].ty { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: value_type, + dim: input_type.dim, + shape: None, + }); + } else { + panic!("ConstantOfShape node must have a Tensor type input"); + } +} + /// Infer the shape of the output tensor of a Conv2d node fn linear_update_outputs(node: &mut Node) { // Extract the configuration of the linear layer (inputs are known) diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index d89b3744805..f6465e98bf7 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -14,30 +14,14 @@ use crate::{ burn::{ graph::BurnGraph, node::{ - avg_pool1d::AvgPool1dNode, - avg_pool2d::AvgPool2dNode, - batch_norm::BatchNormNode, - binary::BinaryNode, - clip::ClipNode, - concat::ConcatNode, - constant::{ConstantNode, ConstantValue, TensorValue}, - conv1d::Conv1dNode, - conv2d::Conv2dNode, - conv_transpose_2d::ConvTranspose2dNode, - dropout::DropoutNode, - gather::GatherNode, - global_avg_pool::GlobalAvgPoolNode, - layer_norm::LayerNormNode, - linear::LinearNode, - mask_where::WhereNode, - matmul::MatmulNode, - max_pool1d::MaxPool1dNode, - max_pool2d::MaxPool2dNode, - prelu::PReluNode, - reshape::ReshapeNode, - squeeze::SqueezeNode, - unary::UnaryNode, - unsqueeze::UnsqueezeNode, + avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, + binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode, + constant_of_shape::ConstantOfShapeNode, conv1d::Conv1dNode, conv2d::Conv2dNode, + conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode, + global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, + mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, + max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, + squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, }, ScalarKind, ScalarType, TensorKind, TensorType, Type, }, @@ -297,6 +281,9 @@ impl OnnxGraph { NodeType::Where => graph.register(Self::where_conversion(node)), NodeType::Sign => graph.register(Self::sign_conversion(node)), NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), + NodeType::ConstantOfShape => { + graph.register(Self::constant_of_shape_conversion(node)) + } node_type => unsupported_ops.push(node_type), } } @@ -324,6 +311,9 @@ impl OnnxGraph { } fn constant_conversion(node: Node) -> ConstantNode { + // Additional types needed for Constant: + use crate::burn::node::constant::{ConstantValue, TensorValue}; + let output = node.outputs.first().unwrap(); let attr = convert_constant_value(&node); @@ -374,6 +364,39 @@ impl OnnxGraph { ConstantNode::new(node.name.clone(), const_value, output.to_type()) } + pub(crate) fn constant_of_shape_conversion(node: Node) -> ConstantOfShapeNode { + // Additional types needed for ConstantOfShape: + use crate::burn::node::constant_of_shape::ConstantValue; + + let input = node + .inputs + .first() + .expect("ConstantOfShape requires an input tensor"); + let output = node.outputs.first().unwrap(); + + let value = node + .attrs + .get("value") + .and_then(|val| val.clone().into_tensor().data) + .map(|val_data| match val_data { + // TODO: Handle Float16 + Data::Float32(val) => val.into(), + Data::Float32s(vals) => ConstantValue::from_vec(vals), + Data::Float64(val) => val.into(), + Data::Float64s(vals) => ConstantValue::from_vec(vals), + Data::Int32(val) => val.into(), + Data::Int32s(vals) => ConstantValue::from_vec(vals), + Data::Int64(val) => val.into(), + Data::Int64s(vals) => ConstantValue::from_vec(vals), + Data::Bool(val) => val.into(), + Data::Bools(vals) => ConstantValue::from_vec(vals), + _ => panic!("Unsupported value type for ConstantOfShape!"), + }) + .unwrap_or(ConstantValue::Float32(0.0f32)); + + ConstantOfShapeNode::new(input.to_type(), output.to_type(), value) + } + fn add_conversion(node: Node) -> BinaryNode { let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type();