From 6ba111043f815c5cbbbe82c3407b90a602fb30dc Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 7 Nov 2022 16:01:33 -0500 Subject: [PATCH] [Relax][OP] More high-level operators (#18) * relax.cumsum * Legalizer for expand_dims * relax.trilu * relax.cast * Legalizer for batch_norm and flatten * relax.take * relax.full * relax.split * relax.broadcast_to * relax.strided_slice * relax.image.resize2d * relax.nn.max_pool2d * relax.nn.adaptive_avg_pool2d --- include/tvm/relax/op_attr_types.h | 165 +++++ python/tvm/relax/block_builder.py | 4 +- python/tvm/relax/frontend/pytorch_fx.py | 94 +-- python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/image.py | 128 ++++ python/tvm/relax/op/nn/nn.py | 71 +- python/tvm/relax/op/op_attrs.py | 45 ++ python/tvm/relax/op/transform.py | 327 +++++++++ python/tvm/relax/transform/op_legalizer.py | 134 +++- python/tvm/script/ir_builder/relax/ir.py | 22 + python/tvm/script/ir_builder/tir/ir.py | 2 + src/relax/op/image/resize.cc | 148 ++++ src/relax/op/image/resize.h | 40 + src/relax/op/nn/pooling.cc | 91 +++ src/relax/op/nn/pooling.h | 4 + src/relax/op/tensor/transform.cc | 692 ++++++++++++++++++ src/relax/op/tensor/transform.h | 40 + tests/python/relax/test_op_legalizer.py | 557 ++++++++++++-- tests/python/relax/test_relax_image_ops.py | 45 ++ tests/python/relax/test_relax_tensor_ops.py | 16 + .../python/relax/test_relax_transform_ops.py | 214 ++++++ 21 files changed, 2728 insertions(+), 112 deletions(-) create mode 100644 python/tvm/relax/op/image.py create mode 100644 src/relax/op/image/resize.cc create mode 100644 src/relax/op/image/resize.h create mode 100644 tests/python/relax/test_relax_image_ops.py diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 7f4c13399640..58186a64af16 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -403,6 +403,171 @@ struct ReduceAttrs : public tvm::AttrsNode { } }; // struct ReduceAttrs +/*! \brief Attributes used in cumsum operator */ +struct CumsumAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(CumsumAttrs, "relax.attrs.CumsumAttrs") { + TVM_ATTR_FIELD(axis).set_default(Optional{NullOpt}); + } +}; // struct CumsumAttrs + +/*! \brief Attributes used in trilu operator */ +struct TriluAttrs : public tvm::AttrsNode { + int k; + bool is_upper; + + TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") { + TVM_ATTR_FIELD(k).describe( + "The number of diagonals above or below the main diagonal to exclude or include."); + TVM_ATTR_FIELD(is_upper).set_default(true).describe( + "Whether to keep the upper or lower half of the diagonal."); + } +}; // struct TriluAttrs + +/*! \brief Attributes used in cast operator */ +struct CastAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(CastAttrs, "relax.attrs.CastAttrs") { + TVM_ATTR_FIELD(dtype).describe("Target data type"); + } +}; // struct CastAttrs. + +/*! \brief Attributes used in take operator */ +struct TakeAttrs : public tvm::AttrsNode { + Optional axis; + int batch_dims; + String mode; + + TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(Optional{NullOpt}) + .describe("The axis over which to select values."); + TVM_ATTR_FIELD(batch_dims) + .set_default(0) + .describe("The batch_dims over which to select values."); + TVM_ATTR_FIELD(mode).set_default("clip").describe( + "Specify how out-of-bound indices will behave." + "clip - clip to the range (default)" + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); + } +}; // struct TakeAttrs + +/*! \brief Attributes used in full operator */ +struct FullAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(FullAttrs, "relax.attrs.FullAttrs") { + TVM_ATTR_FIELD(dtype).describe("Target data type."); + } +}; // struct FullAttrs + +/*! \brief Attributes used in split operator */ +struct SplitAttrs : public tvm::AttrsNode { + ObjectRef indices_or_sections; + int axis; + + TVM_DECLARE_ATTRS(SplitAttrs, "relax.attrs.SplitAttrs") { + TVM_ATTR_FIELD(indices_or_sections) + .describe("The input array of indices or the number of split sections."); + TVM_ATTR_FIELD(axis).describe("The axis to be splitted"); + } +}; // struct SplitAttrs + +/*! \brief Attributes used in strided_slice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array begin; + Array end; + Optional> strides; + Optional> axes; + String slice_mode; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); + TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive"); + TVM_ATTR_FIELD(strides).describe( + "Stride values of the slice, a stride can be negative, which causes a reverse slice."); + TVM_ATTR_FIELD(axes).describe( + "Axes along which slicing is applied. When it is specified, the length of begin, end, " + "strides, and axes must be equal."); + TVM_ATTR_FIELD(slice_mode) + .set_default("end") + .describe( + "The slice mode [end, size]." + "end - The default slice mode, ending indices for the slice." + "size - The input strides will be ignored, input end in this mode indicates the size" + "of a slice starting at the location specified by begin. If end[i] is -1," + "all remaining elements in that dimension are included in the slice"); + } +}; // struct StridedSliceAttrs + +/*! \brief Attributes used in image resize2d operator */ +struct Resize2DAttrs : public tvm::AttrsNode { + Array size; + Array roi; + String layout; + String method; + String coordinate_transformation_mode; + String rounding_method; + double cubic_alpha; + int cubic_exclude; + double extrapolation_value; + + TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") { + TVM_ATTR_FIELD(size).describe("Output image size."); + TVM_ATTR_FIELD(roi).describe( + "Region of Interest for coordinate transformation mode 'tf_crop_and_resize'"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).set_default("linear").describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "linear - Bilinear Interpolation" + "cubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(rounding_method) + .set_default("round") + .describe( + "indicates how to find the \"nearest\" pixel in nearest_neighbor method" + "Available options are round, floor, and ceil."); + TVM_ATTR_FIELD(cubic_alpha) + .set_default(-0.5) + .describe("Spline Coefficient for Bicubic Interpolation"); + TVM_ATTR_FIELD(cubic_exclude) + .set_default(0) + .describe("Flag to exclude exterior of the image during bicubic interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .set_default(0.0) + .describe("Value to return when roi is outside of the image"); + } +}; // struct Resize2dAttrs + +/*! \brief Attributes for 2d adaptive pool operator */ +struct AdaptivePool2DAttrs : public tvm::AttrsNode { + Optional> output_size; + String layout; + + TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") { + TVM_ATTR_FIELD(output_size).describe("Output height and width."); + TVM_ATTR_FIELD(layout).describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + } +}; // struct AdaptivePool2DAttrs + } // namespace relax } // namespace tvm #endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index bbc1f518ef67..85e9e91b2fb2 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -213,7 +213,9 @@ def _convert_te_arg_helper(arg): ), "emit_te only supports dict with string as the key currently" return {k: _convert_te_arg_helper(arg[k]) for k in arg} elif ( - isinstance(arg, (int, float, str, tir.IntImm, tvm.ir.Type, tvm.ir.Attrs)) + isinstance( + arg, (int, float, str, tir.IntImm, tir.FloatImm, tvm.ir.Type, tvm.ir.Attrs) + ) or arg is None ): return arg diff --git a/python/tvm/relax/frontend/pytorch_fx.py b/python/tvm/relax/frontend/pytorch_fx.py index fc7ac0f76f76..899455253c18 100644 --- a/python/tvm/relax/frontend/pytorch_fx.py +++ b/python/tvm/relax/frontend/pytorch_fx.py @@ -204,30 +204,29 @@ def _max_pool2d(self, node: fx.node.Node) -> relax.Var: padding = padding if isinstance(padding, tuple) else (padding, padding, padding, padding) dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) - return self.bb.emit_te( - topi.nn.pool2d, - x, - kernel, - stride, - dilation, - padding, - pool_type="max", + return self.bb.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + dilation=dilation, + layout="NCHW", + ) ) def _embedding(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - x = self.bb.emit_te(topi.cast, x, "int32") - return self.bb.emit_te(topi.take, weight, x, axis=0) + x = self.bb.emit(relax.op.cast(x, "int32")) + return self.bb.emit(relax.op.take(weight, x, axis=0)) def _adaptive_avg_pool2d(self, node: fx.node.Node) -> relax.Var: module = self.named_modules[node.target] x = self.env[node.args[0]] - return self.bb.emit_te( - topi.nn.adaptive_pool, x, module.output_size, pool_type="avg", layout="NCHW" - ) + return self.bb.emit(relax.op.nn.adaptive_avg_pool2d(x, module.output_size, layout="NCHW")) def _flatten(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -272,7 +271,7 @@ def _sub(self, node: fx.node.Node) -> relax.Var: def _cumsum(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] axis = node.args[1] - return self.bb.emit_te(topi.cumsum, x, axis) + return self.bb.emit(relax.op.cumsum(x, axis)) def _size(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -283,7 +282,13 @@ def _size(self, node: fx.node.Node) -> relax.Var: idx = node.args[1] return x.shape[idx].value + def _type(self, node: fx.node.Node) -> relax.Var: + args = self.retrive_args(node) + return self.bb.emit(relax.op.cast(args[0], args[1])) + def _getattr(self, node: fx.node.Node) -> relax.Var: + if isinstance(self.env[node.args[0]], relax.Var) and node.args[1] == "dtype": + return self.env[node.args[0]].checked_type.dtype return getattr(self.env[node.args[0]], node.args[1]) def _getitem(self, node: fx.node.Node) -> relax.Var: @@ -293,8 +298,8 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: elif isinstance(x, relax.Var): if isinstance(x.shape, relax.Tuple): return self.bb.emit(relax.TupleGetItem(x, node.args[1])) - else: - begin = [] + + begin = [] end = [] stride = [] axes = [] @@ -315,6 +320,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: i = i + 1 elif index is None: expand_dim.append(i) + i = i + 1 else: raise ValueError("Unsupported index type: " + str(type(index))) while i < len(x.shape_): @@ -322,7 +328,7 @@ def _getitem(self, node: fx.node.Node) -> relax.Var: end.append(x.shape_[i]) axes.append(i) i = i + 1 - sliced = self.bb.emit_te(topi.strided_slice, x, begin, end, stride, axes) + sliced = self.bb.emit(relax.op.strided_slice(x, begin, end, stride, axes)) sliced_shape = list(sliced.shape_) for i in expand_dim: sliced_shape.insert(i, 1) @@ -408,21 +414,17 @@ def _interpolate(self, node: fx.node.Node) -> relax.Var: else: coord_trans = "half_pixel" - return self.bb.emit_te( - topi.image.resize2d, - data, - [0.0] * 4, - size, - layout="NCHW", - method=method, - coordinate_transformation_mode=coord_trans, + return self.bb.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) ) def _addmm(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] z = self.env[node.args[2]] - matmul = self.bb.emit_te(topi.matmul, y, z) + matmul = self.bb.emit(relax.op.nn.matmul(y, z)) return self.bb.emit(relax.op.add(x, matmul)) def _split(self, node: fx.node.Node) -> relax.Var: @@ -432,13 +434,14 @@ def _split(self, node: fx.node.Node) -> relax.Var: dim = node.kwargs["dim"] else: dim = 0 - split_size = x.shape[dim].value // split_size - return self.bb.emit_te(topi.split, x, split_size, dim) + n_section = x.shape[dim].value // split_size + return self.bb.emit(relax.op.split(x, n_section, dim)) def _tril(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] k = node.args[1] if len(node.args) > 1 else 0 - return self.bb.emit_te(topi.trilu, x, tvm.tir.const(k, "int32"), False) + assert isinstance(k, int) + return self.bb.emit(relax.op.trilu(x, k, False)) def _new_ones(self, node: fx.node.Node) -> relax.Var: args = self.retrive_args(node) @@ -446,14 +449,18 @@ def _new_ones(self, node: fx.node.Node) -> relax.Var: size = args[1:] if not iterable(size): size = (size,) - return self.bb.emit_te(topi.full, size, fill_value=1, dtype=self_var.checked_type.dtype) + return self.bb.emit( + relax.op.full( + relax.const(1, self_var.checked_type.dtype), size, self_var.checked_type.dtype + ) + ) def _expand(self, node: fx.node.Node) -> relax.Var: args = self.retrive_args(node) - return self.bb.emit_te(topi.broadcast_to, args[0], args[1:]) + return self.bb.emit(relax.op.broadcast_to(args[0], args[1:])) def _float(self, node: fx.node.Node) -> relax.Var: - return self.bb.emit_te(topi.cast, self.env[node.args[0]], "float32") + return self.bb.emit(relax.op.cast(self.env[node.args[0]], "float32")) def _permute(self, node: fx.node.Node) -> relax.Var: args = self.retrive_args(node) @@ -491,17 +498,9 @@ def _softmax(self, node: fx.node.Node) -> relax.Var: def _view(self, node: fx.node.Node) -> relax.Var: args = self.retrive_args(node) - infer_idx = -1 - prod = 1 - new_shape = list(args[1:]) - for i in range(len(new_shape)): - if new_shape[i] == -1: - infer_idx = i - else: - prod *= new_shape[i] - if infer_idx != -1: - new_shape[infer_idx] = np.prod(args[0].shape).value // prod - return self.bb.emit(relax.op.reshape(args[0], new_shape)) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.bb.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.bb.emit(relax.op.reshape(args[0], args[1:])) def _silu(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -526,7 +525,7 @@ def _group_norm(self, node: fx.node.Node) -> relax.Var: sub_x = self.bb.emit(relax.op.subtract(grouped_x, mean_x)) square_x = self.bb.emit(relax.op.multiply(sub_x, sub_x)) sum_square_x = self.bb.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) - var_x = self.bb.emit_te(topi.divide, sum_square_x, C // num_groups * H * W) + var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) std_x = self.bb.emit(relax.op.sqrt(var_x_eps)) norm_x = self.bb.emit(relax.op.divide(sub_x, std_x)) @@ -564,7 +563,7 @@ def create_convert_map(self): # call_module nn.Conv2d: self._conv2d, nn.Linear: self._linear, - nn.ReLU: lambda node: self.bb.emit_te(topi.nn.relu, self.env[node.args[0]]), + nn.ReLU: lambda node: self.bb.emit(relax.op.nn.relu(self.env[node.args[0]])), nn.MaxPool2d: self._max_pool2d, nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d, nn.Flatten: self._flatten, @@ -581,8 +580,8 @@ def create_convert_map(self): "flatten": self._flatten, "size": self._size, "cumsum": self._cumsum, - "unsqueeze": lambda node: self.bb.emit_te( - topi.expand_dims, self.env[node.args[0]], node.args[1], 1 + "unsqueeze": lambda node: self.bb.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), "getattr": self._getattr, "getitem": self._getitem, @@ -607,6 +606,7 @@ def create_convert_map(self): "transpose": self._transpose, "softmax": self._softmax, "view": self._view, + "type": self._type, "contiguous": lambda node: self.env[node.args[0]], } diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 942ffc3e55f8..0c963b383918 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -19,6 +19,7 @@ # Operators from .base import * +from .image import * from .nn import * from .op_attrs import * from .reduce import * diff --git a/python/tvm/relax/op/image.py b/python/tvm/relax/op/image.py new file mode 100644 index 000000000000..3533ae71dbee --- /dev/null +++ b/python/tvm/relax/op/image.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Image operators.""" +from typing import List, Optional, Tuple, Union + +import tvm +from tvm import relax +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr + + +PrimExprLike = Union[int, PrimExpr] + + +def resize2d( + data: Expr, + size: Union[PrimExprLike, List[PrimExprLike], Tuple[PrimExprLike]], + roi: Optional[Union[float, List[float], Tuple[float]]] = None, + layout: str = "NCHW", + method: str = "linear", + coordinate_transformation_mode: str = "half_pixel", + rounding_method: str = "round", + cubic_alpha: float = -0.5, + cubic_exclude: int = 0, + extrapolation_value: float = 0.0, +) -> Expr: + """Image resize2d operator. + + This operator takes data as input and does 2D scaling to the given scale factor. + In the default case, where the data_layout is `NCHW` + with data of shape (n, c, h, w) + out will have a shape (n, c, size[0], size[1]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("linear", "nearest_neighbor", "cubic") + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + size: Union[PrimExprLike, List[PrimExprLike], Tuple[PrimExprLike]] + The out size to which the image will be resized. + + roi: Optional[Union[float, List[float], Tuple[float]]] + The region of interest for cropping the input image. Expected to be of + size 4, and format [start_h, start_w, end_h, end_w]. + Only used if coordinate_transformation_mode is tf_crop_and_resize. + + layout : str + Layout of the input. + + method : str + Scale method to used [nearest_neighbor, linear, cubic]. + + coordinate_transformation_mode : str + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. Defintions can be found + in topi/image/resize.py. + [half_pixel, align_corners, asymmetric, pytorch_half_pixel, + tf_half_pixel_for_nn, and tf_crop_and_resize]. + + rounding_method: str + indicates how to find the "nearest" pixel in nearest_neighbor method + [round, floor, ceil] + + cubic_alpha: float + Spline Coefficient for bicubic interpolation + + cubic_exclude: int + Flag to exclude exterior of the image during bicubic interpolation + + extrapolation_value: float + Fill value to use when roi is outside of the image + + Returns + ------- + result: relay.Expr + The resized result. + """ + if roi is None: + roi = [0.0] * 4 + elif isinstance(roi, float): + roi = [roi] * 4 + + if isinstance(size, (PrimExpr, int)): + size = [size] + if isinstance(size, (tuple, list)): + temp_size = [] + for shape in size: + if isinstance(shape, PrimExpr): + temp_size.append(shape) + elif isinstance(shape, int): + temp_size.append(tvm.tir.const(shape, "int32")) + else: + raise RuntimeError( + f"The input new shape of reshape operator contains unrecognized dimension {shape}" + ) + size = temp_size + + return _ffi_api.resize2d( + data, + size, + roi, + layout, + method, + coordinate_transformation_mode, + rounding_method, + cubic_alpha, + cubic_exclude, + extrapolation_value, + ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 4ec4a6e809a9..e08ec196649c 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -15,11 +15,16 @@ # specific language governing permissions and limitations # under the License. """Relax Neural Network (NN) operators""" -from typing import List, Union +from typing import List, Optional, Tuple, Union +import tvm +from tvm.ir.expr import PrimExpr from tvm.relay.op.nn.utils import get_pad_tuple2d -from ...expr import Expr + from . import _ffi_api +from ...expr import Expr, ShapeExpr + +PrimExprLike = Union[int, PrimExpr] def dense(data, weight, units=None, out_dtype=""): @@ -556,3 +561,65 @@ def matmul(a: Expr, b: Expr) -> Expr: The result of the matmul. """ return _ffi_api.matmul(a, b) + + +def adaptive_avg_pool2d( + data: Expr, + output_size: Optional[Union[PrimExprLike, Tuple[PrimExprLike], List[PrimExprLike]]] = None, + layout: str = "NCHW", +) -> Expr: + r"""2D adaptive average pooling operator. This operator is experimental. + + This operator takes data as input and does 2D average value calculation + across each window represented by WxH. + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with shape + (batch_size, in_channels, output_height, output_width). + + The pooling kernel and stride sizes are automatically chosen for + desired output sizes. + + For output_size: + If this argument is not provided, input height and width will be used + as output height and width. + + If a single integer is provided for output_size, the output size is + (N x C x output_size x output_size) for any input (NCHW). + + If a tuple of integers (height, width) are provided for output_size, + the output size is (N x C x height x width) for any input (NCHW). + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + output_size : Optional[Union[PrimExprLike, Tuple[PrimExprLike], List[PrimExprLike]]] + Output height and width. + + layout : str + Layout of the input. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if output_size is not None: + if isinstance(output_size, (PrimExpr, int)): + output_size = [output_size] + temp_size = [] + for shape in output_size: + if isinstance(shape, PrimExpr): + temp_size.append(shape) + elif isinstance(shape, int): + temp_size.append(tvm.tir.const(shape, "int32")) + else: + raise RuntimeError( + f"The input new shape of reshape operator contains unrecognized dimension {shape}" + ) + output_size = temp_size + return _ffi_api.adaptive_avg_pool2d(data, output_size, layout) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 700dbd6a1c42..b1e3ef5585d9 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -117,3 +117,48 @@ class LayerNormAttrs(Attrs): @tvm._ffi.register_object("relax.attrs.ReduceAttrs") class ReduceAttrs(Attrs): """Attributes used in reduction operator""" + + +@tvm._ffi.register_object("relax.attrs.CumsumAttrs") +class CumsumAttrs(Attrs): + """Attributes used in cumsum operator""" + + +@tvm._ffi.register_object("relax.attrs.TriluAttrs") +class TriluAttrs(Attrs): + """Attributes used in trilu operator""" + + +@tvm._ffi.register_object("relax.attrs.CastAttrs") +class CastAttrs(Attrs): + """Attributes used in cast operator""" + + +@tvm._ffi.register_object("relax.attrs.TakeAttrs") +class TakeAttrs(Attrs): + """Attributes used in take operator""" + + +@tvm._ffi.register_object("relax.attrs.FullAttrs") +class FullAttrs(Attrs): + """Attributes used in full operator""" + + +@tvm._ffi.register_object("relax.attrs.SplitAttrs") +class SplitAttrs(Attrs): + """Attributes used in split operator""" + + +@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") +class StridedSliceAttrs(Attrs): + """Attributes used in strided_slice operator""" + + +@tvm._ffi.register_object("relax.attrs.Resize2DAttrs") +class Resize2DAttrs(Attrs): + """Attributes used in image resize2d operator""" + + +@tvm._ffi.register_object("relax.attrs.AdaptivePool2DAttrs") +class AdaptivePool2DAttrs(Attrs): + """Attributes for 2d adaptive pool operator""" diff --git a/python/tvm/relax/op/transform.py b/python/tvm/relax/op/transform.py index 0be48f1390bf..5aa45a22e8fc 100644 --- a/python/tvm/relax/op/transform.py +++ b/python/tvm/relax/op/transform.py @@ -159,3 +159,330 @@ def concatenate(data: Union[Expr, List[Expr], Tuple[Expr]], axis: Optional[int] if isinstance(data, (list, tuple)): data = relax.Tuple(data) return _ffi_api.concatenate(data, axis) + + +def cumsum(data: Expr, axis: Optional[int] = None) -> Expr: + """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along + a given axis. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axis : Optional[int] + Axis along which the cumulative sum is computed. The default (None) is to compute + the cumsum over the flattened array. + + Returns + ------- + result : relax.Expr + The result has the same size as data, and the same shape as data if axis is not None. + If axis is None, the result is a 1-d array. + + Examples + -------- + .. code-block:: python + + a = [[1,2,3], [4,5,6]] + + cumsum(a) # if axis is not provided, cumsum is done over the flattened input. + -> [ 1, 3, 6, 10, 15, 21] + + cumsum(a, axis=0) # sum over rows for each of the 3 columns + -> [[1, 2, 3], + [5, 7, 9]] + + cumsum(a, axis=1) + -> [[ 1, 3, 6], + [ 4, 9, 15]] + """ + return _ffi_api.cumsum(data, axis) + + +def trilu(data: Expr, k: int = 0, is_upper: bool = True) -> Expr: + """ + Given a 2-D matrix or batches of 2-D matrices, returns the + upper or lower triangular part of the tensor. + + Parameters + ---------- + data: relax.Expr + The tensor that trilu will be applied to. Must be either + a 2D matrix or a tensor of batches of 2D matrices. + + k: int + The number of diagonals above or below the main diagonal + to exclude or include. + + is_upper: bool + If True, only upper triangular values of input are kept, + if False, the lower triangular values are kept. + + Returns + ------- + ret : relax.Expr + The new tensor with appropriate diagonals set to zero. + + Examples + -------- + .. code-block:: python + + x = [[0, 1, 2], + [3, 4, 5], + [6, 7, 8]] + + relay.trilu(x, 0, True) = + [[0, 1, 2], + [0, 4, 5], + [0, 0, 8]] + """ + return _ffi_api.trilu(data, k, is_upper) + + +def cast(data: Expr, dtype: Union[str, tvm.DataType]) -> Expr: + """Cast input tensor to data type. + + Parameters + ---------- + data : relax.Expr + The input data to the operator. + + dtype: Union[str, tvm.DataType] + The target data type + + Returns + ------- + result : relax.Expr + The casted result. + """ + if isinstance(dtype, str): + dtype = tvm.DataType(dtype) + return _ffi_api.cast(data, dtype) + + +def take( + data: Expr, indices: Expr, axis: Optional[int] = None, batch_dims: int = 0, mode: str = "clip" +) -> Expr: + """Take elements from an array along an axis. + + Parameters + ---------- + data : relax.Expr + The source array. + + indices : relax.Expr + The indices of the values to extract. + + axis : Optional[int] + The axis over which to select values. By default, + the flattened input array is used. + + batch_dims : int + The number of batch dimensions. By default is 0. + + mode : str, optional + Specifies how out-of-bound indices will behave [clip, wrap, fast]. + clip: clip to the range (default). + wrap: wrap around the indices. + fast: no clip or wrap around (user must make sure indices are in-bound). + + Returns + ------- + ret : relax.Expr + The computed result. + """ + return _ffi_api.take(data, indices, axis, batch_dims, mode) + + +def full( + fill_value: Expr, + shape: Union[PrimExprLike, List[PrimExprLike], Tuple[PrimExprLike], Expr], + dtype: Optional[Union[str, tvm.DataType]], +) -> Expr: + """Fill array with scalar value. + + Parameters + ---------- + fill_value : relax.Expr + The value to fill. Must be a scalar. + + shape : Union[PrimExprLike, List[PrimExprLike], Tuple[PrimExprLike], Expr] + The shape of the target. + + dtype : Optional[str] + The data type of the target. + + Returns + ------- + result : relax.Expr + The resulting tensor. + """ + if isinstance(shape, (PrimExpr, int)): + shape = [shape] + if isinstance(shape, (tuple, list)): + temp_shape = [] + for shape in shape: + if isinstance(shape, PrimExpr): + temp_shape.append(shape) + elif isinstance(shape, int): + temp_shape.append(tvm.tir.const(shape, "int32")) + else: + raise RuntimeError( + f"The input new shape of reshape operator contains unrecognized dimension {shape}" + ) + shape = relax.ShapeExpr(temp_shape) + + if dtype is None: + dtype = tvm.DataType("void") + elif isinstance(dtype, str): + dtype = tvm.DataType(dtype) + return _ffi_api.full(fill_value, shape, dtype) + + +def split( + data: Expr, + indices_or_sections: Union[int, List[PrimExprLike], Tuple[PrimExprLike]], + axis: int = 0, +) -> Expr: + """Split input tensor along axis by sections or indices. + + If indices_or_sections is an integer, the input will be divided equally + along given axis. If such a split is not possible, an error is raised. + + If indices_or_sections is a tuple of mixture of int or PrimExpr, + the entries indicate the indices where along axis the array is split. + + Parameters + ---------- + data : relax.Expr + The source array. + + indices_or_sections : Union[int, Tuple[PrimExprLike]] + Indices or sections to split into. Accepts an int or a tuple + + axis : int + The axis over which to split. + + Returns + ------- + ret : relax.Expr + The computed result. + """ + if isinstance(indices_or_sections, (tuple, list)): + indices = [] + for idx in indices_or_sections: + if isinstance(idx, PrimExpr): + indices.append(idx) + elif isinstance(idx, int): + indices.append(tvm.tir.const(idx, "int32")) + else: + raise RuntimeError( + f'The input indices of split operator contains unrecognized index "{idx}"' + ) + indices_or_sections = indices + elif isinstance(indices_or_sections, int): + indices_or_sections = tvm.tir.IntImm("int32", indices_or_sections) + else: + raise RuntimeError( + f"The input `indices_or_sections` has unrecognized type {type(indices_or_sections)}" + ) + return _ffi_api.split(data, indices_or_sections, axis) + + +def broadcast_to( + data: Expr, shape: Union[PrimExprLike, List[PrimExprLike], Tuple[PrimExprLike], Expr] +) -> Expr: + """Return a scalar value array with the same type, broadcast to + the provided shape. + + Parameters + ---------- + data : relay.Expr + The input tensor. + + shape : Union[PrimExprLike, List[PrimExprLike], Tuple[PrimExprLike], Expr] + Provide the shape to broadcast to. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ + if isinstance(shape, (PrimExpr, int)): + shape = [shape] + if isinstance(shape, (tuple, list)): + temp_shape = [] + for shape in shape: + if isinstance(shape, PrimExpr): + temp_shape.append(shape) + elif isinstance(shape, int): + temp_shape.append(tvm.tir.const(shape, "int32")) + else: + raise RuntimeError( + f"The input new shape of reshape operator contains unrecognized dimension {shape}" + ) + shape = relax.ShapeExpr(temp_shape) + + return _ffi_api.broadcast_to(data, shape) + + +def strided_slice( + data: Expr, + begin: Union[List[PrimExprLike], Tuple[PrimExprLike]], + end: Union[List[PrimExprLike], Tuple[PrimExprLike]], + strides: Optional[Union[List[PrimExprLike], Tuple[PrimExprLike]]] = None, + axes: Optional[Union[List[int], Tuple[int]]] = None, + slice_mode: str = "end", +) -> Expr: + """Strided slice of an array. + + Parameters + ---------- + data : relax.Expr + The source array to be sliced. + + begin : Union[List[PrimExprLike], Tuple[PrimExprLike]], + The indices to begin with in the slicing. + + end : Union[List[PrimExprLike], Tuple[PrimExprLike]] + Indices indicating end of the slice. + + strides : Optional[Union[List[PrimExprLike], Tuple[PrimExprLike]]] + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + + axes : Optional[Union[List[int], Tuple[int]]] + Axes along which slicing is applied. When it is specified, the length of begin, end, + strides, and axes must be equal. + + slice_mode : str + The slice mode [end, size]. + end: The ending indices for the slice [default]. + size: The input strides will be ignored, input end in this mode indicates + the size of a slice starting at the location specified by begin. If end[i] + is -1, all remaining elements in that dimension are included in the slice. + + Returns + ------- + ret : relax.Expr + The computed result. + """ + + def convert_int(arr): + res = [] + for x in arr: + if isinstance(x, PrimExpr): + res.append(x) + elif isinstance(x, int): + res.append(tvm.tir.const(x, "int32")) + else: + raise RuntimeError( + f"The input of strided_slice operator contains unrecognized value {x}" + ) + return res + + begin = convert_int(begin) + end = convert_int(end) + strides = convert_int(strides) if strides else None + return _ffi_api.strided_slice(data, begin, end, strides, axes, slice_mode) diff --git a/python/tvm/relax/transform/op_legalizer.py b/python/tvm/relax/transform/op_legalizer.py index cd6fefe6c9a2..49427acaf3d4 100644 --- a/python/tvm/relax/transform/op_legalizer.py +++ b/python/tvm/relax/transform/op_legalizer.py @@ -114,6 +114,98 @@ def _concatenate(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: return bb.call_te(topi.concatenate, fields, None if attrs.axis is None else attrs.axis.value) +def _expand_dims(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + output_ndim = len(output_shape) + + def expand_dims(data, axis): + data_dims = [] + for i in range(output_ndim): + if i not in axis and (i - output_ndim) not in axis: + data_dims.append(i) + return te.compute(output_shape, lambda *idx: data(*[idx[dim] for dim in data_dims])) + + return bb.call_te(expand_dims, args[0], attrs.axis) + + +def _cumsum(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te(topi.cumsum, args[0], attrs.axis) + + +def _trilu(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te(topi.trilu, args[0], tvm.tir.const(attrs.k, "int32"), attrs.is_upper) + + +def _cast(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te(topi.cast, args[0], attrs.dtype) + + +def _take(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te(topi.take, args[0], args[1], attrs.axis, attrs.batch_dims, attrs.mode) + + +def _full(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te( + topi.full, + args[1], + attrs.dtype if attrs.dtype is not None else args[0].checked_type.dtype, + args[0], + ) + + +def _split(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + indices_or_sections = ( + attrs.indices_or_sections.value + if isinstance(attrs.indices_or_sections, tvm.tir.IntImm) + else attrs.indices_or_sections + ) + return bb.call_te(topi.split, args[0], indices_or_sections, attrs.axis) + + +def _broadcast_to(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te(topi.broadcast_to, args[0], args[1]) + + +def _strided_slice(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te( + topi.strided_slice, + args[0], + attrs.begin, + attrs.end, + attrs.strides, + attrs.axes, + attrs.slice_mode, + ) + + +def _nn_max_pool2d(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te( + topi.nn.pool2d, + args[0], + kernel=attrs.pool_size, + stride=attrs.strides, + dilation=attrs.dilation, + padding=attrs.padding, + pool_type="max", + ceil_mode=attrs.ceil_mode, + layout=attrs.layout, + ) + + +def _nn_batch_norm(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te( + topi.nn.batch_norm, + data=args[0], + gamma=args[1], + beta=args[2], + moving_mean=args[3], + moving_var=args[4], + axis=attrs.axis, + epsilon=attrs.epsilon, + center=attrs.center, + scale=attrs.scale, + ) + + def _nn_layer_norm(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): def layer_norm(x, gamma, beta, axis, eps): shape_prod = tvm.tir.const(1, "int32") @@ -169,8 +261,6 @@ def multiply_compute(idx_reduce): if not b_appended: b_indices.append(idx_spatial[-1]) - print(a_indices) - print(b_indices) return a(*a_indices) * b(*b_indices) return te.sum(multiply_compute(k), axis=k) @@ -184,6 +274,16 @@ def _nn_softmax(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: return bb.call_te(topi.nn.softmax, args[0], attrs.axis) +def _nn_flatten(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te(topi.nn.flatten, args[0]) + + +def _nn_adaptive_max_pool2d(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te( + topi.nn.adaptive_pool, args[0], attrs.output_size, pool_type="avg", layout=attrs.layout + ) + + def _sum(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): return bb.call_te(topi.sum, args[0], attrs.axis, attrs.keepdims) @@ -197,6 +297,22 @@ def _mean(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): return bb.call_te(topi.divide, sum_var, shape_prod) +def _image_resize2d(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): + return bb.call_te( + topi.image.resize2d, + args[0], + roi=attrs.roi, + size=attrs.size, + layout=attrs.layout, + method=attrs.method, + coordinate_transformation_mode=attrs.coordinate_transformation_mode, + rounding_method=attrs.rounding_method, + bicubic_alpha=attrs.cubic_alpha, + bicubic_exclude=attrs.cubic_exclude, + extrapolation_value=attrs.extrapolation_value, + ) + + op_legalization_map = { ir.Op.get("relax.nn.conv2d"): _nn_conv2d, ir.Op.get("relax.add"): _add, @@ -213,11 +329,25 @@ def _mean(bb: BlockBuilder, args: List[Expr], attrs: Attrs, output_shape: Expr): ir.Op.get("relax.reshape"): _reshape, ir.Op.get("relax.transpose"): _transpose, ir.Op.get("relax.concatenate"): _concatenate, + ir.Op.get("relax.expand_dims"): _expand_dims, + ir.Op.get("relax.cumsum"): _cumsum, + ir.Op.get("relax.trilu"): _trilu, + ir.Op.get("relax.cast"): _cast, + ir.Op.get("relax.take"): _take, + ir.Op.get("relax.full"): _full, + ir.Op.get("relax.split"): _split, + ir.Op.get("relax.strided_slice"): _strided_slice, + ir.Op.get("relax.broadcast_to"): _broadcast_to, + ir.Op.get("relax.nn.max_pool2d"): _nn_max_pool2d, + ir.Op.get("relax.nn.batch_norm"): _nn_batch_norm, ir.Op.get("relax.nn.layer_norm"): _nn_layer_norm, ir.Op.get("relax.nn.matmul"): _nn_matmul, ir.Op.get("relax.nn.softmax"): _nn_softmax, + ir.Op.get("relax.nn.flatten"): _nn_flatten, + ir.Op.get("relax.nn.adaptive_avg_pool2d"): _nn_adaptive_max_pool2d, ir.Op.get("relax.sum"): _sum, ir.Op.get("relax.mean"): _mean, + ir.Op.get("relax.image.resize2d"): _image_resize2d, } diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a92e373e224e..2681caf66447 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -26,39 +26,50 @@ ############################### Operators ############################### from tvm.relax.op import ( + adaptive_avg_pool2d, add, assert_op, + broadcast_to, builtin, call_tir, + cast, concatenate, conv2d, cos, + cumsum, divide, dropout, ewise_fma, expand_dims, floor_divide, + full, gelu, invoke_closure, layer_norm, make_closure, matmul, max, + max_pool2d, mean, min, multiply, print, relu, reshape, + resize2d, shape_of, silu, sin, softmax, + split, sqrt, squeeze, + strided_slice, subtract, sum, + take, transpose, + trilu, unique, variance, ) @@ -395,16 +406,20 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name "Then", "TupleGetItem", "Void", + "adaptive_avg_pool2d", "add", "arg", "assert_op", "builtin", + "broadcast_to", "call_packed", "call_tir", + "cast", "concatenate", "const", "conv2d", "cos", + "cumsum", "dataflow", "divide", "dropout", @@ -413,6 +428,7 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name "emit_match_shape", "ewise_fma", "floor_divide", + "full", "func_attr", "func_name", "func_ret_type", @@ -425,6 +441,7 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name "make_closure", "matmul", "max", + "max_pool2d", "mean", "min", "multiply", @@ -432,14 +449,19 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name "print", "relu", "reshape", + "resize2d", "silu", "sin", "softmax", + "split", "sqrt", "squeeze", + "strided_slice", "subtract", "sum", + "take", "transpose", + "trilu", "unique", "variance", "shape_of", diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 842e21378fd1..98b69307f512 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1504,6 +1504,7 @@ def wrapped(*args, **kwargs): copysign = _op_wrapper(_tir_op.copysign) cos = _op_wrapper(_tir_op.cos) cosh = _op_wrapper(_tir_op.cosh) +div = _op_wrapper(_tir_op.div) erf = _op_wrapper(_tir_op.erf) exp = _op_wrapper(_tir_op.exp) exp2 = _op_wrapper(_tir_op.exp2) @@ -1685,6 +1686,7 @@ def f(): "copysign", "cos", "cosh", + "div", "erf", "exp", "exp2", diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc new file mode 100644 index 000000000000..a66e2ce03543 --- /dev/null +++ b/src/relax/op/image/resize.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file resize.cc + * \brief Image resize operators. + */ + +#include "resize.h" + +namespace tvm { +namespace relax { + +/* relax.resize2d */ +TVM_REGISTER_NODE_TYPE(Resize2DAttrs); + +RELAX_REGISTER_OP("relax.image.resize2d") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferShape", InferShapeResize2d) + .set_attr("FInferType", InferTypeResize2d); + +Expr MakeResize2D(Expr data, Array size, Array roi, String layout, + String method, String coordinate_transformation_mode, String rounding_method, + double cubic_alpha, int cubic_exclude, double extrapolation_value) { + ObjectPtr attrs = make_object(); + attrs->size = std::move(size); + attrs->roi = std::move(roi); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = std::move(coordinate_transformation_mode); + attrs->rounding_method = std::move(rounding_method); + attrs->cubic_alpha = cubic_alpha; + attrs->cubic_exclude = cubic_exclude; + attrs->extrapolation_value = extrapolation_value; + + const static Op& op = Op::Get("relax.image.resize2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.resize2d").set_body_typed(MakeResize2D); + +Optional InferShapeResize2d(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Resize2d op should have 1 argument"); + } + + const auto* input_shape = call->args[0]->shape().as(); + const auto* attrs = call->attrs.as(); + if (input_shape == nullptr) { + return RuntimeDepShape(); + } + if (input_shape->values.size() != 4) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The resize2d operator expects the input data to be a tensor of 4 " + "dimensions. However, the given data has " + << input_shape->values.size() << " dimensions"); + } + + if (attrs->layout->size != 4) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The resize2d operator expects the input layout to be a string of length 4, containing " + "letters \"N\", \"C\", \"H\", \"W\". However, the given layout is " + << attrs->layout); + } + int batch_axis = -1; + int height_axis = -1; + int width_axis = -1; + int channel_axis = -1; + for (int i = 0; i < 4; ++i) { + char letter = attrs->layout.at(i); + if (letter == 'N') { + batch_axis = i; + } else if (letter == 'H') { + height_axis = i; + } else if (letter == 'W') { + width_axis = i; + } else if (letter == 'C') { + channel_axis = i; + } + } + if (batch_axis == -1 || height_axis == -1 || width_axis == -1 || channel_axis == -1) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The resize2d operator expects the input layout to be a string of length 4, containing " + "letters \"N\", \"C\", \"H\", \"W\". However, the given layout is " + << attrs->layout); + } + + Array size = attrs->size; + if (size.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The resize2d operator expects the input size to have exactly two " + "elements. However, the given size is " + << size << ", which contains " << size.size() << " elements"); + } + + Array output_shape; + output_shape.resize(4); + output_shape.Set(batch_axis, input_shape->values[batch_axis]); + output_shape.Set(height_axis, size[0]); + output_shape.Set(width_axis, size[1]); + output_shape.Set(channel_axis, input_shape->values[channel_axis]); + return ShapeExpr(output_shape); +} + +Type InferTypeResize2d(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Resize2d op should have 1 argument"); + } + + const auto* input_type = call->args[0]->checked_type().as(); + if (input_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input data should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type DynTensorType."); + } + if (!input_type->IsUnknownNdim() && input_type->ndim != 4) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The resize2d operator expects the input data to be a tensor of 4 " + "dimensions. However, the given data has " + << input_type->ndim << " dimensions"); + } + + return DynTensorType(4, input_type->dtype); +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h new file mode 100644 index 000000000000..5bea361b49f2 --- /dev/null +++ b/src/relax/op/image/resize.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RELAX_OP_IMAGE_RESIZE_H +#define TVM_RELAX_OP_IMAGE_RESIZE_H + + +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/* relax.resize2d */ +Optional InferShapeResize2d(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeResize2d(const Call& call, DiagnosticContext diag_ctx); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_IMAGE_RESIZE_H diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 93a94e359506..90bef9bae2fa 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -28,6 +28,7 @@ namespace tvm { namespace relax { +/* relax.nn.max_pool2d */ TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); template @@ -101,5 +102,95 @@ Expr MakeMaxPool2D(Expr data, Array pool_size, Array strides TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(MakeMaxPool2D); +/* relax.nn.adaptive_avg_pool2d */ +TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); + +RELAX_REGISTER_OP("relax.nn.adaptive_avg_pool2d") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr("FInferShape", InferShapeAdaptiveAvgPool2D) + .set_attr("FInferType", InferTypeUnaryBroadcast); + +Expr MakeAdaptiveAvgPool2D(Expr data, Optional> output_size, String layout) { + ObjectPtr attrs = make_object(); + attrs->output_size = std::move(output_size); + attrs->layout = std::move(layout); + + const static Op& op = Op::Get("relax.nn.adaptive_avg_pool2d"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(MakeAdaptiveAvgPool2D); + +Optional InferShapeAdaptiveAvgPool2D(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "AdaptiveAvgPool2d op should have 1 argument"); + } + + const auto* input_shape = call->args[0]->shape().as(); + const auto* attrs = call->attrs.as(); + if (input_shape == nullptr) { + return RuntimeDepShape(); + } + if (input_shape->values.size() != 4) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The resize2d operator expects the input data to be a tensor of 4 " + "dimensions. However, the given data has " + << input_shape->values.size() << " dimensions"); + } + if (attrs->layout->size != 4) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The resize2d operator expects the input layout to be a string of length 4, containing " + "letters \"N\", \"C\", \"H\", \"W\". However, the given layout is " + << attrs->layout); + } + + if (!attrs->output_size.defined()) { + return GetRef(input_shape); + } + int batch_axis = -1; + int height_axis = -1; + int width_axis = -1; + int channel_axis = -1; + for (int i = 0; i < 4; ++i) { + char letter = attrs->layout.at(i); + if (letter == 'N') { + batch_axis = i; + } else if (letter == 'H') { + height_axis = i; + } else if (letter == 'W') { + width_axis = i; + } else if (letter == 'C') { + channel_axis = i; + } + } + if (batch_axis == -1 || height_axis == -1 || width_axis == -1 || channel_axis == -1) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The adaptive_avg_pool2d operator expects the input layout to be a string of length 4, " + "containing letters \"N\", \"C\", \"H\", \"W\". However, the given layout is " + << attrs->layout); + } + + Array output_size = attrs->output_size.value(); + if (output_size.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The adaptive_avg_pool2d operator expects the input size to have exactly " + "two elements. However, the given size is " + << output_size << ", which contains " << output_size.size() << " elements"); + } + + Array output_shape; + output_shape.resize(4); + output_shape.Set(batch_axis, input_shape->values[batch_axis]); + output_shape.Set(height_axis, output_size[0]); + output_shape.Set(width_axis, output_size[1]); + output_shape.Set(channel_axis, input_shape->values[channel_axis]); + return ShapeExpr(output_shape); +} + } // namespace relax } // namespace tvm diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h index 3f0d1d45f386..4d373a4d33c6 100644 --- a/src/relax/op/nn/pooling.h +++ b/src/relax/op/nn/pooling.h @@ -27,6 +27,7 @@ namespace tvm { namespace relax { +/* relax.nn.max_pool2d */ Optional InferShapeMaxPool2d(const Call& call, DiagnosticContext diag_ctx) { if (call->args.size() != 1) { diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "MaxPool2d op should have 1 argument"); @@ -57,6 +58,9 @@ Optional InferShapeMaxPool2d(const Call& call, DiagnosticContext diag_ctx) } } +/* relax.nn.adaptive_avg_pool2d */ +Optional InferShapeAdaptiveAvgPool2D(const Call& call, DiagnosticContext diag_ctx); + } // namespace relax } // namespace tvm #endif // TVM_RELAX_OP_NN_POOLING_H_ diff --git a/src/relax/op/tensor/transform.cc b/src/relax/op/tensor/transform.cc index 8176355137f8..79cf023d9209 100644 --- a/src/relax/op/tensor/transform.cc +++ b/src/relax/op/tensor/transform.cc @@ -616,5 +616,697 @@ Type InferTypeConcatenate(const Call& call, DiagnosticContext diag_ctx) { return DynTensorType(output_ndim, dtype); } +/* relax.cumsum */ +TVM_REGISTER_NODE_TYPE(CumsumAttrs); + +RELAX_REGISTER_OP("relax.cumsum") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferShape", InferShapeCumsum) + .set_attr("FInferType", InferTypeCumsum); + +Expr MakeCumsum(Expr data, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = axis; + + static const Op& op = Op::Get("relax.cumsum"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.cumsum").set_body_typed(MakeCumsum); + +Optional InferShapeCumsum(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Cumsum op should have 1 argument"); + } + + const auto* shape = call->args[0]->shape().as(); + const auto* attrs = call->attrs.as(); + if (shape == nullptr) { + return RuntimeDepShape(); + } + + if (attrs->axis.defined()) { + return GetRef(shape); + } + + PrimExpr prod = tir::make_const(DataType::Int(32), 1); + for (const PrimExpr& shape_dim : shape->values) { + prod = prod * shape_dim; + } + return ShapeExpr({prod}); +} + +Type InferTypeCumsum(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Cumsum op should have 1 argument"); + } + + const auto* input_type = call->args[0]->checked_type().as(); + if (input_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input has type DynTensorType."); + } + + const auto* attrs = call->attrs.as(); + if (attrs->axis.defined()) { + return GetRef(input_type); + } else { + return DynTensorType(/*ndim=*/1, input_type->dtype); + } +} + +/* relax.trilu */ +TVM_REGISTER_NODE_TYPE(TriluAttrs); + +RELAX_REGISTER_OP("relax.trilu") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferShape", InferShapeTrilu) + .set_attr("FInferType", InferTypeTrilu); + +Expr MakeTrilu(Expr data, int k, bool is_upper) { + auto attrs = make_object(); + attrs->k = k; + attrs->is_upper = is_upper; + + static const Op& op = Op::Get("relax.trilu"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.trilu").set_body_typed(MakeTrilu); + +Optional InferShapeTrilu(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Trilu op should have 1 argument"); + } + + return call->args[0]->shape(); +} + +Type InferTypeTrilu(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Trilu op should have 1 argument"); + } + + const auto* input_type = call->args[0]->checked_type().as(); + if (input_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Trilu operator requires the input data to have type DynTensorType. " + "However, the type of the given input is " + << call->args[0]->checked_type()->GetTypeKey()); + } + + return GetRef(input_type); +} + +/* relax.cast */ +TVM_REGISTER_NODE_TYPE(CastAttrs); + +RELAX_REGISTER_OP("relax.cast") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor") + .set_attr("FInferShape", InferShapeCast) + .set_attr("FInferType", InferTypeCast); + +Expr MakeCast(Expr data, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.cast"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.cast").set_body_typed(MakeCast); + +Optional InferShapeCast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Cast op should have 1 argument"); + } + return call->args[0]->shape(); +} + +Type InferTypeCast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Cast op should have 1 argument"); + } + + const auto* input_type = call->args[0]->checked_type().as(); + if (input_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input has type DynTensorType."); + } + const auto* attrs = call->attrs.as(); + return DynTensorType(input_type->ndim, attrs->dtype); +} + +/* relax.take */ +TVM_REGISTER_NODE_TYPE(TakeAttrs); + +RELAX_REGISTER_OP("relax.take") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .set_attr("FInferShape", InferShapeTake) + .set_attr("FInferType", InferTypeTake); + +Expr MakeTake(Expr data, Expr indices, Optional axis, int batch_dims, String mode) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + attrs->batch_dims = batch_dims; + attrs->mode = std::move(mode); + + static const Op& op = Op::Get("relax.take"); + return Call(op, {std::move(data), std::move(indices)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(MakeTake); + +Optional InferShapeTake(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Take op should have 2 arguments"); + } + + const auto* data_shape = call->args[0]->shape().as(); + const auto* indices_shape = call->args[1]->shape().as(); + const auto* attrs = call->attrs.as(); + + if (indices_shape == nullptr) { + return RuntimeDepShape(); + } else if (!attrs->axis.defined()) { + return GetRef(indices_shape); + } else if (data_shape == nullptr) { + return RuntimeDepShape(); + } + + int axis = attrs->axis.value()->value; + int ndim_data = data_shape->values.size(); + if (axis < 0) { + axis = ndim_data + axis; + } + if (axis < 0 || axis >= ndim_data) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Take operator expects the input axis to be in range [" << -ndim_data + << ", " << ndim_data << "). However, the given axis is " + << attrs->axis.value()->value << ", which is out of range"); + } + + Array output_shape = data_shape->values; + output_shape.erase(output_shape.begin() + axis); + output_shape.insert(output_shape.begin() + axis, indices_shape->values.begin(), + indices_shape->values.end()); + return ShapeExpr(output_shape); +} + +Type InferTypeTake(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Take op should have 2 arguments"); + } + + const auto* data_type = call->args[0]->checked_type().as(); + const auto* indices_type = call->args[1]->checked_type().as(); + const auto* attrs = call->attrs.as(); + if (data_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input data should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type DynTensorType."); + } + if (indices_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input indices should has type DynTensorType, but actually it is " + << call->args[1]->checked_type()->GetTypeKey() + << ". Please make sure the input indices has type DynTensorType."); + } + if (!indices_type->IsUnknownDtype() && !indices_type->dtype.is_int()) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Take operator expects the input indices to have integer dtype. However, " + "the given indices has dtype " + << indices_type->dtype); + } + + if (indices_type->IsUnknownNdim()) { + return DynTensorType(-1, data_type->dtype); + } else if (!attrs->axis.defined()) { + return DynTensorType(indices_type->ndim, data_type->dtype); + } else if (data_type->IsUnknownNdim()) { + return DynTensorType(-1, data_type->dtype); + } else { + return DynTensorType(data_type->ndim - 1 + indices_type->ndim, data_type->dtype); + } +} + +/* Initialization operators */ +TVM_REGISTER_NODE_TYPE(FullAttrs); + +/* relax.full */ +RELAX_REGISTER_OP("relax.full") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") + .add_argument("shape", "ShapeExpr", "The shape of the created tensor.") + .set_attr("FInferShape", InferShapeFull) + .set_attr("FInferType", InferTypeFull); + +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.full"); + return Call(op, {std::move(fill_value), std::move(shape)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full").set_body_typed(MakeFull); + +Optional InferShapeFull(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Full op should have 2 arguments"); + } + + const auto* fill_value_shape = call->args[0]->shape().as(); + if (fill_value_shape != nullptr && fill_value_shape->values.size() != 0) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Full operator expects the input fill value to be a scalar tensor " + "(0-rank tensor). However, the input fill value has rank " + << fill_value_shape->values.size()); + } + + return call->args[1]; +} + +Type InferTypeFull(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Full op should have 2 arguments"); + } + + const auto* fill_value_type = call->args[0]->checked_type().as(); + const auto* shape_type = call->args[1]->checked_type().as(); + if (fill_value_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The input fill value should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type DynTensorType."); + } + if (!fill_value_type->IsUnknownNdim() && fill_value_type->ndim != 0) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Full operator expects the input fill value to be a scalar tensor " + "(0-rank tensor). However, the input fill value has rank " + << fill_value_type->ndim); + } + if (shape_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The input shape should has type ShapeType, but actually it is " + << call->args[1]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type ShapeType."); + } + + // Todo(ruihang): add ndim to ShapeType + int ndim = -1; + const auto* shape = call->args[1].as(); + if (shape != nullptr) { + ndim = shape->values.size(); + } + + const auto* attrs = call->attrs.as(); + return DynTensorType(ndim, attrs->dtype.is_void() ? fill_value_type->dtype : attrs->dtype); +} + +/* relax.split */ +TVM_REGISTER_NODE_TYPE(SplitAttrs); + +RELAX_REGISTER_OP("relax.split") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferShape", InferShapeSplit) + .set_attr("FInferType", InferTypeSplit); + +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { + ObjectPtr attrs = make_object(); + attrs->indices_or_sections = indices_or_sections; + if (const auto* n_section = indices_or_sections.as()) { + CHECK(n_section->value > 0) << "Split operator expects the input number of sections to be a " + "positive integer. However, the given number of sections is " + << n_section->value; + } + attrs->axis = axis; + + static const Op& op = Op::Get("relax.split"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.split").set_body_typed(MakeSplit); + +Optional InferShapeSplit(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Split op should have 1 argument"); + } + + const auto* input_shape = call->args[0]->shape().as(); + const auto* attrs = call->attrs.as(); + if (input_shape == nullptr) { + return RuntimeDepShape(); + } + + int ndim = input_shape->values.size(); + int axis = attrs->axis; + if (axis < 0) { + axis = ndim + axis; + } + if (axis < 0 || axis >= ndim) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Split operator expects the input axis to be in range [" << -ndim << ", " + << ndim << "). However, the given axis is " << attrs->axis + << ", which is out of range"); + } + + Array output_shape; + PrimExpr len_axis = input_shape->values[axis]; + if (const auto* p_indices = attrs->indices_or_sections.as()) { + Array indices = GetRef>(p_indices); + PrimExpr zero = tir::make_const(DataType::Int(32), 0); + + output_shape.reserve(indices.size() + 1); + indices.insert(indices.begin(), zero); + indices.insert(indices.end(), len_axis); + + for (int i = 0; i + 1 < static_cast(indices.size()); ++i) { + PrimExpr l = tvm::max(zero, indices[i]); + PrimExpr r = tvm::min(len_axis, indices[i + 1]); + PrimExpr len = tvm::max(zero, r - l); + Array shape = input_shape->values; + shape.erase(shape.begin() + axis); + shape.insert(shape.begin() + axis, len); + output_shape.push_back(ShapeExpr(shape)); + } + } else { + const auto* p_n_section = attrs->indices_or_sections.as(); + ICHECK_NOTNULL(p_n_section); + int n_section = p_n_section->value; + if (n_section <= 0) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Split operator expects the input number of sections to be a positive " + "integer. However, the given number of sections is " + << n_section); + } + if (const int64_t* len_axis_value = tir::as_const_int(len_axis)) { + if (*len_axis_value % n_section != 0) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Split operator expects the length of the input axis is divisible by " + "the input number of section. However, the axis has length " + << *len_axis_value << " while the given number of section is " + << n_section << ", which does not result in an equal division."); + } + } + // Todo(relax-team): need runtime divisibility check for the cases where `len_axis` is symbolic + + PrimExpr n_section_expr = tir::make_const(DataType::Int(32), n_section); + Array shape = input_shape->values; + shape.erase(shape.begin() + axis); + shape.insert(shape.begin() + axis, tvm::floordiv(len_axis, n_section_expr)); + for (int i = 0; i < n_section; ++i) { + output_shape.push_back(ShapeExpr(shape)); + } + } + return Tuple(output_shape); +} + +Type InferTypeSplit(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Split op should have 1 argument"); + } + + const auto* input_type = call->args[0]->checked_type().as(); + const auto* attrs = call->attrs.as(); + if (input_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input data should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type DynTensorType."); + } + + int n_tensor = -1; + if (const auto* p_indices = attrs->indices_or_sections.as()) { + n_tensor = p_indices->size() + 1; + } else { + const auto* p_n_section = attrs->indices_or_sections.as(); + ICHECK_NOTNULL(p_n_section); + n_tensor = p_n_section->value; + } + + Array output_type; + output_type.reserve(n_tensor); + for (int i = 0; i < n_tensor; ++i) { + output_type.push_back(GetRef(input_type)); + } + return TupleType(output_type); +} + +/* relax.broadcast_to */ +RELAX_REGISTER_OP("relax.broadcast_to") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "ShapeExpr", "The shape of the created tensor.") + .set_attr("FInferShape", InferShapeBroadcastTo) + .set_attr("FInferType", InferTypeBroadcastTo); + +Expr MakeBroadcastTo(Expr data, Expr shape) { + const static Op& op = Op::Get("relax.broadcast_to"); + return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.broadcast_to").set_body_typed(MakeBroadcastTo); + +Optional InferShapeBroadcastTo(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "BroadcastTo op should have 2 arguments"); + } + + const auto* data_shape = call->args[0]->shape().as(); + const auto* new_shape = call->args[1].as(); + if (data_shape == nullptr || new_shape == nullptr) { + // Todo: need runtime shape broadcast compatibility check + return call->args[1]; + } + + int data_ndim = data_shape->values.size(); + int new_ndim = new_shape->values.size(); + if (new_ndim < data_ndim) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The broadcast_to operator expects the input new shape to have at least " + "as many dimensions as the input data. However, the given data has ndim " + << data_ndim << " while the given shape has ndim " << new_ndim); + } + + arith::Analyzer ana; + for (int i = 1; i <= data_ndim; ++i) { + PrimExpr prev_len = data_shape->values[data_ndim - i]; + PrimExpr new_len = new_shape->values[new_ndim - i]; + if (tir::is_const_int(prev_len, 1)) { + continue; + } else if (ana.CanProve(prev_len != new_len)) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The broadcast_to operator expects the input new shape is broadcast " + "compatible with the shape of the input data. However, on the last but " + << i << " dimension, the input data shape has length " << prev_len + << " while the nwe shape has length " << new_len + << ", which are not compatible"); + } + } + return GetRef(new_shape); +} + +Type InferTypeBroadcastTo(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "BroadcastTo op should have 2 arguments"); + } + + const auto* data_type = call->args[0]->checked_type().as(); + const auto* shape_type = call->args[1]->checked_type().as(); + if (data_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input data should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type DynTensorType."); + } + if (shape_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input new shape should has type ShapeType, but actually it is " + << call->args[1]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type ShapeType."); + } + + // Todo(ruihang): add ndim to ShapeType + int ndim = -1; + if (const auto* shape = call->args[1].as()) { + ndim = shape->values.size(); + } + return DynTensorType(ndim, data_type->dtype); +} + +/* relax.strided_slice */ +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); + +RELAX_REGISTER_OP("relax.strided_slice") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferShape", InferShapeStridedSlice) + .set_attr("FInferType", InferTypeStridedSlice); + +Expr MakeStridedSlice(Expr data, // + Array begin, // + Array end, // + Optional> strides, // + Optional> axes, // + String slice_mode) { + CHECK(slice_mode == "end" || slice_mode == "size") + << "Operator strided_slice expects the input `slice_mode` to be either \"end\" or \"size\". " + "However, the given `slice_mode` is " + << slice_mode; + + ObjectPtr attrs = make_object(); + attrs->begin = std::move(begin); + attrs->end = std::move(end); + attrs->strides = std::move(strides); + attrs->axes = std::move(axes); + attrs->slice_mode = std::move(slice_mode); + + const static Op& op = Op::Get("relax.strided_slice"); + return Call(op, {std::move(data)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(MakeStridedSlice); + +Optional InferShapeStridedSlice(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "StridedSlice op should have 1 argument"); + } + + const auto* input_shape = call->args[0]->shape().as(); + const auto* attrs = call->attrs.as(); + if (input_shape == nullptr) { + return RuntimeDepShape(); + } + + int ndim = input_shape->values.size(); + Array axes; + if (attrs->axes.defined()) { + axes = attrs->axes.value(); + } else { + axes.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + axes.push_back(Integer(i)); + } + } + + int n_axis = axes.size(); + Array begins = attrs->begin; + Array ends = attrs->end; + Array strides; + if (attrs->strides.defined()) { + strides = attrs->strides.value(); + } else { + strides.reserve(n_axis); + for (int i = 0; i < n_axis; ++i) { + strides.push_back(tir::make_const(DataType::Int(32), 1)); + } + } + + if (static_cast(begins.size()) != n_axis) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The strided_slice operator expects the input begin values to have the same length as " + "the number of input axes. However, the input axes length is " + << n_axis << " while the length of begin values is " << begins.size()); + } + if (static_cast(ends.size()) != n_axis) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The strided_slice operator expects the input end values to have the same length as " + "the number of input axes. However, the input axes length is " + << n_axis << " while the length of end values is " << ends.size()); + } + if (static_cast(strides.size()) != n_axis) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "The strided_slice operator expects the input stride values to have the same length as " + "the number of input axes. However, the input axes length is " + << n_axis << " while the length of stride values is " << strides.size()); + } + + arith::Analyzer ana; + Array output_shape = input_shape->values; + std::unordered_set specified_axes; + specified_axes.reserve(axes.size()); + for (int i = 0; i < static_cast(axes.size()); ++i) { + int axis = axes[i]->value; + if (axis < 0) { + axis = ndim + axis; + } + if (axis < 0 || axis >= ndim) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Operator strided_slice expects the input axis to be in range [" + << -ndim << ", " << ndim << "). However, the given axis " << i << " is " + << axes[i]->value << ", which is out of range"); + } + if (specified_axes.count(axis)) { + diag_ctx.EmitFatal( + Diagnostic::Error(call->span) + << "Operator strided_slice expects the input axes not to duplicate. However, axis " + << axis << " occurs twice"); + } + specified_axes.insert(axis); + + PrimExpr begin = begins[i]; + PrimExpr end{nullptr}; + PrimExpr stride = strides[i]; + + if (attrs->slice_mode == "size") { + stride = tir::make_const(DataType::Int(32), 1); + end = begin + ends[i]; + } else { + if (attrs->slice_mode != "end") { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The strided_slice operator expects the input `slice_mode` to be " + "either \"end\" or \"size\". However, the given `slice_mode` is " + << attrs->slice_mode); + } + end = tvm::min(input_shape->values[axis], ends[i]); + } + if (ana.CanProveLess(stride, 0)) { + output_shape.Set(axis, tvm::ceildiv(begin - end, -stride)); + } else { + output_shape.Set(axis, tvm::ceildiv(end - begin, stride)); + } + } + + return ShapeExpr(output_shape); +} + +Type InferTypeStridedSlice(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "StridedSlice op should have 1 argument"); + } + + const auto* input_type = call->args[0]->checked_type().as(); + if (input_type == nullptr) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "The op input data should has type DynTensorType, but actually it is " + << call->args[0]->checked_type()->GetTypeKey() + << ". Please make sure the input data has type DynTensorType."); + } + + return GetRef(input_type); +} + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/transform.h b/src/relax/op/tensor/transform.h index c9c018a7b01d..4ee841c021aa 100644 --- a/src/relax/op/tensor/transform.h +++ b/src/relax/op/tensor/transform.h @@ -58,6 +58,46 @@ Optional InferShapeConcatenate(const Call& call, DiagnosticContext diag_ct Type InferTypeConcatenate(const Call& call, DiagnosticContext diag_ctx); +/* relax.cumsum */ +Optional InferShapeCumsum(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeCumsum(const Call& call, DiagnosticContext diag_ctx); + +/* relax.trilu */ +Optional InferShapeTrilu(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeTrilu(const Call& call, DiagnosticContext diag_ctx); + +/* relax.cast */ +Optional InferShapeCast(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeCast(const Call& call, DiagnosticContext diag_ctx); + +/* relax.take */ +Optional InferShapeTake(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeTake(const Call& call, DiagnosticContext diag_ctx); + +/* relax.full */ +Optional InferShapeFull(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeFull(const Call& call, DiagnosticContext diag_ctx); + +/* relax.split */ +Optional InferShapeSplit(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeSplit(const Call& call, DiagnosticContext diag_ctx); + +/* relax.broadcast_to */ +Optional InferShapeBroadcastTo(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeBroadcastTo(const Call& call, DiagnosticContext diag_ctx); + +/* relax.strided_slice */ +Optional InferShapeStridedSlice(const Call& call, DiagnosticContext diag_ctx); + +Type InferTypeStridedSlice(const Call& call, DiagnosticContext diag_ctx); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_legalizer.py b/tests/python/relax/test_op_legalizer.py index 1e465b8543c8..fb5b144d025f 100644 --- a/tests/python/relax/test_op_legalizer.py +++ b/tests/python/relax/test_op_legalizer.py @@ -74,30 +74,6 @@ def conv2d( mod = OperatorLegalizer(Conv2d).transform() tvm.ir.assert_structural_equal(mod, Expected) - # dev = tvm.cpu() - # ex = relax.vm.build(mod, target="llvm") - # vm = relax.VirtualMachine(ex, dev) - - # import numpy as np - # import torch - - # x_np = np.random.rand(2, 3, 28, 28).astype("float32") - # w_np = np.random.rand(4, 3, 3, 3).astype("float32") - # res_torch = torch.nn.functional.conv2d(torch.tensor(x_np), torch.tensor(w_np)) - # res_relax = vm["main"](tvm.nd.array(x_np, dev), tvm.nd.array(w_np, dev)) - # tvm.testing.assert_allclose(res_relax.numpy(), res_torch.numpy(), rtol=1e-5, atol=1e-5) - - # print("pass") - - # x = relax.Var("x", [2, 3, 28, 28], relax.DynTensorType(ndim=4, dtype="float32")) - # w = relax.Var("w", [4, 3, 3, 3], relax.DynTensorType(ndim=4, dtype="float32")) - # bb = relax.BlockBuilder() - # with bb.function("main", [x, w]): - # gv = bb.emit(relax.op.conv2d(x, w, 3)) - # bb.emit_func_output(gv) - - # print(mod.script()) - def test_add(): @I.ir_module @@ -135,14 +111,6 @@ def add( mod = OperatorLegalizer(Add).transform() tvm.ir.assert_structural_equal(mod, Expected) - # x = relax.Var("x", [2, 3], relax.DynTensorType(ndim=2, dtype="float32")) - # y = relax.Var("y", [2, 3], relax.DynTensorType(ndim=2, dtype="float32")) - # bb = relax.BlockBuilder() - # with bb.function("main", [x, y]): - # gv = bb.emit(relax.op.add(x, y)) - # bb.emit_func_output(gv) - # print(bb.get().script()) - def test_subtract(): @I.ir_module @@ -687,6 +655,403 @@ def concatenate( tvm.ir.assert_structural_equal(mod, Expected) +def test_cumsum(): + @I.ir_module + class Cumsum: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3): + gv: R.Tensor((2, 3, 4), "float32") = R.cumsum(x, axis=-2) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3): + gv = R.call_tir(cumsum, (x,), (2, 3, 4), dtype="float32") + return gv + + @T.prim_func + def cumsum( + rxplaceholder: T.Buffer[(2, 3, 4), "float32"], out_buf: T.Buffer[(2, 3, 4), "float32"] + ) -> None: + T.func_attr({"global_symbol": "cumsum", "tir.noalias": True}) + with T.block("cumsum_generic"): + T.reads(rxplaceholder[0:2, 0:3, 0:4]) + T.writes(out_buf[0:2, 0:3, 0:4]) + for fused in T.parallel(8): + out_buf[ + (fused // 4 * 3 * 4 + fused % 4) // 4 // 3, + (fused // 4 * 3 * 4 + fused % 4) // 4 % 3, + (fused // 4 * 3 * 4 + fused % 4) % 4, + ] = rxplaceholder[ + (fused // 4 * 3 * 4 + fused % 4) // 4 // 3, + (fused // 4 * 3 * 4 + fused % 4) // 4 % 3, + (fused // 4 * 3 * 4 + fused % 4) % 4, + ] + for v_k in T.serial(2): + out_buf[ + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 // 3, + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 % 3, + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) % 4, + ] = ( + out_buf[ + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1 - 1) * 4) // 4 // 3, + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1 - 1) * 4) // 4 % 3, + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1 - 1) * 4) % 4, + ] + + rxplaceholder[ + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 // 3, + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 % 3, + (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) % 4, + ] + ) + + mod = OperatorLegalizer(Cumsum).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cumsum_without_specified_axis(): + @I.ir_module + class Cumsum: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=1): + gv: R.Tensor((24,), "float32") = R.cumsum(x) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=1): + gv = R.call_tir(cumsum, (x,), (24,), dtype="float32") + return gv + + @T.prim_func + def cumsum( + rxplaceholder: T.Buffer[(2, 3, 4), "float32"], out_buf: T.Buffer[24, "float32"] + ) -> None: + T.func_attr({"global_symbol": "cumsum", "tir.noalias": True}) + with T.block("cumsum_generic"): + T.reads(rxplaceholder[0:2, 0:3, 0:4]) + T.writes(out_buf[0:24]) + for fused in T.parallel(1): + out_buf[fused * 24] = rxplaceholder[ + fused * 24 // 4 // 3, fused * 24 // 4 % 3, fused * 24 % 4 + ] + for v_k in T.serial(23): + out_buf[fused * 24 + (v_k + 1)] = ( + out_buf[fused * 24 + (v_k + 1 - 1)] + + rxplaceholder[ + (fused * 24 + (v_k + 1)) // 4 // 3, + (fused * 24 + (v_k + 1)) // 4 % 3, + (fused * 24 + (v_k + 1)) % 4, + ] + ) + + mod = OperatorLegalizer(Cumsum).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_expand_dims(): + @I.ir_module + class ExpandDims: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=8): + gv: R.Tensor((2, 1, 1, 1, 3, 1, 4, 1), "float32") = R.expand_dims( + x, axis=[-1, 1, -6, 3, 5] + ) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=8): + gv = R.call_tir(expand_dims, (x,), (2, 1, 1, 1, 3, 1, 4, 1), dtype="float32") + return gv + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer[(2, 3, 4), "float32"], + compute: T.Buffer[(2, 1, 1, 1, 3, 1, 4, 1), "float32"], + ) -> None: + T.func_attr({"global_symbol": "expand_dims", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): + with T.block("compute"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( + "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] + ) + T.reads(rxplaceholder[i0_1, i4_1, i6_1]) + T.writes(compute[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) + compute[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[ + i0_1, i4_1, i6_1 + ] + + mod = OperatorLegalizer(ExpandDims).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_trilu(): + @I.ir_module + class Trilu: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3): + gv: R.Tensor((2, 3, 4), "float32") = R.trilu(x, k=0, is_upper=False) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3): + gv = R.call_tir(trilu, (x,), (2, 3, 4), dtype="float32") + return gv + + @T.prim_func + def trilu( + rxplaceholder: T.Buffer[(2, 3, 4), "float32"], trilu: T.Buffer[(2, 3, 4), "float32"] + ) -> None: + T.func_attr({"global_symbol": "trilu", "tir.noalias": True}) + for i0, i1, i2 in T.grid(2, 3, 4): + with T.block("trilu"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(trilu[i0_1, i1_1, i2_1]) + trilu[i0_1, i1_1, i2_1] = T.Select( + i2_1 <= i1_1, rxplaceholder[i0_1, i1_1, i2_1], T.float32(0) + ) + + mod = OperatorLegalizer(Trilu).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_cast(): + @I.ir_module + class Cast: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "int32", ndim=3): + gv: R.Tensor((2, 3, 4), "int32") = R.cast(x, "int32") + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "int32", ndim=3): + gv = R.call_tir(cast, (x,), (2, 3, 4), dtype="int32") + return gv + + @T.prim_func + def cast( + rxplaceholder: T.Buffer[(2, 3, 4), "float32"], compute: T.Buffer[(2, 3, 4), "int32"] + ) -> None: + T.func_attr({"global_symbol": "cast", "tir.noalias": True}) + for i0, i1, i2 in T.grid(2, 3, 4): + with T.block("compute"): + i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(rxplaceholder[i0_1, i1_1, i2_1]) + T.writes(compute[i0_1, i1_1, i2_1]) + compute[i0_1, i1_1, i2_1] = T.cast(rxplaceholder[i0_1, i1_1, i2_1], "int32") + + mod = OperatorLegalizer(Cast).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_take(): + @I.ir_module + class Take: + @R.function + def main( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3, 4, 2), "int32") + ) -> R.Tensor(None, "float32", ndim=5): + gv: R.Tensor((2, 3, 4, 2, 4), "float32") = R.take(x, indices, axis=1) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3, 4, 2), "int32") + ) -> R.Tensor(None, "float32", ndim=5): + gv = R.call_tir(take, (x, indices), (2, 3, 4, 2, 4), dtype="float32") + return gv + + @T.prim_func + def take( + rxplaceholder: T.Buffer[(2, 3, 4), "float32"], + rxplaceholder_1: T.Buffer[(3, 4, 2), "int32"], + T_take: T.Buffer[(2, 3, 4, 2, 4), "float32"], + ) -> None: + T.func_attr({"global_symbol": "take", "tir.noalias": True}) + for i0, i1, i2, i3, i4 in T.grid(2, 3, 4, 2, 4): + with T.block("T_take"): + ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads( + rxplaceholder[ax0, T.min(T.max(0, rxplaceholder_1[ax1, ax2, ax3]), 2), ax4], + rxplaceholder_1[ax1, ax2, ax3], + ) + T.writes(T_take[ax0, ax1, ax2, ax3, ax4]) + T_take[ax0, ax1, ax2, ax3, ax4] = rxplaceholder[ + ax0, T.min(T.max(0, rxplaceholder_1[ax1, ax2, ax3]), 2), ax4 + ] + + mod = OperatorLegalizer(Take).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_full(): + @I.ir_module + class Full: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((2, 3), "float32") = R.full(v, (2, 3), dtype="float32") + return gv + + @I.ir_module + class Expected: + @R.function + def main(v: R.Tensor((), "int32")) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(full, (v,), (2, 3), dtype="float32") + return gv + + @T.prim_func + def full(rxplaceholder: T.Buffer[(), "int32"], T_full: T.Buffer[(2, 3), "float32"]) -> None: + T.func_attr({"global_symbol": "full", "tir.noalias": True}) + for i0, i1 in T.grid(2, 3): + with T.block("T_full"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads(rxplaceholder[()]) + T.writes(T_full[ax0, ax1]) + T_full[ax0, ax1] = T.cast(rxplaceholder[()], "float32") + + mod = OperatorLegalizer(Full).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_broadcast_to(): + @I.ir_module + class BroadcastTo: + @R.function + def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv = R.call_tir(broadcast_to, (x,), (4, 2, 5, 3), dtype="float32") + return gv + + @T.prim_func + def broadcast_to( + rxplaceholder: T.Buffer[(2, 1, 3), "float32"], + T_broadcast_to: T.Buffer[(4, 2, 5, 3), "float32"], + ) -> None: + T.func_attr({"global_symbol": "broadcast_to", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(4, 2, 5, 3): + with T.block("T_broadcast_to"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax1, 0, ax3]) + T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) + T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, 0, ax3] + + mod = OperatorLegalizer(BroadcastTo).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_strided_slice(): + @I.ir_module + class StridedSlice: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice( + x, + begin=[1, 0, 8], + end=[8, 9, 0], + strides=[2, 1, -3], + axes=[0, 1, 3], + slice_mode="end", + ) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv = R.call_tir(strided_slice, (x,), (4, 9, 10, 3), dtype="float32") + return gv + + @T.prim_func + def strided_slice( + rxplaceholder: T.Buffer[(8, 9, 10, 10), "float32"], + T_strided_slice_with_axes: T.Buffer[(4, 9, 10, 3), "float32"], + ) -> None: + T.func_attr({"global_symbol": "strided_slice", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(4, 9, 10, 3): + with T.block("T_strided_slice_with_axes"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0 * 2 + 1, ax1, ax2, 8 - ax3 * 3]) + T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3]) + T_strided_slice_with_axes[ax0, ax1, ax2, ax3] = rxplaceholder[ + ax0 * 2 + 1, ax1, ax2, 8 - ax3 * 3 + ] + + mod = OperatorLegalizer(StridedSlice).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_max_pool2d(): + @I.ir_module + class MaxPool2D: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((4, 6, 56, 56), "float32") = R.max_pool2d( + x, + pool_size=[3, 3], + strides=[2, 2], + dilation=[1, 1], + padding=[1, 1, 1, 1], + layout="NCHW", + ) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv = R.call_tir(pool2d, (x,), (4, 6, 56, 56), dtype="float32") + return gv + + @T.prim_func + def pool2d( + rxplaceholder: T.Buffer[(4, 6, 112, 112), "float32"], + tensor: T.Buffer[(4, 6, 56, 56), "float32"], + ) -> None: + T.func_attr({"global_symbol": "pool2d", "tir.noalias": True}) + pad_temp = T.alloc_buffer([4, 6, 114, 114], dtype="float32") + for i0, i1, i2, i3 in T.grid(4, 6, 114, 114): + with T.block("pad_temp"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(rxplaceholder[ax0, ax1, ax2 - 1, ax3 - 1]) + T.writes(pad_temp[ax0, ax1, ax2, ax3]) + pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else( + 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, + rxplaceholder[ax0, ax1, ax2 - 1, ax3 - 1], + T.float32(-3.4028234663852886e38), + dtype="float32", + ) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 6, 56, 56, 3, 3): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(pad_temp[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1]) + T.writes(tensor[ax0, ax1, ax2, ax3]) + with T.init(): + tensor[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) + tensor[ax0, ax1, ax2, ax3] = T.max( + tensor[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1] + ) + + mod = OperatorLegalizer(MaxPool2D).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + def test_layer_norm(): @I.ir_module class LayerNorm: @@ -943,7 +1308,6 @@ class Expected: def main( x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32") ) -> R.Tensor(None, "float32", ndim=0): - # block 0 gv = R.call_tir(matmul, (x, y), (), dtype="float32") return gv @@ -953,10 +1317,7 @@ def matmul( rxplaceholder_1: T.Buffer[4, "float32"], matmul: T.Buffer[(), "float32"], ) -> None: - # function attr dict T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - # body - # with T.block("root") for i0 in T.serial(4): with T.block("matmul"): k = T.axis.reduce(4, i0) @@ -1089,6 +1450,52 @@ def softmax( tvm.ir.assert_structural_equal(mod, Expected) +def test_adaptive_avg_pool2d(): + @I.ir_module + class AdaptiveAvgPool2D: + @R.function + def main(x: R.Tensor((2, 64, 7, 7), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((2, 64, 1, 1), "float32") = R.adaptive_avg_pool2d(x, output_size=[1, 1]) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 64, 7, 7), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv = R.call_tir(adaptive_pool, (x,), (2, 64, 1, 1), dtype="float32") + return gv + + @T.prim_func + def adaptive_pool( + rxplaceholder: T.Buffer[(2, 64, 7, 7), "float32"], + tensor: T.Buffer[(2, 64, 1, 1), "float32"], + ) -> None: + T.func_attr({"global_symbol": "adaptive_pool", "tir.noalias": True}) + tensor_1 = T.alloc_buffer([2, 64, 1, 1], dtype="float32") + for i0, i1, i2, i3, i4, i5 in T.grid(2, 64, 1, 1, 7, 7): + with T.block("tensor"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(rxplaceholder[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]) + T.writes(tensor_1[ax0, ax1, ax2, ax3]) + with T.init(): + tensor_1[ax0, ax1, ax2, ax3] = T.float32(0) + tensor_1[ax0, ax1, ax2, ax3] = ( + tensor_1[ax0, ax1, ax2, ax3] + + rxplaceholder[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1] + ) + for i0, i1, i2, i3 in T.grid(2, 64, 1, 1): + with T.block("tensor_1"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(tensor_1[ax0, ax1, ax2, ax3]) + T.writes(tensor[ax0, ax1, ax2, ax3]) + tensor[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] * T.float32( + 0.020408163265306121 + ) + + mod = OperatorLegalizer(AdaptiveAvgPool2D).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + def test_sum(): @I.ir_module class Sum: @@ -1174,28 +1581,58 @@ def divide( tvm.ir.assert_structural_equal(mod, Expected) +def test_image_resize2d(): + @I.ir_module + class Resize2D: + @R.function + def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((2, 16, 16, 3), "float32") = R.resize2d( + x, + size=[16, 16], + layout="NHWC", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + ) + return gv + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv = R.call_tir(resize2d, (x,), (2, 16, 16, 3), dtype="float32") + return gv + + @T.prim_func + def resize2d( + rxplaceholder: T.Buffer[(2, 8, 8, 3), "float32"], + resize: T.Buffer[(2, 16, 16, 3), "float32"], + ) -> None: + T.func_attr({"global_symbol": "resize2d", "tir.noalias": True}) + for i0, i1, i2, i3 in T.grid(2, 16, 16, 3): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads( + rxplaceholder[ + i0_1, + T.max(T.min(T.div(i1_1, 2), 7), 0), + T.max(T.min(T.div(i2_1, 2), 7), 0), + i3_1, + ] + ) + T.writes(resize[i0_1, i1_1, i2_1, i3_1]) + resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[ + i0_1, + T.max(T.min(T.div(i1_1, 2), 7), 0), + T.max(T.min(T.div(i2_1, 2), 7), 0), + i3_1, + ] + + mod = OperatorLegalizer(Resize2D).transform() + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": - test_conv2d() - test_add() - test_subtract() - test_multiply() - test_divide() - test_floor_divide() - test_sin() - test_cos() - test_sqrt() - test_relu() - test_gelu() - test_silu() - test_reshape() - test_reshape_dim_inference() - test_transpose() - test_concatenate() - test_layer_norm() - test_matmul_1_4() - test_matmul_4_1() - test_matmul_1_1() - test_matmul_4_5() - test_softmax() - test_sum() - test_mean() + # Todo: test_split_by_indices + # Todo: test_split_by_n_section + # Todo: test_batch_norm + pytest.main([__file__]) diff --git a/tests/python/relax/test_relax_image_ops.py b/tests/python/relax/test_relax_image_ops.py new file mode 100644 index 000000000000..6eda4be8d961 --- /dev/null +++ b/tests/python/relax/test_relax_image_ops.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations # must import to defer parsing of annotations +import pytest +import tvm +from tvm import relax +from tvm.error import DiagnosticError +from tvm.relax.testing import transform +from tvm.script._parser import relax as R +import tvm.testing + + +def test_resize2d(): + @R.function + def expected(x: R.Tensor((2, 14, 14, 3), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((2, 28, 28, 3), "float32") = R.resize2d(x, size=[28, 28], layout="NHWC") + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", (2, 14, 14, 3), relax.DynTensorType(4, "float32")) + with bb.function("main", [x]): + gv = bb.emit(relax.op.image.resize2d(x, (28, 28), layout="NHWC")) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +if __name__ == "__main__": + test_resize2d() diff --git a/tests/python/relax/test_relax_tensor_ops.py b/tests/python/relax/test_relax_tensor_ops.py index c74df5a397c8..cef709600497 100644 --- a/tests/python/relax/test_relax_tensor_ops.py +++ b/tests/python/relax/test_relax_tensor_ops.py @@ -549,5 +549,21 @@ def test_matmul_fail_on_not_broadcastable(): bb.emit_func_output(gv) +def test_adaptive_avg_pool2d(): + @R.function + def expected(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((2, 64, 7, 7), "float32") = R.adaptive_avg_pool2d(x, output_size=[7, 7]) + return gv + + x = relax.Var("x", [2, 64, 8, 9], relax.DynTensorType(ndim=4, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.nn.adaptive_avg_pool2d(x, output_size=(7, 7))) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relax/test_relax_transform_ops.py b/tests/python/relax/test_relax_transform_ops.py index a077c622fd3e..b8f576aab690 100644 --- a/tests/python/relax/test_relax_transform_ops.py +++ b/tests/python/relax/test_relax_transform_ops.py @@ -288,6 +288,208 @@ def expected( tvm.ir.assert_structural_equal(bb.get()["main"], expected) +def test_cumsum(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3): + gv: R.Tensor((2, 3, 4), "float32") = R.cumsum(x, axis=-2) + return gv + + x = relax.Var("x", [2, 3, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.cumsum(x, axis=-2)) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_cumsum_without_specified_axis(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=1): + gv: R.Tensor((24,), "float32") = R.cumsum(x) + return gv + + x = relax.Var("x", [2, 3, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.cumsum(x)) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_trilu(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3): + gv: R.Tensor((2, 3, 4), "float32") = R.trilu(x, k=0, is_upper=False) + return gv + + x = relax.Var("x", [2, 3, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.trilu(x, k=0, is_upper=False)) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_cast(): + @R.function + def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "int32", ndim=3): + gv: R.Tensor((2, 3, 4), "int32") = R.cast(x, "int32") + return gv + + x = relax.Var("x", [2, 3, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.cast(x, "int32")) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_take(): + @R.function + def expected( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((1,), "int32") + ) -> R.Tensor(None, "float32", ndim=1): + gv: R.Tensor((1,), "float32") = R.take(x, indices) + return gv + + x = relax.Var("x", [2, 3, 4], relax.DynTensorType(ndim=3, dtype="float32")) + indices = relax.Var("indices", [1], relax.DynTensorType(ndim=1, dtype="int32")) + bb = relax.BlockBuilder() + with bb.function("main", [x, indices]): + gv = bb.emit(relax.op.transform.take(x, indices)) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_take_high_dim_indices_with_axis(): + @R.function + def expected( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3, 4, 2), "int32") + ) -> R.Tensor(None, "float32", ndim=5): + gv: R.Tensor((2, 3, 4, 2, 4), "float32") = R.take(x, indices, axis=1) + return gv + + x = relax.Var("x", [2, 3, 4], relax.DynTensorType(ndim=3, dtype="float32")) + indices = relax.Var("indices", [3, 4, 2], relax.DynTensorType(ndim=3, dtype="int32")) + bb = relax.BlockBuilder() + with bb.function("main", [x, indices]): + gv = bb.emit(relax.op.transform.take(x, indices, axis=1)) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_full(): + @R.function + def expected(v: R.Tensor((), "int32")) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((2, 3), "float32") = R.full(v, (2, 3), dtype="float32") + return gv + + bb = relax.BlockBuilder() + v = relax.Var("v", (), relax.DynTensorType(0, "int32")) + with bb.function("main", [v]): + gv = bb.emit(relax.op.transform.full(v, (2, 3), "float32")) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_split_by_indices(): + @R.function + def expected(x: R.Tensor((2, 10, 4), "float32")): + gv = R.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1) + return gv + + x = relax.Var("x", [2, 10, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit( + relax.op.transform.split(x, indices_or_sections=[-2, 2, 6, 4, 8, 12, 9], axis=1) + ) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_split_by_n_section(): + @R.function + def expected(x: R.Tensor((2, 10, 4), "float32")): + gv = R.split(x, indices_or_sections=5, axis=1) + return gv + + x = relax.Var("x", [2, 10, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.split(x, indices_or_sections=5, axis=1)) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_split_by_n_section_not_divisible(): + x = relax.Var("x", [2, 10, 4], relax.DynTensorType(ndim=3, dtype="float32")) + bb = relax.BlockBuilder() + with pytest.raises(DiagnosticError): + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.split(x, indices_or_sections=3, axis=1)) + bb.emit_func_output(gv) + + +def test_broadcast_to(): + @R.function + def expected(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((4, 2, 5, 3), "float32") = R.broadcast_to(x, (4, 2, 5, 3)) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", (2, 1, 3), relax.DynTensorType(3, "float32")) + with bb.function("main", [x]): + gv = bb.emit(relax.op.transform.broadcast_to(x, (4, 2, 5, 3))) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + +def test_strided_slice(): + @R.function + def expected(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor(None, "float32", ndim=4): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice( + x, + begin=[1, 0, 8], + end=[8, 9, 0], + strides=[2, 1, -3], + axes=[0, 1, -1], + slice_mode="end", + ) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", (8, 9, 10, 10), relax.DynTensorType(4, "float32")) + with bb.function("main", [x]): + gv = bb.emit( + relax.op.transform.strided_slice(x, [1, 0, 8], [8, 9, 0], [2, 1, -3], [0, 1, -1]) + ) + bb.emit_func_output(gv) + + expected = expected.with_attr("global_symbol", "main") + tvm.ir.assert_structural_equal(bb.get()["main"], expected) + + if __name__ == "__main__": test_transpose() test_transpose_none_arg() @@ -304,3 +506,15 @@ def expected( test_concatenate() test_concatenate_fail_on_incompatible_shape() test_concatenate_without_specified_axis() + test_cumsum() + test_cumsum_without_specified_axis() + test_trilu() + test_cast() + test_take() + test_take_high_dim_indices_with_axis() + test_full() + test_split_by_indices() + test_split_by_n_section() + test_split_by_n_section_not_divisible() + test_broadcast_to() + test_strided_slice()