From 5e4fe463434717ddfe7d2684e3c19dd38b76de56 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 11:12:10 +0300 Subject: [PATCH 01/19] update paddlepaddle --- python/tvm/relay/frontend/paddlepaddle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 0842cd55dae2..75fecf217851 100755 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1743,7 +1743,7 @@ def convert_scatter(g, op, block): if overwrite: out = _op.scatter(x, index, updates, axis=0) else: - out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0) + out = _op.scatter_elements(_op.zeros_like(x), index, updates, axis=0, reduction="add") out += _op.scatter(x, index, _op.zeros_like(updates), axis=0) g.add_node(op.output("Out")[0], out) From 16b2c5c4c40eabfde681eede63c46618a9796751 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 11:21:00 +0300 Subject: [PATCH 02/19] update tensorflow --- python/tvm/relay/frontend/tensorflow_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index ab773f9a2a8a..e8da60a1afb5 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -2918,7 +2918,7 @@ def _impl(inputs, attr, params, mod): counts_shape = _op.reshape(size, [1]) counts = _op.zeros(counts_shape, out_dtype) - out = _op.scatter_add(counts, input, updates, axis=0) + out = _op.scatter_elements(counts, input, updates, axis=0, reduction="add") return out return _impl @@ -2965,11 +2965,11 @@ def _impl(inputs, attr, params, mod): size_arr = _op.reshape(size, [1]) counts_shape = _op.concatenate([batch_arr, size_arr], axis=0) counts = _op.zeros(counts_shape, out_dtype) - out = _op.scatter_add(counts, input, updates, axis=1) + out = _op.scatter_elements(counts, input, updates, axis=1, reduction="add") else: counts_shape = _op.reshape(size, [1]) counts = _op.zeros(counts_shape, out_dtype) - out = _op.scatter_add(counts, input, updates, axis=0) + out = _op.scatter_elements(counts, input, updates, axis=0, reduction="add") if attr["binary_output"]: out = _op.cast(_op.cast(out, "bool"), out_dtype) From 52236a057ca4c1ce19d5231d68bc81c385c4f366 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 11:22:37 +0300 Subject: [PATCH 03/19] update pytorch --- python/tvm/relay/frontend/pytorch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2ea764872c27..61fe4769fc43 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2671,18 +2671,19 @@ def bincount(self, inputs, input_types): updates = _op.ones_like(data) counts = _op.zeros(_op.reshape(dim, [1]), out_dtype) - out = _op.scatter_add(counts, data, updates, axis=0) + out = _op.scatter_elements(counts, data, updates, axis=0, reduction="add") if input_type == "int32": # Torch always outputs int64 results for bincount return _op.cast(out, "int64") return out def scatter_add(self, inputs, input_types): + # TODO(vvchernov): need some check? data = inputs[0] axis = inputs[1] index = inputs[2] src = inputs[3] - return _op.scatter_add(data, index, src, axis=axis) + return _op.scatter_elements(data, index, src, axis=axis, reduction="add") def scatter_reduce(self, inputs, input_types): assert len(inputs) == 5 or len(inputs) == 6, ( From 43e5373b04873343857a5450c9b4317820701970 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 11:24:33 +0300 Subject: [PATCH 04/19] remove scatter_add strategy --- python/tvm/relay/op/_transform.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index e40179ed2d03..f28c28ce62a6 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -195,15 +195,6 @@ def stft_shape_func(attrs, inputs, _): _reg.register_strategy("trilu", strategy.trilu_strategy) -# scatter_add -@_reg.register_compute("scatter_add") -def compute_scatter_add(attrs, inputs, output_type): - """Compute definition of scatter_add""" - return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)] - - -_reg.register_strategy("scatter_add", strategy.scatter_add_strategy) - # scatter_elements @_reg.register_compute("scatter_elements") def compute_scatter_elements(attrs, inputs, output_type): @@ -687,7 +678,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims): _reg.register_shape_func("scatter", False, elemwise_shape_func) -_reg.register_shape_func("scatter_add", False, elemwise_shape_func) _reg.register_shape_func("scatter_elements", False, elemwise_shape_func) _reg.register_shape_func("scatter_nd", False, elemwise_shape_func) From 233a49875a2ace4d4da15fab36ec0ff5ebc483e3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 11:25:34 +0300 Subject: [PATCH 05/19] remove ScatterAddAttrs registration --- python/tvm/relay/op/op_attrs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index ea7c415b511f..0214ae8a46b6 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -634,11 +634,6 @@ class DynExpandDimsAttrs(Attrs): """Attributes used in dynamic expand_dims operators""" -@tvm._ffi.register_object("relay.attrs.ScatterAddAttrs") -class ScatterAddAttrs(Attrs): - """Attributes used in scatter_add operators""" - - @tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs") class ScatterElementsAttrs(Attrs): """Attributes used in scatter_elements operators""" From b92bf03bdb0711f7458fc07399e18e0f797a0135 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 11:28:47 +0300 Subject: [PATCH 06/19] remove scatter_add from front-end --- python/tvm/relay/op/transform.py | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 3df13da04426..833d14eb5897 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -378,31 +378,6 @@ def scatter(data, indices, updates, axis): return _make.scatter(data, indices, updates, axis) -def scatter_add(data, indices, updates, axis): - """Update data by adding values in updates at positions defined by indices. - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - indices : relay.Expr - The index locations to update. - - updates : relay.Expr - The values to add. - - axis : int - The axis to scatter_add on. - - Returns - ------- - ret : relay.Expr - The computed result. - """ - return _make.scatter_add(data, indices, updates, axis) - - def scatter_elements(data, indices, updates, axis=0, reduction="update"): """Scatter elements with updating data by reduction of values in updates at positions defined by indices. @@ -1717,7 +1692,7 @@ def segment_sum(data, segment_ids, num_segments=None): expanded_segment_ids = tile(segment_ids, segment_ids_tiled_shape) scatter_add_segment_ids = transpose(expanded_segment_ids) src = cast_like(_dyn_make.zeros(new_shape, "float64"), data) - return scatter_add(src, scatter_add_segment_ids, data, axis=0) + return scatter_elements(src, scatter_add_segment_ids, data, axis=0, reduction="add") def cumsum(data, axis=None, dtype=None, exclusive=None): From b4e44bb8f21f53e60976912b1dd6f91e87c3d7fc Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 12:04:35 +0300 Subject: [PATCH 07/19] remove scatter_add schedule and strategy --- python/tvm/relay/op/strategy/generic.py | 11 ----------- python/tvm/topi/generic/search.py | 16 ---------------- 2 files changed, 27 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4e0448f1799b..4641fb18f7ba 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1569,17 +1569,6 @@ def _compute_scatter(attrs, inputs, _): return _compute_scatter -@override_native_generic_func("scatter_add_strategy") -def scatter_add_strategy(attrs, outs, out_type, target): - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_scatter(topi.scatter_add), - wrap_topi_schedule(topi.generic.schedule_scatter), - name="scatter_add.generic", - ) - return strategy - - # scatter_elements @override_native_generic_func("scatter_elements_strategy") def scatter_elements_strategy(attrs, inputs, out_type, target): diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index f458ee7bc782..826194e75c2a 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -52,22 +52,6 @@ def schedule_scatter(outs): return _default_schedule(outs, False) -def schedule_scatter_add(outs): - """Schedule for scatter_add operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of scatter_add. - - Returns - ------- - s: Schedule - The computation schedule for the op. - """ - return _default_schedule(outs, False) - - def schedule_sparse_fill_empty_rows(outs): return _default_schedule(outs, False) From 96c7d1896267ab97164efd2ebbc7211c53053e79 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 17:19:37 +0300 Subject: [PATCH 08/19] remove back-end API of scatter_add --- src/relay/op/tensor/transform.cc | 48 -------------------------------- 1 file changed, 48 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 01e5a7f5f359..907141c9cb6a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1143,54 +1143,6 @@ RELAY_REGISTER_OP("scatter") .set_attr("TOpPattern", kOpaque) .set_support_level(10); -// Scatter_add -TVM_REGISTER_NODE_TYPE(ScatterAddAttrs); - -// Scatter Add -bool ScatterAddRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(num_inputs, 3); - ICHECK_EQ(types.size(), 4); - auto data = types[0].as(); - if (data == nullptr) { - return false; - } - auto indices = types[1].as(); - if (indices == nullptr) { - return false; - } - auto updates = types[2].as(); - if (updates == nullptr) { - return false; - } - ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) - << "indices of scatter_add must be tensor of integer"; - const auto param = attrs.as(); - ICHECK(param != nullptr); - reporter->Assign(types[3], TensorType(data->shape, data->dtype)); - return true; -} - -TVM_REGISTER_GLOBAL("relay.op._make.scatter_add") - .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { - auto attrs = make_object(); - attrs->axis = std::move(axis); - static const Op& op = Op::Get("scatter_add"); - return Call(op, {data, indices, updates}, Attrs(attrs), {}); - }); - -RELAY_REGISTER_OP("scatter_add") - .describe( - R"doc(Update data by adding values in updates at positions defined by indices)doc" TVM_ADD_FILELINE) - .set_num_inputs(3) - .add_argument("data", "Tensor", "The input data tensor.") - .add_argument("indices", "Tensor", "The indices location tensor.") - .add_argument("updates", "Tensor", "The values to update the input with.") - .add_type_rel("ScatterAdd", ScatterAddRel) - .set_attr("TOpIsStateful", false) - .set_attr("TOpPattern", kOpaque) - .set_support_level(10); - // scatter_elements operator TVM_REGISTER_NODE_TYPE(ScatterElementsAttrs); From 9cfb2e90b694a4d03116370905d929f4341fa071 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 13 Feb 2023 17:25:07 +0300 Subject: [PATCH 09/19] update test_op_level3 --- tests/python/relay/test_op_level3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 225210f4d617..f18e935b57c8 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1112,7 +1112,7 @@ def test_scatter_add(self, target, dev, ref_data, dshape, ishape, axis, dtype, i "i", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype=indice_dtype) ) u = relay.var("u", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype=dtype)) - z = relay.op.scatter_add(d, i, u, axis) + z = relay.op.scatter_elements(d, i, u, axis, "add") func = relay.Function([d, i, u], z) From 7e208afe059d9bf3a8cd94f42b8b8f96cc13243f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 15 Feb 2023 13:41:29 +0300 Subject: [PATCH 10/19] add checks for scatter_add to pytorch front-end --- python/tvm/relay/frontend/pytorch.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 61fe4769fc43..3c67116d6c03 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2678,11 +2678,34 @@ def bincount(self, inputs, input_types): return out def scatter_add(self, inputs, input_types): - # TODO(vvchernov): need some check? + assert ( + len(inputs) == 4 + ), "scatter_add takes 4 inputs (data, dim, index, src), but {} given".format( + len(inputs) + ) data = inputs[0] axis = inputs[1] index = inputs[2] src = inputs[3] + + data_shape = self.infer_shape(inputs[0]) + data_rank = len(data_shape) + index_shape = self.infer_shape(inputs[2]) + index_rank = len(index_shape) + src_shape = self.infer_shape(inputs[3]) + src_rank = len(src_shape) + assert data_rank == index_rank, "Index rank is not the same as data rank" + assert data_rank == src_rank, "Src rank is not the same as data rank" + + assert 0 <= axis < data_rank, "Dim is out of bounds" + + for i in range(data_rank): + assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one" + if i != axis: + assert ( + index_shape[i] <= data_shape[i] + ), "Index dim size should be less than data one" + return _op.scatter_elements(data, index, src, axis=axis, reduction="add") def scatter_reduce(self, inputs, input_types): From 5a45e563f28754711aa70c03aabf3d0e6eec5ba7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 15 Feb 2023 14:06:17 +0300 Subject: [PATCH 11/19] remove cpu topi implementation of scatter_add --- python/tvm/topi/__init__.py | 1 - python/tvm/topi/scatter_add.py | 198 --------------------------------- 2 files changed, 199 deletions(-) delete mode 100644 python/tvm/topi/scatter_add.py diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 75867136e09e..d8f839749fa4 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -41,7 +41,6 @@ from .scatter_elements import * from .sparse_fill_empty_rows import * from .sparse_reshape import * -from .scatter_add import * from .argwhere import * from .scan import * from .einsum import * diff --git a/python/tvm/topi/scatter_add.py b/python/tvm/topi/scatter_add.py deleted file mode 100644 index 6b04837b7766..000000000000 --- a/python/tvm/topi/scatter_add.py +++ /dev/null @@ -1,198 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks -"""Scatter Add operator""" -from tvm.te import hybrid - - -@hybrid.script -def _scatter_add_1d(data, indices, updates): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - out[i] = data[i] - for i in range(indices.shape[0]): - out[indices[i] if indices[i] >= 0 else indices[i] + data.shape[0]] += updates[i] - return out - - -@hybrid.script -def _scatter_add_2d(data, indices, updates, axis): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - for j in range(data.shape[1]): - out[i, j] = data[i, j] - if axis == 0: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - out[ - indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis], j - ] += updates[i, j] - else: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - out[ - i, indices[i, j] if indices[i, j] >= 0 else indices[i, j] + data.shape[axis] - ] += updates[i, j] - - return out - - -@hybrid.script -def _scatter_add_3d(data, indices, updates, axis): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - for j in range(data.shape[1]): - for k in range(data.shape[2]): - out[i, j, k] = data[i, j, k] - if axis == 0: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - out[ - indices[i, j, k] - if indices[i, j, k] >= 0 - else indices[i, j, k] + data.shape[axis], - j, - k, - ] += updates[i, j, k] - elif axis == 1: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - out[ - i, - indices[i, j, k] - if indices[i, j, k] >= 0 - else indices[i, j, k] + data.shape[axis], - k, - ] += updates[i, j, k] - else: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - out[ - i, - j, - indices[i, j, k] - if indices[i, j, k] >= 0 - else indices[i, j, k] + data.shape[axis], - ] += updates[i, j, k] - - return out - - -@hybrid.script -def _scatter_add_4d(data, indices, updates, axis): - out = output_tensor(data.shape, data.dtype) - for i in range(data.shape[0]): - for j in range(data.shape[1]): - for k in range(data.shape[2]): - for l in range(data.shape[3]): - out[i, j, k, l] = data[i, j, k, l] - - if axis == 0: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - j, - k, - l, - ] += updates[i, j, k, l] - elif axis == 1: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - i, - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - k, - l, - ] += updates[i, j, k, l] - elif axis == 2: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - i, - j, - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - l, - ] += updates[i, j, k, l] - else: - for i in range(indices.shape[0]): - for j in range(indices.shape[1]): - for k in range(indices.shape[2]): - for l in range(indices.shape[3]): - out[ - i, - j, - k, - indices[i, j, k, l] - if indices[i, j, k, l] >= 0 - else indices[i, j, k, l] + data.shape[axis], - ] += updates[i, j, k, l] - - return out - - -def scatter_add(data, indices, updates, axis=0): - """Update data by adding values in updates at positions defined by indices - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - indices : relay.Expr - The index locations to update. - - updates : relay.Expr - The values to update. - - axis : int - The axis to scatter_add on - - Returns - ------- - ret : relay.Expr - The computed result. - """ - if axis < 0: - axis += len(data.shape) - assert axis >= 0 - assert axis < len(data.shape) - - if len(data.shape) == 1: - return _scatter_add_1d(data, indices, updates) - if len(data.shape) == 2: - return _scatter_add_2d(data, indices, updates, axis) - if len(data.shape) == 3: - return _scatter_add_3d(data, indices, updates, axis) - if len(data.shape) == 4: - return _scatter_add_4d(data, indices, updates, axis) - raise ValueError("scatter_add only support for 1-4 dimensions") From 861bab08967e66630aa2dcc36f30e440b4670ae5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 15 Feb 2023 14:07:45 +0300 Subject: [PATCH 12/19] remove scatter_add strategy for cuda-gpu --- python/tvm/relay/op/strategy/cuda.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index fc1691fe9ef0..e0229a615d50 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -1086,19 +1086,6 @@ def scatter_cuda(attrs, inputs, out_type, target): return strategy -@scatter_add_strategy.register(["cuda", "gpu"]) -def scatter_add_cuda(attrs, inputs, out_type, target): - """scatter_add cuda strategy""" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_scatter(topi.cuda.scatter_add), - wrap_topi_schedule(topi.generic.schedule_extern), - name="scatter_add.cuda", - plevel=10, - ) - return strategy - - @scatter_elements_strategy.register(["cuda", "gpu"]) def scatter_elements_cuda(attrs, inputs, out_type, target): """scatter elements cuda strategy""" From 55fca2ee29920c4b90b11ca26d9e6e5ca54d0db5 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 15 Feb 2023 16:18:45 +0300 Subject: [PATCH 13/19] transfer reduction_func definition on higher level --- python/tvm/topi/cuda/scatter_elements.py | 194 +++++++++++++---------- 1 file changed, 106 insertions(+), 88 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 8ed3e2972081..fca2ec51ee60 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -22,6 +22,81 @@ from ..math import cast +def gen_ir(data, indices, updates, out, axis, reduce_func): + ib = tir.ir_builder.create() + + data_ptr = ib.buffer_ptr(data) + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + out_ptr = ib.buffer_ptr(out) + + # Prepare ranges and strides + shape = data.shape + if axis < 0: + axis = len(shape) + axis + axis_range = cast(shape[axis], indices.dtype) + + before_axis_range = 1 + after_axis_range = 1 + for i, value in enumerate(shape, 0): + if i < axis: + before_axis_range *= value + elif i > axis: + after_axis_range *= value + before_axis_stride = axis_range * after_axis_range + full_range = before_axis_range * before_axis_stride + + ind_shape = indices.shape + ind_axis_range = ind_shape[axis] + + ind_before_axis_range = 1 + ind_after_axis_range = 1 + for i, value in enumerate(ind_shape, 0): + if i < axis: + ind_before_axis_range *= value + elif i > axis: + ind_after_axis_range *= value + ind_before_axis_stride = ind_axis_range * ind_after_axis_range + ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + # Copy initial input data to output + with ib.new_scope(): + num_blocks = ceil_div(full_range, max_threads) + bx = te.thread_axis("blockIdx.x") + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(bx, "thread_extent", num_blocks) + ib.scope_attr(tx, "thread_extent", max_threads) + + index = bx * max_threads + tx + with ib.if_scope(index < full_range): + out_ptr[index] = data_ptr[index] + + # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) + with ib.new_scope(): + num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) + bx2 = te.thread_axis("blockIdx.x") + tx2 = te.thread_axis("threadIdx.x") + ib.scope_attr(bx2, "thread_extent", num_blocks_2) + ib.scope_attr(tx2, "thread_extent", max_threads) + + ind_fused = bx2 * max_threads + tx2 + with ib.if_scope(ind_fused < ind_full_range_excl_axis): + i = ind_fused // ind_after_axis_range + j = ind_fused % ind_after_axis_range + with ib.for_range(0, ind_axis_range, "k") as k: + # Offset along indices or updates + index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j + # Get index and shift to positive side if need + new_index = indices_ptr[index1] + shifted_index = new_index + (new_index < 0) * axis_range + # Offset along data + index2 = i * before_axis_stride + shifted_index * after_axis_range + j + reduce_func(out_ptr, index2, updates_ptr[index1]) + + return ib.get() + + def scatter_elements(data, indices, updates, axis=0, reduction="update"): """Scatter elements from updates to corresponding indices of copied data. @@ -67,99 +142,42 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): if not isinstance(axis, int): axis = get_const_int(axis) - def gen_ir(data, indices, updates, out, axis): - ib = tir.ir_builder.create() - - data_ptr = ib.buffer_ptr(data) - indices_ptr = ib.buffer_ptr(indices) - updates_ptr = ib.buffer_ptr(updates) - out_ptr = ib.buffer_ptr(out) - - # Prepare ranges and strides - shape = data.shape - if axis < 0: - axis = len(shape) + axis - axis_range = cast(shape[axis], indices.dtype) - - before_axis_range = 1 - after_axis_range = 1 - for i, value in enumerate(shape, 0): - if i < axis: - before_axis_range *= value - elif i > axis: - after_axis_range *= value - before_axis_stride = axis_range * after_axis_range - full_range = before_axis_range * before_axis_stride - - ind_shape = indices.shape - ind_axis_range = ind_shape[axis] - - ind_before_axis_range = 1 - ind_after_axis_range = 1 - for i, value in enumerate(ind_shape, 0): - if i < axis: - ind_before_axis_range *= value - elif i > axis: - ind_after_axis_range *= value - ind_before_axis_stride = ind_axis_range * ind_after_axis_range - ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - # Copy initial input data to output - with ib.new_scope(): - num_blocks = ceil_div(full_range, max_threads) - bx = te.thread_axis("blockIdx.x") - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(bx, "thread_extent", num_blocks) - ib.scope_attr(tx, "thread_extent", max_threads) - - index = bx * max_threads + tx - with ib.if_scope(index < full_range): - out_ptr[index] = data_ptr[index] - - # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) - with ib.new_scope(): - num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) - bx2 = te.thread_axis("blockIdx.x") - tx2 = te.thread_axis("threadIdx.x") - ib.scope_attr(bx2, "thread_extent", num_blocks_2) - ib.scope_attr(tx2, "thread_extent", max_threads) - - ind_fused = bx2 * max_threads + tx2 - with ib.if_scope(ind_fused < ind_full_range_excl_axis): - i = ind_fused // ind_after_axis_range - j = ind_fused % ind_after_axis_range - with ib.for_range(0, ind_axis_range, "k") as k: - # Offset along indices or updates - index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j - # Get index and shift to positive side if need - new_index = indices_ptr[index1] - shifted_index = new_index + (new_index < 0) * axis_range - # Offset along data - index2 = i * before_axis_stride + shifted_index * after_axis_range + j - if reduction == "update": - out_ptr[index2] = updates_ptr[index1] - elif reduction == "add": - out_ptr[index2] += updates_ptr[index1] - elif reduction == "mul": - out_ptr[index2] *= updates_ptr[index1] - elif reduction == "min": - out_ptr[index2] = tir.min(out_ptr[index2], updates_ptr[index1]) - elif reduction == "max": - out_ptr[index2] = tir.max(out_ptr[index2], updates_ptr[index1]) - else: - raise NotImplementedError( - "scatter_elements reduction not in [update, add, mul, min, max]:", - reduction, - ) - - return ib.get() + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + def add_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + def mul_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] *= update + + def min_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) + + def max_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tir.max(dst_ptr[dst_index], update) + + reduce_func = None + if reduction == "update": + reduce_func = update_func + elif reduction == "add": + reduce_func = add_func + elif reduction == "mul": + reduce_func = mul_func + elif reduction == "min": + reduce_func = min_func + elif reduction == "max": + reduce_func = max_func + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", reduction + ) out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis), + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis, reduce_func), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements_cuda", From 2f7ace38535c4e1382c27de8a62dbf1b3738aaad Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 15 Feb 2023 16:37:55 +0300 Subject: [PATCH 14/19] use atomic_add in cude scatter_elements. remove cuda implementation for scatter_add --- python/tvm/topi/cuda/scatter.py | 132 ----------------------- python/tvm/topi/cuda/scatter_elements.py | 92 +++++++++++++++- 2 files changed, 90 insertions(+), 134 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 1bdd53156623..c88c3086f317 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -591,138 +591,6 @@ def schedule_scatter_via_sort(_, outs): return schedule_extern(outs) -def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): - """Generate scatter add ir for 1d inputs, using atomic_add instruction - - Parameters - ---------- - data : tir.Tensor - The input data to the operator. - - indices : tir.Tensor - The index locations to update. - - updates : tir.Tensor - The values to update. - - axis : int - The axis to scatter on - - out : tir.Tensor - The output tensor. - - Returns - ------- - ret : tir - The computational ir. - """ - assert axis == 0 - n = data.shape[0] - - ib = tvm.tir.ir_builder.create() - - out_ptr = ib.buffer_ptr(out) - data_ptr = ib.buffer_ptr(data) - - max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - nthread_tx = max_threads - - with ib.new_scope(): - nthread_bx = ceil_div(n, nthread_tx) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * nthread_tx + tx - with ib.if_scope(tid < n): - out_ptr[tid] = data_ptr[tid] - - indices_ptr = ib.buffer_ptr(indices) - updates_ptr = ib.buffer_ptr(updates) - - ni = indices.shape[0] - - atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") - - with ib.new_scope(): - nthread_bx = ceil_div(ni, nthread_tx) - tx = te.thread_axis("threadIdx.x") - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) - tid = bx * nthread_tx + tx - - with ib.if_scope(tid < ni): - index = indices_ptr[tid] - with ib.if_scope(index < 0): - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), - updates_ptr[tid], - ) - with ib.else_scope(): - atomic_add_return[0] = atomic_add( - tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), - updates_ptr[tid], - ) - - return ib.get() - - -def scatter_add(data, indices, updates, axis=0): - """Update data by adding values in updates at positions defined by indices - - Parameters - ---------- - data : relay.Expr - The input data to the operator. - - indices : relay.Expr - The index locations to update. - - updates : relay.Expr - The values to be added. - - axis : int - The axis to scatter on - - Returns - ------- - ret : relay.Expr - The computed result. - """ - if axis < 0: - axis += len(data.shape) - assert axis >= 0 - assert axis < len(data.shape) - - rank = len(data.shape) - assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions" - - ir_funcs = { - 1: gen_scatter_add_1d_atomic, - 2: gen_ir_2d, - 3: gen_ir_3d, - 4: gen_ir_4d, - } - - def update_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] += update - - out_shape = data.shape - out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") - out = te.extern( - [out_shape], - [data, indices, updates], - lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), - dtype=data.dtype, - out_buffers=[out_buf], - name="scatter_add_gpu", - tag="scatter_add_gpu", - ) - - return out - - def scatter_nd(data, indices, updates, mode): """Scatter elements from a n-dimension array. diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index fca2ec51ee60..f953b7c9c611 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -20,6 +20,84 @@ from tvm import te, tir from ..utils import ceil_div, get_const_int from ..math import cast +from .nms import atomic_add + + +def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): + """Generate scatter add ir for 1d inputs, using atomic_add instruction + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + assert axis == 0 + n = data.shape[0] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + + with ib.new_scope(): + nthread_bx = ceil_div(n, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + with ib.if_scope(tid < n): + out_ptr[tid] = data_ptr[tid] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + + ni = indices.shape[0] + + atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local") + + with ib.new_scope(): + nthread_bx = ceil_div(ni, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.if_scope(tid < ni): + index = indices_ptr[tid] + with ib.if_scope(index < 0): + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]), + updates_ptr[tid], + ) + with ib.else_scope(): + atomic_add_return[0] = atomic_add( + tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]), + updates_ptr[tid], + ) + + return ib.get() def gen_ir(data, indices, updates, out, axis, reduce_func): @@ -72,7 +150,6 @@ def gen_ir(data, indices, updates, out, axis, reduce_func): with ib.if_scope(index < full_range): out_ptr[index] = data_ptr[index] - # TODO (vvchernov): use atomic function for special conditions (see cuda.scatter_nd) with ib.new_scope(): num_blocks_2 = ceil_div(ind_full_range_excl_axis, max_threads) bx2 = te.thread_axis("blockIdx.x") @@ -173,11 +250,22 @@ def max_func(dst_ptr, dst_index, update): "scatter_elements reduction not in [update, add, mul, min, max]:", reduction ) + cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind) + gen_scatter_elements_ir = None + if ( + reduction == "add" + and cur_target_kind not in ["vulkan", "metal"] + and updates.dtype in ["int32", "float32"] + ): + gen_scatter_elements_ir = gen_scatter_add_1d_atomic + else: + gen_scatter_elements_ir = gen_ir + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], axis, reduce_func), + lambda ins, outs: gen_scatter_elements_ir(ins[0], ins[1], ins[2], outs[0], axis, reduce_func), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements_cuda", From 05fc3565f5420bfae544b658f9b67696009143c3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 12:36:43 +0300 Subject: [PATCH 15/19] fix lint --- python/tvm/relay/frontend/pytorch.py | 4 +-- python/tvm/topi/cuda/scatter_elements.py | 42 ++++++++++++++++++++---- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3c67116d6c03..098496908d12 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2680,9 +2680,7 @@ def bincount(self, inputs, input_types): def scatter_add(self, inputs, input_types): assert ( len(inputs) == 4 - ), "scatter_add takes 4 inputs (data, dim, index, src), but {} given".format( - len(inputs) - ) + ), "scatter_add takes 4 inputs (data, dim, index, src), but {} given".format(len(inputs)) data = inputs[0] axis = inputs[1] index = inputs[2] diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index f953b7c9c611..97d18754f241 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -23,8 +23,9 @@ from .nms import atomic_add -def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): - """Generate scatter add ir for 1d inputs, using atomic_add instruction +def gen_scatter_add_1d_atomic(data, indices, updates, out, axis, _): + """Generate ir for scatter elements for reduction sum for 1d inputs, + using atomic_add instruction Parameters ---------- @@ -37,12 +38,12 @@ def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): updates : tir.Tensor The values to update. - axis : int - The axis to scatter on - out : tir.Tensor The output tensor. + axis : int + The axis to scatter on + Returns ------- ret : tir @@ -101,6 +102,33 @@ def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _): def gen_ir(data, indices, updates, out, axis, reduce_func): + """Generate ir for scatter elements + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + out : tir.Tensor + The output tensor. + + axis : int + The axis to scatter on + + reduce_func : Any + The function reduced update and output to output + + Returns + ------- + ret : tir + The computational ir. + """ ib = tir.ir_builder.create() data_ptr = ib.buffer_ptr(data) @@ -265,7 +293,9 @@ def max_func(dst_ptr, dst_index, update): return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_scatter_elements_ir(ins[0], ins[1], ins[2], outs[0], axis, reduce_func), + lambda ins, outs: gen_scatter_elements_ir( + ins[0], ins[1], ins[2], outs[0], axis, reduce_func + ), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements_cuda", From fa5cd428a7f1bece76a1f409a2f148242ff6c89f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 16 Feb 2023 19:59:26 +0300 Subject: [PATCH 16/19] remove ScatterAddAttrs --- include/tvm/relay/attrs/transform.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b5333961ebf9..7680883248e0 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -156,14 +156,6 @@ struct ScatterAttrs : public tvm::AttrsNode { } }; -struct ScatterAddAttrs : public tvm::AttrsNode { - Integer axis; - - TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") { - TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); - } -}; - struct ScatterElementsAttrs : public tvm::AttrsNode { Integer axis; String reduction; From 3879a578a26cb3184a7678561630e7f0ef11b2b1 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 17 Feb 2023 22:36:40 +0300 Subject: [PATCH 17/19] fix condition for using of atomic_add --- python/tvm/topi/cuda/scatter_elements.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 97d18754f241..1e735f5618e2 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -278,10 +278,13 @@ def max_func(dst_ptr, dst_index, update): "scatter_elements reduction not in [update, add, mul, min, max]:", reduction ) + shape = data.shape + rank = len(shape) cur_target_kind = str(tvm.target.Target.current(allow_none=False).kind) gen_scatter_elements_ir = None if ( reduction == "add" + and rank == 1 and cur_target_kind not in ["vulkan", "metal"] and updates.dtype in ["int32", "float32"] ): @@ -289,9 +292,9 @@ def max_func(dst_ptr, dst_index, update): else: gen_scatter_elements_ir = gen_ir - out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + out_buf = tir.decl_buffer(shape, data.dtype, "out_buf") return te.extern( - [data.shape], + [shape], [data, indices, updates], lambda ins, outs: gen_scatter_elements_ir( ins[0], ins[1], ins[2], outs[0], axis, reduce_func From 0ad0db8053309771164540a4ea89fe7ae2601a6e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 18 Feb 2023 21:43:28 +0300 Subject: [PATCH 18/19] last upstream of scatter_add with pytorch description. add test case --- python/tvm/relay/frontend/pytorch.py | 15 +++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 6 ++++++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 098496908d12..635cb960a829 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -210,6 +210,18 @@ def infer_shape(self, inputs, mod=None): def infer_shape_with_prelude(self, inputs): return self.infer_shape(inputs, mod=self.prelude.mod) + def is_empty_shape(self, shape): + rank = len(shape) + if rank: + is_empty = False + for i in range(rank): + if shape[i] == 0: + is_empty = True + break + return is_empty + else: + return True + def record_output_type(self, output): if isinstance(output, tuple): cleaned_output = [o for o in output if o is not None] @@ -2690,6 +2702,9 @@ def scatter_add(self, inputs, input_types): data_rank = len(data_shape) index_shape = self.infer_shape(inputs[2]) index_rank = len(index_shape) + # When index is empty, the operation returns data unchanged + if self.is_empty_shape(index_shape): + return data src_shape = self.infer_shape(inputs[3]) src_rank = len(src_shape) assert data_rank == index_rank, "Index rank is not the same as data rank" diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 39d78bd6065c..21defa6a59b2 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4232,6 +4232,12 @@ def test_fn_scatter_add(dim): verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets) verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets) + # Check empty indices for scatter_add + in_data = torch.zeros(2, 4) + in_index = torch.empty((0,)) + in_src = torch.rand(2, 1) + verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) + def test_forward_scatter_reduce(): """test_forward_scatter_reduce""" From 4877b72b1bc5b9b0dd95cc257960f9d6c8d73a49 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 21 Feb 2023 21:41:55 +0300 Subject: [PATCH 19/19] upstream scatter_elements for CPU with CUDA approach --- python/tvm/topi/cuda/scatter_elements.py | 6 ++- python/tvm/topi/scatter_elements.py | 63 +++++++++++++++--------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 1e735f5618e2..25f15a0e73a6 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -189,14 +189,16 @@ def gen_ir(data, indices, updates, out, axis, reduce_func): with ib.if_scope(ind_fused < ind_full_range_excl_axis): i = ind_fused // ind_after_axis_range j = ind_fused % ind_after_axis_range + pre_index1 = i * ind_before_axis_stride + j + pre_index2 = i * before_axis_stride + j with ib.for_range(0, ind_axis_range, "k") as k: # Offset along indices or updates - index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j + index1 = pre_index1 + k * ind_after_axis_range # Get index and shift to positive side if need new_index = indices_ptr[index1] shifted_index = new_index + (new_index < 0) * axis_range # Offset along data - index2 = i * before_axis_stride + shifted_index * after_axis_range + j + index2 = pre_index2 + shifted_index * after_axis_range reduce_func(out_ptr, index2, updates_ptr[index1]) return ib.get() diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index bfa765855b0e..b4052702268b 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -92,7 +92,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): ind_after_axis_range *= value ind_before_axis_stride = ind_axis_range * ind_after_axis_range - def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func): # pylint: disable=invalid-name ib = tir.ir_builder.create() @@ -105,44 +105,61 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.for_range(0, full_range, "i", kind="parallel") as i: out[i] = data[i] - # TODO(vvchernov): find optimal parallel approach with ib.for_range( 0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel" ) as fused: i = fused // ind_after_axis_range j = fused % ind_after_axis_range + pre_index1 = i * ind_before_axis_stride + j + pre_index2 = i * before_axis_stride + j with ib.for_range(0, ind_axis_range, "k") as k: # Offset along indices or updates - index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j - # TODO(vvchernov): assert for out of bounds, separated check for indices + index1 = pre_index1 + k * ind_after_axis_range + # Get index and shift to positive side if need k_new = indices[index1] - index_check = tir.LT(k_new, tir.const(0, indices.dtype)) - k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + shifted_index = k_new + (k_new < 0) * axis_range # Offset along data - index2 = i * before_axis_stride + k_new * after_axis_range + j - if reduction == "update": - out[index2] = updates[index1] - elif reduction == "add": - out[index2] += updates[index1] - elif reduction == "mul": - out[index2] *= updates[index1] - elif reduction == "min": - out[index2] = tir.min(out[index2], updates[index1]) - elif reduction == "max": - out[index2] = tir.max(out[index2], updates[index1]) - else: - raise NotImplementedError( - "scatter_elements reduction not in [update, add, mul, min, max]:", - reduction, - ) + index2 = pre_index2 + shifted_index * after_axis_range + reduce_func(out, index2, updates[index1]) return ib.get() + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + def add_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + def mul_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] *= update + + def min_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) + + def max_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tir.max(dst_ptr[dst_index], update) + + reduce_func = None + if reduction == "update": + reduce_func = update_func + elif reduction == "add": + reduce_func = add_func + elif reduction == "mul": + reduce_func = mul_func + elif reduction == "min": + reduce_func = min_func + elif reduction == "max": + reduce_func = max_func + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", reduction + ) + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements.generic",