From 8dbc01387c890f264f39806f7769a4721aae818d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 10 Feb 2023 12:51:11 +0300 Subject: [PATCH] extend mode types for scatter_nd --- python/tvm/topi/cuda/scatter.py | 12 ++++++++++-- python/tvm/topi/scatter.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fa7545cd323a..f104d1e42cb1 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument """Scatter operator """ import tvm -from tvm import te, autotvm +from tvm import te, tir, autotvm from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add @@ -871,8 +871,16 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j]) + elif mode == "max": + out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j]) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", mode + ) return ib.get() diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index e0578aab41b9..2298f09b4a67 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" from ..te import extern, hybrid -from ..tir import decl_buffer, expr, ir_builder +from ..tir import decl_buffer, expr, ir_builder, min, max @hybrid.script @@ -308,8 +308,16 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = min(out[index], updates[i * fused_updates_dimension + j]) + elif mode == "max": + out[index] = max(out[index], updates[i * fused_updates_dimension + j]) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", mode + ) return ib.get()