Skip to content

Commit

Permalink
Merge pull request #472 from mindest/batch_space
Browse files Browse the repository at this point in the history
implement onnx ops for SpaceToBatchND and BatchToSpaceND
  • Loading branch information
nbcsm authored Apr 23, 2019
2 parents 27f5e67 + 96ff4fb commit 7538f59
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,5 +2043,30 @@ def test_softsign(self):
_ = tf.identity(x_, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: x_val})

def test_batch_to_spacend(self):
block_size = [2, 2]
crop = [[0, 1], [2, 1]]

input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32)
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
_ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: input_val})

def test_space_to_batchnd(self):
block_size = [2, 2]
pad = [[0, 1], [2, 1]]
input_val = np.random.random_sample([40, 5, 7, 66]).astype(np.float32)
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: input_val})

tf.reset_default_graph()

pad = [[0, 0], [1, 2]]
input_val = np.random.random_sample([10, 6, 7, 66]).astype(np.float32)
input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC
_ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT)
self._run_test_case([_OUTPUT], {_INPUT: input_val})

if __name__ == '__main__':
unittest_main()
74 changes: 74 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,3 +793,77 @@ class IsNan:
@classmethod
def version_9(cls, ctx, node, **kwargs):
pass


@tf_op("BatchToSpaceND", onnx_op="DepthToSpace")
class BatchToSpace:
@classmethod
def version_4(cls, ctx, node, **kwargs):
# https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
# and we only support 4D here, so the data format is NHWC
# onnx op "DepthToSpace" does the same work on input tensor except that it works on "C",
# and it only supports NCHW
# T out = BatchToSpaceND(T input, int32 block_shape, int32 crops)
input_tensor = node.inputs[0]
blocksize = node.inputs[1].get_tensor_value()
crops = node.inputs[2].get_tensor_value()

utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
"only support same blocksize at different dims")

ctx.remove_node(node.name)
# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]})
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]})

# implement crop logic, the data format is NHWC
slice_axis = [1, 2]
top, bottom = crops[0]
left, right = crops[1]
starts = [top, left]
ends = []
for end in [bottom, right]:
if end != 0:
ends.append(-end)
else:
ends.append(np.iinfo(np.int32).max)

ctx.make_node("Slice", trans2.output, attr={"axes": slice_axis, "ends": ends, "starts": starts},
name=node.name, outputs=node.output)


@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth")
class SpaceToBatch:
@classmethod
def version_4(cls, ctx, node, **kwargs):
# https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd
# the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape)
# and we only support 4D here, so the data format is NHWC
# onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C",
# and it only supports NCHW
# T out = SpaceToBatchND(T input, int32 block_shape, int32 crops)
input_tensor = node.inputs[0]
blocksize = node.inputs[1].get_tensor_value()
paddings = node.inputs[2].get_tensor_value()

utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now")
utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1],
"only support same blocksize at different dims")

ctx.remove_node(node.name)

# implement pads logic, the data format is NHWC
top, bottom = paddings[0]
left, right = paddings[1]
pads = [0, top, left, 0,
0, bottom, right, 0]

pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads})

# NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow
trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]})
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)

0 comments on commit 7538f59

Please sign in to comment.