Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] rs_op_selector does not allow to use mixed FP8 types for A and B #1268

Closed
ex-rzr opened this issue Dec 13, 2023 · 6 comments
Closed

[BUG] rs_op_selector does not allow to use mixed FP8 types for A and B #1268

ex-rzr opened this issue Dec 13, 2023 · 6 comments
Labels
? - Needs Triage bug Something isn't working

Comments

@ex-rzr
Copy link

ex-rzr commented Dec 13, 2023

Describe the bug
rs_op_selector does not allow to use mixed FP8 types for A and B.
There is a static assert that prevents this (it is placed correctly in ss_op_selector, each branch has its own check if needed).
See https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm90.hpp#L857

Steps/Code to reproduce bug

cute::make_tiled_mma(cute::GMMA::rs_op_selector<cutlass::float_e4m3_t, cutlass::float_e5m2_t, float, ...
cutlass/include/cute/arch/mma_sm90.hpp(856): error: static assertion failed with "ElementA and ElementB must be the same type for this config."
      static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");
      ^
          detected during:
            instantiation of "auto cute::GMMA::rs_op_selector<ElementA,ElementB,ElementC,TileShape_MNK,MajorA,MajorB,Args...>() [with ElementA=cutlass::float_e4m3_t, ElementB=cutlass::float_e5m2_t, ElementC=float, ..."

Expected behavior
The same behavior as in ss_op_selector.

@ex-rzr ex-rzr added ? - Needs Triage bug Something isn't working labels Dec 13, 2023
@hwu36
Copy link
Collaborator

hwu36 commented Dec 13, 2023

@IonThruster

@jayhshah
Copy link
Contributor

I also noticed that the option of FP8 operand and FP16 accumulator is excluded in ss_op_selector and rs_op_selector, even though this is accommodated by the PTX ISA and otherwise has MMA_Traits structs defined for it (F16E4M3E4M3 etc.). This is done by the following code:

// FP16 accumulator
  if constexpr (is_same_v<ElementC, half_t>) {
    static_assert(is_same_v<ElementA, half_t>, "Element types for AB must be half if ElementC is half.");
    static_assert(is_same_v<ElementB, half_t>, "Element types for AB must be half if ElementC is half.");

Was this intentional or is this an oversight?

@thakkarV
Copy link
Collaborator

thakkarV commented Jan 10, 2024

hi @ex-rzr , we support mixed fp8 types in the RS op selector now:

else if constexpr (is_same_v<ElementA, float_e4m3_t> && is_same_v<ElementB, float_e5m2_t>) {

Is this what you wanted? Can we close the issue?

@ex-rzr
Copy link
Author

ex-rzr commented Jan 11, 2024

@thakkarV
Nothing has changed since I opened this issue.
There is a static assert before the line you pointed at:

static_assert(is_same_v<ElementA, ElementB>, "ElementA and ElementB must be the same type for this config.");

@thakkarV
Copy link
Collaborator

sorry about that. I noticed the assert after replying earlier to the thread. This will get fixed with 3.4 tagging in the next few days.

@thakkarV
Copy link
Collaborator

Fixed witb 3.4. Closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants