diff --git a/tests/test_backend.py b/tests/test_backend.py index 4a0b1199e..76948ea79 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2043,5 +2043,30 @@ def test_softsign(self): _ = tf.identity(x_, name=_TFOUTPUT) self._run_test_case([_OUTPUT], {_INPUT: x_val}) + def test_batch_to_spacend(self): + block_size = [2, 2] + crop = [[0, 1], [2, 1]] + + input_val = np.random.random_sample([40, 3, 5, 100]).astype(np.float32) + input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC + _ = tf.batch_to_space_nd(input_x, block_size, crop, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: input_val}) + + def test_space_to_batchnd(self): + block_size = [2, 2] + pad = [[0, 1], [2, 1]] + input_val = np.random.random_sample([40, 5, 7, 66]).astype(np.float32) + input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC + _ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: input_val}) + + tf.reset_default_graph() + + pad = [[0, 0], [1, 2]] + input_val = np.random.random_sample([10, 6, 7, 66]).astype(np.float32) + input_x = tf.placeholder(dtype=tf.float32, shape=input_val.shape, name=_TFINPUT) # NHWC + _ = tf.space_to_batch_nd(input_x, block_size, pad, name=_TFOUTPUT) + self._run_test_case([_OUTPUT], {_INPUT: input_val}) + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 90424cb8a..ae4b27124 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -793,3 +793,77 @@ class IsNan: @classmethod def version_9(cls, ctx, node, **kwargs): pass + + +@tf_op("BatchToSpaceND", onnx_op="DepthToSpace") +class BatchToSpace: + @classmethod + def version_4(cls, ctx, node, **kwargs): + # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d.html + # the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape) + # and we only support 4D here, so the data format is NHWC + # onnx op "DepthToSpace" does the same work on input tensor except that it works on "C", + # and it only supports NCHW + # T out = BatchToSpaceND(T input, int32 block_shape, int32 crops) + input_tensor = node.inputs[0] + blocksize = node.inputs[1].get_tensor_value() + crops = node.inputs[2].get_tensor_value() + + utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now") + utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1], + "only support same blocksize at different dims") + + ctx.remove_node(node.name) + # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow + trans1 = ctx.make_node("Transpose", input_tensor.output, {"perm": [3, 0, 1, 2]}) + reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]}) + trans2 = ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}) + + # implement crop logic, the data format is NHWC + slice_axis = [1, 2] + top, bottom = crops[0] + left, right = crops[1] + starts = [top, left] + ends = [] + for end in [bottom, right]: + if end != 0: + ends.append(-end) + else: + ends.append(np.iinfo(np.int32).max) + + ctx.make_node("Slice", trans2.output, attr={"axes": slice_axis, "ends": ends, "starts": starts}, + name=node.name, outputs=node.output) + + +@tf_op("SpaceToBatchND", onnx_op="SpaceToDepth") +class SpaceToBatch: + @classmethod + def version_4(cls, ctx, node, **kwargs): + # https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd + # the above link says the data format of input tensor should be (batch, spatial_shape, remaining_shape) + # and we only support 4D here, so the data format is NHWC + # onnx op "SpaceToDepth" does the same work on input tensor except that it works on "C", + # and it only supports NCHW + # T out = SpaceToBatchND(T input, int32 block_shape, int32 crops) + input_tensor = node.inputs[0] + blocksize = node.inputs[1].get_tensor_value() + paddings = node.inputs[2].get_tensor_value() + + utils.make_sure(len(ctx.get_shape(input_tensor.output[0])) == 4, "only supports 4D for now") + utils.make_sure(len(blocksize) == 2 and blocksize[0] == blocksize[1], + "only support same blocksize at different dims") + + ctx.remove_node(node.name) + + # implement pads logic, the data format is NHWC + top, bottom = paddings[0] + left, right = paddings[1] + pads = [0, top, left, 0, + 0, bottom, right, 0] + + pad_op = ctx.make_node("Pad", input_tensor.output, attr={"pads": pads}) + + # NHWC TO CNHW, so onnx op will work on "N" which is the same as tensorflow + trans1 = ctx.make_node("Transpose", pad_op.output, {"perm": [3, 0, 1, 2]}) + reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]}) + ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output)