Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 training: fix bug with AC + compile #1329

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,6 @@ def __init__(self, *args, **kwargs):
# TODO(future PR): add serialization for this flag
self.is_amax_initialized = not self.config.enable_amax_init

# This is needed to properly handle autocast in the amax/scale
# update function for torch.float16
self.last_seen_output_dtype = None

# pre_forward and post_forward are currently broken with FSDP
# and torch.compile, this option can disable them
# Note that when using `self.config.enable_pre_and_post_forward = False`,
Expand Down Expand Up @@ -627,7 +623,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

if self.has_any_delayed_scaling:
self.float8_post_forward()
self.last_seen_output_dtype = output.dtype
return output

def extra_repr(self):
Expand Down
15 changes: 3 additions & 12 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def inner_func():
fp8_weight_amax_history_stack = [None] * len(fp8_layers)
fp8_grad_output_amax_history_stack = [None] * len(fp8_layers)

x_dtypes = set()
scale_fn_recipes = set()

for idx, child in enumerate(fp8_layers):
Expand All @@ -236,16 +235,8 @@ def inner_func():
fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output

x_dtypes.add(child.last_seen_output_dtype)
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
raise ValueError(
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))

if len(scale_fn_recipes) != 1:
raise ValueError(
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
Expand Down Expand Up @@ -303,13 +294,13 @@ def inner_func():

# Calculate the new scales from the updated history stacks
new_input_scales = amax_history_to_scale_stack(
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe
)
new_weight_scales = amax_history_to_scale_stack(
fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe
)
new_grad_output_scales = amax_history_to_scale_stack(
fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe
)

# Iterate through the layers and update the scales
Expand Down
4 changes: 1 addition & 3 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
cur_amax.fill_(new_amax)
amax_history[0] = new_amax
new_scale = amax_history_to_scale(
amax_history, float8_dtype, x.dtype, scale_fn_name
)
new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name)
scale.copy_(new_scale)


Expand Down
20 changes: 4 additions & 16 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
Expand All @@ -51,51 +48,42 @@ def amax_to_scale(
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
return res.to(torch.float32)


@torch.no_grad()
def amax_history_to_scale(
amax_history: torch.Tensor,
float8_dtype: torch.Tensor,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
):
"""Takes in a history of amax values and returns a scale tensor.
Args:
amax_history: A tensor containing the history of amax values.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
return amax_to_scale(amax, float8_dtype, orig_dtype)
return amax_to_scale(amax, float8_dtype)
raise NotImplementedError()


@torch.no_grad()
def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor.
Args:
amax_history: A 2D tensor containing a stack of amax histories.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
return amax_to_scale(amax_stack, float8_dtype)
raise NotImplementedError(
f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
)
Expand Down Expand Up @@ -142,7 +130,7 @@ def tensor_to_scale(
scaling_granularity,
axiswise_dim,
)
return amax_to_scale(amax, float8_dtype, x.dtype)
return amax_to_scale(amax, float8_dtype)


def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
Expand Down
Loading