-
Notifications
You must be signed in to change notification settings - Fork 19
[RFC] First class Triton support in OpenXLA Nvgpu #54
Comments
Neat. My main comment is a meta one: the openxla-nvgpu project is still pretty young and even missing CI and full/proper build support/integration. I'm open to moving fast, but we also need to prioritize some project infrastructure work to hold everything together. |
There's other ways of doing this that are much better integrated and should all work today - I'll respond on the doc but the short of it is custom dispatches (ala samples/custom_dispatch/cuda/) are sufficient and well-supported - custom modules and other things should not be required. |
👍 good point, I think we can start with custom dispatches. Although if we want to push Triton compilation to run time and bundle it with auto tuning (tile selection mostly?), then we'll not be able to do it as a custom dispatch? |
Ah, so you're intending to use the sample compiled IREE program but vary the triton kernels without recompiling the program? |
I think we'll have both strategies:
|
Cool. For #1 the custom dispatch way should work. For #2 there are some other ways that are potentially easier. Executable specialization constants can be used to parameterize executables when they are loaded but they may be slightly trickier to integrate with black boxes - may still be interesting to reuse that mechanism with a custom executable type at runtime though. Another option would be to have your custom module return a !hal.executable and schedule work as normal, but at that point it's probably best to use streamable custom calls instead - you'd take your params as push constants, do whatever you needed, and then launch the kernel against the stream. |
Initial implementation of the First class triton integration: #54 Requires Triton + patches from https://github.com/ezhulenev/triton/commits/openxla-triton ``` git submodule update --remote third_party/triton ``` Run tests: ``` ctest --test-dir build -R triton ```
[RFC] First class Triton support in OpenXLA-Nvgpu
We want to improve the state of Triton and OpenXLA integration, and make jax-triton more user and compiler friendly.
Please let us know what you think!
The text was updated successfully, but these errors were encountered: