Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch FP8 dtypes #445

Merged
merged 19 commits into from
May 27, 2024
Merged

Add support for torch FP8 dtypes #445

merged 19 commits into from
May 27, 2024

Conversation

riccardofelluga
Copy link
Collaborator

@riccardofelluga riccardofelluga commented May 22, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

This PR fixes #254 and adds native thunder support for the following dtypes:

torch.float8_e5m2
torch.float8_e5m2fnuz
torch.float8_e4m3fn
torch.float8_e4m3fnuz

Since the float8 dtype is implemented in 4 different variants I added the variant mechanism for Thunder dtypes such that we can differentiate between them.

This PR also adds the option to create test fp8 tensors with make_tensor so that we can start testing fp8 operations. After running the existing operators tests it is evident that the support for this dtype in torch is scarce since the majority of tests fail with "not implemented" runtime errors. With that I decided to skip the operator testing for all the fp8.

Furthermore, I updated the type promotion table, please get a look and don't hesitate to comment if you think some promotions are not in the right place.

Did you have fun?

Oh yes!

thunder/core/dtypes.py Outdated Show resolved Hide resolved
thunder/core/dtypes.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @riccardofelluga for the PR, this will enable a lot of interesting opportunities for FP8 support in thunder.

Overall the PR looks really good, I just think FP8 should not be involved in type promotion for the following reasons:

  1. For gradient computation of linear, it is common to use E5M2 for gradient (to avoid underflow/overflow) and weights/input in E4M3 for higher precision, in such case, we don't want E4M3 to be upcasted as there are GEMM kernels for mixed FP8 dtypes.
  2. When we are using an operator with FP8, we want to enforce that FP8 inputs will stay in FP8 for performance reasons (we won't want the input to be upcasted to higher precision and not use special FP8 GEMMs).
  3. Except for matmul, there aren't any math operation support for FP8. (Also, I don't think the coverage would increase soon).

So, I think we should remove the type promotion logic. (cc: @mruberry for his thoughts on the same)

Also, I think we should not enable testing with FP8 by default, as FP8 will have a very limited op support and all tests will have to disable it for now. It is likely that we will have separate tests for selected ops to test with fp8. So we should probably tweak instantiate to not enable FP8 types by default for testing and only use them if they were specifically passed to dtypes argument.

class instantiate:
# TODO: support other kinds of dtype specifications
def __init__(
self,
*,
executors=None,
devicetypes=None,
dtypes=None,
num_devices: int = 1,
decorators: None | Sequence = None,
scope=None,
as_name: str | None = None,
):
self.executors = set(executors) if executors is not None else set(_all_test_executors())
self.devicetypes = set(devicetypes) if devicetypes is not None else set(available_devicetypes())
self.devicetypes = set(filter_ci_devicetypes(self.devicetypes))
if dtypes == NOTHING:
self.dtypes = (None,)
else:
self.dtypes = datatypes.resolve_dtypes(dtypes) if dtypes is not None else datatypes.all_dtypes

thunder/core/dtypes.py Show resolved Hide resolved
thunder/core/dtypes.py Show resolved Hide resolved
thunder/core/dtypes.py Outdated Show resolved Hide resolved
thunder/core/utils.py Outdated Show resolved Hide resolved
thunder/core/utils.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @riccardofelluga! Cool stuff. I made a few small suggestions. The big things I'm curious about before merging is:

  • should we delay implementing type promotion logic for fp8 dtypes? The logic update seems pretty reasonable, but maybe different fp8 dtypes should promote to fp16 for now?
  • a convenient set of datatypes so test authors can select the floating point types except the fp8 types would be nice for now

Curious to hear your thoughts!

@riccardofelluga
Copy link
Collaborator Author

Thanks everybody for your comments! I though about it a bit and I mainly agree with your points, the support for this dtype in torch is so low that I agree with you, we should delay the type promotion.

Regarding the tests, I also had the idea to disable them from the start but I didn't realize how ugly the pytest.skip everywhere ended up looking. I introduce the float_math_dtypes set of floating datatypes so that it is possible to @instantiate the test with only the floating dtypes that are worth testing for now.

@t-vi t-vi enabled auto-merge (squash) May 27, 2024 13:19
@t-vi t-vi merged commit a4dcd89 into main May 27, 2024
37 checks passed
@t-vi t-vi deleted the native-fp8 branch May 27, 2024 13:22
@t-vi
Copy link
Collaborator

t-vi commented May 27, 2024

crcrpar pushed a commit that referenced this pull request May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for FP8E4M3 and FP8E5M2 dtypes
6 participants