diff --git a/python/tvm/topi/cuda/softmax.py b/python/tvm/topi/cuda/softmax.py index d669c64ca97c..a3c3e431e7d3 100644 --- a/python/tvm/topi/cuda/softmax.py +++ b/python/tvm/topi/cuda/softmax.py @@ -21,11 +21,12 @@ from tvm.contrib import cudnn from .. import generic from .injective import schedule_injective_from_existing -from ..utils import traverse_inline +from ..utils import get_const_int, traverse_inline def _schedule_softmax(softmax_op, s, outs, tgt): op_tag = softmax_op.tag + axis = get_const_int(softmax_op.attrs["axis"]) # reduce axis if op_tag == "softmax_output": expsum = softmax_op.input_tensors[1] exp = softmax_op.input_tensors[0] @@ -83,15 +84,16 @@ def sched_warp_softmax(): # (4) softmax output = outs[0] - xo, xi = s[output].split(output.op.axis[1], nparts=num_thread) + xo, xi = s[output].split(output.op.axis[axis], nparts=num_thread) xio, xii = s[output].split(xi, factor=4) s[output].vectorize(xii) s[output].bind(xo, thread_x) - s[output].bind(output.op.axis[0], block_x) + s[output].bind(output.op.axis[axis ^ 1], block_x) + s[output].reorder(output.op.axis[axis ^ 1], xo, xio, xii) if softmax_op != outs[0].op: s[softmax_op].compute_at(s[output], xio) - s[softmax_op].vectorize(softmax_op.axis[1]) # vec_len == 4 + s[softmax_op].vectorize(softmax_op.axis[axis]) # vec_len == 4 # (3) expsum k = expsum.op.reduce_axis[0] @@ -104,12 +106,12 @@ def sched_warp_softmax(): s[exp].compute_inline() s[delta].compute_inline() elif exp is not None: - xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread) + xo, xi = s[exp].split(exp.op.axis[axis], nparts=num_thread) _, xii = s[exp].split(xi, factor=4) s[exp].vectorize(xii) s[exp].bind(xo, thread_x) s[exp].compute_at(s[expsum], expsum.op.axis[0]) - s[exp].compute_at(s[output], output.op.axis[0]) + s[exp].compute_at(s[output], output.op.axis[axis ^ 1]) s[exp].set_scope("warp") # (1) max_elem @@ -131,7 +133,7 @@ def sched_warp_softmax(): s[exp].compute_inline() s[delta].compute_inline() elif exp is not None: - s[exp].bind(exp.op.axis[0], block_x) + s[exp].bind(exp.op.axis[axis ^ 1], block_x) s[max_elem].bind(max_elem.op.axis[0], block_x) k = expsum.op.reduce_axis[0] @@ -143,9 +145,10 @@ def sched_warp_softmax(): s[expsum].set_store_predicate(thread_x.var.equal(0)) output = outs[0] - tx, xi = s[output].split(output.op.axis[1], nparts=num_thread) - s[output].bind(output.op.axis[0], block_x) + tx, xi = s[output].split(output.op.axis[axis], nparts=num_thread) + s[output].bind(output.op.axis[axis ^ 1], block_x) s[output].bind(tx, thread_x) + s[output].reorder(output.op.axis[axis ^ 1], tx, xi) if softmax_op != outs[0].op: s[softmax_op].compute_at(s[output], tx) diff --git a/python/tvm/topi/nn/softmax.py b/python/tvm/topi/nn/softmax.py index a13b17686708..cb6d5b321eac 100644 --- a/python/tvm/topi/nn/softmax.py +++ b/python/tvm/topi/nn/softmax.py @@ -144,4 +144,8 @@ def log_softmax(x, axis=-1): max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k)) k = te.reduce_axis((0, n), name="k") expsum = te.compute((m,), lambda i: te.sum(te.exp(x[i, k] - max_elem[i]), axis=k)) - return te.compute(x.shape, lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i])) + return te.compute( + x.shape, + lambda i, j: x[i, j] - max_elem[i] - te.log(expsum[i]), + attrs={"axis": axis}, + ) diff --git a/tests/python/topi/python/test_topi_softmax.py b/tests/python/topi/python/test_topi_softmax.py index 10bad979c80b..8243211a8674 100644 --- a/tests/python/topi/python/test_topi_softmax.py +++ b/tests/python/topi/python/test_topi_softmax.py @@ -45,48 +45,44 @@ "topi": topi.nn.softmax, "ref": tvm.topi.testing.softmax_python, "dimensions": [1, 2, 4], + "axis": [0, 1, 2, 3], }, "log_softmax": { "topi": topi.nn.log_softmax, "ref": tvm.topi.testing.log_softmax_python, "dimensions": [2], + "axis": [1], }, } shapes = [(32, 10), (3, 4), (1, 16, 256, 256), (32,)] -softmax_operation, shape = tvm.testing.parameters( +softmax_operation, shape, axis = tvm.testing.parameters( *[ - (name, shape) + (name, shape, axis) for name, config in configs.items() for shape in shapes if len(shape) in config["dimensions"] + for axis in range(len(shape)) + if axis in config["axis"] ] ) @tvm.testing.fixture(cache_return_value=True) -def ref_data(shape, dtype, softmax_operation): +def ref_data(shape, dtype, softmax_operation, axis): ref_func = configs[softmax_operation]["ref"] a_np = np.random.uniform(size=shape).astype(dtype) - - if len(shape) == 1: - a_np_2d = a_np[None, :] - b_np_2d = tvm.topi.testing.softmax_python(a_np_2d) - b_np = b_np_2d[0] - elif len(shape) == 2: - b_np = ref_func(a_np) - elif len(shape) == 4: - _, c, h, w = a_np.shape - a_np_2d = a_np.transpose(0, 2, 3, 1).reshape(h * w, c) - b_np_2d = tvm.topi.testing.softmax_python(a_np_2d) - b_np = b_np_2d.reshape(1, h, w, c).transpose(0, 3, 1, 2) - else: - raise NotImplementedError(f"{len(shape)}-D shape not supported") + perm = list(range(a_np.ndim)) + perm[-1], perm[axis] = perm[axis], perm[-1] + trans_shape = [a_np.shape[i] for i in perm] + a_np_2d = a_np.transpose(perm).reshape(-1, trans_shape[-1]) + b_np_2d = ref_func(a_np_2d) + b_np = b_np_2d.reshape(*trans_shape).transpose(perm) return a_np, b_np -def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation): +def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation, axis): target = tvm.target.Target(target) if target.kind.name == "vulkan" and dtype == "float64": # https://www.khronos.org/registry/SPIR-V/specs/1.0/GLSL.std.450.html @@ -95,7 +91,7 @@ def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation): A = te.placeholder(shape, dtype=dtype, name="A") topi_op = configs[softmax_operation]["topi"] - B = topi_op(A, axis=min(len(shape) - 1, 1)) + B = topi_op(A, axis=axis) with tvm.target.Target(target): fschedule = tvm.topi.testing.dispatch(target, _softmax_schedule)