Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMDGPU support for triton MLIR #1073

Open
rsanthanam-amd opened this issue Jan 18, 2023 · 5 comments
Open

AMDGPU support for triton MLIR #1073

rsanthanam-amd opened this issue Jan 18, 2023 · 5 comments

Comments

@rsanthanam-amd
Copy link
Contributor

Greetings from AMD,

We would like to upstream AMDGPU support for triton MLIR.

To that end, we have been maintaining and enhancing triton with AMDGPU support in a specific fork:

https://github.com/ROCmSoftwarePlatform/triton/tree/triton-mlir

Here is a comprehensive list of specific triton features that we have enhanced to support AMDGPU to date:

  • load op
  • store op
  • reduce op (shflSync and storeShared)
  • partial atomic ops support (atomicRMW at the moment)
  • bfloat16 support (conversion to/from float32)
  • fp8 support (conversion to/from float16 and bfloat16)
  • add op (elementwise)
  • sub op (elementwise)
  • mul op (elementwise)
  • div op (elementwise)
  • shift ops
  • AMDGCN (ISA assembly) and GPU binary (HSACO) generation for triton kernels

We are also interested in enhancing your CI infrastructure by contributing the use of several of our compute nodes to check incoming upstream PRs for successful unit test runs on AMDGPU.

Please let us know if this is feasible.

If so, then we can submit a PR against the master branch with all of the aforementioned functionality and also help to integrate our compute nodes into your CI infrastructure for triton PR verification for AMDGPU.

We look forward to working with you to enable triton MLIR on AMDGPU.

@ptillet
Copy link
Collaborator

ptillet commented Jan 18, 2023

Yes, I'd like to talk more about it :) Shoot me an e-mail and we can arrange a meeting.

More generally, there are other ongoing efforts to add Triton support for Hopper etc., and I would like every effort to be as independent as possible. What would you guys think of focusing your effort on a new TritonGPUToROCM directory entirely? This would make the merge much easier. Code duplication within this scope is fine -- though we can probably re-use a lot of common infra for element-wise ops and control flow.

We are also interested in enhancing your CI infrastructure by contributing the use of several of our compute nodes to check incoming upstream PRs for successful unit test runs on AMDGPU.

That would be awesome. We can arrange that as well.

@binarman
Copy link
Contributor

binarman commented Feb 8, 2023

@ptillet Hello!

In connection with this issue I have a question about triton.tools.aot tool (maybe some other tools as well).

It contains PTX specific option --sm <compute capability>, which is required for compilation: https://github.com/openai/triton/blob/main/python/triton/tools/aot.py#L41-L42

This option has no meaning for AMD GPUs (and probably for other future target hardware), so I think it is a good idea to change interface of this tool.

I have two options on my mind:

  1. Duplicate set of valid formats for --target option:
    VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx'] -> VALID_FORMATS = ['triton-ir', 'ptx-triton-gpu-ir', 'ptx-llvm-ir', 'ptx', 'gcn-triton-gpu-ir', 'gcn-llvm-ir', 'gcn',]

  2. Separate hardware target and required compilation artifact:

  • --target [ptx|gcn]
  • --artifact [triton-ir|triton-gpu-ir|llvm-ir|binary]

What do you think about this?

@micmelesse
Copy link
Contributor

micmelesse commented Mar 17, 2023

@micmelesse
Copy link
Contributor

Upcoming PRs

  • load and store ops
  • binary op
  • shift ops
  • element_wise ops
  • reduce ops
  • atomic ops
  • dot op

@micmelesse
Copy link
Contributor

Update on Support for ROCM. We are working on enabling support for ROCM through the new third party backend system, see here. This will allow us to integrate most the changes in our fork to triton. If you wish to build our fork, see the instrutions here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants