Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][Relay][ONNX] Replace scatter_add by scatter_elements(reduction="add") #14008

Merged
merged 19 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,6 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
}
};

struct ScatterAddAttrs : public tvm::AttrsNode<ScatterAddAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(ScatterAddAttrs, "relay.attrs.ScatterAddAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values.");
}
};

struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
String reduction;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,7 +1743,7 @@ def convert_scatter(g, op, block):
if overwrite:
out = _op.scatter(x, index, updates, axis=0)
else:
out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0)
out = _op.scatter_elements(_op.zeros_like(x), index, updates, axis=0, reduction="add")
out += _op.scatter(x, index, _op.zeros_like(updates), axis=0)
g.add_node(op.output("Out")[0], out)

Expand Down
41 changes: 39 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ def infer_shape(self, inputs, mod=None):
def infer_shape_with_prelude(self, inputs):
return self.infer_shape(inputs, mod=self.prelude.mod)

def is_empty_shape(self, shape):
rank = len(shape)
if rank:
is_empty = False
for i in range(rank):
if shape[i] == 0:
is_empty = True
break
return is_empty
else:
return True

def record_output_type(self, output):
if isinstance(output, tuple):
cleaned_output = [o for o in output if o is not None]
Expand Down Expand Up @@ -2671,18 +2683,43 @@ def bincount(self, inputs, input_types):
updates = _op.ones_like(data)

counts = _op.zeros(_op.reshape(dim, [1]), out_dtype)
out = _op.scatter_add(counts, data, updates, axis=0)
out = _op.scatter_elements(counts, data, updates, axis=0, reduction="add")
if input_type == "int32":
# Torch always outputs int64 results for bincount
return _op.cast(out, "int64")
return out

def scatter_add(self, inputs, input_types):
assert (
len(inputs) == 4
), "scatter_add takes 4 inputs (data, dim, index, src), but {} given".format(len(inputs))
data = inputs[0]
axis = inputs[1]
index = inputs[2]
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)

data_shape = self.infer_shape(inputs[0])
data_rank = len(data_shape)
index_shape = self.infer_shape(inputs[2])
index_rank = len(index_shape)
# When index is empty, the operation returns data unchanged
if self.is_empty_shape(index_shape):
return data
src_shape = self.infer_shape(inputs[3])
src_rank = len(src_shape)
assert data_rank == index_rank, "Index rank is not the same as data rank"
assert data_rank == src_rank, "Src rank is not the same as data rank"

assert 0 <= axis < data_rank, "Dim is out of bounds"

for i in range(data_rank):
assert index_shape[i] <= src_shape[i], "Index dim size should be less than src one"
if i != axis:
assert (
index_shape[i] <= data_shape[i]
), "Index dim size should be less than data one"

return _op.scatter_elements(data, index, src, axis=axis, reduction="add")

def scatter_reduce(self, inputs, input_types):
assert len(inputs) == 5 or len(inputs) == 6, (
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,7 +2918,7 @@ def _impl(inputs, attr, params, mod):

counts_shape = _op.reshape(size, [1])
counts = _op.zeros(counts_shape, out_dtype)
out = _op.scatter_add(counts, input, updates, axis=0)
out = _op.scatter_elements(counts, input, updates, axis=0, reduction="add")
return out

return _impl
Expand Down Expand Up @@ -2965,11 +2965,11 @@ def _impl(inputs, attr, params, mod):
size_arr = _op.reshape(size, [1])
counts_shape = _op.concatenate([batch_arr, size_arr], axis=0)
counts = _op.zeros(counts_shape, out_dtype)
out = _op.scatter_add(counts, input, updates, axis=1)
out = _op.scatter_elements(counts, input, updates, axis=1, reduction="add")
else:
counts_shape = _op.reshape(size, [1])
counts = _op.zeros(counts_shape, out_dtype)
out = _op.scatter_add(counts, input, updates, axis=0)
out = _op.scatter_elements(counts, input, updates, axis=0, reduction="add")

if attr["binary_output"]:
out = _op.cast(_op.cast(out, "bool"), out_dtype)
Expand Down
10 changes: 0 additions & 10 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,6 @@ def stft_shape_func(attrs, inputs, _):
_reg.register_strategy("trilu", strategy.trilu_strategy)


# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
"""Compute definition of scatter_add"""
return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)]


_reg.register_strategy("scatter_add", strategy.scatter_add_strategy)

# scatter_elements
@_reg.register_compute("scatter_elements")
def compute_scatter_elements(attrs, inputs, output_type):
Expand Down Expand Up @@ -687,7 +678,6 @@ def argwhere_shape_func(attrs, inputs, out_ndims):


_reg.register_shape_func("scatter", False, elemwise_shape_func)
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
_reg.register_shape_func("scatter_add", False, elemwise_shape_func)
_reg.register_shape_func("scatter_elements", False, elemwise_shape_func)
_reg.register_shape_func("scatter_nd", False, elemwise_shape_func)

Expand Down
5 changes: 0 additions & 5 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,6 @@ class DynExpandDimsAttrs(Attrs):
"""Attributes used in dynamic expand_dims operators"""


@tvm._ffi.register_object("relay.attrs.ScatterAddAttrs")
class ScatterAddAttrs(Attrs):
"""Attributes used in scatter_add operators"""


@tvm._ffi.register_object("relay.attrs.ScatterElementsAttrs")
class ScatterElementsAttrs(Attrs):
"""Attributes used in scatter_elements operators"""
Expand Down
13 changes: 0 additions & 13 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,19 +1086,6 @@ def scatter_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_add_strategy.register(["cuda", "gpu"])
def scatter_add_cuda(attrs, inputs, out_type, target):
"""scatter_add cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter_add),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_add.cuda",
plevel=10,
)
return strategy


@scatter_elements_strategy.register(["cuda", "gpu"])
def scatter_elements_cuda(attrs, inputs, out_type, target):
"""scatter elements cuda strategy"""
Expand Down
11 changes: 0 additions & 11 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,17 +1569,6 @@ def _compute_scatter(attrs, inputs, _):
return _compute_scatter


@override_native_generic_func("scatter_add_strategy")
def scatter_add_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter_add),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter_add.generic",
)
return strategy


# scatter_elements
@override_native_generic_func("scatter_elements_strategy")
def scatter_elements_strategy(attrs, inputs, out_type, target):
Expand Down
27 changes: 1 addition & 26 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,31 +378,6 @@ def scatter(data, indices, updates, axis):
return _make.scatter(data, indices, updates, axis)


def scatter_add(data, indices, updates, axis):
"""Update data by adding values in updates at positions defined by indices.

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The index locations to update.

updates : relay.Expr
The values to add.

axis : int
The axis to scatter_add on.

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.scatter_add(data, indices, updates, axis)


def scatter_elements(data, indices, updates, axis=0, reduction="update"):
"""Scatter elements with updating data by reduction of values in updates
at positions defined by indices.
Expand Down Expand Up @@ -1717,7 +1692,7 @@ def segment_sum(data, segment_ids, num_segments=None):
expanded_segment_ids = tile(segment_ids, segment_ids_tiled_shape)
scatter_add_segment_ids = transpose(expanded_segment_ids)
src = cast_like(_dyn_make.zeros(new_shape, "float64"), data)
return scatter_add(src, scatter_add_segment_ids, data, axis=0)
return scatter_elements(src, scatter_add_segment_ids, data, axis=0, reduction="add")


def cumsum(data, axis=None, dtype=None, exclusive=None):
Expand Down
1 change: 0 additions & 1 deletion python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .scatter_elements import *
from .sparse_fill_empty_rows import *
from .sparse_reshape import *
from .scatter_add import *
from .argwhere import *
from .scan import *
from .einsum import *
Expand Down
132 changes: 0 additions & 132 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,138 +591,6 @@ def schedule_scatter_via_sort(_, outs):
return schedule_extern(outs)


def gen_scatter_add_1d_atomic(data, indices, updates, axis, out, _):
"""Generate scatter add ir for 1d inputs, using atomic_add instruction

Parameters
----------
data : tir.Tensor
The input data to the operator.

indices : tir.Tensor
The index locations to update.

updates : tir.Tensor
The values to update.

axis : int
The axis to scatter on

out : tir.Tensor
The output tensor.

Returns
-------
ret : tir
The computational ir.
"""
assert axis == 0
n = data.shape[0]

ib = tvm.tir.ir_builder.create()

out_ptr = ib.buffer_ptr(out)
data_ptr = ib.buffer_ptr(data)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads

with ib.new_scope():
nthread_bx = ceil_div(n, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx
with ib.if_scope(tid < n):
out_ptr[tid] = data_ptr[tid]

indices_ptr = ib.buffer_ptr(indices)
updates_ptr = ib.buffer_ptr(updates)

ni = indices.shape[0]

atomic_add_return = ib.allocate(updates.dtype, (1,), name="atomic_add_return", scope="local")

with ib.new_scope():
nthread_bx = ceil_div(ni, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.if_scope(tid < ni):
index = indices_ptr[tid]
with ib.if_scope(index < 0):
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index + n]),
updates_ptr[tid],
)
with ib.else_scope():
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", out_ptr[index]),
updates_ptr[tid],
)

return ib.get()


def scatter_add(data, indices, updates, axis=0):
"""Update data by adding values in updates at positions defined by indices

Parameters
----------
data : relay.Expr
The input data to the operator.

indices : relay.Expr
The index locations to update.

updates : relay.Expr
The values to be added.

axis : int
The axis to scatter on

Returns
-------
ret : relay.Expr
The computed result.
"""
if axis < 0:
axis += len(data.shape)
assert axis >= 0
assert axis < len(data.shape)

rank = len(data.shape)
assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions"

ir_funcs = {
1: gen_scatter_add_1d_atomic,
2: gen_ir_2d,
3: gen_ir_3d,
4: gen_ir_4d,
}

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] += update

out_shape = data.shape
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
out = te.extern(
[out_shape],
[data, indices, updates],
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_add_gpu",
tag="scatter_add_gpu",
)

return out


def scatter_nd(data, indices, updates, mode):
"""Scatter elements from a n-dimension array.

Expand Down
Loading