From f83bf21177bd6e1b1e6b26dc44e6122d054249b0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 26 Dec 2023 09:54:16 -0800 Subject: [PATCH] [wip] hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/config.py | 5 +++ float8_experimental/float8_dynamic_linear.py | 39 +++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index f0ba914f..487ba47b 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -14,3 +14,8 @@ # this doesn't work with autocast + torch.compile + FSDP. Enabling this # option is useful for safety, but not strictly necessary. enable_pre_and_post_forward = True + +# If True, dynamic linear uses hooks for activation casting +# TODO(before land): add test coverage for both cases +dynamic_use_activation_hooks = True +# dynamic_use_activation_hooks = False diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 58e352da..f41383c5 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -11,6 +11,7 @@ from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated +import float8_experimental.config as config @torch._dynamo.allow_in_graph @@ -39,25 +40,54 @@ def backward(ctx, gradY): None, ) +def cast_x_to_float8_e4m3fn_pre_hook(module, args): + """ + Hook to cast the incoming activation to `torch.float8_e4m3fn` + """ + return module.cast_to_float8(args[0]) + +def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output): + """ + Hook to cast the incoming gradient to `torch.float8_e5m2` + """ + gradY = grad_output[0] + gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2) + gradY_scaled = gradY * gradY_scale + bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2) + tensor_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate) + return (tensor_fp8,) class Float8DynamicLinear(torch.nn.Linear): """ A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly conversion to fp8 of the input and weight tensors. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.use_activation_hooks = config.dynamic_use_activation_hooks def forward(self, x): - x_fp8 = self.cast_to_float8(x) + # cast x to float8_e4m3fn + if self.use_activation_hooks: + x_fp8 = x + else: + x_fp8 = self.cast_to_float8(x) + + # cast w to float8_e4m3fn w_fp8 = self.cast_to_float8(self.weight) y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias) # Cast gradY to float8_e5m2 during backward - y = self.cast_to_float8e5m2_bw(y) + if self.use_activation_hooks: + pass + else: + y = self.cast_to_float8e5m2_bw(y) return y def cast_to_float8(self, inpt_tensor): + # TODO rename this function to clarify e4m3 scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn) return Float8Tensor.to_float8( inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate @@ -80,4 +110,9 @@ def from_float(cls, mod, emulate: bool = False): new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + new_mod.use_activation_hooks = config.dynamic_use_activation_hooks + if new_mod.use_activation_hooks: + # install the hooks + new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook) + new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_backward_pre_hook) return new_mod