diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index f0c8cc2b5..2548aff90 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -334,10 +334,6 @@ def __init__(self, *args, **kwargs): # TODO(future PR): add serialization for this flag self.is_amax_initialized = not self.config.enable_amax_init - # This is needed to properly handle autocast in the amax/scale - # update function for torch.float16 - self.last_seen_output_dtype = None - # pre_forward and post_forward are currently broken with FSDP # and torch.compile, this option can disable them # Note that when using `self.config.enable_pre_and_post_forward = False`, @@ -627,7 +623,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_post_forward() - self.last_seen_output_dtype = output.dtype return output def extra_repr(self): diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 3a060e7fc..d8a9e64e9 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -224,7 +224,6 @@ def inner_func(): fp8_weight_amax_history_stack = [None] * len(fp8_layers) fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) - x_dtypes = set() scale_fn_recipes = set() for idx, child in enumerate(fp8_layers): @@ -236,16 +235,8 @@ def inner_func(): fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output - x_dtypes.add(child.last_seen_output_dtype) scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) - # TODO This way to get the activation dtype is not ideal - if len(x_dtypes) != 1: - raise ValueError( - f"All layers must have the same last seen input_dtype, got {x_dtypes}" - ) - x_dtype = next(iter(x_dtypes)) - if len(scale_fn_recipes) != 1: raise ValueError( f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" @@ -303,13 +294,13 @@ def inner_func(): # Calculate the new scales from the updated history stacks new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe ) new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe ) new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe + fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index fa5eff733..cbd357789 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -177,9 +177,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) cur_amax.fill_(new_amax) amax_history[0] = new_amax - new_scale = amax_history_to_scale( - amax_history, float8_dtype, x.dtype, scale_fn_name - ) + new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name) scale.copy_(new_scale) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 72cf5ad97..eade09d3b 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -34,14 +34,11 @@ @torch.no_grad() -def amax_to_scale( - amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype -): +def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. - orig_dtype: The original dtype of the tensor. """ # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager @@ -51,11 +48,6 @@ def amax_to_scale( else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - # Ensure that the scale is representable in float16, - # this helps when amax is small. We are assuming that we don't need - # to care about this for float32/bfloat16. - if orig_dtype is torch.float16: - res = torch.clamp(res, max=torch.finfo(torch.float16).max) return res.to(torch.float32) @@ -63,19 +55,17 @@ def amax_to_scale( def amax_history_to_scale( amax_history: torch.Tensor, float8_dtype: torch.Tensor, - orig_dtype: torch.dtype, history_to_scale_fn_type: Literal["max"], ): """Takes in a history of amax values and returns a scale tensor. Args: amax_history: A tensor containing the history of amax values. float8_dtype: The float8 dtype. - orig_dtype: The original dtype of the tensor. history_to_scale_fn_type: The type of function to use to convert the history to a scale. """ if history_to_scale_fn_type == "max": amax = torch.max(amax_history) - return amax_to_scale(amax, float8_dtype, orig_dtype) + return amax_to_scale(amax, float8_dtype) raise NotImplementedError() @@ -83,19 +73,17 @@ def amax_history_to_scale( def amax_history_to_scale_stack( amax_history: torch.Tensor, float8_dtype: torch.dtype, - orig_dtype: torch.dtype, history_to_scale_fn_type: Literal["max"], ) -> torch.Tensor: """Takes in a stack of amax_history tensors and returns a scale tensor. Args: amax_history: A 2D tensor containing a stack of amax histories. float8_dtype: The float8 dtype. - orig_dtype: The original dtype of the tensor. history_to_scale_fn_type: The type of function to use to convert the history to a scale. """ if history_to_scale_fn_type == "max": amax_stack = torch.max(amax_history, dim=1).values - return amax_to_scale(amax_stack, float8_dtype, orig_dtype) + return amax_to_scale(amax_stack, float8_dtype) raise NotImplementedError( f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" ) @@ -142,7 +130,7 @@ def tensor_to_scale( scaling_granularity, axiswise_dim, ) - return amax_to_scale(amax, float8_dtype, x.dtype) + return amax_to_scale(amax, float8_dtype) def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):