From 501899d078f8cdf2314ba00e70d875c4c7f3cd59 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Wed, 18 Dec 2024 17:11:43 +0000 Subject: [PATCH] Support DepthToSpace operator In PyTorch this corresponds to the `nn.PixelShuffle` module. --- rten-convert/rten_convert/converter.py | 5 + rten-convert/rten_convert/schema_generated.py | 101 ++++++- src/model.rs | 7 +- src/model_builder.rs | 24 +- src/op_registry.rs | 18 +- src/ops/layout.rs | 137 +++++++++ src/ops/mod.rs | 4 +- src/schema.fbs | 14 + src/schema_generated.rs | 267 +++++++++++++++++- 9 files changed, 559 insertions(+), 18 deletions(-) diff --git a/rten-convert/rten_convert/converter.py b/rten-convert/rten_convert/converter.py index e30d8728..8f606e55 100644 --- a/rten-convert/rten_convert/converter.py +++ b/rten-convert/rten_convert/converter.py @@ -765,6 +765,11 @@ def op_node_from_onnx_operator( attrs = sg.DequantizeLinearAttrsT() attrs.axis = op_reader.get_attr("axis", "int", 1) + case "DepthToSpace": + attrs = sg.DepthToSpaceAttrsT() + attrs.blockSize = op_reader.require_attr("blocksize", "int") + attrs.mode = op_reader.get_enum_attr("mode", sg.DepthToSpaceMode, "dcr") + case "Einsum": attrs = sg.EinsumAttrsT() attrs.equation = op_reader.require_attr("equation", "string") diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index 19d3c1ae..bed08145 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -116,6 +116,7 @@ class OperatorType(object): QuantizeLinear = 106 DynamicQuantizeLinear = 107 MatMulInteger = 108 + DepthToSpace = 109 class RNNDirection(object): @@ -198,6 +199,7 @@ class OperatorAttrs(object): PadAttrs = 40 DequantizeLinearAttrs = 41 QuantizeLinearAttrs = 42 + DepthToSpaceAttrs = 43 def OperatorAttrsCreator(unionType, table): from flatbuffers.table import Table @@ -287,9 +289,16 @@ def OperatorAttrsCreator(unionType, table): return DequantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos) if unionType == OperatorAttrs().QuantizeLinearAttrs: return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == OperatorAttrs().DepthToSpaceAttrs: + return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos) return None +class DepthToSpaceMode(object): + DCR = 0 + CRD = 1 + + class Scalar(object): NONE = 0 IntScalar = 1 @@ -940,6 +949,96 @@ def Pack(self, builder): return concatAttrs +class DepthToSpaceAttrs(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = DepthToSpaceAttrs() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsDepthToSpaceAttrs(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def DepthToSpaceAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed) + + # DepthToSpaceAttrs + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # DepthToSpaceAttrs + def Mode(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return 0 + + # DepthToSpaceAttrs + def BlockSize(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) + return 0 + +def DepthToSpaceAttrsStart(builder): + builder.StartObject(2) + +def DepthToSpaceAttrsAddMode(builder, mode): + builder.PrependUint8Slot(0, mode, 0) + +def DepthToSpaceAttrsAddBlockSize(builder, blockSize): + builder.PrependUint32Slot(1, blockSize, 0) + +def DepthToSpaceAttrsEnd(builder): + return builder.EndObject() + + + +class DepthToSpaceAttrsT(object): + + # DepthToSpaceAttrsT + def __init__(self): + self.mode = 0 # type: int + self.blockSize = 0 # type: int + + @classmethod + def InitFromBuf(cls, buf, pos): + depthToSpaceAttrs = DepthToSpaceAttrs() + depthToSpaceAttrs.Init(buf, pos) + return cls.InitFromObj(depthToSpaceAttrs) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos+n) + + @classmethod + def InitFromObj(cls, depthToSpaceAttrs): + x = DepthToSpaceAttrsT() + x._UnPack(depthToSpaceAttrs) + return x + + # DepthToSpaceAttrsT + def _UnPack(self, depthToSpaceAttrs): + if depthToSpaceAttrs is None: + return + self.mode = depthToSpaceAttrs.Mode() + self.blockSize = depthToSpaceAttrs.BlockSize() + + # DepthToSpaceAttrsT + def Pack(self, builder): + DepthToSpaceAttrsStart(builder) + DepthToSpaceAttrsAddMode(builder, self.mode) + DepthToSpaceAttrsAddBlockSize(builder, self.blockSize) + depthToSpaceAttrs = DepthToSpaceAttrsEnd(builder) + return depthToSpaceAttrs + + class IntScalar(object): __slots__ = ['_tab'] @@ -5053,7 +5152,7 @@ class OperatorNodeT(object): def __init__(self): self.type = 0 # type: int self.attrsType = 0 # type: int - self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT] + self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT] self.inputs = None # type: List[int] self.outputs = None # type: List[int] diff --git a/src/model.rs b/src/model.rs index b03a1918..e428ad5e 100644 --- a/src/model.rs +++ b/src/model.rs @@ -859,7 +859,8 @@ mod tests { }; use crate::ops; use crate::ops::{ - BoxOrder, CoordTransformMode, DataType, NearestMode, OpError, Output, ResizeMode, Scalar, + BoxOrder, CoordTransformMode, DataType, DepthToSpaceMode, NearestMode, OpError, Output, + ResizeMode, Scalar, }; use crate::{ModelLoadError, OpRegistry, ReadOpError}; @@ -1307,6 +1308,10 @@ mod tests { add_operator!(DequantizeLinear, [const_u8, scale, zero_point], { axis: 0, }); + add_operator!(DepthToSpace, [input_node], { + mode: DepthToSpaceMode::DepthColumnRow, + block_size: 1, + }); add_operator!(QuantizeLinear, [const_f32, scale, zero_point], { axis: 0, output_dtype: None, diff --git a/src/model_builder.rs b/src/model_builder.rs index f315be47..b66b7251 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -7,12 +7,12 @@ use crate::header::Header; use crate::number::LeBytes; use crate::ops::{ ArgMax, ArgMin, AveragePool, BatchNormalization, BoxOrder, Cast, Concat, ConstantOfShape, Conv, - ConvTranspose, CoordTransformMode, DataType, DequantizeLinear, Einsum, Elu, Flatten, Gather, - GatherElements, GatherND, Gelu, Gemm, HardSigmoid, InstanceNormalization, LayerNormalization, - LeakyRelu, LogSoftmax, MaxPool, Mod, NearestMode, NonMaxSuppression, OneHot, Padding, - QuantizeLinear, ReduceMax, ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare, - Reshape, Resize, ResizeMode, Scalar, ScatterElements, ScatterReduction, Softmax, Split, TopK, - Transpose, Trilu, + ConvTranspose, CoordTransformMode, DataType, DepthToSpace, DepthToSpaceMode, DequantizeLinear, + Einsum, Elu, Flatten, Gather, GatherElements, GatherND, Gelu, Gemm, HardSigmoid, + InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod, NearestMode, + NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax, ReduceMean, ReduceMin, + ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode, Scalar, ScatterElements, + ScatterReduction, Softmax, Split, TopK, Transpose, Trilu, }; use crate::schema_generated as sg; @@ -47,6 +47,7 @@ pub enum OpType<'a> { ConvTranspose(ConvTranspose), Cos, DequantizeLinear(DequantizeLinear), + DepthToSpace(DepthToSpace), Div, DynamicQuantizeLinear, Einsum(Einsum), @@ -514,6 +515,17 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { axis: args.axis as i32, } ), + OpType::DepthToSpace(args) => op_with_attrs!( + DepthToSpace, + DepthToSpaceAttrs, + sg::DepthToSpaceAttrsArgs { + block_size: args.block_size, + mode: match args.mode { + DepthToSpaceMode::DepthColumnRow => sg::DepthToSpaceMode::DCR, + DepthToSpaceMode::ColumnRowDepth => sg::DepthToSpaceMode::CRD, + } + } + ), OpType::Div => op!(Div), OpType::DynamicQuantizeLinear => op!(DynamicQuantizeLinear), OpType::Einsum(args) => { diff --git a/src/op_registry.rs b/src/op_registry.rs index 12441cde..d6fbf540 100644 --- a/src/op_registry.rs +++ b/src/op_registry.rs @@ -7,8 +7,8 @@ use smallvec::smallvec; use crate::graph::Graph; use crate::ops; use crate::ops::{ - BoxOrder, CoordTransformMode, DataType, Direction, NearestMode, Operator, PadMode, Padding, - ResizeMode, Scalar, ScatterReduction, + BoxOrder, CoordTransformMode, DataType, DepthToSpaceMode, Direction, NearestMode, Operator, + PadMode, Padding, ResizeMode, Scalar, ScatterReduction, }; use crate::schema_generated as sg; use crate::schema_generated::{AutoPad, OperatorNode, OperatorType}; @@ -101,6 +101,7 @@ impl OpRegistry { register_op!(Cos); register_op!(CumSum); register_op!(DequantizeLinear); + register_op!(DepthToSpace); register_op!(Div); register_op!(DynamicQuantizeLinear); register_op!(Einsum); @@ -477,6 +478,19 @@ impl_read_op!( impl_read_op!(Cos); impl_read_op!(CumSum); impl_read_op!(DequantizeLinear, attrs_as_dequantize_linear_attrs, axis); +impl_read_op!( + DepthToSpace, + attrs_as_depth_to_space_attrs, + |attrs: sg::DepthToSpaceAttrs| { + let mode = match attrs.mode() { + sg::DepthToSpaceMode::DCR => DepthToSpaceMode::DepthColumnRow, + sg::DepthToSpaceMode::CRD => DepthToSpaceMode::ColumnRowDepth, + _ => return Err(ReadOpError::AttrError)?, + }; + let block_size = attrs.block_size(); + Ok(ops::DepthToSpace { mode, block_size }) + } +); impl_read_op!(Div); impl_read_op!(DynamicQuantizeLinear); impl_read_op!(Einsum, attrs_as_einsum_attrs, |attrs: sg::EinsumAttrs| { diff --git a/src/ops/layout.rs b/src/ops/layout.rs index 85cffea6..d326dc70 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -12,6 +12,69 @@ use crate::ops::{ }; use crate::tensor_pool::TensorPool; +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum DepthToSpaceMode { + DepthColumnRow, + ColumnRowDepth, +} + +pub fn depth_to_space( + pool: &TensorPool, + input: TensorView, + block_size: u32, + mode: DepthToSpaceMode, +) -> Result, OpError> { + if block_size == 0 { + return Err(OpError::InvalidValue("`block_size` must be > 0")); + } + + let input = static_dims!(input, 4, "NCHW")?; + let [n, c, h, w] = input.shape(); + let block_size = block_size as usize; + + if c % (block_size * block_size) != 0 { + return Err(OpError::InvalidValue( + "input channels must be a multiple of `block_size` squared", + )); + } + + let new_c = c / (block_size * block_size); + let new_shape = [n, new_c, h * block_size, w * block_size]; + + // Reshape following steps in `DepthToSpace` ONNX spec. + // See https://onnx.ai/onnx/operators/onnx__DepthToSpace.html#summary + let tmp = input.to_contiguous_in(pool); + let tmp = match mode { + DepthToSpaceMode::DepthColumnRow => tmp + .reshaped([n, block_size, block_size, new_c, h, w]) + .permuted([0, 3, 4, 1, 5, 2]), + DepthToSpaceMode::ColumnRowDepth => tmp + .reshaped([n, new_c, block_size, block_size, h, w]) + .permuted([0, 1, 4, 2, 5, 3]), + }; + let mut tmp = tmp.to_tensor_in(pool).into_dyn(); + tmp.reshape(&new_shape); + + Ok(tmp) +} + +#[derive(Debug)] +pub struct DepthToSpace { + pub block_size: u32, + pub mode: DepthToSpaceMode, +} + +impl Operator for DepthToSpace { + fn name(&self) -> &str { + "DepthToSpace" + } + + fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { + let input = inputs.require_as(0)?; + depth_to_space::(pool, input, self.block_size, self.mode).into_op_result() + } +} + /// Return the tensor shape resulting from broadcasting `input_shape` with `shape`. fn expand_output_shape( input_shape: &[usize], @@ -641,6 +704,7 @@ mod tests { use rten_tensor::test_util::expect_equal; use rten_tensor::{NdTensor, Tensor}; + use super::{depth_to_space, DepthToSpaceMode}; use crate::ops::layout::{ expand, flatten, reshape, reshape_in_place, squeeze, squeeze_in_place, transpose, unsqueeze, Reshape, Shape, Size, @@ -648,6 +712,79 @@ mod tests { use crate::ops::tests::new_pool; use crate::ops::{OpError, Operator}; + #[test] + fn test_depth_to_space() { + struct Case { + input: NdTensor, + block_size: u32, + mode: DepthToSpaceMode, + expected: Result, + } + + let input = NdTensor::from([ + [[1.0]], + [[2.0]], + [[3.0]], + [[4.0]], + [[5.0]], + [[6.0]], + [[7.0]], + [[8.0]], + ]) + .into_shape([1, 8, 1, 1]); + + let cases = [ + // DepthColumnRow (DCR) mode + Case { + input: input.clone(), + block_size: 2, + mode: DepthToSpaceMode::DepthColumnRow, + expected: Ok( + NdTensor::from([[[1.0, 3.0], [5.0, 7.0]], [[2.0, 4.0], [6.0, 8.0]]]) + .into_shape([1, 2, 2, 2].as_slice()), + ), + }, + // ColumnRowDepth (CRD) mode + Case { + input: input.clone(), + block_size: 2, + mode: DepthToSpaceMode::ColumnRowDepth, + expected: Ok( + NdTensor::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + .into_shape([1, 2, 2, 2].as_slice()), + ), + }, + // C % block_size^2 != 0 + Case { + input: NdTensor::full([1, 16, 2, 2], 1.0), + block_size: 3, + mode: DepthToSpaceMode::ColumnRowDepth, + expected: Err(OpError::InvalidValue( + "input channels must be a multiple of `block_size` squared", + )), + }, + // block_size == 0 + Case { + input: NdTensor::full([1, 16, 2, 2], 1.0), + block_size: 0, + mode: DepthToSpaceMode::ColumnRowDepth, + expected: Err(OpError::InvalidValue("`block_size` must be > 0")), + }, + ]; + + let pool = new_pool(); + for Case { + input, + block_size, + mode, + expected, + } in cases + { + let result = depth_to_space(&pool, input.as_dyn(), block_size, mode); + assert_eq!(result, expected); + } + } + #[test] fn test_expand() { let pool = new_pool(); diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 69ce5ed4..b1ec466b 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -79,8 +79,8 @@ pub use gather::{ pub use generate::{constant_of_shape, onehot, range, ConstantOfShape, OneHot, Range}; pub use identity::Identity; pub use layout::{ - expand, flatten, reshape, squeeze, squeeze_in_place, Expand, Flatten, Reshape, Shape, Size, - Squeeze, Transpose, Unsqueeze, + depth_to_space, expand, flatten, reshape, squeeze, squeeze_in_place, DepthToSpace, + DepthToSpaceMode, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose, Unsqueeze, }; pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulAdd, MatMulInteger}; pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression}; diff --git a/src/schema.fbs b/src/schema.fbs index 0b7b3b42..44929602 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -122,6 +122,7 @@ enum OperatorType: ubyte { QuantizeLinear, DynamicQuantizeLinear, MatMulInteger, + DepthToSpace, } enum RNNDirection: ubyte { @@ -215,6 +216,7 @@ union OperatorAttrs { PadAttrs, DequantizeLinearAttrs, QuantizeLinearAttrs, + DepthToSpaceAttrs, } table ArgMaxAttrs { @@ -246,6 +248,18 @@ table ConcatAttrs { axis:int; } +enum DepthToSpaceMode: ubyte { + // Depth-column-row + DCR, + // Column-row-depth + CRD, +} + +table DepthToSpaceAttrs { + mode:DepthToSpaceMode; + block_size:uint; +} + union Scalar { IntScalar, FloatScalar diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 8247ea4b..4880368f 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -18,13 +18,13 @@ pub const ENUM_MIN_OPERATOR_TYPE: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_OPERATOR_TYPE: u8 = 108; +pub const ENUM_MAX_OPERATOR_TYPE: u8 = 109; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 109] = [ +pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 110] = [ OperatorType::Add, OperatorType::ArgMin, OperatorType::ArgMax, @@ -134,6 +134,7 @@ pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 109] = [ OperatorType::QuantizeLinear, OperatorType::DynamicQuantizeLinear, OperatorType::MatMulInteger, + OperatorType::DepthToSpace, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -250,9 +251,10 @@ impl OperatorType { pub const QuantizeLinear: Self = Self(106); pub const DynamicQuantizeLinear: Self = Self(107); pub const MatMulInteger: Self = Self(108); + pub const DepthToSpace: Self = Self(109); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 108; + pub const ENUM_MAX: u8 = 109; pub const ENUM_VALUES: &'static [Self] = &[ Self::Add, Self::ArgMin, @@ -363,6 +365,7 @@ impl OperatorType { Self::QuantizeLinear, Self::DynamicQuantizeLinear, Self::MatMulInteger, + Self::DepthToSpace, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -476,6 +479,7 @@ impl OperatorType { Self::QuantizeLinear => Some("QuantizeLinear"), Self::DynamicQuantizeLinear => Some("DynamicQuantizeLinear"), Self::MatMulInteger => Some("MatMulInteger"), + Self::DepthToSpace => Some("DepthToSpace"), _ => None, } } @@ -1111,13 +1115,13 @@ pub const ENUM_MIN_OPERATOR_ATTRS: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_OPERATOR_ATTRS: u8 = 42; +pub const ENUM_MAX_OPERATOR_ATTRS: u8 = 43; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_OPERATOR_ATTRS: [OperatorAttrs; 43] = [ +pub const ENUM_VALUES_OPERATOR_ATTRS: [OperatorAttrs; 44] = [ OperatorAttrs::NONE, OperatorAttrs::ArgMaxAttrs, OperatorAttrs::AveragePoolAttrs, @@ -1161,6 +1165,7 @@ pub const ENUM_VALUES_OPERATOR_ATTRS: [OperatorAttrs; 43] = [ OperatorAttrs::PadAttrs, OperatorAttrs::DequantizeLinearAttrs, OperatorAttrs::QuantizeLinearAttrs, + OperatorAttrs::DepthToSpaceAttrs, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -1211,9 +1216,10 @@ impl OperatorAttrs { pub const PadAttrs: Self = Self(40); pub const DequantizeLinearAttrs: Self = Self(41); pub const QuantizeLinearAttrs: Self = Self(42); + pub const DepthToSpaceAttrs: Self = Self(43); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 42; + pub const ENUM_MAX: u8 = 43; pub const ENUM_VALUES: &'static [Self] = &[ Self::NONE, Self::ArgMaxAttrs, @@ -1258,6 +1264,7 @@ impl OperatorAttrs { Self::PadAttrs, Self::DequantizeLinearAttrs, Self::QuantizeLinearAttrs, + Self::DepthToSpaceAttrs, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -1305,6 +1312,7 @@ impl OperatorAttrs { Self::PadAttrs => Some("PadAttrs"), Self::DequantizeLinearAttrs => Some("DequantizeLinearAttrs"), Self::QuantizeLinearAttrs => Some("QuantizeLinearAttrs"), + Self::DepthToSpaceAttrs => Some("DepthToSpaceAttrs"), _ => None, } } @@ -1363,6 +1371,96 @@ impl<'a> flatbuffers::Verifiable for OperatorAttrs { impl flatbuffers::SimpleToVerifyInSlice for OperatorAttrs {} pub struct OperatorAttrsUnionTableOffset {} +#[deprecated( + since = "2.0.0", + note = "Use associated constants instead. This will no longer be generated in 2021." +)] +pub const ENUM_MIN_DEPTH_TO_SPACE_MODE: u8 = 0; +#[deprecated( + since = "2.0.0", + note = "Use associated constants instead. This will no longer be generated in 2021." +)] +pub const ENUM_MAX_DEPTH_TO_SPACE_MODE: u8 = 1; +#[deprecated( + since = "2.0.0", + note = "Use associated constants instead. This will no longer be generated in 2021." +)] +#[allow(non_camel_case_types)] +pub const ENUM_VALUES_DEPTH_TO_SPACE_MODE: [DepthToSpaceMode; 2] = + [DepthToSpaceMode::DCR, DepthToSpaceMode::CRD]; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] +#[repr(transparent)] +pub struct DepthToSpaceMode(pub u8); +#[allow(non_upper_case_globals)] +impl DepthToSpaceMode { + pub const DCR: Self = Self(0); + pub const CRD: Self = Self(1); + + pub const ENUM_MIN: u8 = 0; + pub const ENUM_MAX: u8 = 1; + pub const ENUM_VALUES: &'static [Self] = &[Self::DCR, Self::CRD]; + /// Returns the variant's name or "" if unknown. + pub fn variant_name(self) -> Option<&'static str> { + match self { + Self::DCR => Some("DCR"), + Self::CRD => Some("CRD"), + _ => None, + } + } +} +impl core::fmt::Debug for DepthToSpaceMode { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + if let Some(name) = self.variant_name() { + f.write_str(name) + } else { + f.write_fmt(format_args!("", self.0)) + } + } +} +impl<'a> flatbuffers::Follow<'a> for DepthToSpaceMode { + type Inner = Self; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + let b = flatbuffers::read_scalar_at::(buf, loc); + Self(b) + } +} + +impl flatbuffers::Push for DepthToSpaceMode { + type Output = DepthToSpaceMode; + #[inline] + unsafe fn push(&self, dst: &mut [u8], _written_len: usize) { + flatbuffers::emplace_scalar::(dst, self.0); + } +} + +impl flatbuffers::EndianScalar for DepthToSpaceMode { + type Scalar = u8; + #[inline] + fn to_little_endian(self) -> u8 { + self.0.to_le() + } + #[inline] + #[allow(clippy::wrong_self_convention)] + fn from_little_endian(v: u8) -> Self { + let b = u8::from_le(v); + Self(b) + } +} + +impl<'a> flatbuffers::Verifiable for DepthToSpaceMode { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use self::flatbuffers::Verifiable; + u8::run_verifier(v, pos) + } +} + +impl flatbuffers::SimpleToVerifyInSlice for DepthToSpaceMode {} #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." @@ -2710,6 +2808,137 @@ impl core::fmt::Debug for ConcatAttrs<'_> { ds.finish() } } +pub enum DepthToSpaceAttrsOffset {} +#[derive(Copy, Clone, PartialEq)] + +pub struct DepthToSpaceAttrs<'a> { + pub _tab: flatbuffers::Table<'a>, +} + +impl<'a> flatbuffers::Follow<'a> for DepthToSpaceAttrs<'a> { + type Inner = DepthToSpaceAttrs<'a>; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { + _tab: flatbuffers::Table::new(buf, loc), + } + } +} + +impl<'a> DepthToSpaceAttrs<'a> { + pub const VT_MODE: flatbuffers::VOffsetT = 4; + pub const VT_BLOCK_SIZE: flatbuffers::VOffsetT = 6; + + #[inline] + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + DepthToSpaceAttrs { _tab: table } + } + #[allow(unused_mut)] + pub fn create<'bldr: 'args, 'args: 'mut_bldr, 'mut_bldr, A: flatbuffers::Allocator + 'bldr>( + _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr, A>, + args: &'args DepthToSpaceAttrsArgs, + ) -> flatbuffers::WIPOffset> { + let mut builder = DepthToSpaceAttrsBuilder::new(_fbb); + builder.add_block_size(args.block_size); + builder.add_mode(args.mode); + builder.finish() + } + + #[inline] + pub fn mode(&self) -> DepthToSpaceMode { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(DepthToSpaceAttrs::VT_MODE, Some(DepthToSpaceMode::DCR)) + .unwrap() + } + } + #[inline] + pub fn block_size(&self) -> u32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(DepthToSpaceAttrs::VT_BLOCK_SIZE, Some(0)) + .unwrap() + } + } +} + +impl flatbuffers::Verifiable for DepthToSpaceAttrs<'_> { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use self::flatbuffers::Verifiable; + v.visit_table(pos)? + .visit_field::("mode", Self::VT_MODE, false)? + .visit_field::("block_size", Self::VT_BLOCK_SIZE, false)? + .finish(); + Ok(()) + } +} +pub struct DepthToSpaceAttrsArgs { + pub mode: DepthToSpaceMode, + pub block_size: u32, +} +impl<'a> Default for DepthToSpaceAttrsArgs { + #[inline] + fn default() -> Self { + DepthToSpaceAttrsArgs { + mode: DepthToSpaceMode::DCR, + block_size: 0, + } + } +} + +pub struct DepthToSpaceAttrsBuilder<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> { + fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + start_: flatbuffers::WIPOffset, +} +impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> DepthToSpaceAttrsBuilder<'a, 'b, A> { + #[inline] + pub fn add_mode(&mut self, mode: DepthToSpaceMode) { + self.fbb_.push_slot::( + DepthToSpaceAttrs::VT_MODE, + mode, + DepthToSpaceMode::DCR, + ); + } + #[inline] + pub fn add_block_size(&mut self, block_size: u32) { + self.fbb_ + .push_slot::(DepthToSpaceAttrs::VT_BLOCK_SIZE, block_size, 0); + } + #[inline] + pub fn new( + _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>, + ) -> DepthToSpaceAttrsBuilder<'a, 'b, A> { + let start = _fbb.start_table(); + DepthToSpaceAttrsBuilder { + fbb_: _fbb, + start_: start, + } + } + #[inline] + pub fn finish(self) -> flatbuffers::WIPOffset> { + let o = self.fbb_.end_table(self.start_); + flatbuffers::WIPOffset::new(o.value()) + } +} + +impl core::fmt::Debug for DepthToSpaceAttrs<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut ds = f.debug_struct("DepthToSpaceAttrs"); + ds.field("mode", &self.mode()); + ds.field("block_size", &self.block_size()); + ds.finish() + } +} pub enum IntScalarOffset {} #[derive(Copy, Clone, PartialEq)] @@ -8478,6 +8707,21 @@ impl<'a> OperatorNode<'a> { None } } + + #[inline] + #[allow(non_snake_case)] + pub fn attrs_as_depth_to_space_attrs(&self) -> Option> { + if self.attrs_type() == OperatorAttrs::DepthToSpaceAttrs { + self.attrs().map(|t| { + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + unsafe { DepthToSpaceAttrs::init_from_table(t) } + }) + } else { + None + } + } } impl flatbuffers::Verifiable for OperatorNode<'_> { @@ -8533,6 +8777,7 @@ impl flatbuffers::Verifiable for OperatorNode<'_> { OperatorAttrs::PadAttrs => v.verify_union_variant::>("OperatorAttrs::PadAttrs", pos), OperatorAttrs::DequantizeLinearAttrs => v.verify_union_variant::>("OperatorAttrs::DequantizeLinearAttrs", pos), OperatorAttrs::QuantizeLinearAttrs => v.verify_union_variant::>("OperatorAttrs::QuantizeLinearAttrs", pos), + OperatorAttrs::DepthToSpaceAttrs => v.verify_union_variant::>("OperatorAttrs::DepthToSpaceAttrs", pos), _ => Ok(()), } })? @@ -9038,6 +9283,16 @@ impl core::fmt::Debug for OperatorNode<'_> { ) } } + OperatorAttrs::DepthToSpaceAttrs => { + if let Some(x) = self.attrs_as_depth_to_space_attrs() { + ds.field("attrs", &x) + } else { + ds.field( + "attrs", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("attrs", &x)