Skip to content

Commit

Permalink
float8 training: fix bug with AC + compile (#1329)
Browse files Browse the repository at this point in the history
Summary:

In #1306 I accidentally broke
torchtitan + float8 + AC + compile.

I don't have a non-torchtitan repro now, putting up the fix first
to ensure torchtitan still works, and we should follow-up later
with adding test coverage to torchao to prevent similar breakages in the
future.

What broke:
* in the forward of `Float8Linear`, we were setting an attribute on
  the module
* ^ is not supported with compile + something how torchtitan
  specifically calls AC

The fix: remove this attribute setting altogether. Unfortunately this
breaks an edge case feature for ensuring scales are reprensentable in
`float16`.  Since `float16` training is not commonly used with `float8`
and this feature was added during very early testing, removing this for
now is fine.

If we need to add this feature back in the future, I'd advocate for
doing it via explicit configuration such as `config.set_scale_upper_bound`
and avoiding the stateful hacks, which are usually not compiler
friendly.

Test Plan:

```
// this repo
./test/float8/test_everything.sh

// torchtitan - broken before this PR, works after this PR
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Nov 22, 2024
1 parent 8f73e84 commit f3c1a00
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 36 deletions.
5 changes: 0 additions & 5 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,6 @@ def __init__(self, *args, **kwargs):
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = not self.config.enable_amax_init

# This is needed to properly handle autocast in the amax/scale
# update function for torch.float16
self.last_seen_output_dtype = None

# pre_forward and post_forward are currently broken with FSDP
# and torch.compile, this option can disable them
# Note that when using `self.config.enable_pre_and_post_forward = False`,
Expand Down Expand Up @@ -628,7 +624,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

if self.has_any_delayed_scaling:
self.float8_post_forward()
self.last_seen_output_dtype = output.dtype
return output

def extra_repr(self):
Expand Down
15 changes: 3 additions & 12 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def inner_func():
fp8_weight_amax_history_stack = [None] * len(fp8_layers)
fp8_grad_output_amax_history_stack = [None] * len(fp8_layers)

x_dtypes = set()
scale_fn_recipes = set()

for idx, child in enumerate(fp8_layers):
Expand All @@ -236,16 +235,8 @@ def inner_func():
fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output

x_dtypes.add(child.last_seen_output_dtype)
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
raise ValueError(
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))

if len(scale_fn_recipes) != 1:
raise ValueError(
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
Expand Down Expand Up @@ -303,13 +294,13 @@ def inner_func():

# Calculate the new scales from the updated history stacks
new_input_scales = amax_history_to_scale_stack(
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe
)
new_weight_scales = amax_history_to_scale_stack(
fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe
)
new_grad_output_scales = amax_history_to_scale_stack(
fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe
)

# Iterate through the layers and update the scales
Expand Down
4 changes: 1 addition & 3 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
amax_history, float8_dtype, x.dtype, scale_fn_name
)
new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name)
scale.copy_(new_scale)


Expand Down
20 changes: 4 additions & 16 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
Expand All @@ -51,51 +48,42 @@ def amax_to_scale(
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
return res.to(torch.float32)


@torch.no_grad()
def amax_history_to_scale(
amax_history: torch.Tensor,
float8_dtype: torch.Tensor,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
):
"""Takes in a history of amax values and returns a scale tensor.
Args:
amax_history: A tensor containing the history of amax values.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
return amax_to_scale(amax, float8_dtype, orig_dtype)
return amax_to_scale(amax, float8_dtype)
raise NotImplementedError()


@torch.no_grad()
def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor.
Args:
amax_history: A 2D tensor containing a stack of amax histories.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
return amax_to_scale(amax_stack, float8_dtype)
raise NotImplementedError(
f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
)
Expand Down Expand Up @@ -142,7 +130,7 @@ def tensor_to_scale(
scaling_granularity,
axiswise_dim,
)
return amax_to_scale(amax, float8_dtype, x.dtype)
return amax_to_scale(amax, float8_dtype)


def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
Expand Down

0 comments on commit f3c1a00

Please sign in to comment.