Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM ERROR: mma16816 data type not supported #4922

Closed
mobicham opened this issue Oct 16, 2024 · 8 comments
Closed

LLVM ERROR: mma16816 data type not supported #4922

mobicham opened this issue Oct 16, 2024 · 8 comments

Comments

@mobicham
Copy link

mobicham commented Oct 16, 2024

The latest Triton build (3.1.0) throws the following error when using bitpacked data inside a loop with tl.dot:

LLVM ERROR: mma16816 data type not supported

with the build from source, I get a different error:

unimplemented code path
UNREACHABLE executed at /root/triton_op/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:79!
Aborted (core dumped)

This error happens on Ampere and Hopper, but not on older gpus like the Titan RTX/2080 Ti.

The bitpacked data is read with indices in the form offs_k[:, None] // num_elements, something like [0,0,0...1,1,1...64,64,64].

I have faced this error in the previous build and I found that replacing for k in range(0, total_blocks_k, 1): with for k in tl.range(0, total_blocks_k, 1, num_stages=1): solved the issue, but this trick no longer works with 3.1.0.

Here's a full-script to reproduce it.
https://gist.github.com/mobicham/f9eba3c07f7e497ae622194a9c5e4822

@lezcano
Copy link
Contributor

lezcano commented Nov 6, 2024

I think #5044 may fix this issue in Ampere. Mixed dtype tl.dot is not so well supported on Hopper yet tho. #5003 is making good progress in that front tho.

@lezcano
Copy link
Contributor

lezcano commented Nov 6, 2024

Also, out of curiosity, can you post the ttgir?

@mobicham
Copy link
Author

mobicham commented Nov 6, 2024

Thank you @lezcano !

The b.to(tl.float32).to(tl.float16) doesn't break the loop though but b.to(tl.float16) does, in the end, with or without the tl.float32, tl.dot is getting f[16 x fp16, sounds kinda strange isn't it?

Here's the ttgir for the b.to(tl.float32).to(tl.float16) version which doesn't crash:
https://gist.github.com/mobicham/ae48cdf55f7062994eae3e2653d26afa#file-b-to-tl-float32-to-tl-float16-_log-txt-L148

@lezcano
Copy link
Contributor

lezcano commented Nov 6, 2024

Just to make sure I understand. The repro in the OP still breaks with #5044 patched in? That's rather weird. What's the crash you see? Could you also run the script with TRITON_ENABLE_PYTHON_STACKTRACE=1 and post the stacktrace for the crash?

@Jokeren
Copy link
Contributor

Jokeren commented Nov 6, 2024

I cannot reproduce it. Maybe the author of the PR is not using the correct triton version

@mobicham
Copy link
Author

mobicham commented Nov 6, 2024

@Jokeren I just tried it on an A100, it does throw that error. I am using 3.1.0 since the nightly builds are broken. Let me build Triton from source and re-check

@lezcano
Copy link
Contributor

lezcano commented Nov 6, 2024

note that you will not have to build master, but the commit linked above as it hasn't landed yet! Master will probably break.

@mobicham
Copy link
Author

mobicham commented Nov 6, 2024

I can confirm that the build from https://github.com/triton-lang/triton/tree/keren/dot-mma-1 solves the issue.
Thank you both for taking the time to look into this, really appreciate it!

@mobicham mobicham closed this as completed Nov 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants