-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for torch FP8 dtypes #445
Conversation
for more information, see https://pre-commit.ci
Co-authored-by: Masaki Kozuki <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @riccardofelluga for the PR, this will enable a lot of interesting opportunities for FP8 support in thunder
.
Overall the PR looks really good, I just think FP8 should not be involved in type promotion for the following reasons:
- For gradient computation of linear, it is common to use E5M2 for gradient (to avoid underflow/overflow) and weights/input in
E4M3
for higher precision, in such case, we don't wantE4M3
to be upcasted as there are GEMM kernels for mixed FP8 dtypes. - When we are using an operator with FP8, we want to enforce that FP8 inputs will stay in FP8 for performance reasons (we won't want the input to be upcasted to higher precision and not use special FP8 GEMMs).
- Except for
matmul
, there aren't any math operation support for FP8. (Also, I don't think the coverage would increase soon).
So, I think we should remove the type promotion logic. (cc: @mruberry for his thoughts on the same)
Also, I think we should not enable testing with FP8
by default, as FP8 will have a very limited op support and all tests will have to disable it for now. It is likely that we will have separate tests for selected ops to test with fp8. So we should probably tweak instantiate
to not enable FP8 types by default for testing and only use them if they were specifically passed to dtypes
argument.
lightning-thunder/thunder/tests/framework.py
Lines 416 to 438 in 82185e3
class instantiate: | |
# TODO: support other kinds of dtype specifications | |
def __init__( | |
self, | |
*, | |
executors=None, | |
devicetypes=None, | |
dtypes=None, | |
num_devices: int = 1, | |
decorators: None | Sequence = None, | |
scope=None, | |
as_name: str | None = None, | |
): | |
self.executors = set(executors) if executors is not None else set(_all_test_executors()) | |
self.devicetypes = set(devicetypes) if devicetypes is not None else set(available_devicetypes()) | |
self.devicetypes = set(filter_ci_devicetypes(self.devicetypes)) | |
if dtypes == NOTHING: | |
self.dtypes = (None,) | |
else: | |
self.dtypes = datatypes.resolve_dtypes(dtypes) if dtypes is not None else datatypes.all_dtypes | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @riccardofelluga! Cool stuff. I made a few small suggestions. The big things I'm curious about before merging is:
- should we delay implementing type promotion logic for fp8 dtypes? The logic update seems pretty reasonable, but maybe different fp8 dtypes should promote to fp16 for now?
- a convenient set of datatypes so test authors can select the floating point types except the fp8 types would be nice for now
Curious to hear your thoughts!
for more information, see https://pre-commit.ci
Thanks everybody for your comments! I though about it a bit and I mainly agree with your points, the support for this dtype in torch is so low that I agree with you, we should delay the type promotion. Regarding the tests, I also had the idea to disable them from the start but I didn't realize how ugly the |
Thank you @riccardofelluga @crcrpar @kshitij12345 @mruberry @lantiga |
Before submitting
What does this PR do?
This PR fixes #254 and adds native thunder support for the following dtypes:
Since the float8 dtype is implemented in 4 different variants I added the variant mechanism for Thunder dtypes such that we can differentiate between them.
This PR also adds the option to create test fp8 tensors with
make_tensor
so that we can start testing fp8 operations. After running the existing operators tests it is evident that the support for this dtype in torch is scarce since the majority of tests fail with "not implemented" runtime errors. With that I decided to skip the operator testing for all the fp8.Furthermore, I updated the type promotion table, please get a look and don't hesitate to comment if you think some promotions are not in the right place.
Did you have fun?
Oh yes!