Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE Schedule] Fix broken 2D softmax TE schedules when axis=0 #11803

Merged
merged 2 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective_from_existing
from ..utils import traverse_inline
from ..utils import get_const_int, traverse_inline


def _schedule_softmax(softmax_op, s, outs, tgt):
op_tag = softmax_op.tag
axis = get_const_int(softmax_op.attrs["axis"]) # reduce axis
if op_tag == "softmax_output":
expsum = softmax_op.input_tensors[1]
exp = softmax_op.input_tensors[0]
Expand Down Expand Up @@ -83,15 +84,16 @@ def sched_warp_softmax():

# (4) softmax
output = outs[0]
xo, xi = s[output].split(output.op.axis[1], nparts=num_thread)
xo, xi = s[output].split(output.op.axis[axis], nparts=num_thread)
xio, xii = s[output].split(xi, factor=4)
s[output].vectorize(xii)
s[output].bind(xo, thread_x)
s[output].bind(output.op.axis[0], block_x)
s[output].bind(output.op.axis[axis ^ 1], block_x)
s[output].reorder(output.op.axis[axis ^ 1], xo, xio, xii)

if softmax_op != outs[0].op:
s[softmax_op].compute_at(s[output], xio)
s[softmax_op].vectorize(softmax_op.axis[1]) # vec_len == 4
s[softmax_op].vectorize(softmax_op.axis[axis]) # vec_len == 4

# (3) expsum
k = expsum.op.reduce_axis[0]
Expand All @@ -104,12 +106,12 @@ def sched_warp_softmax():
s[exp].compute_inline()
s[delta].compute_inline()
elif exp is not None:
xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread)
xo, xi = s[exp].split(exp.op.axis[axis], nparts=num_thread)
_, xii = s[exp].split(xi, factor=4)
s[exp].vectorize(xii)
s[exp].bind(xo, thread_x)
s[exp].compute_at(s[expsum], expsum.op.axis[0])
s[exp].compute_at(s[output], output.op.axis[0])
s[exp].compute_at(s[output], output.op.axis[axis ^ 1])
s[exp].set_scope("warp")

# (1) max_elem
Expand All @@ -131,7 +133,7 @@ def sched_warp_softmax():
s[exp].compute_inline()
s[delta].compute_inline()
elif exp is not None:
s[exp].bind(exp.op.axis[0], block_x)
s[exp].bind(exp.op.axis[axis ^ 1], block_x)

s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
Expand All @@ -143,9 +145,10 @@ def sched_warp_softmax():
s[expsum].set_store_predicate(thread_x.var.equal(0))

output = outs[0]
tx, xi = s[output].split(output.op.axis[1], nparts=num_thread)
s[output].bind(output.op.axis[0], block_x)
tx, xi = s[output].split(output.op.axis[axis], nparts=num_thread)
s[output].bind(output.op.axis[axis ^ 1], block_x)
s[output].bind(tx, thread_x)
s[output].reorder(output.op.axis[axis ^ 1], tx, xi)

if softmax_op != outs[0].op:
s[softmax_op].compute_at(s[output], tx)
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/topi/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,8 @@ def log_softmax(x, axis=-1):
max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k))
k = te.reduce_axis((0, n), name="k")
expsum = te.compute((m,), lambda i: te.sum(te.exp(x[i, k] - max_elem[i]), axis=k))
return te.compute(x.shape, lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]))
return te.compute(
x.shape,
lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]),
attrs={"axis": axis},
)
34 changes: 15 additions & 19 deletions tests/python/topi/python/test_topi_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,48 +45,44 @@
"topi": topi.nn.softmax,
"ref": tvm.topi.testing.softmax_python,
"dimensions": [1, 2, 4],
"axis": [0, 1, 2, 3],
},
"log_softmax": {
"topi": topi.nn.log_softmax,
"ref": tvm.topi.testing.log_softmax_python,
"dimensions": [2],
"axis": [1],
},
}
shapes = [(32, 10), (3, 4), (1, 16, 256, 256), (32,)]
softmax_operation, shape = tvm.testing.parameters(
softmax_operation, shape, axis = tvm.testing.parameters(
*[
(name, shape)
(name, shape, axis)
for name, config in configs.items()
for shape in shapes
if len(shape) in config["dimensions"]
for axis in range(len(shape))
if axis in config["axis"]
]
)


@tvm.testing.fixture(cache_return_value=True)
def ref_data(shape, dtype, softmax_operation):
def ref_data(shape, dtype, softmax_operation, axis):
ref_func = configs[softmax_operation]["ref"]

a_np = np.random.uniform(size=shape).astype(dtype)

if len(shape) == 1:
a_np_2d = a_np[None, :]
b_np_2d = tvm.topi.testing.softmax_python(a_np_2d)
b_np = b_np_2d[0]
elif len(shape) == 2:
b_np = ref_func(a_np)
elif len(shape) == 4:
_, c, h, w = a_np.shape
a_np_2d = a_np.transpose(0, 2, 3, 1).reshape(h * w, c)
b_np_2d = tvm.topi.testing.softmax_python(a_np_2d)
b_np = b_np_2d.reshape(1, h, w, c).transpose(0, 3, 1, 2)
else:
raise NotImplementedError(f"{len(shape)}-D shape not supported")
perm = list(range(a_np.ndim))
perm[-1], perm[axis] = perm[axis], perm[-1]
trans_shape = [a_np.shape[i] for i in perm]
a_np_2d = a_np.transpose(perm).reshape(-1, trans_shape[-1])
b_np_2d = ref_func(a_np_2d)
b_np = b_np_2d.reshape(*trans_shape).transpose(perm)

return a_np, b_np


def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation):
def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation, axis):
target = tvm.target.Target(target)
if target.kind.name == "vulkan" and dtype == "float64":
# https://www.khronos.org/registry/SPIR-V/specs/1.0/GLSL.std.450.html
Expand All @@ -95,7 +91,7 @@ def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation):
A = te.placeholder(shape, dtype=dtype, name="A")

topi_op = configs[softmax_operation]["topi"]
B = topi_op(A, axis=min(len(shape) - 1, 1))
B = topi_op(A, axis=axis)

with tvm.target.Target(target):
fschedule = tvm.topi.testing.dispatch(target, _softmax_schedule)
Expand Down