Skip to content

Commit

Permalink
upstream scatter_elements for CPU with CUDA approach
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 22, 2023
1 parent 0ad0db8 commit 4877b72
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
6 changes: 4 additions & 2 deletions python/tvm/topi/cuda/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,16 @@ def gen_ir(data, indices, updates, out, axis, reduce_func):
with ib.if_scope(ind_fused < ind_full_range_excl_axis):
i = ind_fused // ind_after_axis_range
j = ind_fused % ind_after_axis_range
pre_index1 = i * ind_before_axis_stride + j
pre_index2 = i * before_axis_stride + j
with ib.for_range(0, ind_axis_range, "k") as k:
# Offset along indices or updates
index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j
index1 = pre_index1 + k * ind_after_axis_range
# Get index and shift to positive side if need
new_index = indices_ptr[index1]
shifted_index = new_index + (new_index < 0) * axis_range
# Offset along data
index2 = i * before_axis_stride + shifted_index * after_axis_range + j
index2 = pre_index2 + shifted_index * after_axis_range
reduce_func(out_ptr, index2, updates_ptr[index1])

return ib.get()
Expand Down
63 changes: 40 additions & 23 deletions python/tvm/topi/scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"):
ind_after_axis_range *= value
ind_before_axis_stride = ind_axis_range * ind_after_axis_range

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func):
# pylint: disable=invalid-name
ib = tir.ir_builder.create()

Expand All @@ -105,44 +105,61 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
with ib.for_range(0, full_range, "i", kind="parallel") as i:
out[i] = data[i]

# TODO(vvchernov): find optimal parallel approach
with ib.for_range(
0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel"
) as fused:
i = fused // ind_after_axis_range
j = fused % ind_after_axis_range
pre_index1 = i * ind_before_axis_stride + j
pre_index2 = i * before_axis_stride + j
with ib.for_range(0, ind_axis_range, "k") as k:
# Offset along indices or updates
index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j
# TODO(vvchernov): assert for out of bounds, separated check for indices
index1 = pre_index1 + k * ind_after_axis_range
# Get index and shift to positive side if need
k_new = indices[index1]
index_check = tir.LT(k_new, tir.const(0, indices.dtype))
k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype))
shifted_index = k_new + (k_new < 0) * axis_range
# Offset along data
index2 = i * before_axis_stride + k_new * after_axis_range + j
if reduction == "update":
out[index2] = updates[index1]
elif reduction == "add":
out[index2] += updates[index1]
elif reduction == "mul":
out[index2] *= updates[index1]
elif reduction == "min":
out[index2] = tir.min(out[index2], updates[index1])
elif reduction == "max":
out[index2] = tir.max(out[index2], updates[index1])
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:",
reduction,
)
index2 = pre_index2 + shifted_index * after_axis_range
reduce_func(out, index2, updates[index1])

return ib.get()

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

def add_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] += update

def mul_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] *= update

def min_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update)

def max_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = tir.max(dst_ptr[dst_index], update)

reduce_func = None
if reduction == "update":
reduce_func = update_func
elif reduction == "add":
reduce_func = add_func
elif reduction == "mul":
reduce_func = mul_func
elif reduction == "min":
reduce_func = min_func
elif reduction == "max":
reduce_func = max_func
else:
raise NotImplementedError(
"scatter_elements reduction not in [update, add, mul, min, max]:", reduction
)

out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="scatter_elements.generic",
Expand Down

0 comments on commit 4877b72

Please sign in to comment.