-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Prototype for operation-based API #707
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Runs, but need to validate. Runtime errors with non-FP8 params and FP8 compute, or FP8 params and non-FP8 compute. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Test does not pass with FP8. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Not supported by cuBLAS. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Still need to implement amax reductions. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Add documentation for unfused ops Signed-off-by: Tim Moon <[email protected]>
Expand documentation Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass 1
@property | ||
@abc.abstractmethod | ||
def is_fused_op(self) -> bool: | ||
"""Whether this op is the fusion of one or more basic ops""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Whether this op is the fusion of one or more basic ops""" | |
"""Whether this op is the fusion of one or more basic ops""" | |
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyLint prefers just putting the docstring: 738df8a
"""Whether this op is the fusion of one or more basic ops""" | ||
|
||
def pre_forward(self) -> None: | ||
"""Preprocessing before forward pass""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Preprocessing before forward pass""" | |
"""Preprocessing before forward pass""" | |
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyLint prefers just putting the docstring: 738df8a
curr_len = meta.amax_history.size(0) | ||
if curr_len == amax_history_len: | ||
continue | ||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why do we need torch.no_grad
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's needed, but I'm being paranoid about leaking the autograd graph. This code path is infrequent but called outside the OperationFuser
's autograd function:
op.pre_forward() |
|
||
Parameters | ||
---------- | ||
mode: {"input", "param", "grad_output"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is name
a better fit for this arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended to match num_fp8_scales
:
https://github.com/timmoon10/TransformerEngine/blob/f4e6af92e8956d948fe1fbaefbc1b2dd6f32b457/transformer_engine/pytorch/ops/op.py#L170-L173
mode
seems better for that one.
for fp8_meta in self._fp8_metas.values(): | ||
self._check_fp8_meta(fp8_meta) | ||
|
||
# Register FP8 metadata for amax and scale update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this part of the code (or in spirit) from prepare_for_forward
from the original API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly:
TransformerEngine/transformer_engine/pytorch/module/base.py
Lines 608 to 611 in 7326af9
if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): | |
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( | |
self.fp8_meta, fp8_weights=self._get_fp8_params() | |
) |
Although now that you mention it, we should register
"grad_output"
in the backward pass instead of the forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this matches the module API. The fp8_meta
s are registered in the forward pass, and we manually trigger an update in the backward pass:
TransformerEngine/transformer_engine/pytorch/module/linear.py
Lines 612 to 613 in 7326af9
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): | |
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) |
torch.Tensor: | ||
Output tensor | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyLint prefers just putting the docstring: 738df8a
Iterable of torch.Tensor: | ||
Loss gradients w.r.t. parameters | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyLint prefers just putting the docstring: 738df8a
self.append(module) | ||
|
||
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: | ||
self._module_groups = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._module_groups
is already set to None
at the begin. of __init__
. Why do we set it to None
again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add a module after calculating operation fusions, then we need to invalidate the operation fusions and recalculate.
|
||
def _get_keys_by_idx(self, idx: int | slice) -> list[str]: | ||
"""Get module keys corresponding to indices""" | ||
if isinstance(idx, slice): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should there be slice indices check as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, but it's simpler to rely on the bounds checking in list
. This implementation is similar to torch.nn.Sequential
:
https://github.com/pytorch/pytorch/blob/389492e2640730b0a199ffe506582ed4fd2c4afc/torch/nn/modules/container.py#L140
# Reshape FP8 tensor | ||
# Note: Preserve cached transpose if possible | ||
if is_float8_tensor(tensor): | ||
out = Float8Tensor.make_like( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this preserve the cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transpose is part of Float8Tensor._fp8_attrs
:
_transpose = property(**_make_fp8_attr_property_funcs("transpose")) |
This function is not quite equivalent to the
Float8Tensor
's view
or reshape
functions since typically reshaping a tensor changes its transpose, while this function tries to preserve the 2D transpose.
def op_forward( | ||
self, | ||
ctx: OperationContext, | ||
input: torch.Tensor, # pylint: disable=redefined-builtin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are worth changing IMO.
input
→ inp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd agree for internal implementations, but input
feels much better for a user-facing API:
op = te.ops.AllGather(...)
y = op(input=x)
I suppose BasicOperation.op_forward
can be considered internal implementation, so I've changed the arg name to input_
. I feel strongly about about keeping the input
arg in other functions like FusableOperation.forward
.
basic_op_ctxs[0], | ||
input_, | ||
basic_op_prev_ops[0], | ||
basic_op_next_ops[0], | ||
**basic_op_kwargs[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why we index 0 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OperationFuser
doesn't make any distinction between BasicOperation
or FusedOperation
, but interacts with them via the base class (e.g. FusableOperation.fuser_forward
). A FusableOperation
consists of one or more BasicOperation
s, so a BasicOperation
will recieve just one ctx from OperationFuser
while FusedOperation
may recieve multiple.
Fix spelling of "fusible". Avoid "input" name in internal APIs. Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
Merging with approval from @ksivaman, @sudhakarsingh27, @ptrendx. This feature is still experimental and incomplete. |
Currently, Transformer Engine exposes fused operations with custom modules like
LayerNormLinear
. These are highly tuned for certain workloads (especially GPT), but are not easy to generalize to other models. This approach is especially cumbersome when the forward and backward passes have different fusion opportunities (e.g. forward GEMM+bias+gelu and backward dgelu+dbias+cast+transpose).This PR adds a new API for specifying Transformer Engine models. Instead of using large compound modules (e.g.
LayerNormLinear
), users can build up aSequential
module out of smallFusibleOperation
s (e.g.LayerNorm
,Linear
). TheSequential
module (with a similar API astorch.nn.Sequential
) will internally attempt to fuse operations together (possibly differently in the forward and backward passes).Some of the more important components:
te.ops.FusibleOperation
: A neural network operation that can be processed by the fuser. They have forward and backward functions similar totorch.autograd.Function
.te.ops.BasicOperation
: A minimalFusibleOperation
. Their forward and backward functions must be implemented and they should hold the model state and parameters.te.ops.FusedOperation
: AFusibleOperation
that is interchangeable with multipleUnfusedOpeation
s. If it implements a forward or backward function, they must save the same context as theUnfusedOperation
s.te.ops.Sequential
: A container module with a similar API astorch.nn.Sequential
.te.ops.OperationFuser
: A helper class that manages autograd, performs the operation fusions, and keeps track of correspondingBasicOperation
s andFusedOperation
s.As a proof-of-concept, I've been able to fuse
Linear
andBias
operations, on a single GPU and with tensor parallelism. These modules have been implemented to supportFloat8Tensor
, which simplifies the implementation and will be important for future work with e.g. FP8 attention. I've also added single-GPU and multi-GPU tests.This work is heavily influenced by #377 from @janekb04.
Remaining tasks:
CheckpointingFuture work: