Skip to content

Commit

Permalink
Support DepthToSpace operator
Browse files Browse the repository at this point in the history
In PyTorch this corresponds to the `nn.PixelShuffle` module.
  • Loading branch information
robertknight committed Dec 18, 2024
1 parent 31ed84d commit ba1a623
Show file tree
Hide file tree
Showing 9 changed files with 559 additions and 18 deletions.
5 changes: 5 additions & 0 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,11 @@ def op_node_from_onnx_operator(
attrs = sg.DequantizeLinearAttrsT()
attrs.axis = op_reader.get_attr("axis", "int", 1)

case "DepthToSpace":
attrs = sg.DepthToSpaceAttrsT()
attrs.blockSize = op_reader.require_attr("blocksize", "int")
attrs.mode = op_reader.get_enum_attr("mode", sg.DepthToSpaceMode, "dcr")

case "Einsum":
attrs = sg.EinsumAttrsT()
attrs.equation = op_reader.require_attr("equation", "string")
Expand Down
101 changes: 100 additions & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class OperatorType(object):
QuantizeLinear = 106
DynamicQuantizeLinear = 107
MatMulInteger = 108
DepthToSpace = 109


class RNNDirection(object):
Expand Down Expand Up @@ -198,6 +199,7 @@ class OperatorAttrs(object):
PadAttrs = 40
DequantizeLinearAttrs = 41
QuantizeLinearAttrs = 42
DepthToSpaceAttrs = 43

def OperatorAttrsCreator(unionType, table):
from flatbuffers.table import Table
Expand Down Expand Up @@ -287,9 +289,16 @@ def OperatorAttrsCreator(unionType, table):
return DequantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().QuantizeLinearAttrs:
return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DepthToSpaceAttrs:
return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos)
return None


class DepthToSpaceMode(object):
DCR = 0
CRD = 1


class Scalar(object):
NONE = 0
IntScalar = 1
Expand Down Expand Up @@ -940,6 +949,96 @@ def Pack(self, builder):
return concatAttrs


class DepthToSpaceAttrs(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = DepthToSpaceAttrs()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsDepthToSpaceAttrs(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def DepthToSpaceAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed)

# DepthToSpaceAttrs
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# DepthToSpaceAttrs
def Mode(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
return 0

# DepthToSpaceAttrs
def BlockSize(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0

def DepthToSpaceAttrsStart(builder):
builder.StartObject(2)

def DepthToSpaceAttrsAddMode(builder, mode):
builder.PrependUint8Slot(0, mode, 0)

def DepthToSpaceAttrsAddBlockSize(builder, blockSize):
builder.PrependUint32Slot(1, blockSize, 0)

def DepthToSpaceAttrsEnd(builder):
return builder.EndObject()



class DepthToSpaceAttrsT(object):

# DepthToSpaceAttrsT
def __init__(self):
self.mode = 0 # type: int
self.blockSize = 0 # type: int

@classmethod
def InitFromBuf(cls, buf, pos):
depthToSpaceAttrs = DepthToSpaceAttrs()
depthToSpaceAttrs.Init(buf, pos)
return cls.InitFromObj(depthToSpaceAttrs)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, depthToSpaceAttrs):
x = DepthToSpaceAttrsT()
x._UnPack(depthToSpaceAttrs)
return x

# DepthToSpaceAttrsT
def _UnPack(self, depthToSpaceAttrs):
if depthToSpaceAttrs is None:
return
self.mode = depthToSpaceAttrs.Mode()
self.blockSize = depthToSpaceAttrs.BlockSize()

# DepthToSpaceAttrsT
def Pack(self, builder):
DepthToSpaceAttrsStart(builder)
DepthToSpaceAttrsAddMode(builder, self.mode)
DepthToSpaceAttrsAddBlockSize(builder, self.blockSize)
depthToSpaceAttrs = DepthToSpaceAttrsEnd(builder)
return depthToSpaceAttrs


class IntScalar(object):
__slots__ = ['_tab']

Expand Down Expand Up @@ -5053,7 +5152,7 @@ class OperatorNodeT(object):
def __init__(self):
self.type = 0 # type: int
self.attrsType = 0 # type: int
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT]
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT]
self.inputs = None # type: List[int]
self.outputs = None # type: List[int]

Expand Down
7 changes: 6 additions & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,8 @@ mod tests {
};
use crate::ops;
use crate::ops::{
BoxOrder, CoordTransformMode, DataType, NearestMode, OpError, Output, ResizeMode, Scalar,
BoxOrder, CoordTransformMode, DataType, DepthToSpaceMode, NearestMode, OpError, Output,
ResizeMode, Scalar,
};
use crate::{ModelLoadError, OpRegistry, ReadOpError};

Expand Down Expand Up @@ -1307,6 +1308,10 @@ mod tests {
add_operator!(DequantizeLinear, [const_u8, scale, zero_point], {
axis: 0,
});
add_operator!(DepthToSpace, [input_node], {
mode: DepthToSpaceMode::DepthColumnRow,
block_size: 1,
});
add_operator!(QuantizeLinear, [const_f32, scale, zero_point], {
axis: 0,
output_dtype: None,
Expand Down
24 changes: 18 additions & 6 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ use crate::header::Header;
use crate::number::LeBytes;
use crate::ops::{
ArgMax, ArgMin, AveragePool, BatchNormalization, BoxOrder, Cast, Concat, ConstantOfShape, Conv,
ConvTranspose, CoordTransformMode, DataType, DequantizeLinear, Einsum, Elu, Flatten, Gather,
GatherElements, GatherND, Gelu, Gemm, HardSigmoid, InstanceNormalization, LayerNormalization,
LeakyRelu, LogSoftmax, MaxPool, Mod, NearestMode, NonMaxSuppression, OneHot, Padding,
QuantizeLinear, ReduceMax, ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare,
Reshape, Resize, ResizeMode, Scalar, ScatterElements, ScatterReduction, Softmax, Split, TopK,
Transpose, Trilu,
ConvTranspose, CoordTransformMode, DataType, DepthToSpace, DepthToSpaceMode, DequantizeLinear,
Einsum, Elu, Flatten, Gather, GatherElements, GatherND, Gelu, Gemm, HardSigmoid,
InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod, NearestMode,
NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax, ReduceMean, ReduceMin,
ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode, Scalar, ScatterElements,
ScatterReduction, Softmax, Split, TopK, Transpose, Trilu,
};
use crate::schema_generated as sg;

Expand Down Expand Up @@ -47,6 +47,7 @@ pub enum OpType<'a> {
ConvTranspose(ConvTranspose),
Cos,
DequantizeLinear(DequantizeLinear),
DepthToSpace(DepthToSpace),
Div,
DynamicQuantizeLinear,
Einsum(Einsum),
Expand Down Expand Up @@ -514,6 +515,17 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
axis: args.axis as i32,
}
),
OpType::DepthToSpace(args) => op_with_attrs!(
DepthToSpace,
DepthToSpaceAttrs,
sg::DepthToSpaceAttrsArgs {
block_size: args.block_size,
mode: match args.mode {
DepthToSpaceMode::DepthColumnRow => sg::DepthToSpaceMode::DCR,
DepthToSpaceMode::ColumnRowDepth => sg::DepthToSpaceMode::CRD,
}
}
),
OpType::Div => op!(Div),
OpType::DynamicQuantizeLinear => op!(DynamicQuantizeLinear),
OpType::Einsum(args) => {
Expand Down
18 changes: 16 additions & 2 deletions src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use smallvec::smallvec;
use crate::graph::Graph;
use crate::ops;
use crate::ops::{
BoxOrder, CoordTransformMode, DataType, Direction, NearestMode, Operator, PadMode, Padding,
ResizeMode, Scalar, ScatterReduction,
BoxOrder, CoordTransformMode, DataType, DepthToSpaceMode, Direction, NearestMode, Operator,
PadMode, Padding, ResizeMode, Scalar, ScatterReduction,
};
use crate::schema_generated as sg;
use crate::schema_generated::{AutoPad, OperatorNode, OperatorType};
Expand Down Expand Up @@ -101,6 +101,7 @@ impl OpRegistry {
register_op!(Cos);
register_op!(CumSum);
register_op!(DequantizeLinear);
register_op!(DepthToSpace);
register_op!(Div);
register_op!(DynamicQuantizeLinear);
register_op!(Einsum);
Expand Down Expand Up @@ -477,6 +478,19 @@ impl_read_op!(
impl_read_op!(Cos);
impl_read_op!(CumSum);
impl_read_op!(DequantizeLinear, attrs_as_dequantize_linear_attrs, axis);
impl_read_op!(
DepthToSpace,
attrs_as_depth_to_space_attrs,
|attrs: sg::DepthToSpaceAttrs| {
let mode = match attrs.mode() {
sg::DepthToSpaceMode::DCR => DepthToSpaceMode::DepthColumnRow,
sg::DepthToSpaceMode::CRD => DepthToSpaceMode::ColumnRowDepth,
_ => return Err(ReadOpError::AttrError)?,
};
let block_size = attrs.block_size();
Ok(ops::DepthToSpace { mode, block_size })
}
);
impl_read_op!(Div);
impl_read_op!(DynamicQuantizeLinear);
impl_read_op!(Einsum, attrs_as_einsum_attrs, |attrs: sg::EinsumAttrs| {
Expand Down
Loading

0 comments on commit ba1a623

Please sign in to comment.