-
Notifications
You must be signed in to change notification settings - Fork 20
bring back torch.autograd.Function #316
base: gh/vkuzo/29/base
Are you sure you want to change the base?
Conversation
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 09c4625b2a859ce6468bac328d5f0ff61bb86251 Pull Request resolved: #316
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. ``` # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files # and modified to only support dynamic scaling # # Why do we want a torch.autograd.Function here? Vasiliy's opinion is that # as we add more scaling granularities, keeping the scaling code close to Float8Linear # will be really useful for readability and debuggability of numerics. # # For example, a future PR to add rowwise scaling could do # # # forward # x_bf16 = ... # if scaling_granularity == ScalingGranularity.PER_TENSOR: # # we can scale the same way for fwd/bwd # x_maybe_fp8 = to_fp8(...) # else: # assert scaling_granularity == ScalingGranularity.PER_ROW: # # defer scaling to float8_mm # x_maybe_fp8 = x_bf16 # # # repeat for w # # y_bf16 = float8_mm(x_maybe_fp8, w_maybe_fp8) # # Requirements for float8_mm # - composes with DTensor, compile, autograd # - readable/debuggable # # Option 1 (this PR): float8_mm is a torch.autograd.Function # - pros # - cons # Option 2 (current code without this PR): float8_mm is an override of torch.mm # - pros # - cons # ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 75842f4858804bc6f204eb55222a493ea9074630 Pull Request resolved: #316
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. We should discuss whether we want Option 1 (keep overriding torch.mm) or Option 2 (torch.autograd.Function). Vasiliy: I think Option 2 is cleaner/more readable/more debuggable, modeling code is usually written in the module or similar torch.autograd.Function overrides. I would consider scaling tensors to float8 modeling code, and it's unintuitive IMO for this to happen deep inside op overrides. However, Option 1 is less risky technically as we avoid torch.autograd.Function which is less mature in interactions with torch.compile. While the current PR is all green, we are using `allow_in_graph` which is a bit unsafe. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ea5bd3e7ec037b351703154363a6bbe9f4f638d5 Pull Request resolved: #316
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. This PR implements Option 2 as IMO this is more readable/debuggable. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c1b8d0c42c7f73fb8a0c8806ae88479a40d0be40 Pull Request resolved: #316
@@ -71,6 +71,54 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( | |||
scale.copy_(new_scale) | |||
|
|||
|
|||
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Does the structure work out to put this in float8 ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in how things look after this PR it would make sense, but might be good to see how the code looks after we add different granularities and the if/else branches on when to convert to lower precision. Maybe we can revisit then?
return res_bits | ||
|
||
@staticmethod | ||
def backward(ctx, go_fp8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: align go_fp8 / other naming to the other PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do that in separate PRs, since not user facing. Just keeping things small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if that changes the size of the PR much but sure thats fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably just a style difference on how to sequence the renames, either is ok IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, TBH I think this is a good balance of both subclassing + autograd func
Summary: I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd: a. keep the bf16 copy to be able to rescale across dim0 and dim1 b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision) c. keep some of the gemms in bf16 to avoid the need to scale twice The modeling logic in Float8Linear for a/b would look like: ```python def forward(self, x): if scaling_type == TENSORWISE: x_maybe_fp8 = to_fp8_tensorwise(x, ...) elif scaling_type == ROWWISE: x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...) # repeat for w y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...) ``` And, there are at least two choices I see for `float8_mm_op`: ```python # Option 1 (current code without this PR): use the torch.mm override implements([aten.mm.default, aten.matmul.default]) def float8_mm(aten_op, args, kwargs=None): ... # Option 2 (this PR): use torch.autograd.Function class float8_mm(torch.autograd.Function): ... ``` To support future scaling granularities, whichever choice we go with will have to do something like below: ```python def float8_mm(x_maybe_fp8, w_maybe_fp8): if isinstance(x_maybe_fp8, Float8Tensor): x_fp8 = x_maybe_fp8 else: x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...) # repeat for w # call torch._scaled_mm ``` Furthermore, to keep things readable / debuggable, it would be good to: 1. be able to print tensors before/after quantization 2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module To do the above, we'll need to pass around metadata such as module FQNs. This PR implements Option 2 as IMO this is more readable/debuggable. Test plan: ``` // all green ./test/test_everything.sh ``` [ghstack-poisoned]
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a7abb00cce87d18273d3bb18996eebb2bb0c4c99 Pull Request resolved: #316
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6cb1588bf59be73b5782f6af94e7a360eba7f40e Pull Request resolved: #336
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 42dd59511e4ec2a55846c2593955c4ff5f12b254 Pull Request resolved: #336
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…at8 matmul" Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068) [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068) [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Pull Request resolved: #344 This is a redo of #316 With upcoming support of scaling granularities other than tensorwise, we need a good way to control which gemm kernel to call and how to scale the input tensors in fwd and bwd. A `torch.autograd.Function` override is the cleanest way to do that, and in 2024 this now works with `torch.compile`. Reviewed By: drisspg Differential Revision: D60291446 fbshipit-source-id: 472f392227bca1c7f83ea0c1234285bc576e58d2
Stack from ghstack (oldest at bottom):
Summary:
I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice
The modeling logic in Float8Linear for a/b would look like:
And, there are at least two choices I see for
float8_mm_op
:To support future scaling granularities, whichever choice we go with will have to do something like below:
Furthermore, to keep things readable / debuggable, it would be good to:
To do the above, we'll need to pass around metadata such as module FQNs.
This PR implements Option 2 as IMO this is more readable/debuggable.
Test plan: