Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 11, 2023
1 parent 8dbc013 commit 320d670
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2873,11 +2873,13 @@ def _inputs_check(cls, inputs):
), "Updates rank should be equal to data_rank + indices_rank - indices_shape[-1] - 1"

@classmethod
def _reduction_check(cls, attr, red_valids=["update"]):
def _reduction_check(cls, attr, red_valids=None):
reduction = attr.get("reduction", None)
if reduction is None:
reduction = b"update"
reduction = reduction.decode("utf-8")
if red_valids is None:
red_valids = ["update"]
assert reduction in red_valids, "Only {} reductions are supported, but {} is gotten".format(
red_valids, reduction
)
Expand All @@ -2889,9 +2891,7 @@ def _impl_v11(cls, inputs, attr, params):
cls._inputs_check(inputs)
indices_dim = len(infer_shape(inputs[1]))
axes = list(range(indices_dim))
return _op.scatter_nd(
inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2]
)
return _op.scatter_nd(inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2])

@classmethod
def _impl_v16(cls, inputs, attr, params):
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,13 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j])
out[index] = tir.min(
out[index], updates[i * fused_updates_dimension + j]
)
elif mode == "max":
out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j])
out[index] = tir.max(
out[index], updates[i * fused_updates_dimension + j]
)
else:
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:", mode
Expand Down
22 changes: 11 additions & 11 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..te import extern, hybrid
from ..tir import decl_buffer, expr, ir_builder, min, max
from tvm import te, tir # hide redefinition of min and max
from tvm.tir import expr


@hybrid.script
@te.hybrid.script
def _scatter_1d(data, indices, updates):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand All @@ -30,7 +30,7 @@ def _scatter_1d(data, indices, updates):
return out


@hybrid.script
@te.hybrid.script
def _scatter_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand All @@ -52,7 +52,7 @@ def _scatter_2d(data, indices, updates, axis):
return out


@hybrid.script
@te.hybrid.script
def _scatter_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand Down Expand Up @@ -96,7 +96,7 @@ def _scatter_3d(data, indices, updates, axis):
return out


@hybrid.script
@te.hybrid.script
def _scatter_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand Down Expand Up @@ -269,7 +269,7 @@ def scatter_nd(data, indices, updates, mode):

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
ib = ir_builder.create()
ib = tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
Expand Down Expand Up @@ -311,18 +311,18 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = min(out[index], updates[i * fused_updates_dimension + j])
out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j])
elif mode == "max":
out[index] = max(out[index], updates[i * fused_updates_dimension + j])
out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j])
else:
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:", mode
)

return ib.get()

out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
return extern(
out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
Expand Down

0 comments on commit 320d670

Please sign in to comment.