Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[wip] hooks
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Jan 13, 2024
1 parent d0af81a commit f83bf21
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
5 changes: 5 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
# option is useful for safety, but not strictly necessary.
enable_pre_and_post_forward = True

# If True, dynamic linear uses hooks for activation casting
# TODO(before land): add test coverage for both cases
dynamic_use_activation_hooks = True
# dynamic_use_activation_hooks = False
39 changes: 37 additions & 2 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
import float8_experimental.config as config


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -39,25 +40,54 @@ def backward(ctx, gradY):
None,
)

def cast_x_to_float8_e4m3fn_pre_hook(module, args):
"""
Hook to cast the incoming activation to `torch.float8_e4m3fn`
"""
return module.cast_to_float8(args[0])

def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output):
"""
Hook to cast the incoming gradient to `torch.float8_e5m2`
"""
gradY = grad_output[0]
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
gradY_scaled = gradY * gradY_scale
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
tensor_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
return (tensor_fp8,)

class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
conversion to fp8 of the input and weight tensors.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_activation_hooks = config.dynamic_use_activation_hooks

def forward(self, x):
x_fp8 = self.cast_to_float8(x)
# cast x to float8_e4m3fn
if self.use_activation_hooks:
x_fp8 = x
else:
x_fp8 = self.cast_to_float8(x)

# cast w to float8_e4m3fn
w_fp8 = self.cast_to_float8(self.weight)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

# Cast gradY to float8_e5m2 during backward
y = self.cast_to_float8e5m2_bw(y)
if self.use_activation_hooks:
pass
else:
y = self.cast_to_float8e5m2_bw(y)

return y

def cast_to_float8(self, inpt_tensor):
# TODO rename this function to clarify e4m3
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
Expand All @@ -80,4 +110,9 @@ def from_float(cls, mod, emulate: bool = False):
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
if new_mod.use_activation_hooks:
# install the hooks
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_backward_pre_hook)
return new_mod

0 comments on commit f83bf21

Please sign in to comment.