Skip to content

Commit

Permalink
extend mode types for scatter_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 11, 2023
1 parent 1139780 commit 8dbc013
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 10 additions & 2 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Scatter operator """
import tvm
from tvm import te, autotvm
from tvm import te, tir, autotvm
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
Expand Down Expand Up @@ -871,8 +871,16 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j])
elif mode == "max":
out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j])
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:", mode
)

return ib.get()

Expand Down
12 changes: 10 additions & 2 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..te import extern, hybrid
from ..tir import decl_buffer, expr, ir_builder
from ..tir import decl_buffer, expr, ir_builder, min, max


@hybrid.script
Expand Down Expand Up @@ -308,8 +308,16 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = min(out[index], updates[i * fused_updates_dimension + j])
elif mode == "max":
out[index] = max(out[index], updates[i * fused_updates_dimension + j])
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:", mode
)

return ib.get()

Expand Down

0 comments on commit 8dbc013

Please sign in to comment.