diff --git a/python/tvm/topi/cuda/scatter_elements.py b/python/tvm/topi/cuda/scatter_elements.py index 1e735f5618e2c..25f15a0e73a61 100644 --- a/python/tvm/topi/cuda/scatter_elements.py +++ b/python/tvm/topi/cuda/scatter_elements.py @@ -189,14 +189,16 @@ def gen_ir(data, indices, updates, out, axis, reduce_func): with ib.if_scope(ind_fused < ind_full_range_excl_axis): i = ind_fused // ind_after_axis_range j = ind_fused % ind_after_axis_range + pre_index1 = i * ind_before_axis_stride + j + pre_index2 = i * before_axis_stride + j with ib.for_range(0, ind_axis_range, "k") as k: # Offset along indices or updates - index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j + index1 = pre_index1 + k * ind_after_axis_range # Get index and shift to positive side if need new_index = indices_ptr[index1] shifted_index = new_index + (new_index < 0) * axis_range # Offset along data - index2 = i * before_axis_stride + shifted_index * after_axis_range + j + index2 = pre_index2 + shifted_index * after_axis_range reduce_func(out_ptr, index2, updates_ptr[index1]) return ib.get() diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index bfa765855b0e5..b4052702268b5 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -92,7 +92,7 @@ def scatter_elements(data, indices, updates, axis=0, reduction="update"): ind_after_axis_range *= value ind_before_axis_stride = ind_axis_range * ind_after_axis_range - def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): + def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func): # pylint: disable=invalid-name ib = tir.ir_builder.create() @@ -105,44 +105,61 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): with ib.for_range(0, full_range, "i", kind="parallel") as i: out[i] = data[i] - # TODO(vvchernov): find optimal parallel approach with ib.for_range( 0, ind_before_axis_range * ind_after_axis_range, "fused", kind="parallel" ) as fused: i = fused // ind_after_axis_range j = fused % ind_after_axis_range + pre_index1 = i * ind_before_axis_stride + j + pre_index2 = i * before_axis_stride + j with ib.for_range(0, ind_axis_range, "k") as k: # Offset along indices or updates - index1 = i * ind_before_axis_stride + k * ind_after_axis_range + j - # TODO(vvchernov): assert for out of bounds, separated check for indices + index1 = pre_index1 + k * ind_after_axis_range + # Get index and shift to positive side if need k_new = indices[index1] - index_check = tir.LT(k_new, tir.const(0, indices.dtype)) - k_new += tir.Select(index_check, axis_range, tir.const(0, indices.dtype)) + shifted_index = k_new + (k_new < 0) * axis_range # Offset along data - index2 = i * before_axis_stride + k_new * after_axis_range + j - if reduction == "update": - out[index2] = updates[index1] - elif reduction == "add": - out[index2] += updates[index1] - elif reduction == "mul": - out[index2] *= updates[index1] - elif reduction == "min": - out[index2] = tir.min(out[index2], updates[index1]) - elif reduction == "max": - out[index2] = tir.max(out[index2], updates[index1]) - else: - raise NotImplementedError( - "scatter_elements reduction not in [update, add, mul, min, max]:", - reduction, - ) + index2 = pre_index2 + shifted_index * after_axis_range + reduce_func(out, index2, updates[index1]) return ib.get() + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + + def add_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + def mul_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] *= update + + def min_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) + + def max_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = tir.max(dst_ptr[dst_index], update) + + reduce_func = None + if reduction == "update": + reduce_func = update_func + elif reduction == "add": + reduce_func = add_func + elif reduction == "mul": + reduce_func = mul_func + elif reduction == "min": + reduce_func = min_func + elif reduction == "max": + reduce_func = max_func + else: + raise NotImplementedError( + "scatter_elements reduction not in [update, add, mul, min, max]:", reduction + ) + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], [data, indices, updates], - lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), + lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func), dtype=data.dtype, out_buffers=[out_buf], name="scatter_elements.generic",