Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE Schedule] Fix broken 2D softmax TE schedules when axis=0 #11803

Merged
merged 2 commits into from
Jun 21, 2022

Conversation

lazycal
Copy link
Contributor

@lazycal lazycal commented Jun 21, 2022

When scheduling 2D softmax, the current cuda schedule assumes the reduction axis to be the last axis, and yields incorrect schedule and raise error messages that are hard to debug. For example, running the follow snippet:

import tvm
from tvm import relay

shape = (64, 2)
dtype = 'float32'

A = relay.var('A', shape=shape, dtype=dtype)
B = relay.nn.softmax(A, axis=0)
f = relay.Function([A], B)
mod = tvm.IRModule.from_expr(f)

dev = tvm.cuda()
target = tvm.target.Target('cuda')
with tvm.transform.PassContext(opt_level=0):
    executor = relay.build_module.create_executor(
        'graph', mod, dev, target).evaluate()

I got
Check failed: (!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) is false: LowerWarpMemory failed to rewrite load to shuffle for index ((threadIdx.x*5) + (k.inner*2)) local_index=(((threadIdx.x*5) + (k.inner*2))/32)
with opt_level=0 and
Check failed: (match) is false: iter_var(blockIdx.x, , blockIdx.x) domain already inferred, cannot prove their extents are the same 64 vs 2
with opt_level=4.

This PR fixes the schedule to also support axis=0 for all the cuda 2D schedules and enhances the unit testing to test all reduction axes.

@lazycal lazycal changed the title Support reduce axis=0 in softmax schedule. [TE Schedule] Fix broken 2D softmax TE schedules when axis=0 Jun 21, 2022
@lazycal
Copy link
Contributor Author

lazycal commented Jun 21, 2022

@masahi

@masahi masahi merged commit b63801c into apache:main Jun 21, 2022
blackkker pushed a commit to blackkker/tvm that referenced this pull request Jul 7, 2022
…11803)

* Support arbitrary reduce axis in softmax schedule.

* Fix lint.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants